fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (209) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/ada_svd/clip_vision.py +4 -1
  9. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  10. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  11. fusion_bench/method/gossip/__init__.py +3 -0
  12. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  13. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  14. fusion_bench/method/gossip/entropy_loss.py +25 -0
  15. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  16. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  17. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  18. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  19. fusion_bench/method/gossip/utils.py +74 -0
  20. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  21. fusion_bench/method/opcm/opcm.py +16 -7
  22. fusion_bench/method/pwe_moe/module.py +1 -1
  23. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  24. fusion_bench/method/regmean/regmean.py +25 -17
  25. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  26. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
  27. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
  28. fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  30. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  31. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  32. fusion_bench/method/we_moe/we_moe.py +14 -15
  33. fusion_bench/mixins/__init__.py +6 -3
  34. fusion_bench/mixins/hydra_config.py +49 -0
  35. fusion_bench/mixins/openclip_classification.py +11 -0
  36. fusion_bench/mixins/simple_profiler.py +4 -2
  37. fusion_bench/modelpool/__init__.py +3 -1
  38. fusion_bench/modelpool/base_pool.py +2 -2
  39. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  40. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  41. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
  42. fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
  43. fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
  44. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
  45. fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
  46. fusion_bench/models/open_clip/__init__.py +6 -0
  47. fusion_bench/models/open_clip/modeling.py +176 -0
  48. fusion_bench/models/open_clip/utils.py +311 -0
  49. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  50. fusion_bench/models/parameter_dict.py +54 -13
  51. fusion_bench/models/rankone_moe.py +2 -88
  52. fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
  53. fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
  54. fusion_bench/models/smile_moe/utils/__init__.py +24 -0
  55. fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
  56. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  57. fusion_bench/taskpool/__init__.py +7 -3
  58. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  59. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  60. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  61. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  62. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  63. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  64. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  65. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  66. fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
  67. fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
  68. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  69. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  70. fusion_bench/utils/data.py +12 -0
  71. fusion_bench/utils/devices.py +14 -0
  72. fusion_bench/utils/instantiate.py +12 -0
  73. fusion_bench/utils/misc.py +9 -2
  74. fusion_bench/utils/packages.py +14 -0
  75. fusion_bench/utils/parameters.py +1 -1
  76. fusion_bench/utils/tensorboard.py +1 -1
  77. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
  78. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
  79. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
  80. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  81. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  82. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  83. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  84. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  85. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  86. fusion_bench_config/fabric/auto.yaml +0 -1
  87. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  88. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  89. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  90. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  91. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  92. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  93. fusion_bench_config/llama_full_finetune.yaml +0 -2
  94. fusion_bench_config/llama_model_fusion.yaml +0 -2
  95. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  96. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  97. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  98. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  99. fusion_bench_config/method/adamerging.yaml +2 -2
  100. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  101. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  102. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  103. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  104. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  105. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  106. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  107. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  108. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  109. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  110. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  111. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  112. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  113. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  114. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  115. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  116. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  117. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  118. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  119. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  120. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  121. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  122. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  123. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  124. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  125. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  126. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  127. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  128. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  129. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  130. fusion_bench_config/method/model_recombination.yaml +0 -1
  131. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  132. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  133. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  134. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  135. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  136. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  137. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  138. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  139. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  140. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  141. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
  142. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
  143. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  144. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  145. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  146. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  147. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  148. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  149. fusion_bench_config/method/ties_merging.yaml +1 -1
  150. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  151. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  153. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  154. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  155. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  156. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  157. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  158. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  159. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  160. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  161. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  162. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  171. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  172. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  173. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  174. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  175. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  178. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  179. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  180. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  181. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  182. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  183. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  184. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  185. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  186. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  187. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  188. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  189. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  190. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  191. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  192. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  193. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  194. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  195. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  196. fusion_bench_config/nyuv2_config.yaml +0 -2
  197. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  198. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  199. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  200. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  201. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
  202. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  203. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  204. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  205. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  206. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  207. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
  208. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
  209. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import hydra.core.global_hydra
