fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  10. fusion_bench/method/gossip/__init__.py +3 -0
  11. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  12. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  13. fusion_bench/method/gossip/entropy_loss.py +25 -0
  14. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  15. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  16. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  17. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  18. fusion_bench/method/gossip/utils.py +74 -0
  19. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  20. fusion_bench/method/opcm/opcm.py +16 -7
  21. fusion_bench/method/pwe_moe/module.py +1 -1
  22. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  23. fusion_bench/method/regmean/regmean.py +25 -17
  24. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  25. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  27. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  28. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  29. fusion_bench/method/we_moe/we_moe.py +14 -15
  30. fusion_bench/mixins/__init__.py +6 -3
  31. fusion_bench/mixins/hydra_config.py +49 -0
  32. fusion_bench/mixins/openclip_classification.py +11 -0
  33. fusion_bench/mixins/simple_profiler.py +4 -2
  34. fusion_bench/modelpool/__init__.py +3 -1
  35. fusion_bench/modelpool/base_pool.py +2 -2
  36. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  37. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  38. fusion_bench/models/open_clip/__init__.py +6 -0
  39. fusion_bench/models/open_clip/modeling.py +176 -0
  40. fusion_bench/models/open_clip/utils.py +311 -0
  41. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  42. fusion_bench/models/parameter_dict.py +54 -13
  43. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  44. fusion_bench/taskpool/__init__.py +5 -3
  45. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  46. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  47. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  48. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  49. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  50. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  51. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  52. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  53. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  54. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  55. fusion_bench/utils/data.py +12 -0
  56. fusion_bench/utils/devices.py +14 -0
  57. fusion_bench/utils/instantiate.py +12 -0
  58. fusion_bench/utils/misc.py +9 -2
  59. fusion_bench/utils/packages.py +14 -0
  60. fusion_bench/utils/parameters.py +1 -1
  61. fusion_bench/utils/tensorboard.py +1 -1
  62. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
  63. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
  64. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  66. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  67. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  68. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  69. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  70. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  71. fusion_bench_config/fabric/auto.yaml +0 -1
  72. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  73. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  74. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  75. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  76. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  77. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  78. fusion_bench_config/llama_full_finetune.yaml +0 -2
  79. fusion_bench_config/llama_model_fusion.yaml +0 -2
  80. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  81. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  82. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  83. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  84. fusion_bench_config/method/adamerging.yaml +2 -2
  85. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  86. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  87. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  88. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  89. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  90. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  91. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  92. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  93. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  94. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  95. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  96. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  97. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  98. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  99. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  100. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  101. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  102. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  103. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  104. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  105. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  106. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  107. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  108. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  109. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  110. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  111. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  112. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  113. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  114. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  115. fusion_bench_config/method/model_recombination.yaml +0 -1
  116. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  117. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  118. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  119. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  120. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  121. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  122. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  123. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  124. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  125. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  126. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  127. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  128. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  129. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  130. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  131. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  132. fusion_bench_config/method/ties_merging.yaml +1 -1
  133. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  134. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  135. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  154. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  155. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  156. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  157. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  158. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  159. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  160. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  161. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  162. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  163. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  164. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  165. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  166. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  167. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  169. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  170. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  171. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  172. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  173. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  174. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  175. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  176. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  177. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  178. fusion_bench_config/nyuv2_config.yaml +0 -2
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  183. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  184. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  185. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  186. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  187. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  188. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  189. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
  190. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ from typing import Callable, List
