fusion-bench 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. fusion_bench/compat/method/base_algorithm.py +1 -1
  2. fusion_bench/dataset/clip_dataset.py +3 -0
  3. fusion_bench/dataset/fer2013.py +12 -0
  4. fusion_bench/dataset/llama/preference_700k.py +1 -1
  5. fusion_bench/method/__init__.py +2 -0
  6. fusion_bench/method/classification/clip_finetune.py +10 -13
  7. fusion_bench/method/surgery/__init__.py +1 -3
  8. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +1 -1
  9. fusion_bench/method/tall_mask/__init__.py +0 -0
  10. fusion_bench/method/tall_mask/utils.py +234 -0
  11. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  12. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  13. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  14. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  15. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  16. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  17. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  18. fusion_bench/mixins/clip_classification.py +6 -6
  19. fusion_bench/mixins/lightning_fabric.py +3 -1
  20. fusion_bench/modelpool/base_pool.py +0 -1
  21. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  22. fusion_bench/models/surgery/__init__.py +1 -0
  23. fusion_bench/models/surgery/surgerymodelwrapper.py +2 -1
  24. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  25. fusion_bench/models/wrappers/task_wise_fusion.py +1 -1
  26. fusion_bench/programs/fabric_fusion_program.py +7 -4
  27. fusion_bench/taskpool/llama/reward_model.py +1 -1
  28. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  29. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  30. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  31. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  32. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  33. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  34. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  35. fusion_bench/tasks/clip_classification/food101.py +105 -0
  36. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  37. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  38. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  39. fusion_bench/utils/parameters.py +12 -3
  40. fusion_bench/utils/type.py +10 -1
  41. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  42. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +195 -62
  43. fusion_bench_config/dataset/image_classification/README.md +6 -0
  44. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  45. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  47. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  48. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  49. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  50. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  51. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  52. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  53. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  54. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  55. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  56. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  57. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  58. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  59. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  60. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  61. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  62. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  63. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  64. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  65. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  66. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  67. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  68. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  69. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  70. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  71. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  72. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  73. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  74. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  75. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  76. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  77. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  78. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  79. fusion_bench_config/model/clip-vit/README.md +38 -0
  80. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  81. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  82. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  83. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  84. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  85. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  86. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  87. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  88. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  89. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  90. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  91. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  92. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  93. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  94. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  95. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  96. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  97. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  98. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  99. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  100. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  101. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  102. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  103. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  104. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  105. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  106. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  107. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  108. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  109. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  110. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  111. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  112. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  113. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  114. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  115. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  116. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  117. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  118. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  119. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  120. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  121. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  122. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  123. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  124. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  125. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  126. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  127. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  128. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  129. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  130. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  131. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  132. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  133. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  134. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  135. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  150. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  168. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  169. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  170. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  171. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  172. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  173. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  174. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  175. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  176. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  177. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  178. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  183. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  185. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  190. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  191. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  192. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  193. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  194. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  195. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -132,13 +132,13 @@ class CLIPClassificationMixin(LightningFabricMixin):
132
132
 
133
133
  # get cache directory
134
134
  if self.modelpool.has_pretrained:
135
- model_name = self.modelpool.get_model_config(
136
- "_pretrained_"
137
- ).pretrained_model_name_or_path
135
+ model_name = self.modelpool.get_model_config("_pretrained_")
136
+ if not isinstance(model_name, str):
137
+ model_name = model_name.pretrained_model_name_or_path
138
138
  else:
