fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. fusion_bench/compat/method/__init__.py +3 -1
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/gsm8k.py +2 -2
  6. fusion_bench/method/__init__.py +12 -2
  7. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/doge_ta/__init__.py +2 -0
  10. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  11. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  12. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  13. fusion_bench/method/gossip/__init__.py +3 -0
  14. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  15. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  16. fusion_bench/method/gossip/entropy_loss.py +25 -0
  17. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  18. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  19. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  20. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  21. fusion_bench/method/gossip/utils.py +74 -0
  22. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  23. fusion_bench/method/opcm/opcm.py +102 -84
  24. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  25. fusion_bench/method/opcm/ties_merging.py +71 -52
  26. fusion_bench/method/pwe_moe/module.py +1 -1
  27. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  28. fusion_bench/method/regmean/regmean.py +25 -17
  29. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  30. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  31. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  32. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  33. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  34. fusion_bench/method/we_moe/we_moe.py +14 -15
  35. fusion_bench/mixins/__init__.py +6 -3
  36. fusion_bench/mixins/hydra_config.py +49 -0
  37. fusion_bench/mixins/openclip_classification.py +11 -0
  38. fusion_bench/mixins/simple_profiler.py +4 -2
  39. fusion_bench/modelpool/__init__.py +3 -1
  40. fusion_bench/modelpool/base_pool.py +2 -2
  41. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  42. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  43. fusion_bench/models/open_clip/__init__.py +6 -0
  44. fusion_bench/models/open_clip/modeling.py +176 -0
  45. fusion_bench/models/open_clip/utils.py +311 -0
  46. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  47. fusion_bench/models/parameter_dict.py +54 -13
  48. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  49. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  50. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  51. fusion_bench/taskpool/__init__.py +5 -3
  52. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  53. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  54. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  55. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  56. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  57. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  58. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  59. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  60. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  61. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  62. fusion_bench/utils/data.py +12 -0
  63. fusion_bench/utils/devices.py +14 -0
  64. fusion_bench/utils/instantiate.py +12 -0
  65. fusion_bench/utils/misc.py +9 -2
  66. fusion_bench/utils/packages.py +14 -0
  67. fusion_bench/utils/parameters.py +1 -1
  68. fusion_bench/utils/tensorboard.py +1 -1
  69. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
  70. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
  71. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  72. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  73. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  74. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  75. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  76. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  77. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  78. fusion_bench_config/fabric/auto.yaml +0 -1
  79. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  80. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  81. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  84. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  85. fusion_bench_config/llama_full_finetune.yaml +0 -2
  86. fusion_bench_config/llama_model_fusion.yaml +0 -2
  87. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  88. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  89. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  90. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  91. fusion_bench_config/method/adamerging.yaml +2 -2
  92. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  93. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  94. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  95. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  96. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  97. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  98. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  99. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  100. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  101. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  102. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  103. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  104. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  105. fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
  106. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  107. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  108. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  109. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  110. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  111. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  112. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  113. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  114. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  115. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  116. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  117. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  118. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  119. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  120. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  121. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  122. fusion_bench_config/method/model_recombination.yaml +0 -1
  123. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  124. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  125. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  126. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  127. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  128. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  129. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  130. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  131. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  132. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  133. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  134. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  135. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  136. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  137. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  139. fusion_bench_config/method/ties_merging.yaml +1 -1
  140. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  141. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  150. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  151. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  152. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  161. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  162. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  163. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  164. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  165. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  166. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  167. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  169. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  170. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  171. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  172. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  173. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  174. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  175. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  176. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  177. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  178. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
  179. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
  180. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  181. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  182. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  183. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  184. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  185. fusion_bench_config/nyuv2_config.yaml +0 -2
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  190. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  191. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  192. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  193. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  194. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  195. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  196. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  197. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  198. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
  199. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,255 @@