2
+
3
+ import open_clip
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from . import utils
8
+ from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
9
+
10
+
11
+ class ImageEncoder(torch.nn.Module):
12
+ R"""
13
+ Examples:
14
+
15
+ load the image encoder for a given model name
16
+
17
+ >>> from fusion_bench.models.open_clip import ImageEncoder
18
+ >>> image_encoder = ImageEncoder(model_name="ViT-B-32")
19
+ """
20
+
21
+ def __init__(self, model_name: str, keep_lang=False):
22
+ super().__init__()
23
+ assert (
24
+ model_name in MODELS
25
+ ), f"Invalid model name: {model_name}. Valid models are: {MODELS}"
26
+
27
+ if "__pretrained__" in model_name:
28
+ name, pretrained = model_name.split("__pretrained__")
29
+ elif "__init__" in model_name:
30
+ print("Using random initialization.")
31
+ name, pretrained = model_name.split("__init__")[0], None
32
+ else:
33
+ name = model_name
34
+ pretrained = "openai"
35
+ (
36
+ self.model,
37
+ self.train_preprocess,
38
+ self.val_preprocess,
39
+ ) = open_clip.create_model_and_transforms(
40
+ name, pretrained=pretrained, cache_dir=OPENCLIP_CACHEDIR
41
+ )
42
+
43
+ self.cache_dir = CACHEDIR
44
+
45
+ if not keep_lang and hasattr(self.model, "transformer"):
46
+ delattr(self.model, "transformer")
47
+
48
+ def forward(self, images):
49
+ assert self.model is not None
50
+ return self.model.encode_image(images)
51
+
52
+ def __call__(self, inputs):
53
+ return self.forward(inputs)
54
+
55
+ def save(self, filename):
56
+ print(f"Saving image encoder to {filename}")
57
+ utils.torch_save(self, filename)
58
+
59
+ @classmethod
60
+ def load(cls, model_name, filename):
61
+ print(f"Loading image encoder from {filename}")
62
+
63
+ state_dict = torch.load(filename, map_location="cpu")
64
+
65
+ model = cls(model_name)
66
+ model.load_state_dict(state_dict)
67
+ return model
68
+
69
+
70
+ class ClassificationHead(torch.nn.Linear):
71
+ def __init__(
72
+ self,
73
+ normalize: bool,
74
+ weights: Tensor,
75
+ biases: Tensor = None,
76
+ ):
77
+ output_size, input_size = weights.shape
78
+ super().__init__(input_size, output_size)
79
+ self.normalize = normalize
80
+ if weights is not None:
81
+ self.weight = torch.nn.Parameter(weights.clone())
82
+ if biases is not None:
83
+ self.bias = torch.nn.Parameter(biases.clone())
84
+ else:
85
+ self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
86
+
87
+ def forward(self, inputs: Tensor):
88
+ if self.normalize:
89
+ inputs = inputs / inputs.norm(dim=-1, keepdim=True)
90
+ return super().forward(inputs)
91
+
92
+ def __call__(self, inputs: Tensor):
93
+ return self.forward(inputs)
94
+
95
+ def save(self, filename):
96
+ print(f"Saving classification head to {filename}")
97
+ utils.torch_save(self, filename, save_state_dict=False)
98
+
99
+ @classmethod
100
+ def load(cls, filename):
101
+ # print(f"Loading classification head from {filename}")
102
+ return utils.torch_load(filename)
103
+
104
+
105
+ class ImageClassifier(torch.nn.Module):
106
+ train_preprocess: Callable
107
+ val_preprocess: Callable
108
+
109
+ def __init__(
110
+ self,
111
+ image_encoder: ImageEncoder,
112
+ classification_head: ClassificationHead,
113
+ ):
114
+ super().__init__()
115
+ self.image_encoder = image_encoder
116
+ self.classification_head = classification_head
117
+ if self.image_encoder is not None:
118
+ self.train_preprocess = self.image_encoder.train_preprocess
119
+ self.val_preprocess = self.image_encoder.val_preprocess
120
+
121
+ def freeze_head(self):
122
+ self.classification_head.weight.requires_grad_(False)
123
+ self.classification_head.bias.requires_grad_(False)
124
+
125
+ def forward(self, inputs: Tensor):
126
+ features = self.image_encoder(inputs)
127
+ outputs = self.classification_head(features)
128
+ return outputs
129
+
130
+ def __call__(self, inputs):
131
+ return self.forward(inputs)
132
+
133
+ def save(self, filename):
134
+ print(f"Saving image classifier to {filename}")
135
+ utils.torch_save(self, filename)
136
+
137
+ @classmethod
138
+ def load(cls, filename):
139
+ print(f"Loading image classifier from {filename}")
140
+ return utils.torch_load(filename)
141
+
142
+
143
+ class MultiHeadImageClassifier(torch.nn.Module):
144
+ def __init__(
145
+ self,
146
+ image_encoder: ImageEncoder,
147
+ classification_heads: List[ClassificationHead],
148
+ ):
149
+ super().__init__()
150
+ self.image_encoder = image_encoder
151
+ self.classification_heads = torch.nn.ModuleList(classification_heads)
152
+ if self.image_encoder is not None:
153
+ self.train_preprocess = self.image_encoder.train_preprocess
154
+ self.val_preprocess = self.image_encoder.val_preprocess
155
+
156
+ def freeze_head(self):
157
+ for idx in range(len(self.classification_heads)):
158
+ self.classification_heads[idx].weight.requires_grad_(False)
159
+ self.classification_heads[idx].bias.requires_grad_(False)
160
+
161
+ def forward(self, inputs, head_idx):
162
+ features = self.image_encoder(inputs)
163
+ outputs = self.classification_heads[head_idx](features)
164
+ return outputs
165
+
166
+ def __call__(self, inputs, head_idx):
167
+ return self.forward(inputs, head_idx)
168
+
169
+ def save(self, filename):
170
+ print(f"Saving image classifier to {filename}")
171
+ utils.torch_save(self, filename)
172
+
173
+ @classmethod
174
+ def load(cls, filename):
175
+ print(f"Loading image classifier from {filename}")
176
+ return utils.torch_load(filename)
@@ -0,0 +1,311 @@
1
+ import copy
2
+ import os
3
+ import pickle
4
+ from collections import OrderedDict
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ def compute_l1_norm(
13
+ model1: nn.Module, model2: nn.Module
14
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
15
+ """
16
+ Computes the L1 norm between the parameters of two models.
17
+
18
+ Args:
19
+ model1 (nn.Module): The first model.
20
+ model2 (nn.Module): The second model.
21
+
22
+ Returns:
23
+ Tuple[torch.Tensor, Dict[str, float]]: A tuple containing the total L1 norm and a dictionary
24
+ with the L1 norm for each layer.
25
+
26
+ """
27
+ norms = dict()
28
+ l1_norm = 0.0
29
+ for (n, p1), p2 in zip(model1.named_parameters(), model2.parameters()):
30
+ layer_l1_norm = torch.norm(p1 - p2, 1)
31
+ l1_norm += layer_l1_norm
32
+ norms[n] = layer_l1_norm.item()
33
+
34
+ return l1_norm, norms
35
+
36
+
37
+ def assign_learning_rate(param_group, new_lr):
38
+ param_group["lr"] = new_lr
39
+
40
+
41
+ def _warmup_lr(base_lr, warmup_length, step):
42
+ return base_lr * (step + 1) / warmup_length
43
+
44
+
45
+ def cosine_lr(optimizer, base_lrs, warmup_length, steps):
46
+ if not isinstance(base_lrs, list):
47
+ base_lrs = [base_lrs for _ in optimizer.param_groups]
48
+ assert len(base_lrs) == len(optimizer.param_groups)
49
+
50
+ def _lr_adjuster(step):
51
+ for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
52
+ if step < warmup_length:
53
+ lr = _warmup_lr(base_lr, warmup_length, step)
54
+ else:
55
+ e = step - warmup_length
56
+ es = steps - warmup_length
57
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
58
+ assign_learning_rate(param_group, lr)
59
+
60
+ return _lr_adjuster
61
+
62
+
63
+ def accuracy(output: torch.Tensor, target: torch.Tensor, topk: List[int] = (1,)):
64
+ pred = output.topk(max(topk), 1, True, True)[1].t()
65
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
66
+ return [
67
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
68
+ for k in topk
69
+ ]
70
+
71
+
72
+ def torch_load_old(save_path: str, device=None):
73
+ with open(save_path, "rb") as f:
74
+ classifier = pickle.load(f)
75
+ if device is not None:
76
+ classifier = classifier.to(device)
77
+ return classifier
78
+
79
+
80
+ def torch_save(model, save_path, save_state_dict=True):
81
+ # TODO: hacky way to save state dict
82
+ if save_state_dict and isinstance(model, torch.nn.Module):
83
+ model = model.state_dict()
84
+ if os.path.dirname(save_path) != "":
85
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
86
+ torch.save(model, save_path)
87
+
88
+
89
+ def torch_load(save_path, device=None):
90
+ model = torch.load(save_path, map_location="cpu")
91
+ if device is not None:
92
+ model = model.to(device)
93
+ return model
94
+
95
+
96
+ def get_logits(inputs, classifier):
97
+ assert callable(classifier)
98
+ if hasattr(classifier, "to"):
99
+ classifier = classifier.to(inputs.device)
100
+ return classifier(inputs)
101
+
102
+
103
+ def get_probs(inputs, classifier):
104
+ if hasattr(classifier, "predict_proba"):
105
+ probs = classifier.predict_proba(inputs.detach().cpu().numpy())
106
+ return torch.from_numpy(probs)
107
+ logits = get_logits(inputs, classifier)
108
+ return logits.softmax(dim=1)
109
+
110
+
111
+ class LabelSmoothing(torch.nn.Module):
112
+ def __init__(self, smoothing=0.0):
113
+ super(LabelSmoothing, self).__init__()
114
+ self.confidence = 1.0 - smoothing
115
+ self.smoothing = smoothing
116
+
117
+ def forward(self, x, target):
118
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
119
+
120
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
121
+ nll_loss = nll_loss.squeeze(1)
122
+ smooth_loss = -logprobs.mean(dim=-1)
123
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
124
+ return loss.mean()
125
+
126
+
127
+ class DotDict(dict):
128
+ """dot.notation access to dictionary attributes"""
129
+
130
+ __getattr__ = dict.get
131
+ __setattr__ = dict.__setitem__
132
+ __delattr__ = dict.__delitem__
133
+
134
+
135
+ def find_optimal_coef(
136
+ results: Dict[str, Any],
137
+ metric: str = "avg_normalized_top1",
138
+ minimize: bool = False,
139
+ control_metric: Optional[str] = None,
140
+ control_metric_threshold: float = 0.0,
141
+ ) -> float:
142
+ """
143
+ Finds the optimal coefficient based on the given results and metric.
144
+
145
+ Args:
146
+ results (Dict[str, Any]): A dictionary containing the results for different scaling coefficients.
147
+ metric (str, optional): The metric to optimize. Defaults to "avg_normalized_top1".
148
+ minimize (bool, optional): Whether to minimize the metric. Defaults to False.
149
+ control_metric (str, optional): The control metric to check against. Defaults to None.
150
+ control_metric_threshold (float, optional): The threshold value for the control metric. Defaults to 0.0.
151
+
152
+ Returns:
153
+ The optimal coefficient based on the given results and metric.
154
+ """
155
+ best_coef = None
156
+ if minimize:
157
+ best_metric = 1
158
+ else:
159
+ best_metric = 0
160
+ for scaling_coef in results.keys():
161
+ if control_metric is not None:
162
+ if results[scaling_coef][control_metric] < control_metric_threshold:
163
+ print(f"Control metric fell below {control_metric_threshold} threshold")
164
+ continue
165
+ if minimize:
166
+ if results[scaling_coef][metric] < best_metric:
167
+ best_metric = results[scaling_coef][metric]
168
+ best_coef = scaling_coef
169
+ else:
170
+ if results[scaling_coef][metric] > best_metric:
171
+ best_metric = results[scaling_coef][metric]
172
+ best_coef = scaling_coef
173
+ return best_coef
174
+
175
+
176
+ def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes):
177
+ """Computes the normalized non-linear advantage of a finetuned model.
178
+
179
+ The nonlinear_advantage is defined as:
180
+ error_rate(linear_model) - error_rate(nonlinear_model) / (1 - 1 / num_classes)
181
+ and takes values between [-1, 1]. A value of 0 indicates that the nonlinear
182
+ model is no better than the linear one. Meanwhile, a value of 1 indicates
183
+ that the nonlinear model is perfect and the linear trivial, and a value of
184
+ -1 indicates the opposite.
185
+ """
186
+ return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes)
187
+
188
+
189
+ def to_cuda(input_dict):
190
+ cuda_dict = {}
191
+ for key, value in input_dict.items():
192
+ cuda_dict[key] = value.to("cuda")
193
+ return cuda_dict
194
+
195
+
196
+ def state_dict_to_vector(state_dict, remove_keys=[]):
197
+ shared_state_dict = copy.deepcopy(state_dict)
198
+ for key in remove_keys:
199
+ if key in shared_state_dict:
200
+ del shared_state_dict[key]
201
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
202
+ return torch.nn.utils.parameters_to_vector(
203
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
204
+ )
205
+
206
+
207
+ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
208
+ # create a reference dict to define the order of the vector
209
+ reference_dict = copy.deepcopy(state_dict)
210
+ for key in remove_keys:
211
+ if key in reference_dict:
212
+ del reference_dict[key]
213
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
214
+
215
+ # create a shared state dict using the reference dict
216
+ torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
217
+
218
+ # add back the encoder and decoder embedding weights.
219
+ if "transformer.shared.weight" in sorted_reference_dict:
220
+ for key in remove_keys:
221
+ sorted_reference_dict[key] = sorted_reference_dict[
222
+ "transformer.shared.weight"
223
+ ]
224
+ return sorted_reference_dict
225
+
226
+
227
+ def add_ptm_to_tv(tv_dict, ptm_dict):
228
+ assert set(tv_dict.keys()) == set(
229
+ ptm_dict.keys()
230
+ ), "Differing parameter names in models."
231
+ final_dict = copy.deepcopy(tv_dict)
232
+ for k, v in ptm_dict.items():
233
+ final_dict[k] = tv_dict[k] + v
234
+ return final_dict
235
+
236
+
237
+ def check_parameterNamesMatch(checkpoints):
238
+ parameter_names = set(checkpoints[0].keys())
239
+
240
+ if len(checkpoints) >= 2:
241
+ # raise ValueError("Number of models is less than 2.")
242
+ for checkpoint in checkpoints[1:]:
243
+ current_parameterNames = set(checkpoint.keys())
244
+ if current_parameterNames != parameter_names:
245
+ raise ValueError(
246
+ "Differing parameter names in models. "
247
+ f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
248
+ )
249
+
250
+
251
+ def check_state_dicts_equal(state_dict1, state_dict2):
252
+ if set(state_dict1.keys()) != set(state_dict2.keys()):
253
+ return False
254
+
255
+ for key in state_dict1.keys():
256
+ if not torch.equal(state_dict1[key], state_dict2[key]):
257
+ return False
258
+
259
+ return True
260
+
261
+
262
+ def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False):
263
+ if K == 100:
264
+ # print("Not applying mask")
265
+ if return_mask:
266
+ return M, torch.ones_like(M), None
267
+ else:
268
+ return M, torch.ones_like(M)
269
+
270
+ if K >= 1:
271
+ K /= 100
272
+
273
+ original_shape = M.shape
274
+ if M.dim() == 1:
275
+ M = M.unsqueeze(0)
276
+
277
+ n, d = M.shape
278
+ k = int(d * K)
279
+ k = d - k # Keep top k elements instead of bottom k elements
280
+
281
+ # Find the k-th smallest element by magnitude for each row
282
+ kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
283
+ # Create a mask tensor with True for the top k elements in each row
284
+ mask = M.abs() >= kth_values
285
+ final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
286
+
287
+ if reshape_mask:
288
+ final_mask = final_mask.reshape(M.shape)
289
+
290
+ if return_mask:
291
+ return M * final_mask, final_mask.float().mean(dim=1), final_mask
292
+ else:
293
+ return M * final_mask, final_mask.float().mean(dim=1)
294
+
295
+
296
+ def cleanup_linear(state_dict):
297
+ # The linear model also has keys for the reference point $\theta_0$ in the state dict with the prefix `params0`.
298
+ state_dict = {k: v for k, v in state_dict.items() if "params." in k}
299
+ return state_dict
300
+
301
+
302
+ def get_ptm_linear(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
303
+ # rename keys so that they match afterwards
304
+ state_dict_new = {
305
+ k.replace("params0", "params"): v
306
+ for k, v in state_dict.items()
307
+ if "params0." in k
308
+ }
309
+ state_dict_remaining = {k: v for k, v in state_dict.items() if "params." not in k}
310
+
311
+ return state_dict_new, state_dict_remaining
@@ -0,0 +1,56 @@
1
+ from pathlib import Path
2
+ from typing import Literal
3
+
4
+ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}"
5
+ MODELS = ["ViT-B-32", "ViT-B-16", "ViT-L-14"]
6
+ OPENCLIP_CACHEDIR = Path(Path.home(), "openclip-cachedir", "open_clip").as_posix()
7
+ CACHEDIR = None
8
+
9
+ ALL_DATASETS = [
10
+ "Cars",
11
+ "DTD",
12
+ "EuroSAT",
13
+ "GTSRB",
14
+ "MNIST",
15
+ "RESISC45",
16
+ "SVHN",
17
+ "SUN397",
18
+ "STL10",
19
+ "OxfordIIITPet",
20
+ "Flowers102",
21
+ "CIFAR100",
22
+ "PCAM",
23
+ "FER2013",
24
+ "CIFAR10",
25
+ "Food101",
26
+ "FashionMNIST",
27
+ "RenderedSST2",
28
+ "EMNIST",
29
+ "KMNIST",
30
+ ]
31
+
32
+ DATASETS_8 = ALL_DATASETS[:8]
33
+ DATASETS_14 = ALL_DATASETS[:14]
34
+ DATASETS_20 = ALL_DATASETS[:20]
35
+
36
+
37
+ def cleanup_dataset_name(dataset_name: str):
38
+ return dataset_name.replace("Val", "") + "Val"
39
+
40
+
41
+ def get_zeroshot_path(root, dataset, model):
42
+ return Path(
43
+ root, model, cleanup_dataset_name(dataset), f"nonlinear_zeroshot.pt"
44
+ ).as_posix()
45
+
46
+
47
+ def get_finetuned_path(root, dataset, model):
48
+ return Path(
49
+ root, model, cleanup_dataset_name(dataset), f"nonlinear_finetuned.pt"
50
+ ).as_posix()
51
+
52
+
53
+ def get_single_task_accuracies_path(model):
54
+ return Path(
55
+ "results/single_task", model, f"nonlinear_ft_accuracies.json"
56
+ ).as_posix()
@@ -1,4 +1,4 @@
1
- from typing import List, Mapping
1
+ from typing import List, Mapping, Optional, Tuple
2
2
 
