fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  10. fusion_bench/method/gossip/__init__.py +3 -0
  11. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  12. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  13. fusion_bench/method/gossip/entropy_loss.py +25 -0
  14. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  15. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  16. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  17. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  18. fusion_bench/method/gossip/utils.py +74 -0
  19. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  20. fusion_bench/method/opcm/opcm.py +16 -7
  21. fusion_bench/method/pwe_moe/module.py +1 -1
  22. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  23. fusion_bench/method/regmean/regmean.py +25 -17
  24. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  25. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  27. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  28. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  29. fusion_bench/method/we_moe/we_moe.py +14 -15
  30. fusion_bench/mixins/__init__.py +6 -3
  31. fusion_bench/mixins/hydra_config.py +49 -0
  32. fusion_bench/mixins/openclip_classification.py +11 -0
  33. fusion_bench/mixins/simple_profiler.py +4 -2
  34. fusion_bench/modelpool/__init__.py +3 -1
  35. fusion_bench/modelpool/base_pool.py +2 -2
  36. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  37. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  38. fusion_bench/models/open_clip/__init__.py +6 -0
  39. fusion_bench/models/open_clip/modeling.py +176 -0
  40. fusion_bench/models/open_clip/utils.py +311 -0
  41. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  42. fusion_bench/models/parameter_dict.py +54 -13
  43. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  44. fusion_bench/taskpool/__init__.py +5 -3
  45. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  46. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  47. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  48. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  49. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  50. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  51. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  52. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  53. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  54. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  55. fusion_bench/utils/data.py +12 -0
  56. fusion_bench/utils/devices.py +14 -0
  57. fusion_bench/utils/instantiate.py +12 -0
  58. fusion_bench/utils/misc.py +9 -2
  59. fusion_bench/utils/packages.py +14 -0
  60. fusion_bench/utils/parameters.py +1 -1
  61. fusion_bench/utils/tensorboard.py +1 -1
  62. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
  63. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
  64. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  66. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  67. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  68. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  69. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  70. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  71. fusion_bench_config/fabric/auto.yaml +0 -1
  72. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  73. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  74. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  75. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  76. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  77. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  78. fusion_bench_config/llama_full_finetune.yaml +0 -2
  79. fusion_bench_config/llama_model_fusion.yaml +0 -2
  80. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  81. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  82. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  83. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  84. fusion_bench_config/method/adamerging.yaml +2 -2
  85. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  86. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  87. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  88. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  89. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  90. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  91. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  92. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  93. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  94. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  95. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  96. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  97. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  98. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  99. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  100. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  101. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  102. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  103. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  104. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  105. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  106. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  107. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  108. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  109. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  110. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  111. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  112. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  113. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  114. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  115. fusion_bench_config/method/model_recombination.yaml +0 -1
  116. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  117. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  118. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  119. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  120. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  121. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  122. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  123. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  124. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  125. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  126. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  127. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  128. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  129. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  130. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  131. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  132. fusion_bench_config/method/ties_merging.yaml +1 -1
  133. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  134. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  135. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  154. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  155. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  156. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  157. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  158. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  159. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  160. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  161. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  162. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  163. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  164. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  165. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  166. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  167. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  169. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  170. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  171. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  172. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  173. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  174. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  175. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  176. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  177. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  178. fusion_bench_config/nyuv2_config.yaml +0 -2
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  183. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  184. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  185. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  186. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  187. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  188. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  189. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
  190. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,8 @@ class AlgorithmFactory:
23
23
  "clip_layer_wise_adamerging_doge_ta": ".doge_ta.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
24
24
  "singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
25
25
  "clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
26
+ "clip_task_wise_gossip": ".gossip.clip_task_wise_gossip.CLIPTaskWiseGossipAlgorithm",
27
+ "clip_layer_wise_gossip": ".gossip.clip_layer_wise_gossip.CLIPLayerWiseGossipAlgorithm",
26
28
  # plug-and-play model merging methods
27
29
  "clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