1
+ import logging
2
+ import pickle
3
+ import sys
4
+ from typing import Callable, Optional, Union, cast
5
+
6
+ import torch
7
+ from datasets import load_dataset
8
+ from omegaconf import DictConfig, OmegaConf
9
+ from torch import nn
10
+
11
+ from fusion_bench.modelpool import BaseModelPool
12
+ from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
13
+ from fusion_bench.utils import instantiate
14
+ from fusion_bench.utils.expr import is_expr_match
15
+ from fusion_bench.utils.packages import _get_package_version, compare_versions
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+ # Add flag to track if warning has been shown
20
+ _openclip_version_warning_shown = False
21
+
22
+
23
+ def _check_and_redirect_open_clip_modeling():
24
+ global _openclip_version_warning_shown
25
+ if compare_versions(_get_package_version("open-clip-torch").__str__(), "2.0.2") > 0:
26
+ if not _openclip_version_warning_shown:
27
+ log.warning(
28
+ "OpenCLIP version is greater than 2.0.2. This may cause issues with the modelpool."
29
+ )
30
+ _openclip_version_warning_shown = True
31
+ import open_clip.model
32
+ import open_clip.transformer
33
+
34
+ if not hasattr(open_clip.model, "VisualTransformer"):
35
+ open_clip.model.VisualTransformer = open_clip.model.VisionTransformer
36
+ if not hasattr(open_clip.model, "Transformer"):
37
+ open_clip.model.Transformer = open_clip.transformer.Transformer
38
+ if not hasattr(open_clip.model, "ResidualAttentionBlock"):
39
+ open_clip.model.ResidualAttentionBlock = (
40
+ open_clip.transformer.ResidualAttentionBlock
41
+ )
42
+
43
+ try:
44
+ import src
45
+ import src.modeling
46
+ except ImportError:
47
+ if "src" not in sys.modules:
48
+ # redirect the import of `src` to `fusion_bench.models.open_clip`
49
+ import fusion_bench.models.open_clip as open_clip
50
+
51
+ sys.modules["src"] = open_clip
52
+ log.warning(
53
+ "`src` is not imported."
54
+ "Redirecting the import to `fusion_bench.models.open_clip`"
55
+ )
56
+ if "src.modeling" not in sys.modules:
57
+ # redirect the import of `src.modeling` to `fusion_bench.models.open_clip.modeling`
58
+ import fusion_bench.models.open_clip.modeling as open_clip_modeling
59
+
60
+ sys.modules["src.modeling"] = open_clip_modeling
61
+ log.warning(
62
+ "`src.modeling` is not imported."
63
+ "Redirecting the import to `fusion_bench.models.open_clip.modeling`"
64
+ )
65
+
66
+
67
+ def load_classifier_head(model_config: Union[str, DictConfig], *args, **kwargs):
68
+ if isinstance(model_config, str):
69
+ _check_and_redirect_open_clip_modeling()
70
+ log.info(f"Loading `ClassificationHead` from {model_config}")
71
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
72
+ head = torch.load(model_config, weights_only=weights_only, *args, **kwargs)
73
+ elif isinstance(model_config, nn.Module):
74
+ log.info(f"Returning existing model: {model_config}")
75
+ head = model_config
76
+ else:
77
+ head = instantiate(model_config, *args, **kwargs)
78
+ head = cast(ClassificationHead, head)
79
+ return head
80
+
81
+
82
+ class OpenCLIPVisionModelPool(BaseModelPool):
83
+ """
84
+ A model pool for managing OpenCLIP Vision models (models from task vector paper).
85
+ """
86
+
87
+ _train_processor = None
88
+ _test_processor = None
89
+
90
+ def __init__(
91
+ self,
92
+ models: DictConfig,
93
+ classification_heads: Optional[DictConfig] = None,
94
+ **kwargs,
95
+ ):
96
+ super().__init__(models, **kwargs)
97
+ self._classification_heads = classification_heads
98
+
99
+ @property
100
+ def train_processor(self):
101
+ if self._train_processor is None:
102
+ encoder: ImageEncoder = self.load_pretrained_or_first_model()
103
+ self._train_processor = encoder.train_preprocess
104
+ if self._test_processor is None:
105
+ self._test_processor = encoder.val_preprocess
106
+ return self._train_processor
107
+
108
+ @property
109
+ def test_processor(self):
110
+ if self._test_processor is None:
111
+ encoder: ImageEncoder = self.load_pretrained_or_first_model()
112
+ if self._train_processor is None:
113
+ self._train_processor = encoder.train_preprocess
114
+ self._test_processor = encoder.val_preprocess
115
+ return self._test_processor
116
+
117
+ def load_model(
118
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
119
+ ) -> ImageEncoder:
120
+ R"""
121
+ The model config can be:
122
+
123
+ - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
124
+ - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
125
+ - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
126
+ - Default, load the model using `instantiate` from hydra.
127
+ """
128
+ if (
129
+ isinstance(model_name_or_config, str)
130
+ and model_name_or_config in self._models
131
+ ):
132
+ model_config = self._models[model_name_or_config]
133
+ else:
134
+ model_config = model_name_or_config
135
+ if isinstance(model_config, DictConfig):
136
+ model_config = OmegaConf.to_container(model_config, resolve=True)
137
+
138
+ if isinstance(model_config, str):
139
+ # the model config is a string, which is the path to the model checkpoint in pickle format
140
+ # load the model using `torch.load`
141
+ # this is the original usage in the task arithmetic codebase
142
+ _check_and_redirect_open_clip_modeling()
143
+ log.info(f"loading ImageEncoder from {model_config}")
144
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
145
+ try:
146
+ encoder = torch.load(
147
+ model_config, weights_only=weights_only, *args, **kwargs
148
+ )
149
+ except RuntimeError as e:
150
+ encoder = pickle.load(open(model_config, "rb"))
151
+ elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
152
+ # the model config is a dictionary with the following keys:
153
+ # - model_name: str, the name of the model
154
+ # - pickle_path: str, the path to the binary file (pickle format)
155
+ # load the model from the binary file (pickle format)
156
+ # this is useful when you use a newer version of torchvision
157
+ _check_and_redirect_open_clip_modeling()
158
+ log.info(
159
+ f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
160
+ )
161
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
162
+ try:
163
+ encoder = torch.load(
164
+ model_config["pickle_path"],
165
+ weights_only=weights_only,
166
+ *args,
167
+ **kwargs,
168
+ )
169
+ except RuntimeError as e:
170
+ encoder = pickle.load(open(model_config["pickle_path"], "rb"))
171
+ _encoder = ImageEncoder(model_config["model_name"])
172
+ _encoder.load_state_dict(encoder.state_dict())
173
+ encoder = _encoder
174
+ elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
175
+ # the model config is a dictionary with the following keys:
176
+ # - model_name: str, the name of the model
177
+ # - state_dict_path: str, the path to the state dict file
178
+ # load the model from the state dict file
179
+ log.info(
180
+ f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
181
+ )
182
+ encoder = ImageEncoder(model_config["model_name"])
183
+ encoder.load_state_dict(
184
+ torch.load(
185
+ model_config["state_dict_path"], weights_only=True, *args, **kwargs
186
+ )
187
+ )
188
+ elif isinstance(model_config, nn.Module):
189
+ # the model config is an existing model
190
+ log.info(f"Returning existing model: {model_config}")
191
+ encoder = model_config
192
+ else:
193
+ encoder = super().load_model(model_name_or_config, *args, **kwargs)
194
+ encoder = cast(ImageEncoder, encoder)
195
+
196
+ # setup the train and test processors
197
+ if self._train_processor is None and hasattr(encoder, "train_preprocess"):
198
+ self._train_processor = encoder.train_preprocess
199
+ if self._test_processor is None and hasattr(encoder, "val_preprocess"):
200
+ self._test_processor = encoder.val_preprocess
201
+
202
+ return encoder
203
+
204
+ def load_classification_head(
205
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
206
+ ) -> ClassificationHead:
207
+ R"""
208
+ The model config can be:
209
+
210
+ - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
211
+ - Default, load the model using `instantiate` from hydra.
212
+ """
213
+ if (
214
+ isinstance(model_name_or_config, str)
215
+ and model_name_or_config in self._classification_heads
216
+ ):
217
+ model_config = self._classification_heads[model_name_or_config]
218
+ else:
219
+ model_config = model_name_or_config
220
+
221
+ head = load_classifier_head(model_config, *args, **kwargs)
222
+ return head
223
+
224
+ def load_train_dataset(self, dataset_name: str, *args, **kwargs):
225
+ dataset_config = self._train_datasets[dataset_name]
226
+ if isinstance(dataset_config, str):
227
+ log.info(
228
+ f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
229
+ )
230
+ dataset = load_dataset(dataset_config, split="train")
231
+ else:
232
+ dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
233
+ return dataset
234
+
235
+ def load_val_dataset(self, dataset_name: str, *args, **kwargs):
236
+ dataset_config = self._val_datasets[dataset_name]
237
+ if isinstance(dataset_config, str):
238
+ log.info(
239
+ f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
240
+ )
241
+ dataset = load_dataset(dataset_config, split="validation")
242
+ else:
243
+ dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
244
+ return dataset
245
+
246
+ def load_test_dataset(self, dataset_name: str, *args, **kwargs):
247
+ dataset_config = self._test_datasets[dataset_name]
248
+ if isinstance(dataset_config, str):
249
+ log.info(
250
+ f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
251
+ )
252
+ dataset = load_dataset(dataset_config, split="test")
253
+ else:
254
+ dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
255
+ return dataset
@@ -0,0 +1,6 @@
1
+ """
2
+ This module contains the support for the open_clip model.
3
+ Modified from https://github.com/nik-dim/tall_masks/
4
+ """
5
+
6
+ from .modeling import ClassificationHead, ImageClassifier, ImageEncoder
@@ -0,0 +1,176 @@
1
+ from typing import Callable, List
2
+
3
+ import open_clip
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from . import utils
8
+ from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
9
+
10
+
11
+ class ImageEncoder(torch.nn.Module):
12
+ R"""
13
+ Examples:
14
+
15
+ load the image encoder for a given model name
16
+
17
+ >>> from fusion_bench.models.open_clip import ImageEncoder
18
+ >>> image_encoder = ImageEncoder(model_name="ViT-B-32")
19
+ """
20
+
21
+ def __init__(self, model_name: str, keep_lang=False):
22
+ super().__init__()
23
+ assert (
24
+ model_name in MODELS
25
+ ), f"Invalid model name: {model_name}. Valid models are: {MODELS}"
26
+
27
+ if "__pretrained__" in model_name:
28
+ name, pretrained = model_name.split("__pretrained__")
29
+ elif "__init__" in model_name:
30
+ print("Using random initialization.")
31
+ name, pretrained = model_name.split("__init__")[0], None
32
+ else:
33
+ name = model_name
34
+ pretrained = "openai"
35
+ (
36
+ self.model,
37
+ self.train_preprocess,
38
+ self.val_preprocess,
39
+ ) = open_clip.create_model_and_transforms(
40
+ name, pretrained=pretrained, cache_dir=OPENCLIP_CACHEDIR
41
+ )
42
+
43
+ self.cache_dir = CACHEDIR
44
+
45
+ if not keep_lang and hasattr(self.model, "transformer"):
46
+ delattr(self.model, "transformer")
47
+
48
+ def forward(self, images):
49
+ assert self.model is not None
50
+ return self.model.encode_image(images)
51
+
52
+ def __call__(self, inputs):
53
+ return self.forward(inputs)
54
+
55
+ def save(self, filename):
56
+ print(f"Saving image encoder to {filename}")
57
+ utils.torch_save(self, filename)
58
+
59
+ @classmethod
60
+ def load(cls, model_name, filename):
61
+ print(f"Loading image encoder from {filename}")
62
+
63
+ state_dict = torch.load(filename, map_location="cpu")
64
+
65
+ model = cls(model_name)
66
+ model.load_state_dict(state_dict)
67
+ return model
68
+
69
+
70
+ class ClassificationHead(torch.nn.Linear):
71
+ def __init__(
72
+ self,
73
+ normalize: bool,
74
+ weights: Tensor,
75
+ biases: Tensor = None,
76
+ ):
77
+ output_size, input_size = weights.shape
78
+ super().__init__(input_size, output_size)
79
+ self.normalize = normalize
80
+ if weights is not None:
81
+ self.weight = torch.nn.Parameter(weights.clone())
82
+ if biases is not None:
83
+ self.bias = torch.nn.Parameter(biases.clone())
84
+ else:
85
+ self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
86
+
87
+ def forward(self, inputs: Tensor):
88
+ if self.normalize:
89
+ inputs = inputs / inputs.norm(dim=-1, keepdim=True)
90
+ return super().forward(inputs)
91
+
92
+ def __call__(self, inputs: Tensor):
93
+ return self.forward(inputs)
94
+
95
+ def save(self, filename):
96
+ print(f"Saving classification head to {filename}")
97
+ utils.torch_save(self, filename, save_state_dict=False)
98
+
99
+ @classmethod
100
+ def load(cls, filename):
101
+ # print(f"Loading classification head from {filename}")
102
+ return utils.torch_load(filename)
103
+
104
+
105
+ class ImageClassifier(torch.nn.Module):
106
+ train_preprocess: Callable
107
+ val_preprocess: Callable
108
+
109
+ def __init__(
110
+ self,
111
+ image_encoder: ImageEncoder,
112
+ classification_head: ClassificationHead,
113
+ ):
114
+ super().__init__()
115
+ self.image_encoder = image_encoder
116
+ self.classification_head = classification_head
117
+ if self.image_encoder is not None:
118
+ self.train_preprocess = self.image_encoder.train_preprocess
119
+ self.val_preprocess = self.image_encoder.val_preprocess
120
+
121
+ def freeze_head(self):
122
+ self.classification_head.weight.requires_grad_(False)
123
+ self.classification_head.bias.requires_grad_(False)
124
+
125
+ def forward(self, inputs: Tensor):
126
+ features = self.image_encoder(inputs)
127
+ outputs = self.classification_head(features)
128
+ return outputs
129
+
130
+ def __call__(self, inputs):
131
+ return self.forward(inputs)
132
+
133
+ def save(self, filename):
134
+ print(f"Saving image classifier to {filename}")
135
+ utils.torch_save(self, filename)
136
+
137
+ @classmethod
138
+ def load(cls, filename):
139
+ print(f"Loading image classifier from {filename}")
140
+ return utils.torch_load(filename)
141
+
142
+
143
+ class MultiHeadImageClassifier(torch.nn.Module):
144
+ def __init__(
145
+ self,
146
+ image_encoder: ImageEncoder,
147
+ classification_heads: List[ClassificationHead],
148
+ ):
149
+ super().__init__()
150
+ self.image_encoder = image_encoder
151
+ self.classification_heads = torch.nn.ModuleList(classification_heads)
152
+ if self.image_encoder is not None:
153
+ self.train_preprocess = self.image_encoder.train_preprocess
154
+ self.val_preprocess = self.image_encoder.val_preprocess
155
+
156
+ def freeze_head(self):
157
+ for idx in range(len(self.classification_heads)):
158
+ self.classification_heads[idx].weight.requires_grad_(False)
159
+ self.classification_heads[idx].bias.requires_grad_(False)
160
+
161
+ def forward(self, inputs, head_idx):
162
+ features = self.image_encoder(inputs)
163
+ outputs = self.classification_heads[head_idx](features)
164
+ return outputs
165
+
166
+ def __call__(self, inputs, head_idx):
167
+ return self.forward(inputs, head_idx)
168
+
169
+ def save(self, filename):
170
+ print(f"Saving image classifier to {filename}")
171
+ utils.torch_save(self, filename)
172
+
173
+ @classmethod
174
+ def load(cls, filename):
175
+ print(f"Loading image classifier from {filename}")
176
+ return utils.torch_load(filename)