3
3
  import torch
4
4
  from torch import nn
@@ -6,7 +6,13 @@ from torch import nn
6
6
  __all__ = "ParamterDictModel"
7
7
 
8
8
 
9
- def set_attr(obj, names: List[str], val, check_parent: bool = False):
9
+ def _set_attr(
10
+ obj,
11
+ names: List[str],
12
+ val,
13
+ check_parent: bool = False,
14
+ parent_builder=nn.Module,
15
+ ):
10
16
  """
11
17
  Sets an attribute of an object recursively.
12
18
 
@@ -20,8 +26,14 @@ def set_attr(obj, names: List[str], val, check_parent: bool = False):
20
26
  setattr(obj, names[0], val)
21
27
  else:
22
28
  if check_parent and not hasattr(obj, names[0]):
23
- setattr(obj, names[0], nn.Module())
24
- set_attr(getattr(obj, names[0]), names[1:], val, check_parent=check_parent)
29
+ setattr(obj, names[0], parent_builder())
30
+ _set_attr(
31
+ getattr(obj, names[0]),
32
+ names[1:],
33
+ val,
34
+ check_parent=check_parent,
35
+ parent_builder=parent_builder,
36
+ )
25
37
 
26
38
 
27
39
  def has_attr(obj, names: List[str]):
@@ -49,17 +61,19 @@ class ParameterDictModel(nn.Module):
49
61
 
50
62
  def __init__(
51
63
  self,
52
- parameters: Mapping[str, nn.Parameter],
64
+ parameters: Optional[Mapping[str, nn.Parameter]] = None,
53
65
  ):
54
66
  super().__init__()
55
- for name, param in parameters.items():
56
- assert isinstance(param, nn.Parameter), f"{name} is not a nn.Parameter"
57
- set_attr(
58
- self,
59
- name.split("."),
60
- param,
61
- check_parent=True,
62
- )
67
+ if parameters is not None:
68
+ for name, param in parameters.items():
69
+ assert isinstance(param, nn.Parameter), f"{name} is not a nn.Parameter"
70
+ _set_attr(
71
+ self,
72
+ name.split("."),
73
+ param,
74
+ check_parent=True,
75
+ parent_builder=self.__class__,
76
+ )
63
77
 
64
78
  def __repr__(self):
65
79
  """