28
30
  "clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
@@ -148,12 +148,13 @@ class FlanT5GLUETextGenerationTaskPool(LightningFabricMixin, TaskPool):
148
148
  else:
149
149
  raise ValueError(f"Unknown task {task_config.name}")
150
150
 
151
- def evaluate(self, model: T5ForConditionalGeneration):
151
+ def evaluate(self, model: T5ForConditionalGeneration, name: str = None):
152
152
  """
153
153
  Evaluate the model on the FlanT5 GLUE text generation tasks.
154
154
 
155
155
  Args:
156
156
  model (T5ForConditionalGeneration): The model to evaluate.
157
+ name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.
157
158
 
158
159
  Returns:
159
160
  dict: A dictionary containing the evaluation results for each task.
@@ -169,6 +170,8 @@ class FlanT5GLUETextGenerationTaskPool(LightningFabricMixin, TaskPool):
169
170
  "all_params": all_params,
170
171
  "trainable_percentage": training_params / all_params,
171
172
  }
173
+ if name is not None:
174
+ report["model_info"]["name"] = name
172
175
  model = self.fabric.setup(model)
173
176
  report.update(super().evaluate(model))
174
177
  log.info(f"evaluation report: {report}")
@@ -0,0 +1,22 @@
1
+ # Constants for CLIP Vision Model Merging
2
+ TASK_NAMES_TA8 = [
3
+ "sun397",
4
+ "stanford-cars",
5
+ "resisc45",
6
+ "eurosat",
7
+ "svhn",
8
+ "gtsrb",
9
+ "mnist",
10
+ "dtd",
11
+ ]
12
+
13
+ TASK_NAMES_TA8_CAP = [
14
+ "SUN397",
15
+ "Cars",
16
+ "RESISC45",
17
+ "EuroSAT",
18
+ "SVHN",
19
+ "GTSRB",
20
+ "MNIST",
21
+ "DTD",
22
+ ]
@@ -2,11 +2,13 @@
2
2
  This module provides a class to convert a dataset whose object is a list of dictionaries with keys "image" and "label" to a dataset whose object is a tuple of tensors (inputs, label) for CLIP models.
3
3
  """
4
4
 
5
- from typing import Optional
5
+ from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from transformers import CLIPProcessor, ProcessorMixin
9
9
 
10
+ __all__ = ["CLIPDataset"]
11
+
10
12
 
11
13
  class CLIPDataset(torch.utils.data.Dataset):
12
14
  """
@@ -34,7 +36,7 @@ class CLIPDataset(torch.utils.data.Dataset):
34
36
  """Returns the number of items in the dataset."""
35
37
  return len(self.dataset)
36
38
 
37
- def __getitem__(self, idx: int):
39
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
38
40
  """
39
41
  Retrieves and processes an item from the dataset.
40
42
 
@@ -62,6 +64,12 @@ class CLIPDataset(torch.utils.data.Dataset):
62
64
  inputs = self.processor(images=[image], return_tensors="pt")[
63
65
  "pixel_values"
64
66
  ][0]
67
+ elif callable(self.processor):
68
+ inputs = self.processor(image)
69
+ else:
70
+ raise ValueError(
71
+ "The processor should be a CLIPProcessor or a callable function"
72
+ )
65
73
  else:
66
74
  # if processor is None, return the raw image directly
67
75
  inputs = image
@@ -7,6 +7,7 @@ def load_fer2013(path: str = "clip-benchmark/wds_fer2013", split: str = "train")
7
7
  dataset = dataset.rename_columns({"jpg": "image", "cls": "label"})
8
8
  return dataset
9
9
 
10
+
10
11
  if __name__ == "__main__":
11
12
  dataset = load_fer2013(split="test")
12
13
  print(dataset)
@@ -6,7 +6,7 @@ from datasets import load_dataset
6
6
 
7
7
 
8
8
  def load_gsm8k_question_label_data(
9
- dataset_name: Literal["train", "test", "train_socratic", "test_socratic"]
9
+ dataset_name: Literal["train", "test", "train_socratic", "test_socratic"],
10
10
  ):