8
+ from hydra import compose, initialize
9
+ from omegaconf import DictConfig, OmegaConf
10
+
11
+ from fusion_bench.utils import import_object, instantiate
12
+ from fusion_bench.utils.instantiate import set_print_function_call
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class HydraConfigMixin:
18
+ """
19
+ A mixin for classes that need to be instantiated from a config file.
20
+ """
21
+
22
+ @classmethod
23
+ def from_config(
24
+ cls,
25
+ config_name: Union[str, Path],
26
+ overrides: Optional[List[str]] = None,
27
+ ):
28
+ if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
29
+ raise RuntimeError("Hydra is not initialized.")
30
+ else:
31
+ cfg = compose(config_name=config_name, overrides=overrides)
32
+
33
+ config_groups = config_name.split("/")[:-1]
34
+ for config_group in config_groups:
35
+ cfg = cfg[config_group]
36
+
37
+ if "_target_" in cfg:
38
+ # if the config has a _target_ key, check if it is equal to the class name
39
+ target_cls = import_object(cfg["_target_"])
40
+ if target_cls != cls:
41
+ log.warning(
42
+ f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
43
+ )
44
+ with set_print_function_call(False):
45
+ obj = instantiate(cfg)
46
+ else:
47
+ obj = cls(**cfg)
48
+
49
+ return obj
@@ -0,0 +1,11 @@
1
+ import logging
2
+
3
+ from fusion_bench.mixins import LightningFabricMixin
4
+ from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
5
+
6
+ log = logging.getLogger(__name__)
7
+
8
+
9
+ class OpenCLIPClassificationMixin(LightningFabricMixin):
10
+ _train_processor = None
11
+ _test_processor = None
@@ -1,5 +1,5 @@
1
1
  from contextlib import contextmanager
2
- from typing import Generator
2
+ from typing import Generator, Optional
3
3
 
4
4
  from lightning.fabric.utilities.rank_zero import rank_zero_only
5
5
  from lightning.pytorch.profilers import SimpleProfiler
@@ -70,7 +70,9 @@ class SimpleProfilerMixin:
70
70
  self.profiler.stop(action_name)
71
71
 
72
72
  @rank_zero_only
73
- def print_profile_summary(self):
73
+ def print_profile_summary(self, title: Optional[str] = None):
74
+ if title is not None:
75
+ print(title)
74
76
  print(self.profiler.summary())
75
77
 
76
78
  def __del__(self):
@@ -6,12 +6,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
6
6
 
