fusion-bench 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. fusion_bench/compat/method/base_algorithm.py +1 -1
  2. fusion_bench/dataset/clip_dataset.py +3 -0
  3. fusion_bench/dataset/fer2013.py +12 -0
  4. fusion_bench/dataset/llama/preference_700k.py +1 -1
  5. fusion_bench/method/__init__.py +2 -0
  6. fusion_bench/method/classification/clip_finetune.py +10 -13
  7. fusion_bench/method/surgery/__init__.py +1 -3
  8. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +1 -1
  9. fusion_bench/method/tall_mask/__init__.py +0 -0
  10. fusion_bench/method/tall_mask/utils.py +234 -0
  11. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  12. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  13. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  14. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  15. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  16. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  17. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  18. fusion_bench/mixins/clip_classification.py +6 -6
  19. fusion_bench/mixins/lightning_fabric.py +3 -1
  20. fusion_bench/modelpool/base_pool.py +0 -1
  21. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  22. fusion_bench/models/surgery/__init__.py +1 -0
  23. fusion_bench/models/surgery/surgerymodelwrapper.py +2 -1
  24. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  25. fusion_bench/models/wrappers/task_wise_fusion.py +1 -1
  26. fusion_bench/programs/fabric_fusion_program.py +7 -4
  27. fusion_bench/taskpool/llama/reward_model.py +1 -1
  28. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  29. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  30. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  31. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  32. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  33. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  34. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  35. fusion_bench/tasks/clip_classification/food101.py +105 -0
  36. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  37. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  38. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  39. fusion_bench/utils/parameters.py +12 -3
  40. fusion_bench/utils/type.py +10 -1
  41. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  42. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +195 -62
  43. fusion_bench_config/dataset/image_classification/README.md +6 -0
  44. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  45. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  47. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  48. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  49. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  50. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  51. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  52. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  53. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  54. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  55. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  56. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  57. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  58. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  59. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  60. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  61. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  62. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  63. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  64. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  65. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  66. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  67. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  68. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  69. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  70. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  71. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  72. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  73. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  74. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  75. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  76. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  77. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  78. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  79. fusion_bench_config/model/clip-vit/README.md +38 -0
  80. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  81. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  82. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  83. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  84. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  85. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  86. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  87. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  88. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  89. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  90. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  91. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  92. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  93. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  94. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  95. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  96. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  97. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  98. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  99. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  100. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  101. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  102. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  103. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  104. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  105. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  106. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  107. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  108. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  109. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  110. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  111. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  112. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  113. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  114. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  115. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  116. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  117. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  118. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  119. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  120. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  121. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  122. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  123. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  124. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  125. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  126. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  127. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  128. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  129. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  130. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  131. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  132. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  133. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  134. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  135. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  150. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  168. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  169. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  170. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  171. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  172. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  173. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  174. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  175. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  176. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  177. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  178. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  183. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  185. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  190. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  191. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  192. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  193. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  194. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  195. {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Optional, TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  from omegaconf import DictConfig
5
5
 
@@ -65,4 +65,7 @@ class CLIPDataset(torch.utils.data.Dataset):
65
65
  else:
66
66
  # if processor is None, return the raw image directly
67
67
  inputs = image
68
+ # convert boolean label to int, this is for the case when the label is a binary classification task
69
+ if isinstance(item["label"], bool):
70
+ item["label"] = 1 if item["label"] else 0
68
71
  return inputs, item["label"]
@@ -0,0 +1,12 @@
1
+ from datasets import load_dataset
2
+
3
+
4
+ def load_fer2013(path: str = "clip-benchmark/wds_fer2013", split: str = "train"):
5
+ dataset = load_dataset(path, split=split)
6
+ dataset = dataset.remove_columns(["__key__", "__url__"])
7
+ dataset = dataset.rename_columns({"jpg": "image", "cls": "label"})
8
+ return dataset
9
+
10
+ if __name__ == "__main__":
11
+ dataset = load_fer2013(split="test")
12
+ print(dataset)
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import os
2
3
  from copy import deepcopy
3
4
  from typing import TYPE_CHECKING, Optional
@@ -7,7 +8,6 @@ from lightning.fabric.utilities import rank_zero_only
7
8
  from tqdm.auto import tqdm
8
9
 
9
10
  from fusion_bench.utils import timeit_context
10
- import logging
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from transformers import PreTrainedTokenizer
@@ -49,6 +49,7 @@ _import_structure = {
49
49
  "PWEMoExactParetoOptimalForCLIP",
50
50
  ],
51
51
  "ada_svd": ["AdaSVDMergingForCLIPVisionModel"],
52
+ "task_singular_vector": ["TaskSingularVectorMerging"],
52
53
  # plug-and-play model merging methods
53
54
  "concrete_subspace": [
54
55
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -153,6 +154,7 @@ if TYPE_CHECKING:
153
154
  SparseLoForLlama,
154
155
  )
155
156
  from .task_arithmetic import TaskArithmeticAlgorithm
157
+ from .task_singular_vector import TaskSingularVectorMerging
156
158
  from .ties_merging import TiesMergingAlgorithm
157
159
  from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
158
160
  from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
@@ -41,11 +41,10 @@ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
41
41
  from fusion_bench import print_parameters
42
42
  from fusion_bench.compat.method import ModelFusionAlgorithm
43
43
  from fusion_bench.compat.modelpool import to_modelpool
44
- from fusion_bench.compat.modelpool.huggingface_clip_vision import (
45
- HuggingFaceClipVisionPool,
46
- )
44
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
47
45
  from fusion_bench.mixins import CLIPClassificationMixin
48
46
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
47
+ from fusion_bench.modelpool import CLIPVisionModelPool
49
48
  from fusion_bench.models.hf_clip import HFCLIPClassifier
50
49
  from fusion_bench.models.linearized.linearized_model_utils import LinearizedModelWraper
51
50
  from fusion_bench.utils.data import InfiniteDataLoader
@@ -92,12 +91,12 @@ class ImageClassificationFineTuningForCLIP(
92
91
  A class for fine-tuning CLIP models for image classification tasks.
93
92
  """
94
93
 
95
- def run(self, modelpool: HuggingFaceClipVisionPool):
94
+ def run(self, modelpool: CLIPVisionModelPool):
96
95
  """
97
96
  Executes the fine-tuning process.
98
97
 
99
98
  Args:
100
- modelpool (HuggingFaceClipVisionPool): The modelpool is responsible for loading the pre-trained model and training datasets.
99
+ modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.
101
100
 
102
101
  Returns:
103
102
  VisionModel: The fine-tuned vision model.
@@ -109,9 +108,7 @@ class ImageClassificationFineTuningForCLIP(
109
108
 
110
109
  L.seed_everything(config.seed)
111
110
 
112
- task_names = [
113
- dataset_config["name"] for dataset_config in modelpool.config.train_datasets
114
- ]
111
+ task_names = modelpool.train_dataset_names
115
112
  with self.profile("setup model and optimizer"):
116
113
  processor, classifier, optimizer, lr_scheduler = self.setup_model()
117
114
 
@@ -133,7 +130,7 @@ class ImageClassificationFineTuningForCLIP(
133
130
 
134
131
  with self.profile("setup data"):
135
132
  train_datasets = [
136
- modelpool.get_train_dataset(task_name, processor)
133
+ CLIPDataset(modelpool.load_train_dataset(task_name), processor)
137
134
  for task_name in task_names
138
135
  ]
139
136
  train_dataloaders = [
@@ -157,6 +154,7 @@ class ImageClassificationFineTuningForCLIP(
157
154
  range(config.num_steps),
158
155
  desc=self.finetune_method,
159
156
  disable=not self.fabric.is_global_zero,
157
+ dynamic_ncols=True,
160
158
  ):
161
159
  optimizer.zero_grad()
162
160
  loss = 0
@@ -183,7 +181,7 @@ class ImageClassificationFineTuningForCLIP(
183
181
  save_path = os.path.join(
184
182
  self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
185
183
  )
186
- self.save_model(classifier, save_path, trainable_only=True)
184
+ self.save_model(classifier, save_path)
187
185
 
188
186
  if config.state_dict_save_path is not None:
189
187
  self.save_model(
@@ -232,9 +230,8 @@ class ImageClassificationFineTuningForCLIP(
232
230
  config = self.config
233
231
  modelpool = self.modelpool
234
232
 
235
- pretrained_model_config = modelpool.get_model_config("_pretrained_")
236
- clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_model_config.path)
237
- processor = CLIPProcessor.from_pretrained(pretrained_model_config.path)
233
+ clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
234
+ processor = modelpool.load_processor()
238
235
 
239
236
  self.finetune_method = "full fine-tune"
240
237
  if config.use_lora or config.use_l_lora:
@@ -1,3 +1 @@
1
- from .clip_layer_wise_adamerging_surgery import (
2
- CLIPLayerWiseAdaMergingSurgeryAlgorithm,
3
- )
1
+ from .clip_layer_wise_adamerging_surgery import CLIPLayerWiseAdaMergingSurgeryAlgorithm
@@ -154,4 +154,4 @@ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
154
154
  self._program.evaluate_merged_model(self._program.taskpool, alpha_model)
155
155
 
156
156
  log.info("test the result of Adamerging")
157
- return merged_model
157
+ return {"adamerging": merged_model, "surgery": alpha_model}
File without changes
@@ -0,0 +1,234 @@
1
+ import copy
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from fusion_bench.utils import state_dict_to_vector, vector_to_state_dict
9
+
10
+
11
+ def generate_task_masks(
12
+ tv_flat_checks: torch.Tensor,
13
+ flat_ft: torch.Tensor,
14
+ flat_ptm: torch.Tensor,
15
+ tv: Optional[torch.Tensor] = None,
16
+ tall_mask_lambda: float = 1.0,
17
+ ) -> torch.Tensor:
18
+ """
19
+ Generate task-specific TALL masks
20
+ TALL masks are generated as: mask_t = |theta_0 - theta_t| > |theta_mt - theta_t| * lambda
21
+
22
+ Args:
23
+ tv_flat_checks: individual task vectors
24
+ flat_ft: individual theta_t (fine-tuned weights)
25
+ flat_ptm: theta_0 (pre-trained weight)
26
+ tv: multi-task vector
27
+ tall_mask_lambda: hyper-parameter lambda for generating TALL masks
28
+ Returns:
29
+ final_mask: generated TALL masks with the given lambda, in shape (n_task, n_parameter)
30
+ """
31
+
32
+ print(f"Generating TALL masks.")
33
+
34
+ if tv is None:
35
+ tv = tv_flat_checks.sum(0)
36
+
37
+ flat_multi = flat_ptm + tv
38
+
39
+ original_shape = flat_ft.shape
40
+
41
+ # generate masks by comparing the l1 distance between |theta_0 - theta_t| and |theta_mt - theta_t|
42
+ diff_pt_ft = (flat_ptm - flat_ft).abs()
43
+ diff_multi_ft = (flat_multi - flat_ft).abs()
44
+ # compare the l1 distance, scaled with hyper-parameter lambda
45
+ mask = diff_pt_ft > diff_multi_ft * tall_mask_lambda
46
+
47
+ final_mask = (
48
+ mask.squeeze() if original_shape == tv_flat_checks.squeeze().shape else mask
49
+ )
50
+
51
+ print(
52
+ f"Average sparsity for the mask with tall_mask_lambda of {tall_mask_lambda}: {final_mask.float().mean():.4f}"
53
+ )
54
+ return final_mask
55
+
56
+
57
+ def construct_tall_mask(
58
+ tv_flat_checks: torch.Tensor,
59
+ flat_ft: torch.Tensor,
60
+ flat_ptm: torch.Tensor,
61
+ merged_tv: torch.Tensor,
62
+ ptm_check: torch.Tensor,
63
+ remove_keys: List[str],
64
+ config,
65
+ ):
66
+ """
67
+ Construct TALL masks for all tasks for each lambda, and store in dictionary
68
+
69
+ Args:
70
+ tv_flat_checks: individual task vectors
71
+ flat_ft: individual theta_t (fine-tuned weights)
72
+ flat_ptm: theta_0 (pre-trained weight)
73
+ merged_tv: multi-task vector
74
+ ptm_check: pre-trained weight as state dictionary
75
+ remove_keys: the keys to be removed when converting between dictionary and vector
76
+ Returns:
77
+ tall_masks: constructed TALL masks in dictionary format of {lambda: {task: mask}}
78
+ """
79
+ tall_masks = {}
80
+ for tall_mask_lambda in [0.2, 0.3, 0.4, 0.5, 0.6]:
81
+ # generate tall masks for each lambda
82
+ masks_at_scale = generate_task_masks(
83
+ tv_flat_checks,
84
+ flat_ft,
85
+ flat_ptm,
86
+ tall_mask_lambda=tall_mask_lambda,
87
+ tv=merged_tv,
88
+ )
89
+ # convert vectors to dictionary
90
+ masks_at_scale = [
91
+ vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys)
92
+ for mask in masks_at_scale
93
+ ]
94
+ # store the masks with {dataset: mask}
95
+ tall_masks[tall_mask_lambda] = {
96
+ key: value for key, value in zip(config.DATASETS, masks_at_scale)
97
+ }
98
+ return tall_masks
99
+
100
+
101
+ def find_optimal_mask(val_metrics, eval_masks, args, save_masks=True):
102
+ """
103
+ Respectively finds the optimal mask for each data task based on the validation accuracy
104
+
105
+ Args:
106
+ val_metrics: validation metrics for each lambda
107
+ eval_masks: all generated masks
108
+
109
+ Returns:
110
+ best_masks_for_test: the best masks for each task, selected based on validation accuracy from each task
111
+ best_val_metrics: best validation metrics for each task
112
+ """
113
+ # transpose the dict from lambda-task to task-lambda
114
+ transposed_dict = {}
115
+ for key, inner_dict in val_metrics.items():
116
+ for inner_key, value in inner_dict.items():
117
+ if inner_key not in transposed_dict:
118
+ transposed_dict[inner_key] = {}
119
+ transposed_dict[inner_key][key] = value
120
+
121
+ # for each task, find the best lambda
122
+ max_subkeys = {
123
+ key: max(inner_dict, key=inner_dict.get)
124
+ for key, inner_dict in transposed_dict.items()
125
+ }
126
+
127
+ # select the best mask for each task, which will be used for testing later
128
+ best_masks_for_test = {}
129
+ best_masks_for_test_vector = {} # the selected masks as vectors
130
+ best_val_metrics = {}
131
+ # respectively for each task:
132
+ for ds in args.DATASETS:
133
+ # select the lambda which achieves the best valdiation accuracy
134
+ best_lambda = float(max_subkeys[ds + "Val:top1"])
135
+ # select the mask based on the selected lambda, save as dictionaries
136
+ best_masks_for_test[ds] = eval_masks[best_lambda][ds]
137
+ # select the mask based on the selected lambda, save as vectors
138
+ best_masks_for_test_vector[ds] = state_dict_to_vector(
139
+ eval_masks[best_lambda][ds], remove_keys=[]
140
+ )
141
+ print(f"Best lambda for {ds} is {best_lambda}")
142
+ # save the best validation metric based on the selected lambda
143
+ best_val_metrics[ds + "Val:top1"] = val_metrics[best_lambda][ds + "Val:top1"]
144
+
145
+ # save the best masks in disk
146
+ if save_masks and not args.method.load_mask:
147
+ # convert to numpy to save with np.packbits for saving storage
148
+ best_masks_for_test_vector = {
149
+ k: np.packbits(v) for k, v in best_masks_for_test_vector.items()
150
+ }
151
+ mask_save_dir = args.model_location.replace("checkpoints", "tall_masks")
152
+ mask_name = (
153
+ f"TALL_mask_{args.num_tasks}task.npy"
154
+ if not args.method.use_ties
155
+ else f"TALL_mask_{args.num_tasks}task_use_ties_{args.method.ties_agg}.npy"
156
+ )
157
+ np.save(
158
+ os.path.join(mask_save_dir, args.model, mask_name),
159
+ best_masks_for_test_vector,
160
+ )
161
+ del best_masks_for_test_vector
162
+
163
+ return best_masks_for_test, best_val_metrics
164
+
165
+
166
+ def load_tall_mask(remove_keys, ptm_check, config):
167
+ """Loads TALL masks from disk, unpack and transform to state dictionaries."""
168
+ mask_location = config.model_location.replace("checkpoints", "tall_masks")
169
+ try:
170
+ if config.method.use_ties:
171
+ print("==== Loading TALL Masks built with TIES ====")
172
+ tall_masks = torch.load(
173
+ os.path.join(
174
+ mask_location,
175
+ config.model,
176
+ f"TALL_mask_{config.num_tasks}task_use_ties.npy",
177
+ )
178
+ )
179
+ else:
180
+ print("==== Loading TALL Masks built with Task Arithmetic ====")
181
+ tall_masks = torch.load(
182
+ os.path.join(
183
+ mask_location, config.model, f"TALL_mask_{config.num_tasks}task.npy"
184
+ )
185
+ )
186
+ except:
187
+ raise Exception("TALL Masks are not constructed yet.")
188
+
189
+ # unpack masks and convert back to torch tensors
190
+ tall_masks = {k: torch.from_numpy(np.unpackbits(v)) for k, v in tall_masks.items()}
191
+
192
+ # convert vectors to dictionaries
193
+ tall_masks = {
194
+ dataset: vector_to_state_dict(mask, ptm_check, remove_keys=remove_keys)
195
+ for dataset, mask in tall_masks.items()
196
+ }
197
+
198
+ return tall_masks
199
+
200
+
201
+ def construct_consensus_mask(ptm_check, prun_thre_k, config, remove_keys=[]):
202
+ """
203
+ Generate consensus mask by filtering out least-used parameters
204
+
205
+ Args:
206
+ ptm_check: pretrained_checkpoint as state dictionary
207
+ prun_thre_k: weight-pruning threhold, stands for the least number of activated tasks for a parameter to be preserved from pruning
208
+ if prun_thre_k is set to 2: remove both catastrophic and selfish weights;
209
+ if prun_thre_k is set to 1: remove only catastrophic weights;
210
+ if prun_thre_k is set to 0: remove no weights -> reduce to TA or TIES
211
+ if prun_thre_k is set to > num_tasks: remove all weights -> reduce to zero-shot
212
+ Returns:
213
+ consensus_mask_vector: constructed consensus mask as vector (boolean in shape (n_parameter, ))
214
+ """
215
+
216
+ print("==== Generating Consensus Mask ====")
217
+ # load TALL masks (in shape (n_task, n_parameter))
218
+ tall_masks = load_tall_mask(remove_keys, ptm_check, config)
219
+ tall_masks = list(tall_masks.values())
220
+
221
+ # generate consensus masks
222
+ consensus_mask = copy.deepcopy(tall_masks[0])
223
+ for key, value in consensus_mask.items():
224
+ consensus_mask[key] = torch.zeros_like(value)
225
+ # count for each parameter, the tasks it has been activated for
226
+ for mask in tall_masks:
227
+ consensus_mask[key] = consensus_mask[key] + mask[key].float()
228
+ # filter out the least-activated parameters based on given threshold
229
+ consensus_mask[key] = consensus_mask[key].float() >= prun_thre_k
230
+ consensus_mask_vector = state_dict_to_vector(
231
+ consensus_mask, remove_keys=remove_keys
232
+ )
233
+
234
+ return consensus_mask_vector
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+ from fusion_bench import BaseAlgorithm
5
+
6
+ from .utils import TSVC_utils, check_parameterNamesMatch
7
+
8
+
9
+ class TaskSingularVectorCompression(BaseAlgorithm):
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+
13
+ def run(self, modelpool):
14
+ raise NotImplementedError(
15
+ "Task Singular Vector Compression is not implemented yet."
16
+ )
@@ -0,0 +1,63 @@
1
+ """
2
+ Example:
3
+
4
+ ```bash
5
+ fusion_bench \
6
+ method=task_singular_vector/TaskSingularVectorMerging \
7
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only \
8
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TALL20
9
+ ```
10
+ """
11
+
12
+ from typing import List, Optional
13
+
14
+ import torch
15
+ from torch import Tensor, nn
16
+
17
+ from fusion_bench import BaseAlgorithm
18
+ from fusion_bench.mixins import LightningFabricMixin
19
+ from fusion_bench.utils import timeit_context
20
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub, state_dict_add
21
+ from fusion_bench.utils.type import StateDictType
22
+
23
+ from .utils import (
24
+ TSVM_utils,
25
+ check_parameterNamesMatch,
26
+ check_state_dicts_equal,
27
+ state_dict_to_vector,
28
+ vector_to_state_dict,
29
+ )
30
+
31
+
32
+ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
33
+
34
+ def __init__(
35
+ self,
36
+ remove_keys: Optional[List[str]] = None,
37
+ **kwargs,
38
+ ):
39
+ self.remove_keys = remove_keys if remove_keys is not None else []
40
+ super().__init__(**kwargs)
41
+
42
+ def run(self, modelpool):
43
+ # Load the pre-trained model and the fine-tuned models
44
+ pretrained_model = modelpool.load_pretrained_model()
45
+ finetuned_models = list(modelpool.models())
46
+
47
+ ptm_check = pretrained_model.state_dict()
48
+ ft_checks = [model.state_dict() for model in finetuned_models]
49
+ check_parameterNamesMatch(ft_checks + [ptm_check])
50
+
51
+ with timeit_context("Flattening out Checkpoints"):
52
+ task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
53
+
54
+ new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
55
+ task_vectors,
56
+ exclude_keys=self.remove_keys,
57
+ accelerator=self.fabric.device,
58
+ )
59
+
60
+ pretrained_model.load_state_dict(
61
+ state_dict_add(new_merged_tv, pretrained_model.state_dict())
62
+ )
63
+ return pretrained_model
@@ -0,0 +1,9 @@
1
+ """
2
+ This module is modified from the original code of the paper:
3
+
4
+ - Gargiulo, et.al. Task Singular Vectors: Reducing Task Interference in Model Merging
5
+ - http://arxiv.org/abs/2412.00081
6
+ - https://github.com/AntoAndGar/task_singular_vectors/
7
+ """
8
+
9
+ from .TSVM import TaskSingularVectorMerging
@@ -0,0 +1,50 @@
1
+ import torch
2
+
3
+
4
+ def compute_svd_and_compress(key, matrix, sv_reduction):
5
+ """
6
+ Computes the Singular Value Decomposition (SVD) of a given matrix and compresses it by reducing the number of singular values.
7
+
8
+ Args:
9
+ key (Any): An identifier for the matrix.
10
+ matrix (torch.Tensor): The input matrix to decompose.
11
+ sv_reduction (float): The fraction of singular values to retain (0 < sv_reduction <= 1).
12
+
13
+ Returns:
14
+ tuple: A tuple containing:
15
+ - key (Any): The original identifier for the matrix.
16
+ - u (torch.Tensor): The left singular vectors of the reduced SVD.
17
+ - s (torch.Tensor): The reduced singular values.
18
+ - v (torch.Tensor): The right singular vectors of the reduced SVD.
19
+ """
20
+ u, s, v = torch.linalg.svd(matrix, full_matrices=False)
21
+ reduced_index_s = int(s.shape[0] * sv_reduction)
22
+ return key, u[:, :reduced_index_s], s[:reduced_index_s], v[:reduced_index_s, :]
23
+
24
+
25
+ def compress_tv(task_vectors, sv_reduction):
26
+ """
27
+ Compress task vectors using Singular Value Decomposition (SVD).
28
+
29
+ Args:
30
+ task_vectors (dict): A dictionary where keys are dataset names and values are task vectors.
31
+ Each task vector is expected to have a 'vector' attribute which is a dictionary
32
+ with keys as layer names and values as layer matrices.
33
+ sv_reduction (int): The fraction of singular values to keep for compression.
34
+
35
+ Returns:
36
+ dict: A dictionary with the same structure as `task_vectors`, but with each layer matrix
37
+ replaced by its compressed SVD components (u, s, v) if the layer is 2-dimensional.
38
+ If the layer is not 2-dimensional, it is stored as is under the key "dim1".
39
+ """
40
+ with torch.no_grad():
41
+ svd_dict = {}
42
+ for dataset, task_vector in task_vectors.items():
43
+ svd_dict[dataset] = {}
44
+ for key, layer in task_vector.vector.items():
45
+ if len(layer.shape) == 2: # and "text_projection" not in key:
46
+ _, u, s, v = compute_svd_and_compress(key, layer, sv_reduction)
47
+ svd_dict[dataset][key] = {"u": u, "s": s, "v": v}
48
+ else:
49
+ svd_dict[dataset][key] = {"dim1": layer}
50
+ return svd_dict