139
- model_name = self.modelpool.get_model_config(
140
- self.modelpool.model_names[0]
141
- ).pretrained_model_name_or_path
139
+ model_name = self.modelpool.get_model_config(self.modelpool.model_names[0])
140
+ if not isinstance(model_name, str):
141
+ model_name = model_name.pretrained_model_name_or_path
142
142
  cache_dir = os.path.join(
143
143
  self.zeroshot_weights_cache_dir,
144
144
  os.path.normpath(model_name.split("/")[-1]),
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
5
5
 
6
6
  import lightning as L
7
7
  import torch
8
+ from lightning.fabric.connector import _is_using_cli
8
9
  from lightning.fabric.loggers import TensorBoardLogger
9
10
  from lightning.fabric.utilities.rank_zero import rank_zero_only
10
11
  from omegaconf import DictConfig, OmegaConf
@@ -79,7 +80,8 @@ class LightningFabricMixin:
79
80
  self._fabric_instance = L.Fabric()
80
81
  else:
81
82
  self._fabric_instance = instantiate(config.fabric)
82
- self._fabric_instance.launch()
83
+ if not _is_using_cli(): # if not using cli, launch the fabric
84
+ self._fabric_instance.launch()
83
85
  # Set the log directory in config if it is not already set
84
86
  if (
85
87
  self.log_dir is not None
@@ -147,7 +147,6 @@ class BaseModelPool(BaseYAMLSerializableModel):
147
147
  DictConfig: The configuration for the specified model.
148
148
  """
149
149
  model_config = self._models[model_name]
150
- assert isinstance(model_config, DictConfig), "Model config must be a DictConfig"
151
150
  if return_copy:
152
151
  model_config = deepcopy(model_config)
153
152
  return model_config
@@ -1,8 +1,11 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Optional
3
+ from typing import Optional, Union
4
4
 
5
+ from datasets import load_dataset
5
6
  from omegaconf import DictConfig, open_dict
7
+ from torch import nn
8
+ from torch.utils.data import Dataset
6
9
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
7
10
  from typing_extensions import override
8
11
 
@@ -36,17 +39,29 @@ class CLIPVisionModelPool(BaseModelPool):
36
39
 
37
40
  def load_processor(self, *args, **kwargs) -> CLIPProcessor:
38
41
  assert self._processor is not None, "Processor is not defined in the config"
39
- processor = instantiate(self._processor, *args, **kwargs)
42
+ if isinstance(self._processor, str):
43
+ log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
44
+ processor = CLIPProcessor.from_pretrained(self._processor)
45
+ else:
46
+ processor = instantiate(self._processor, *args, **kwargs)
40
47
  return processor
41
48
 
42
49
  def load_clip_model(self, model_name: str, *args, **kwargs) -> CLIPModel:
43
50
  model_config = self._models[model_name]
44
- assert isinstance(model_config, DictConfig), "Model config must be a DictConfig"
45
- model_config = deepcopy(model_config)
46
- with open_dict(model_config):
47
- model_config._target_ = "transformers.CLIPModel.from_pretrained"
48
- clip_model = instantiate(model_config, *args, **kwargs)
49
- return clip_model
51
+
52
+ if isinstance(model_config, str):
53
+ log.info(f"Loading `transformers.CLIPModel`: {model_config}")
54
+ clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
55
+ return clip_model
56
+ else:
57
+ assert isinstance(
58
+ model_config, DictConfig
59
+ ), "Model config must be a DictConfig"
60
+ model_config = deepcopy(model_config)
61
+ with open_dict(model_config):
62
+ model_config._target_ = "transformers.CLIPModel.from_pretrained"
63
+ clip_model = instantiate(model_config, *args, **kwargs)
64
+ return clip_model
50
65
 
51
66
  @override
52
67
  def save_model(self, model: CLIPVisionModel, path: str):
@@ -59,3 +74,72 @@ class CLIPVisionModelPool(BaseModelPool):
59
74
  """
60
75
  with timeit_context(f'Saving clip vision model to "{path}"'):
61
76
  model.save_pretrained(path)
77
+
78
+ def load_model(
79
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
80
+ ) -> CLIPVisionModel:
81
+ """
82
+ This method is used to load a CLIPVisionModel from the model pool.
83
+
84
+ Example configuration could be:
85
+
86
+ ```yaml
87
+ models:
88
+ cifar10: tanganke/clip-vit-base-patch32_cifar10
89
+ sun397: tanganke/clip-vit-base-patch32_sun397
90
+ stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
91
+ ```
92
+
93
+ Args:
94
+ model_name_or_config (Union[str, DictConfig]): The name of the model or the model configuration.
95
+
96
+ Returns:
97
+ CLIPVisionModel: The loaded CLIPVisionModel.
98
+ """
99
+ if (
100
+ isinstance(model_name_or_config, str)
101
+ and model_name_or_config in self._models
102
+ ):
103
+ model = self._models[model_name_or_config]
104
+ if isinstance(model, str):
105
+ log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
106
+ return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
107
+ if isinstance(model, nn.Module):
108
+ log.info(f"Returning existing model: {model}")
109
+ return model
110
+
111
+ # If the model is not a string, we use the default load_model method
112
+ return super().load_model(model_name_or_config, *args, **kwargs)
113
+
114
+ def load_train_dataset(self, dataset_name: str, *args, **kwargs):
115
+ dataset_config = self._train_datasets[dataset_name]
116
+ if isinstance(dataset_config, str):
117
+ log.info(
118
+ f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
119
+ )
120
+ dataset = load_dataset(dataset_config, split="train")
121
+ else:
122
+ dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
123
+ return dataset
124
+
125
+ def load_val_dataset(self, dataset_name: str, *args, **kwargs):
126
+ dataset_config = self._val_datasets[dataset_name]
127
+ if isinstance(dataset_config, str):
128
+ log.info(
129
+ f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
130
+ )
131
+ dataset = load_dataset(dataset_config, split="validation")
132
+ else:
133
+ dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
134
+ return dataset
135
+
136
+ def load_test_dataset(self, dataset_name: str, *args, **kwargs):
137
+ dataset_config = self._test_datasets[dataset_name]
138
+ if isinstance(dataset_config, str):
139
+ log.info(
140
+ f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
141
+ )
142
+ dataset = load_dataset(dataset_config, split="test")
143
+ else:
144
+ dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
145
+ return dataset
@@ -0,0 +1 @@
1
+ from .surgerymodelwrapper import SurgeryModelWrapper
@@ -1,5 +1,5 @@
1
1
  import math
2
- from typing import TYPE_CHECKING, List, Union, Callable, Generic
2
+ from typing import TYPE_CHECKING, Callable, Generic, List, Union
3
3
 
4
4
  import torch
5
5
  from torch import nn
@@ -7,6 +7,7 @@ from transformers.models.clip.modeling_clip import (
7
7
  CLIPVisionModel,
8
8
  CLIPVisionTransformer,
9
9
  )
10
+
10
11
  from fusion_bench.utils.type import TorchModelType
11
12
 
12
13
 
@@ -16,7 +16,7 @@ import torch
16
16
  from torch import Tensor, nn
17
17
  from torch.func import functional_call
18
18
 
19
- from fusion_bench.utils.type import TorchModelType, StateDictType
19
+ from fusion_bench.utils.type import StateDictType, TorchModelType
20
20
 
21
21
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
22
22
 
@@ -22,7 +22,7 @@ import torch
22
22
  from torch import Tensor, nn
23
23
  from torch.func import functional_call
24
24
 
25
- from fusion_bench.utils.type import TorchModelType, StateDictType
25
+ from fusion_bench.utils.type import StateDictType, TorchModelType
26
26
 
27
27
  log = logging.getLogger(__name__)
28
28
 
@@ -185,10 +185,13 @@ class FabricModelFusionProgram(
185
185
  report = taskpool.evaluate(merged_model)
186
186
  return report
187
187
  elif isinstance(merged_model, Dict):
188
- model = merged_model.pop("model")
189
- report: dict = taskpool.evaluate(model)
190
- report.update(merged_model)
191
- print(report)
188
+ report = {}
189
+ for key, item in merged_model.items():
190
+ if isinstance(item, nn.Module):
191
+ report[key] = taskpool.evaluate(item)
192
+ else:
193
+ # metadata
194
+ report[key] = item
192
195
  return report
193
196
  elif isinstance(merged_model, Iterable):
194
197
  return [
@@ -11,10 +11,10 @@ import functools
11
11
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
12
12
 
13
13
  import lightning as L
14
+ import numpy as np
14
15
  import torch
15
16
  from omegaconf import DictConfig
16
17
  from torch.utils.data import Subset
17
- import numpy as np
18
18
  from tqdm.auto import tqdm
19
19
 
20
20
  from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
@@ -58,11 +58,24 @@ class CLIPTemplateFactory:
58
58
  "templates": "templates",
59
59
  },
60
60
  "nateraw/rendered-sst2": ".rendered_sst2",
61
+ "rendered-sst2": ".rendered_sst2",
61
62
  "tanganke/stl10": ".stl10",
63
+ "stl10": ".stl10",
62
64
  "dpdl-benchmark/oxford_flowers102": ".flower102",
65
+ "oxford_flowers102": ".flower102",
63
66
  "timm/oxford-iiit-pet": ".oxford_iiit_pet",
67
+ "oxford-iiit-pet": ".oxford_iiit_pet",
64
68
  "imagenet": ".imagenet",
65
69
  "tiny-imagenet": ".tiny_imagenet",
70
+ "pcam": ".pcam",
71
+ "fer2013": ".fer2013",
72
+ "emnist_mnist": ".emnist_mnist",
73
+ "emnist_letters": ".emnist_letters",
74
+ "kmnist": ".kmnist",
75
+ "food101": ".food101",
76
+ "fashion_mnist": ".fashion_mnist",
77
+ "cub-200-2011": ".cub_200_2011",
78
+ "mango-leaf-disease": ".mango_leaf_disease",
66
79
  }
67
80
 
68
81
  @staticmethod
@@ -168,48 +181,3 @@ class CLIPTemplateFactory:
168
181
 
169
182
  def get_classnames_and_templates(dataset_name: str):
170
183
  return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
171
-
172
-
173
- def _load_hf_dataset(dataset_name: str):
174
- """
175
- Load a dataset from the Hugging Face datasets library based on the specified dataset name.
176
-
177
- This function handles specific preprocessing steps for certain datasets to ensure consistency in dataset format.
178
- For example, it renames columns, removes unnecessary columns, and specifies subsets for certain datasets.
179
-
180
- Expected dataset format:
181
- - The dataset should have an "image" column containing the image data.
182
- - The dataset should have a "label" column containing the class labels.
183
-
184
- Args:
185
- dataset_name (str): The name of the dataset to load. Can be one of "svhn", "cifar10", "cifar100", "timm/oxford-iiit-pet", or any other dataset name supported by the Hugging Face datasets library. By default, the datasets have two columns: "image" and "label".
186
-
187
- Returns:
188
- A dataset object loaded from the Hugging Face datasets library, with any necessary preprocessing applied.
189
- """
190
- if dataset_name == "svhn":
191
- return load_dataset(dataset_name, "cropped_digits")
192
- elif dataset_name == "cifar10":
193
- dataset = load_dataset(dataset_name)
194
- dataset = dataset.rename_columns({"img": "image"})
195
- return dataset
196
- elif dataset_name == "cifar100":
197
- dataset = load_dataset(dataset_name)
198
- dataset = dataset.remove_columns(["coarse_label"]).rename_columns(
199
- {"img": "image", "fine_label": "label"}
200
- )
201
- return dataset
202
- elif dataset_name == "timm/oxford-iiit-pet":
203
- dataset = load_dataset(dataset_name)
204
- dataset = dataset.remove_columns(["image_id", "label_cat_dog"])
205
- return dataset
206
- else:
207
- return load_dataset(dataset_name)
208
-
209
-
210
- def load_clip_dataset(dataset: str, processor):
211
- hf_dataset = _load_hf_dataset(dataset)
212
- return (
213
- CLIPDataset(hf_dataset["train"], processor),
214
- CLIPDataset(hf_dataset["test"], processor),
215
- )
@@ -1,16 +1 @@
1
- import torch
2
-
3
-
4
- class CLIPDataset(torch.utils.data.Dataset):
5
- def __init__(self, dataset, processor):
6
- self.dataset = dataset
7
- self.processor = processor
8
-
9
- def __len__(self):
10
- return len(self.dataset)
11
-
12
- def __getitem__(self, idx):
13
- item = self.dataset[idx]
14
- image = item["image"]
15
- inputs = self.processor(images=[image], return_tensors="pt")["pixel_values"][0]
16
- return inputs, item["label"]
1
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
@@ -0,0 +1,208 @@
1
+ classname_mapping = {
2
+ "0": "Black_footed_Albatross",
3
+ "1": "Laysan_Albatross",
4
+ "2": "Sooty_Albatross",
5
+ "3": "Groove_billed_Ani",
6
+ "4": "Crested_Auklet",
7
+ "5": "Least_Auklet",
8
+ "6": "Parakeet_Auklet",
9
+ "7": "Rhinoceros_Auklet",
10
+ "8": "Brewer_Blackbird",
11
+ "9": "Red_winged_Blackbird",
12
+ "10": "Rusty_Blackbird",
13
+ "11": "Yellow_headed_Blackbird",
14
+ "12": "Bobolink",
15
+ "13": "Indigo_Bunting",
16
+ "14": "Lazuli_Bunting",
17
+ "15": "Painted_Bunting",
18
+ "16": "Cardinal",
19
+ "17": "Spotted_Catbird",
20
+ "18": "Gray_Catbird",
21
+ "19": "Yellow_breasted_Chat",
22
+ "20": "Eastern_Towhee",
23
+ "21": "Chuck_will_Widow",
24
+ "22": "Brandt_Cormorant",
25
+ "23": "Red_faced_Cormorant",
26
+ "24": "Pelagic_Cormorant",
27
+ "25": "Bronzed_Cowbird",
28
+ "26": "Shiny_Cowbird",
29
+ "27": "Brown_Creeper",
30
+ "28": "American_Crow",
31
+ "29": "Fish_Crow",
32
+ "30": "Black_billed_Cuckoo",
33
+ "31": "Mangrove_Cuckoo",
34
+ "32": "Yellow_billed_Cuckoo",
35
+ "33": "Gray_crowned_Rosy_Finch",
36
+ "34": "Purple_Finch",
37
+ "35": "Northern_Flicker",
38
+ "36": "Acadian_Flycatcher",
39
+ "37": "Great_Crested_Flycatcher",
40
+ "38": "Least_Flycatcher",
41
+ "39": "Olive_sided_Flycatcher",
42
+ "40": "Scissor_tailed_Flycatcher",
43
+ "41": "Vermilion_Flycatcher",
44
+ "42": "Yellow_bellied_Flycatcher",
45
+ "43": "Frigatebird",
46
+ "44": "Northern_Fulmar",
47
+ "45": "Gadwall",
48
+ "46": "American_Goldfinch",
49
+ "47": "European_Goldfinch",
50
+ "48": "Boat_tailed_Grackle",
51
+ "49": "Eared_Grebe",
52
+ "50": "Horned_Grebe",
53
+ "51": "Pied_billed_Grebe",
54
+ "52": "Western_Grebe",
55
+ "53": "Blue_Grosbeak",
56
+ "54": "Evening_Grosbeak",
57
+ "55": "Pine_Grosbeak",
58
+ "56": "Rose_breasted_Grosbeak",
59
+ "57": "Pigeon_Guillemot",
60
+ "58": "California_Gull",
61
+ "59": "Glaucous_winged_Gull",
62
+ "60": "Heermann_Gull",
63
+ "61": "Herring_Gull",
64
+ "62": "Ivory_Gull",
65
+ "63": "Ring_billed_Gull",
66
+ "64": "Slaty_backed_Gull",
67
+ "65": "Western_Gull",
68
+ "66": "Anna_Hummingbird",
69
+ "67": "Ruby_throated_Hummingbird",
70
+ "68": "Rufous_Hummingbird",
71
+ "69": "Green_Violetear",
72
+ "70": "Long_tailed_Jaeger",
73
+ "71": "Pomarine_Jaeger",
74
+ "72": "Blue_Jay",
75
+ "73": "Florida_Jay",
76
+ "74": "Green_Jay",
77
+ "75": "Dark_eyed_Junco",
78
+ "76": "Tropical_Kingbird",
79
+ "77": "Gray_Kingbird",
80
+ "78": "Belted_Kingfisher",
81
+ "79": "Green_Kingfisher",
82
+ "80": "Pied_Kingfisher",
83
+ "81": "Ringed_Kingfisher",
84
+ "82": "White_breasted_Kingfisher",
85
+ "83": "Red_legged_Kittiwake",
86
+ "84": "Horned_Lark",
87
+ "85": "Pacific_Loon",
88
+ "86": "Mallard",
89
+ "87": "Western_Meadowlark",
90
+ "88": "Hooded_Merganser",
91
+ "89": "Red_breasted_Merganser",
92
+ "90": "Mockingbird",
93
+ "91": "Nighthawk",
94
+ "92": "Clark_Nutcracker",
95
+ "93": "White_breasted_Nuthatch",
96
+ "94": "Baltimore_Oriole",
97
+ "95": "Hooded_Oriole",
98
+ "96": "Orchard_Oriole",
99
+ "97": "Scott_Oriole",
100
+ "98": "Ovenbird",
101
+ "99": "Brown_Pelican",
102
+ "100": "White_Pelican",
103
+ "101": "Western_Wood_Pewee",
104
+ "102": "Sayornis",
105
+ "103": "American_Pipit",
106
+ "104": "Whip_poor_Will",
107
+ "105": "Horned_Puffin",
108
+ "106": "Common_Raven",
109
+ "107": "White_necked_Raven",
110
+ "108": "American_Redstart",
111
+ "109": "Geococcyx",
112
+ "110": "Loggerhead_Shrike",
113
+ "111": "Great_Grey_Shrike",
114
+ "112": "Baird_Sparrow",
115
+ "113": "Black_throated_Sparrow",
116
+ "114": "Brewer_Sparrow",
117
+ "115": "Chipping_Sparrow",
118
+ "116": "Clay_colored_Sparrow",
119
+ "117": "House_Sparrow",
120
+ "118": "Field_Sparrow",
121
+ "119": "Fox_Sparrow",
122
+ "120": "Grasshopper_Sparrow",
123
+ "121": "Harris_Sparrow",
124
+ "122": "Henslow_Sparrow",
125
+ "123": "Le_Conte_Sparrow",
126
+ "124": "Lincoln_Sparrow",
127
+ "125": "Nelson_Sharp_tailed_Sparrow",
128
+ "126": "Savannah_Sparrow",
129
+ "127": "Seaside_Sparrow",
130
+ "128": "Song_Sparrow",
131
+ "129": "Tree_Sparrow",
132
+ "130": "Vesper_Sparrow",
133
+ "131": "White_crowned_Sparrow",
134
+ "132": "White_throated_Sparrow",
135
+ "133": "Cape_Glossy_Starling",
136
+ "134": "Bank_Swallow",
137
+ "135": "Barn_Swallow",
138
+ "136": "Cliff_Swallow",
139
+ "137": "Tree_Swallow",
140
+ "138": "Scarlet_Tanager",
141
+ "139": "Summer_Tanager",
142
+ "140": "Artic_Tern",
143
+ "141": "Black_Tern",
144
+ "142": "Caspian_Tern",
145
+ "143": "Common_Tern",
146
+ "144": "Elegant_Tern",
147
+ "145": "Forsters_Tern",
148
+ "146": "Least_Tern",
149
+ "147": "Green_tailed_Towhee",
150
+ "148": "Brown_Thrasher",
151
+ "149": "Sage_Thrasher",
152
+ "150": "Black_capped_Vireo",
153
+ "151": "Blue_headed_Vireo",
154
+ "152": "Philadelphia_Vireo",
155
+ "153": "Red_eyed_Vireo",
156
+ "154": "Warbling_Vireo",
157
+ "155": "White_eyed_Vireo",
158
+ "156": "Yellow_throated_Vireo",
159
+ "157": "Bay_breasted_Warbler",
160
+ "158": "Black_and_white_Warbler",
161
+ "159": "Black_throated_Blue_Warbler",
162
+ "160": "Blue_winged_Warbler",
163
+ "161": "Canada_Warbler",
164
+ "162": "Cape_May_Warbler",
165
+ "163": "Cerulean_Warbler",
166
+ "164": "Chestnut_sided_Warbler",
167
+ "165": "Golden_winged_Warbler",
168
+ "166": "Hooded_Warbler",
169
+ "167": "Kentucky_Warbler",
170
+ "168": "Magnolia_Warbler",
171
+ "169": "Mourning_Warbler",
172
+ "170": "Myrtle_Warbler",
173
+ "171": "Nashville_Warbler",
174
+ "172": "Orange_crowned_Warbler",
175
+ "173": "Palm_Warbler",
176
+ "174": "Pine_Warbler",
177
+ "175": "Prairie_Warbler",
178
+ "176": "Prothonotary_Warbler",
179
+ "177": "Swainson_Warbler",
180
+ "178": "Tennessee_Warbler",
181
+ "179": "Wilson_Warbler",
182
+ "180": "Worm_eating_Warbler",
183
+ "181": "Yellow_Warbler",
184
+ "182": "Northern_Waterthrush",
185
+ "183": "Louisiana_Waterthrush",
186
+ "184": "Bohemian_Waxwing",
187
+ "185": "Cedar_Waxwing",
188
+ "186": "American_Three_toed_Woodpecker",
189
+ "187": "Pileated_Woodpecker",
190
+ "188": "Red_bellied_Woodpecker",
191
+ "189": "Red_cockaded_Woodpecker",
192
+ "190": "Red_headed_Woodpecker",
193
+ "191": "Downy_Woodpecker",
194
+ "192": "Bewick_Wren",
195
+ "193": "Cactus_Wren",
196
+ "194": "Carolina_Wren",
197
+ "195": "House_Wren",
198
+ "196": "Marsh_Wren",
199
+ "197": "Rock_Wren",
200
+ "198": "Winter_Wren",
201
+ "199": "Common_Yellowthroat",
202
+ }
203
+
204
+ classnames = [classname_mapping[str(i)] for i in range(200)]
205
+ templates = [
206
+ lambda c: f"a photo of a {c}.",
207
+ lambda c: f"a photo of the {c}.",
208
+ ]
@@ -0,0 +1,31 @@
1
+ classnames_mapping = {
2
+ "0": "A",
3
+ "1": "B",
4
+ "2": "C",
5
+ "3": "D",
6
+ "4": "E",
7
+ "5": "F",
8
+ "6": "G",
9
+ "7": "H",
10
+ "8": "I",
11
+ "9": "J",
12
+ "10": "K",
13
+ "11": "L",
14
+ "12": "M",
15
+ "13": "N",
16
+ "14": "O",
17
+ "15": "P",
18
+ "16": "Q",
19
+ "17": "R",
20
+ "18": "S",
21
+ "19": "T",
22
+ "20": "U",
23
+ "21": "V",
24
+ "22": "W",
25
+ "23": "X",
26
+ "24": "Y",
27
+ "25": "Z",
28
+ }
29
+
30
+ classnames = [classnames_mapping[str(i)] for i in range(26)]
31
+ templates = [lambda c: f'a photo of the digit character: "{c}".']
@@ -0,0 +1,5 @@
1
+ # https://huggingface.co/datasets/tanganke/emnist_mnist
2
+ classnames = [str(i) for i in range(10)]
3
+ templates = [
4
+ lambda c: f'a photo of the number: "{c}".',
5
+ ]
@@ -0,0 +1,18 @@
1
+ classname_mapping = {
2
+ "0": "T - shirt / top",
3
+ "1": "Trouser",
4
+ "2": "Pullover",
5
+ "3": "Dress",
6
+ "4": "Coat",
7
+ "5": "Sandal",
8
+ "6": "Shirt",
9
+ "7": "Sneaker",
10
+ "8": "Bag",
11
+ "9": "Ankle boot",
12
+ }
13
+ classnames = [classname_mapping[str(i)] for i in range(10)]
14
+
15
+ templates = [
16
+ lambda c: f"a photo of a {c}.",
17
+ lambda c: f"a photo of the {c}.",
18
+ ]
@@ -0,0 +1,18 @@
1
+ classnames = [
2
+ "angry",
3
+ "disgusted",
4
+ "fearful",
5
+ "happy",
6
+ "neutral",
7
+ "sad",
8
+ "surprised",
9
+ ]
10
+
11
+ templates = [
12
+ lambda c: f"a photo of a {c} looking face.",
13
+ lambda c: f"a photo of a face showing the emotion: {c}.",
14
+ lambda c: f"a photo of a face looking {c}.",
15
+ lambda c: f"a face that looks {c}.",
16
+ lambda c: f"they look {c}.",
17
+ lambda c: f"look at how {c} they are.",
18
+ ]