11
11
  R"""
12
12
  Load the GSM8K dataset and extract questions and labels.
@@ -45,7 +45,7 @@ def load_gsm8k_question_label_data(
45
45
 
46
46
 
47
47
  def load_gsm8k_question_label_dataset(
48
- dataset_name: Literal["train", "test", "train_socratic", "test_socratic"]
48
+ dataset_name: Literal["train", "test", "train_socratic", "test_socratic"],
49
49
  ):
50
50
  """
51
51
  Load the GSM8K dataset and return it as a Hugging Face Dataset object.
@@ -62,6 +62,11 @@ _import_structure = {
62
62
  "IsotropicMergingInCommonSubspace",
63
63
  ],
64
64
  "opcm": ["OPCMForCLIP"],
65
+ "gossip": [
66
+ "CLIPLayerWiseGossipAlgorithm",
67
+ "CLIPTaskWiseGossipAlgorithm",
68
+ "FlanT5LayerWiseGossipAlgorithm",
69
+ ],
65
70
  # plug-and-play model merging methods
66
71
  "concrete_subspace": [
67
72
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -136,6 +141,11 @@ if TYPE_CHECKING:
136
141
  WeightedEnsembleAlgorithm,
137
142
  )
138
143
  from .fisher_merging import FisherMergingForCLIPVisionModel