@@ -73,3 +87,30 @@ class ParameterDictModel(nn.Module):
73
87
  param_repr = f"{name}: {param.size()}"
74
88
  param_reprs.append(param_repr)
75
89
  return f"{self.__class__.__name__}({', '.join(param_reprs)})"
90
+
91
+ def __getitem__(self, key: str):
92
+ if not has_attr(self, key.split(".")):
93
+ raise KeyError(f"Key {key} not found in {self}")
94
+ key = key.split(".")
95
+ obj = self
96
+ for k in key:
97
+ obj = getattr(obj, k)
98
+ return obj
99
+
100
+ def __setitem__(self, key: str, value: nn.Parameter):
101
+ if not has_attr(self, key.split(".")):
102
+ _set_attr(self, key.split("."), value, check_parent=True)
103
+ else:
104
+ _set_attr(self, key.split("."), value, check_parent=False)
105
+
106
+ def __contains__(self, key: str):
107
+ return has_attr(self, key.split("."))
108
+
109
+ def keys(self):
110
+ return [name for name, _ in self.named_parameters()]
111
+
112
+ def items(self) -> List[Tuple[str, nn.Parameter]]:
113
+ return [(name, self[name]) for name in self.keys()]
114
+
115
+ def values(self) -> List[nn.Parameter]:
116
+ return [self[name] for name in self.keys()]
@@ -1,5 +1,5 @@
1
1
  R"""
2
- This script is used to train a multi-task learning (MTL) model on the NYUv2 dataset.
2
+ This script is used to train a multi-task learning (MTL) model on the NYUv2 dataset.
3
3
  """