7
7
  _import_structure = {
8
8
  "base_pool": ["BaseModelPool"],
9
+ "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
9
10
  "clip_vision": ["CLIPVisionModelPool"],
10
11
  "nyuv2_modelpool": ["NYUv2ModelPool"],
11
12
  "huggingface_automodel": ["AutoModelPool"],
12
- "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
13
13
  "seq2seq_lm": ["Seq2SeqLMPool"],
14
14
  "PeftModelForSeq2SeqLM": ["PeftModelForSeq2SeqLMPool"],
15
+ "openclip_vision": ["OpenCLIPVisionModelPool"],
15
16
  "huggingface_gpt2_classification": [
16
17
  "HuggingFaceGPT2ClassificationPool",
17
18
  "GPT2ForSequenceClassificationPool",
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
30
31
  HuggingFaceGPT2ClassificationPool,
31
32
  )
32
33
  from .nyuv2_modelpool import NYUv2ModelPool
34
+ from .openclip_vision import OpenCLIPVisionModelPool
33
35
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
34
36
  from .seq2seq_lm import Seq2SeqLMPool
35
37
  from .seq_classification_lm import SeqenceClassificationModelPool
@@ -7,7 +7,7 @@ from omegaconf import DictConfig
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
- from fusion_bench.mixins import BaseYAMLSerializableModel
10
+ from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
11
11
  from fusion_bench.utils import instantiate, timeit_context
12
12
 
13
13
  __all__ = ["BaseModelPool"]
@@ -15,7 +15,7 @@ __all__ = ["BaseModelPool"]
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class BaseModelPool(BaseYAMLSerializableModel):
18
+ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
19
19
  """
20
20
  A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
21
21
 
@@ -0,0 +1 @@
1
+ from .modelpool import OpenCLIPVisionModelPool
@@ -0,0 +1,255 @@
1
+ import logging
2
+ import pickle
3
+ import sys
4
+ from typing import Callable, Optional, Union, cast
5
+
6
+ import torch
7
+ from datasets import load_dataset
8
+ from omegaconf import DictConfig, OmegaConf
9
+ from torch import nn
10
+
11
+ from fusion_bench.modelpool import BaseModelPool
12
+ from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
13
+ from fusion_bench.utils import instantiate
14
+ from fusion_bench.utils.expr import is_expr_match
15
+ from fusion_bench.utils.packages import _get_package_version, compare_versions
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+ # Add flag to track if warning has been shown
20
+ _openclip_version_warning_shown = False
21
+
22
+
23
+ def _check_and_redirect_open_clip_modeling():
24
+ global _openclip_version_warning_shown
25
+ if compare_versions(_get_package_version("open-clip-torch").__str__(), "2.0.2") > 0:
26
+ if not _openclip_version_warning_shown:
27
+ log.warning(
28
+ "OpenCLIP version is greater than 2.0.2. This may cause issues with the modelpool."
29
+ )
30
+ _openclip_version_warning_shown = True
31
+ import open_clip.model
32
+ import open_clip.transformer
33
+
34
+ if not hasattr(open_clip.model, "VisualTransformer"):
35
+ open_clip.model.VisualTransformer = open_clip.model.VisionTransformer
36
+ if not hasattr(open_clip.model, "Transformer"):
37
+ open_clip.model.Transformer = open_clip.transformer.Transformer
38
+ if not hasattr(open_clip.model, "ResidualAttentionBlock"):
39
+ open_clip.model.ResidualAttentionBlock = (
40
+ open_clip.transformer.ResidualAttentionBlock
41
+ )
42
+
43
+ try:
44
+ import src
45
+ import src.modeling
46
+ except ImportError:
47
+ if "src" not in sys.modules:
48
+ # redirect the import of `src` to `fusion_bench.models.open_clip`
49
+ import fusion_bench.models.open_clip as open_clip
50
+
51
+ sys.modules["src"] = open_clip
52
+ log.warning(
53
+ "`src` is not imported."
54
+ "Redirecting the import to `fusion_bench.models.open_clip`"
55
+ )
56
+ if "src.modeling" not in sys.modules:
57
+ # redirect the import of `src.modeling` to `fusion_bench.models.open_clip.modeling`
58
+ import fusion_bench.models.open_clip.modeling as open_clip_modeling
59
+
60
+ sys.modules["src.modeling"] = open_clip_modeling
61
+ log.warning(
62
+ "`src.modeling` is not imported."
63
+ "Redirecting the import to `fusion_bench.models.open_clip.modeling`"
64
+ )
65
+
66
+
67
+ def load_classifier_head(model_config: Union[str, DictConfig], *args, **kwargs):
68
+ if isinstance(model_config, str):
69
+ _check_and_redirect_open_clip_modeling()
70
+ log.info(f"Loading `ClassificationHead` from {model_config}")
71
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
72
+ head = torch.load(model_config, weights_only=weights_only, *args, **kwargs)
73
+ elif isinstance(model_config, nn.Module):
74
+ log.info(f"Returning existing model: {model_config}")
75
+ head = model_config
76
+ else:
77
+ head = instantiate(model_config, *args, **kwargs)
78
+ head = cast(ClassificationHead, head)
79
+ return head
80
+
81
+
82
+ class OpenCLIPVisionModelPool(BaseModelPool):
83
+ """
84
+ A model pool for managing OpenCLIP Vision models (models from task vector paper).
85
+ """
86
+
87
+ _train_processor = None
88
+ _test_processor = None
89
+
90
+ def __init__(
91
+ self,
92
+ models: DictConfig,
93
+ classification_heads: Optional[DictConfig] = None,
94
+ **kwargs,
95
+ ):
96
+ super().__init__(models, **kwargs)
97
+ self._classification_heads = classification_heads
98
+
99
+ @property
100
+ def train_processor(self):
101
+ if self._train_processor is None:
102
+ encoder: ImageEncoder = self.load_pretrained_or_first_model()
103
+ self._train_processor = encoder.train_preprocess
104
+ if self._test_processor is None:
105
+ self._test_processor = encoder.val_preprocess
106
+ return self._train_processor
107
+
108
+ @property
109
+ def test_processor(self):
110
+ if self._test_processor is None:
111
+ encoder: ImageEncoder = self.load_pretrained_or_first_model()
112
+ if self._train_processor is None:
113
+ self._train_processor = encoder.train_preprocess
114
+ self._test_processor = encoder.val_preprocess
115
+ return self._test_processor
116
+
117
+ def load_model(
118
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
119
+ ) -> ImageEncoder:
120
+ R"""
121
+ The model config can be:
122
+
123
+ - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
124
+ - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
125
+ - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
126
+ - Default, load the model using `instantiate` from hydra.
127
+ """
128
+ if (
129
+ isinstance(model_name_or_config, str)
130
+ and model_name_or_config in self._models
131
+ ):
132
+ model_config = self._models[model_name_or_config]
133
+ else:
134
+ model_config = model_name_or_config
135
+ if isinstance(model_config, DictConfig):
136
+ model_config = OmegaConf.to_container(model_config, resolve=True)
137
+
138
+ if isinstance(model_config, str):
139
+ # the model config is a string, which is the path to the model checkpoint in pickle format
140
+ # load the model using `torch.load`
141
+ # this is the original usage in the task arithmetic codebase
142
+ _check_and_redirect_open_clip_modeling()
143
+ log.info(f"loading ImageEncoder from {model_config}")
144
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
145
+ try:
146
+ encoder = torch.load(
147
+ model_config, weights_only=weights_only, *args, **kwargs
148
+ )
149
+ except RuntimeError as e:
150
+ encoder = pickle.load(open(model_config, "rb"))
151
+ elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
152
+ # the model config is a dictionary with the following keys:
153
+ # - model_name: str, the name of the model
154
+ # - pickle_path: str, the path to the binary file (pickle format)
155
+ # load the model from the binary file (pickle format)
156
+ # this is useful when you use a newer version of torchvision
157
+ _check_and_redirect_open_clip_modeling()
158
+ log.info(
159
+ f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
160
+ )
161
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
162
+ try:
163
+ encoder = torch.load(
164
+ model_config["pickle_path"],
165
+ weights_only=weights_only,
166
+ *args,
167
+ **kwargs,
168
+ )
169
+ except RuntimeError as e:
170
+ encoder = pickle.load(open(model_config["pickle_path"], "rb"))
171
+ _encoder = ImageEncoder(model_config["model_name"])
172
+ _encoder.load_state_dict(encoder.state_dict())
173
+ encoder = _encoder
174
+ elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
175
+ # the model config is a dictionary with the following keys:
176
+ # - model_name: str, the name of the model
177
+ # - state_dict_path: str, the path to the state dict file
178
+ # load the model from the state dict file
179
+ log.info(
180
+ f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
181
+ )
182
+ encoder = ImageEncoder(model_config["model_name"])
183
+ encoder.load_state_dict(
184
+ torch.load(
185
+ model_config["state_dict_path"], weights_only=True, *args, **kwargs
186
+ )
187
+ )
188
+ elif isinstance(model_config, nn.Module):
189
+ # the model config is an existing model
190
+ log.info(f"Returning existing model: {model_config}")
191
+ encoder = model_config
192
+ else:
193
+ encoder = super().load_model(model_name_or_config, *args, **kwargs)
194
+ encoder = cast(ImageEncoder, encoder)
195
+
196
+ # setup the train and test processors
197
+ if self._train_processor is None and hasattr(encoder, "train_preprocess"):
198
+ self._train_processor = encoder.train_preprocess
199
+ if self._test_processor is None and hasattr(encoder, "val_preprocess"):
200
+ self._test_processor = encoder.val_preprocess
201
+
202
+ return encoder
203
+
204
+ def load_classification_head(
205
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
206
+ ) -> ClassificationHead:
207
+ R"""
208
+ The model config can be:
209
+
210
+ - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
211
+ - Default, load the model using `instantiate` from hydra.
212
+ """
213
+ if (
214
+ isinstance(model_name_or_config, str)
215
+ and model_name_or_config in self._classification_heads
216
+ ):
217
+ model_config = self._classification_heads[model_name_or_config]
218
+ else:
219
+ model_config = model_name_or_config
220
+
221
+ head = load_classifier_head(model_config, *args, **kwargs)
222
+ return head
223
+
224
+ def load_train_dataset(self, dataset_name: str, *args, **kwargs):
225
+ dataset_config = self._train_datasets[dataset_name]
226
+ if isinstance(dataset_config, str):
227
+ log.info(
228
+ f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
229
+ )
230
+ dataset = load_dataset(dataset_config, split="train")
231
+ else:
232
+ dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
233
+ return dataset
234
+
235
+ def load_val_dataset(self, dataset_name: str, *args, **kwargs):
236
+ dataset_config = self._val_datasets[dataset_name]
237
+ if isinstance(dataset_config, str):
238
+ log.info(
239
+ f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
240
+ )
241
+ dataset = load_dataset(dataset_config, split="validation")
242
+ else:
243
+ dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
244
+ return dataset
245
+
246
+ def load_test_dataset(self, dataset_name: str, *args, **kwargs):
247
+ dataset_config = self._test_datasets[dataset_name]
248
+ if isinstance(dataset_config, str):
249
+ log.info(
250
+ f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
251
+ )
252
+ dataset = load_dataset(dataset_config, split="test")
253
+ else:
254
+ dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
255
+ return dataset
@@ -27,6 +27,8 @@ from transformers.models.mistral.modeling_mistral import (
27
27
  MistralRotaryEmbedding,
28
28
  )
29
29
 
30
+ from fusion_bench.models.smile_moe.linear_from_hf_config import SmileLinear
31
+
30
32
  from .configuration_smile_mistral import SmileMistralConfig
31
33
 
32
34
  logger = logging.getLogger(__name__)
@@ -80,209 +82,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
80
82
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
81
83
 
82
84
 
83
- class SmileGate(nn.Module):
84
- __constants__ = ["in_features", "num_experts", "k"]
85
- in_features: int
86
- num_experts: int
87
- k: int
88
- weight: Tensor
89
-
90
- def __init__(
91
- self,
92
- in_features: int,
93
- num_experts: int,
94
- k: int,
95
- device=None,
96
- dtype=None,
97
- ):
98
- factory_kwargs = {"device": device, "dtype": dtype}
99
- super().__init__()
100
- self.input_features = in_features
101
- self.num_experts = num_experts
102
- self.k = k
103
-
104
- self.weight = nn.Parameter(
105
- torch.empty(num_experts * k, in_features, **factory_kwargs)
106
- )
107
-
108
- def forward(self, x: Tensor):
109
- batch_size = x.size(0)
110
- if self.num_experts == 1:
111
- return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
112
-
113
- routing_weights = F.linear(x, self.weight).view(
114
- batch_size, self.num_experts, self.k
115
- )
116
- routing_weights = routing_weights.norm(p=2, dim=2)
117
- return routing_weights
118
-
119
-
120
- class SmileLinearExpert(nn.Module):
121
- __constants__ = ["in_features", "out_features", "k"]
122
- in_features: int
123
- out_features: int
124
- k: int
125
-
126
- def __init__(
127
- self,
128
- in_features,
129
- out_features,
130
- k: int,
131
- bias: bool,
132
- device=None,
133
- dtype=None,
134
- ):
135
- factory_kwargs = {"device": device, "dtype": dtype}
136
- super().__init__()
137
- self.in_features = in_features
138
- self.out_features = out_features
139
- self.k = k
140
-
141
- self.u = nn.Parameter(torch.empty(out_features, k, **factory_kwargs))
142
- self.svh = nn.Parameter(torch.empty(k, in_features, **factory_kwargs))
143
-
144
- if bias:
145
- self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
146
- else:
147
- self.register_parameter("bias", None)
148
-
149
- def forward(self, x):
150
- x = F.linear(x, self.svh)
151
- x = F.linear(x, self.u, self.bias)
152
- return x
153
-
154
-
155
- class SmileLinear(nn.Module):
156
- @torch.no_grad()
157
- def __init__(
158
- self,
159
- config: SmileMistralConfig,
160
- in_features,
161
- out_features,
162
- bias: bool,
163
- device=None,
164
- dtype=None,
165
- ):
166
- factory_kwargs = {"device": device, "dtype": dtype}
167
- super().__init__()
168
- self.num_local_experts = config.num_local_experts
169
- self.num_experts_per_tok = config.num_experts_per_tok
170
- self.rank_of_expert = config.rank_of_expert
171
- self.rank_of_router = config.rank_of_router
172
- self.in_features = in_features
173
- self.out_features = out_features
174
-
175
- # construct the gate network
176
- self.gate = SmileGate(
177
- in_features=in_features,
178
- num_experts=self.num_local_experts,
179
- k=self.rank_of_router,
180
- **factory_kwargs,
181
- )
182
-
183
- # the shared linear
184
- self.shared_linear = nn.Linear(
185
- in_features, out_features, bias=bias, **factory_kwargs
186
- )
187
-
188
- # construct experts
189
- if self.rank_of_expert > 0:
190
- self.experts = nn.ModuleList(
191
- [
192
- SmileLinearExpert(
193
- in_features=in_features,
194
- out_features=out_features,
195
- bias=bias,
196
- k=self.rank_of_expert,
197
- **factory_kwargs,
198
- )
199
- for _ in range(self.num_local_experts)
200
- ]
201
- )
202
- else:
203
- self.experts = nn.ModuleList(
204
- [
205
- nn.Linear(in_features, out_features, bias=bias, **factory_kwargs)
206
- for _ in range(self.num_local_experts)
207
- ]
208
- )
209
-
210
- def forward(self, hidden_states: Tensor):
211
- pretrained_out = self.shared_linear(hidden_states)
212
-
213
- input_shape = hidden_states.size()
214
- hidden_states = hidden_states.view(-1, self.in_features)
215
-
216
- router_logits = self.gate(hidden_states)
217
- routing_weights = F.softmax(router_logits, dim=1)
218
- # sample the expert according to the routing weights
219
- routing_weights, selected_experts = torch.topk(
220
- routing_weights, self.num_experts_per_tok, dim=-1
221
- )
222
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
223
-
224
- final_hidden_states = torch.zeros(
225
- (hidden_states.size(0), self.out_features),
226
- dtype=hidden_states.dtype,
227
- device=hidden_states.device,
228
- )
229
-
230
- # One hot encode the selected experts to create an expert mask
231
- # this will be used to easily index which expert is going to be sollicitated
232
- expert_mask = torch.nn.functional.one_hot(
233
- selected_experts, num_classes=self.num_local_experts
234
- ).permute(2, 1, 0)
235
-
236
- # Loop over all available experts in the model and perform the computation on each expert
237
- for expert_idx in range(self.num_local_experts):
238
- expert_layer = self.experts[expert_idx]
239
- idx, top_x = torch.where(expert_mask[expert_idx])
240
-
241
- # Index the correct hidden states and compute the expert hidden state for
242
- # the current expert. We need to make sure to multiply the output hidden
243
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
244
- current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
245
- if current_state.numel() == 0:
246
- continue
247
- current_hidden_states = (
248
- expert_layer(current_state) * routing_weights[top_x, idx, None]
249
- )
250
-
251
- # However `index_add_` only support torch tensors for indexing so we'll use
252
- # the `top_x` tensor here.
253
- final_hidden_states.index_add_(
254
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
255
- )
256
- final_hidden_states = final_hidden_states.reshape(
257
- *input_shape[:-1], self.out_features
258
- )
259
- final_hidden_states = pretrained_out + final_hidden_states
260
- return final_hidden_states
261
-
262
- @property
263
- def weight(self):
264
- """
265
- Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
266
- """
267
- return self.shared_linear.weight
268
-
269
- @property
270
- def bias(self):
271
- return self.shared_linear.bias
272
-
273
- def __repr__(self):
274
- return (
275
- f"SingularMoELinear("
276
- f"in_features={self.shared_linear.in_features}, "
277
- f"out_features={self.shared_linear.out_features}, "
278
- f"num_local_experts={self.num_local_experts}, "
279
- f"num_experts_per_tok={self.num_experts_per_tok}, "
280
- f"rank_of_router={self.rank_of_router}, "
281
- f"rank_of_expert={self.rank_of_expert}"
282
- f")"
283
- )
284
-
285
-
286
85
  class SmileMistralAttention(nn.Module):
287
86
  """
288
87
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
@@ -0,0 +1,8 @@
1
+ from . import register
2
+ from .configuration_smile_qwen2 import SmileQwen2Config
3
+ from .modeling_smile_qwen2 import (
4
+ SmileQwen2ForCausalLM,
5
+ SmileQwen2ForQuestionAnswering,
6
+ SmileQwen2ForSequenceClassification,
7
+ SmileQwen2Model,
8
+ )
@@ -0,0 +1,21 @@
1
+ from transformers import PretrainedConfig
2
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
3
+
4
+
5
+ class SmileQwen2Config(Qwen2Config):
6
+ model_type = "smile_qwen2"
7
+
8
+ def __init__(
9
+ self,
10
+ num_experts_per_tok: int = 1,
11
+ rank_of_router: int = None,
12
+ rank_of_expert: int = None,
13
+ num_local_experts: int = None,
14
+ **kwargs,
15
+ ):
16
+ self.num_experts_per_tok = num_experts_per_tok
17
+ self.rank_of_router = rank_of_router
18
+ self.rank_of_expert = rank_of_expert
19
+ self.num_local_experts = num_local_experts
20
+
21
+ super().__init__(**kwargs)