144
+ from .gossip import (
145
+ CLIPLayerWiseGossipAlgorithm,
146
+ CLIPTaskWiseGossipAlgorithm,
147
+ FlanT5LayerWiseGossipAlgorithm,
148
+ )
139
149
  from .isotropic_merging import (
140
150
  ISO_C_Merge,
141
151
  ISO_CTS_Merge,
@@ -13,41 +13,13 @@ from fusion_bench.modelpool import CLIPVisionModelPool
13
13
  from fusion_bench.models.hf_clip import HFCLIPClassifier
14
14
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
15
15
  from fusion_bench.utils import timeit_context
16
+ from fusion_bench.utils.data import InfiniteDataLoader
16
17
 
17
18
  from .task_wise_adamerging import TaskWiseAdaMergingAlgorithm
18
19
 
19
20
  log = logging.getLogger(__name__)
20
21
 
21
22
 
22
- class InfiniteDataLoader:
23
- """
24
- A wrapper class for DataLoader to create an infinite data loader.
25
- This is useful in case we are only interested in the number of steps and not the number of epochs.
26
-
27
- This class wraps a DataLoader and provides an iterator that resets
28
- when the end of the dataset is reached, creating an infinite loop.
29
-
30
- Attributes:
31
- data_loader (DataLoader): The DataLoader to wrap.
32
- data_iter (iterator): An iterator over the DataLoader.
33
- """
34
-
35
- def __init__(self, data_loader):
36
- self.data_loader = data_loader
37
- self.data_iter = iter(data_loader)
38
-
39
- def __iter__(self):
40
- return self
41
-
42
- def __next__(self):
43
- try:
44
- data = next(self.data_iter)
45
- except StopIteration:
46
- self.data_iter = iter(self.data_loader) # Reset the data loader
47
- data = next(self.data_iter)
48
- return data
49
-
50
-
51
23
  class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
52
24
  """
53
25
  A class for task-wise adaptive merging of CLIP models.
@@ -12,6 +12,7 @@ from torch import Tensor, nn
12
12
  from tqdm.autonotebook import tqdm
13
13
 
14
14
  from fusion_bench.method import BaseAlgorithm
15
+ from fusion_bench.mixins import SimpleProfilerMixin
15
16
  from fusion_bench.modelpool import BaseModelPool
16
17
 
17
18
  log = logging.getLogger(__name__)
@@ -352,7 +353,7 @@ def filter_state_dict(
352
353
  return filtered_state_dict
353
354
 
354
355
 
355
- class FisherMergingAlgorithm(BaseAlgorithm):
356
+ class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
356
357
  """
357
358
  Implements the Fisher Merging Algorithm.
358
359
 
@@ -432,25 +433,36 @@ class FisherMergingAlgorithm(BaseAlgorithm):
432
433
  for param_name in param_names_to_merge:
433
434
  models_to_merge_param_dict[param_name].append(param_dict[param_name])
434
435
 
435
- model_to_merge_fisher_weights = self.get_fisher_weights(
436
- model_name=name,
437
- model=model,
438
- train_dataset=modelpool.load_train_dataset(name),
439
- param_names_to_merge=param_names_to_merge,
440
- )
436
+ with (
437
+ self.profile("merging models"),
438
+ self.profile("computing fisher weights"),
439
+ ):
440
+ model_to_merge_fisher_weights = self.get_fisher_weights(
441
+ model_name=name,
442
+ model=model,
443
+ train_dataset=modelpool.load_train_dataset(name),
444
+ param_names_to_merge=param_names_to_merge,
445
+ )
441
446
 
442
- models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
447
+ models_to_merge_fisher_weights_list.append(
448
+ model_to_merge_fisher_weights
449
+ )
443
450
 
444
- merged_params = merging_with_fisher_weights(
445
- models_to_merge_param_dict=models_to_merge_param_dict,
446
- models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
447
- fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
448
- normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
449
- minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
450
- )
451
+ with self.profile("merging models"):
452
+ merged_params = merging_with_fisher_weights(
453
+ models_to_merge_param_dict=models_to_merge_param_dict,
454
+ models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
455
+ fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
456
+ normalize_fisher_weight=self.config.get(
457
+ "normalize_fisher_weight", True
458
+ ),
459
+ minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
460
+ )
461
+
462
+ merged_model = modelpool.load_model("_pretrained_")
463
+ merged_model.load_state_dict(merged_params, strict=False)
451
464
 
452
- merged_model = modelpool.load_model("_pretrained_")
453
- merged_model.load_state_dict(merged_params, strict=False)
465
+ self.print_profile_summary()
454
466
  return merged_model
455
467
 
456
468
  def get_fisher_weights(
@@ -0,0 +1,3 @@
1
+ from .clip_layer_wise_gossip import CLIPLayerWiseGossipAlgorithm
2
+ from .clip_task_wise_gossip import CLIPTaskWiseGossipAlgorithm
3
+ from .flan_t5_layer_wise_gossip import FlanT5LayerWiseGossipAlgorithm
@@ -0,0 +1,43 @@
1
+ """
2
+ Example Usage:
3
+
4
+ ```bash
5
+ fusion_bench \
6
+ method=adamerging \
7
+ method.name=clip_layer_wise_adamerging \
8
+ method.save_merging_weights=merging_weights.pt \
9
+ modelpool=clip-vit-base-patch32_TA8 \
10
+ taskpool=clip-vit-classification_TA8 \
11
+ fabric_logger.root_dir=outputs/logs/ViT-B-32 \
12
+ fabric_logger.name=clip_layer_wise_adamerging_adam
13
+ ```
14
+ """
15
+
16
+ import functools
17
+ import logging
18
+
19
+ from fusion_bench.mixins import CLIPClassificationMixin
20
+
21
+ from .layer_wise_gossip import LayerWiseGossipAlgorithm
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ class CLIPLayerWiseGossipAlgorithm(
27
+ CLIPClassificationMixin,
28
+ LayerWiseGossipAlgorithm,
29
+ ):
30
+ def on_test_time_adaptation_start(self):
31
+ """
32
+ Here we load the CLIP processor and construct the zero-shot classification head for each task.
33
+ """
34
+ if self.whether_setup_zero_shot_classification_head == False:
35
+ self.setup_zero_shot_classification_head()
36
+
37
+ @functools.cache
38
+ def get_shuffled_test_loader_iter(self, task: str):
39
+ return super().get_shuffled_test_loader_iter(
40
+ task,
41
+ batch_size=self.config.batch_size,
42
+ num_workers=self.config.num_workers,
43
+ )
@@ -0,0 +1,190 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch import Tensor
8
+ from torch.utils.data import DataLoader
9
+ from transformers import CLIPModel, CLIPProcessor
10
+
11
+ from fusion_bench.dataset import CLIPDataset
12
+ from fusion_bench.modelpool import CLIPVisionModelPool
13
+ from fusion_bench.models.hf_clip import HFCLIPClassifier
14
+ from fusion_bench.tasks.clip_classification import get_classnames_and_templates
15
+ from fusion_bench.utils import timeit_context
16
+
17
+ from .task_wise_gossip import TaskWiseGossipAlgorithm
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ class InfiniteDataLoader:
23
+ """
24
+ A wrapper class for DataLoader to create an infinite data loader.
25
+ This is useful in case we are only interested in the number of steps and not the number of epochs.
26
+
27
+ This class wraps a DataLoader and provides an iterator that resets
28
+ when the end of the dataset is reached, creating an infinite loop.
29
+
30
+ Attributes:
31
+ data_loader (DataLoader): The DataLoader to wrap.
32
+ data_iter (iterator): An iterator over the DataLoader.
33
+ """
34
+
35
+ def __init__(self, data_loader):
36
+ self.data_loader = data_loader
37
+ self.data_iter = iter(data_loader)
38
+
39
+ def __iter__(self):
40
+ return self
41
+
42
+ def __next__(self):
43
+ try:
44
+ data = next(self.data_iter)
45
+ except StopIteration:
46
+ self.data_iter = iter(self.data_loader) # Reset the data loader
47
+ data = next(self.data_iter)
48
+ return data
49
+
50
+
51
+ class CLIPTaskWiseGossipAlgorithm(TaskWiseGossipAlgorithm):
52
+ """
53
+ A class for task-wise adaptive merging of CLIP models.
54
+
55
+ This class extends the TaskWiseGossipAlgorithm to provide specific
56
+ functionality for CLIP models, including loading datasets, constructing
57
+ zero-shot classification heads, and computing logits.
58
+
59
+ Attributes:
60
+ modelpool (CLIPVisionModelPool): The model pool containing CLIP models.
61
+ _clip_processor (CLIPProcessor): The CLIP processor for preparing inputs.
62
+ zeroshot_weights (dict): A dictionary to store zero-shot weights for each task.
63
+ """
64
+
65
+ modelpool: CLIPVisionModelPool = None
66
+ _clip_processor: CLIPProcessor = None
67
+ zeroshot_weights = {}
68
+
69
+ def __init__(self, algorithm_config: DictConfig):
70
+ super().__init__(algorithm_config)
71
+
72
+ @functools.cache
73
+ def get_test_dataset(self, task: str):
74
+ """
75
+ Load the test dataset for the task.
76
+ This method is cached, so the dataset is loaded only once.
77
+
78
+ Args:
79
+ task (str): The name of the task.
80
+
81
+ Returns:
82
+ CLIPDataset: The test dataset for the task.
83
+ """
84
+ log.info(f"Loading test dataset: {task}")
85
+ dataset = self.modelpool.load_test_dataset(task)
86
+ dataset = CLIPDataset(dataset, self._clip_processor)
87
+ return dataset
88
+
89
+ @functools.cache
90
+ def get_shuffled_test_loader_iter(self, task: str):
91
+ """
92
+ Get an iterator over the shuffled test DataLoader for the task.
93
+
94
+ Args:
95
+ task (str): The name of the task.
96
+
97
+ Returns:
98
+ iterator: An iterator over the shuffled test DataLoader.
99
+ """
100
+ loader = DataLoader(
101
+ self.get_test_dataset(task),
102
+ batch_size=self.config.batch_size,
103
+ shuffle=True,
104
+ num_workers=self.config.num_workers,
105
+ pin_memory=True,
106
+ )
107
+ if self._fabric is not None:
108
+ loader = self._fabric.setup_dataloaders(loader)
109
+ return iter(InfiniteDataLoader(loader))
110
+
111
+ def on_test_time_adaptation_start(self):
112
+ """
113
+ Prepare for test-time adaptation.
114
+
115
+ This method loads the CLIP processor and constructs the zero-shot
116
+ classification head for each task.
117
+ """
118
+ if self._clip_processor is not None and self.zeroshot_weights is not None:
119
+ return # this can be reused in Gossip
120
+
121
+ clip_model_config = self.modelpool.get_model_config("_pretrained_")
122
+ pretrained_path = (
123
+ clip_model_config.pretrained_model_name_or_path
124
+ if hasattr(clip_model_config, "pretrained_model_name_or_path")
125
+ else clip_model_config.path
126
+ )
127
+
128
+ with timeit_context("Loading CLIP processor and pretrained CLIP model."):
129
+ self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
130
+ clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)
131
+
132
+ clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
133
+ self.visual_projection = clip_model.visual_projection.requires_grad_(False)
134
+ self.logit_scale_exp = clip_model.logit_scale.exp()
135
+ if self._fabric is not None:
136
+ self.visual_projection = self._fabric.to_device(self.visual_projection)
137
+ self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)
138
+
139
+ for task in self.modelpool.model_names:
140
+ cache_file = os.path.join(
141
+ self.config.cache_dir,
142
+ f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
143
+ )
144
+ if os.path.exists(cache_file):
145
+ log.info(f"Loading cached zeroshot weights for task: {task}")
146
+ zeroshot_weights = torch.load(cache_file, map_location="cpu")
147
+ else:
148
+ log.info(f"Construct zero shot classification head for task: {task}")
149
+ classnames, templates = get_classnames_and_templates(task)
150
+ clip_classifier.set_classification_task(classnames, templates)
151
+ zeroshot_weights = clip_classifier.zeroshot_weights
152
+ log.info(f"save zeroshot weights to {cache_file}")
153
+ torch.save(zeroshot_weights, cache_file)
154
+ self.zeroshot_weights[task] = zeroshot_weights
155
+ if self._fabric is not None:
156
+ self.zeroshot_weights[task] = self._fabric.to_device(
157
+ self.zeroshot_weights[task]
158
+ )
159
+
160
+ def compute_logits(self, module, batch, task: str) -> Tensor:
161
+ """
162
+ Compute the logits for the given batch and task.
163
+
164
+ This method computes the image embeddings, normalizes them, and calculates
165
+ the cosine similarity with the text embeddings to produce classification logits.
166
+
167
+ Args:
168
+ module (nn.Module): The model module.
169
+ batch (tuple): A batch of input data.
170
+ task (str): The name of the task.
171
+
172
+ Returns:
173
+ Tensor: The classification logits for the batch.
174
+ """
175
+ images, _ = batch
176
+ text_embeds = self.zeroshot_weights[task]
177
+
178
+ image_embeds = module(images)[1]
179
+ image_embeds = self.visual_projection(image_embeds)
180
+
181
+ # normalize embeddings
182
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
183
+
184
+ # cosine similarity
185
+ logits_per_text = (
186
+ torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
187
+ )
188
+ logits_per_image = logits_per_text.t()
189
+
190
+ return logits_per_image
@@ -0,0 +1,25 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
6
+ """
7
+ Compute the entropy loss of a set of logits.
8
+
9
+ Args:
10
+ logits (Tensor): The logits to compute the entropy loss of.
11
+ eps (float): A small value to avoid log(0). Default is 1e-8.
12
+
13
+ Returns:
14
+ Tensor: The entropy loss of the logits.
15
+ """
16
+ # Ensure the logits tensor has 2 dimensions
17
+ assert (
18
+ logits.dim() == 2
19
+ ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
20
+
21
+ # Compute the softmax probabilities
22
+ probs = torch.softmax(logits, dim=-1)
23
+
24
+ # Compute the entropy loss
25
+ return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()