4
4
 
5
5
  import importlib
@@ -10,12 +10,13 @@ _import_structure = {
10
10
  "clip_vision": [
11
11
  "CLIPVisionModelTaskPool",
12
12
  "SparseWEMoECLIPVisionModelTaskPool",
13
- "RankoneWEMoECLIPVisionModelTaskPool",
13
+ "RankoneMoECLIPVisionModelTaskPool",
14
14
  ],
15
15
  "dummy": ["DummyTaskPool"],
16
16
  "gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
17
- "nyuv2_taskpool": ["NYUv2TaskPool"],
18
17
  "llama": ["LlamaTestGenerationTaskPool"],
18
+ "nyuv2_taskpool": ["NYUv2TaskPool"],
19
+ "openclip_vision": ["OpenCLIPVisionModelTaskPool"],
19
20
  }
20
21
 
21
22
 
@@ -23,13 +24,14 @@ if TYPE_CHECKING:
23
24
  from .base_pool import BaseTaskPool
24
25
  from .clip_vision import (
25
26
  CLIPVisionModelTaskPool,
26
- RankoneWEMoECLIPVisionModelTaskPool,
27
+ RankoneMoECLIPVisionModelTaskPool,
27
28
  SparseWEMoECLIPVisionModelTaskPool,
28
29
  )
29
30
  from .dummy import DummyTaskPool
30
31
  from .gpt2_text_classification import GPT2TextClassificationTaskPool
31
32
  from .llama import LlamaTestGenerationTaskPool
32
33
  from .nyuv2_taskpool import NYUv2TaskPool
34
+ from .openclip_vision import OpenCLIPVisionModelTaskPool
33
35
 
34
36
  else:
35
37
  sys.modules[__name__] = LazyImporter(
@@ -1,4 +1,5 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_rankone_moe_taskpool import RankoneMoECLIPVisionModelTaskPool
3
+ from .clip_smile_taskpool import SmileCLIPVisionModelTaskPool
3
4
  from .clip_sparse_wemoe_taskpool import SparseWEMoECLIPVisionModelTaskPool
4
5
  from .taskpool import CLIPVisionModelTaskPool