fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (209) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/ada_svd/clip_vision.py +4 -1
  9. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  10. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  11. fusion_bench/method/gossip/__init__.py +3 -0
  12. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  13. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  14. fusion_bench/method/gossip/entropy_loss.py +25 -0
  15. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  16. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  17. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  18. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  19. fusion_bench/method/gossip/utils.py +74 -0
  20. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  21. fusion_bench/method/opcm/opcm.py +16 -7
  22. fusion_bench/method/pwe_moe/module.py +1 -1
  23. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  24. fusion_bench/method/regmean/regmean.py +25 -17
  25. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  26. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
  27. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
  28. fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  30. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  31. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  32. fusion_bench/method/we_moe/we_moe.py +14 -15
  33. fusion_bench/mixins/__init__.py +6 -3
  34. fusion_bench/mixins/hydra_config.py +49 -0
  35. fusion_bench/mixins/openclip_classification.py +11 -0
  36. fusion_bench/mixins/simple_profiler.py +4 -2
  37. fusion_bench/modelpool/__init__.py +3 -1
  38. fusion_bench/modelpool/base_pool.py +2 -2
  39. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  40. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  41. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
  42. fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
  43. fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
  44. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
  45. fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
  46. fusion_bench/models/open_clip/__init__.py +6 -0
  47. fusion_bench/models/open_clip/modeling.py +176 -0
  48. fusion_bench/models/open_clip/utils.py +311 -0
  49. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  50. fusion_bench/models/parameter_dict.py +54 -13
  51. fusion_bench/models/rankone_moe.py +2 -88
  52. fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
  53. fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
  54. fusion_bench/models/smile_moe/utils/__init__.py +24 -0
  55. fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
  56. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  57. fusion_bench/taskpool/__init__.py +7 -3
  58. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  59. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  60. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  61. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  62. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  63. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  64. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  65. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  66. fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
  67. fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
  68. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  69. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  70. fusion_bench/utils/data.py +12 -0
  71. fusion_bench/utils/devices.py +14 -0
  72. fusion_bench/utils/instantiate.py +12 -0
  73. fusion_bench/utils/misc.py +9 -2
  74. fusion_bench/utils/packages.py +14 -0
  75. fusion_bench/utils/parameters.py +1 -1
  76. fusion_bench/utils/tensorboard.py +1 -1
  77. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
  78. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
  79. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
  80. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  81. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  82. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  83. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  84. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  85. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  86. fusion_bench_config/fabric/auto.yaml +0 -1
  87. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  88. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  89. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  90. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  91. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  92. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  93. fusion_bench_config/llama_full_finetune.yaml +0 -2
  94. fusion_bench_config/llama_model_fusion.yaml +0 -2
  95. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  96. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  97. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  98. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  99. fusion_bench_config/method/adamerging.yaml +2 -2
  100. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  101. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  102. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  103. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  104. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  105. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  106. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  107. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  108. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  109. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  110. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  111. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  112. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  113. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  114. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  115. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  116. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  117. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  118. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  119. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  120. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  121. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  122. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  123. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  124. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  125. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  126. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  127. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  128. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  129. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  130. fusion_bench_config/method/model_recombination.yaml +0 -1
  131. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  132. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  133. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  134. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  135. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  136. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  137. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  138. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  139. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  140. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  141. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
  142. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
  143. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  144. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  145. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  146. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  147. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  148. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  149. fusion_bench_config/method/ties_merging.yaml +1 -1
  150. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  151. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  153. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  154. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  155. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  156. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  157. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  158. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  159. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  160. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  161. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  162. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  171. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  172. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  173. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  174. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  175. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  178. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  179. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  180. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  181. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  182. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  183. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  184. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  185. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  186. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  187. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  188. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  189. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  190. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  191. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  192. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  193. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  194. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  195. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  196. fusion_bench_config/nyuv2_config.yaml +0 -2
  197. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  198. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  199. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  200. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  201. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
  202. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  203. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  204. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  205. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  206. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  207. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
  208. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
  209. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,24 @@
1
+ from typing import List
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from .svd_utils import svd
7
+
8
+ __all__ = ["svd_utils", "_is_all_zeros"]
9
+
10
+
11
+ def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
12
+ """
13
+ Check if a tensor or a list of tensors are all zeros.
14
+
15
+ Args:
16
+ tensor (Tensor | List[Tensor]): A tensor or a list of tensors.
17
+
18
+ Returns:
19
+ bool: True if all elements are zeros, False otherwise.
20
+ """
21
+ if isinstance(tensor, Tensor):
22
+ return torch.allclose(tensor, torch.zeros_like(tensor))
23
+ else:
24
+ return all(_is_all_zeros(t) for t in tensor)
@@ -0,0 +1,46 @@
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ def _svd(w: Tensor, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
8
+ """
9
+ Perform Singular Value Decomposition (SVD) on a tensor.
10
+
11
+ Args:
12
+ w (Tensor): The input tensor.
13
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
14
+
15
+ Returns:
16
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
17
+ """
18
+ u, s, vh = torch.linalg.svd(
19
+ w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
20
+ )
21
+ v = vh.T
22
+ return u, s, v
23
+
24
+
25
+ def svd(
26
+ w: Tensor,
27
+ full_matrices: bool = True,
28
+ accelerator: Optional[Union[torch.device, str]] = None,
29
+ ) -> Tuple[Tensor, Tensor, Tensor]:
30
+ """
31
+ Perform SVD on a tensor, optionally using a specified accelerator.
32
+
33
+ Args:
34
+ w (Tensor): The input tensor.
35
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
36
+ accelerator (Optional[Union[torch.device, str]]): The device to perform the computation on.
37
+
38
+ Returns:
39
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
40
+ """
41
+ if accelerator is None:
42
+ return _svd(w, full_matrices=full_matrices)
43
+ original_device = w.device
44
+ w = w.to(accelerator)
45
+ u, s, v = _svd(w)
46
+ return u.to(original_device), s.to(original_device), v.to(original_device)
@@ -1,5 +1,5 @@
1
1
  R"""
2
- This script is used to train a multi-task learning (MTL) model on the NYUv2 dataset.
2
+ This script is used to train a multi-task learning (MTL) model on the NYUv2 dataset.
3
3
  """
4
4
 
5
5
  import importlib
@@ -10,12 +10,14 @@ _import_structure = {
10
10
  "clip_vision": [
11
11
  "CLIPVisionModelTaskPool",
12
12
  "SparseWEMoECLIPVisionModelTaskPool",
13
- "RankoneWEMoECLIPVisionModelTaskPool",
13
+ "RankoneMoECLIPVisionModelTaskPool",
14
14
  ],
15
15
  "dummy": ["DummyTaskPool"],
16
16
  "gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
17
- "nyuv2_taskpool": ["NYUv2TaskPool"],
18
17
  "llama": ["LlamaTestGenerationTaskPool"],
18
+ "lm_eval_harness": ["LMEvalHarnessTaskPool"],
19
+ "nyuv2_taskpool": ["NYUv2TaskPool"],
20
+ "openclip_vision": ["OpenCLIPVisionModelTaskPool"],
19
21
  }
20
22
 
21
23
 
@@ -23,13 +25,15 @@ if TYPE_CHECKING:
23
25
  from .base_pool import BaseTaskPool
24
26
  from .clip_vision import (
25
27
  CLIPVisionModelTaskPool,
26
- RankoneWEMoECLIPVisionModelTaskPool,
28
+ RankoneMoECLIPVisionModelTaskPool,
27
29
  SparseWEMoECLIPVisionModelTaskPool,
28
30
  )
29
31
  from .dummy import DummyTaskPool
30
32
  from .gpt2_text_classification import GPT2TextClassificationTaskPool
31
33
  from .llama import LlamaTestGenerationTaskPool
34
+ from .lm_eval_harness import LMEvalHarnessTaskPool
32
35
  from .nyuv2_taskpool import NYUv2TaskPool
36
+ from .openclip_vision import OpenCLIPVisionModelTaskPool
33
37
 
34
38
  else:
35
39
  sys.modules[__name__] = LazyImporter(
@@ -1,4 +1,5 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_rankone_moe_taskpool import RankoneMoECLIPVisionModelTaskPool
3
+ from .clip_smile_taskpool import SmileCLIPVisionModelTaskPool
3
4
  from .clip_sparse_wemoe_taskpool import SparseWEMoECLIPVisionModelTaskPool
4
5
  from .taskpool import CLIPVisionModelTaskPool
@@ -12,36 +12,7 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
12
12
  from fusion_bench.models.rankone_moe import RankOneMoE
13
13
 
14
14
  from .taskpool import CLIPVisionModelTaskPool
15
-
16
-
17
- class LayerWiseRoutingWeightSaver:
18
- def __init__(self, save_path: Path, max_num: Optional[int] = None):
19
- self.save_path = save_path
20
- self.max_num = max_num
21
- self.routing_weights = []
22
-
23
- def __call__(self, module, input: Tuple[Tensor], output: Tensor):
24
- assert isinstance(output, Tensor), "Output is expected to be a Tensor"
25
- # (batch_size, num_tokens, num_experts)
26
- routing_weights = output.detach().cpu()
27
- if self.max_num is not None and self.max_num > 0:
28
- if len(self.routing_weights) > self.max_num:
29
- return
30
- elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
31
- self.routing_weights.append(
32
- routing_weights[: self.max_num - len(self.routing_weights)]
33
- )
34
- else:
35
- self.routing_weights.append(routing_weights)
36
- else:
37
- self.routing_weights.append(routing_weights)
38
-
39
- def save_routing_weights(self):
40
- routing_weights = torch.cat(self.routing_weights, dim=0)
41
- if self.save_path is not None:
42
- self.save_path.parent.mkdir(parents=True, exist_ok=True)
43
- print(f"Saving routing weights to {self.save_path}")
44
- torch.save(routing_weights, self.save_path)
15
+ from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
45
16
 
46
17
 
47
18
  class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
@@ -109,4 +80,5 @@ class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
109
80
  # remove hooks for saving layer-wise routing weights
110
81
  for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
111
82
  self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
83
+ self._layer_wise_routing_weights_save_hook_handles.pop(i)
112
84
  handle.remove()
@@ -0,0 +1,102 @@
1
+ from copy import deepcopy
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.utils.hooks import RemovableHandle
8
+ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
9
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
10
+
11
+ from fusion_bench.method.smile_upscaling import SmileMoELinear
12
+ from fusion_bench.models.hf_clip import HFCLIPClassifier
13
+
14
+ from .taskpool import CLIPVisionModelTaskPool
15
+ from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
16
+
17
+
18
+ class SmileCLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
19
+
20
+ # hooks and handles for saving layer-wise routing weights
21
+ _layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
22
+ _layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
23
+
24
+ def __init__(
25
+ self,
26
+ linear_module_names: Union[List[str], str],
27
+ layer_wise_routing_weights_save_path: Optional[str],
28
+ layer_wise_routing_weights_max_num: Optional[int] = None,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Initialize the SMILECLIPVisionModelTaskPool.
33
+
34
+ Args:
35
+ linear_module_names (Union[List[str], str]): The names of the linear modules to save the layer-wise routing weights for.
36
+ layer_wise_routing_weights_save_path (Optional[str]): The path to save the layer-wise routing weights.
37
+ layer_wise_routing_weights_max_num (Optional[int]): The maximum number of layer-wise routing weights to save.
38
+ """
39
+ # linear module names
40
+ assert linear_module_names is not None, "linear_module_names must be provided"
41
+ self.linear_module_names = (
42
+ [linear_module_names]
43
+ if isinstance(linear_module_names, str)
44
+ else list(linear_module_names)
45
+ )
46
+ # save path for layer-wise routing weights
47
+ self._layer_wise_routing_weights_save_path = (
48
+ layer_wise_routing_weights_save_path
49
+ )
50
+ self.layer_wise_routing_weights_save_path = (
51
+ Path(layer_wise_routing_weights_save_path)
52
+ if layer_wise_routing_weights_save_path is not None
53
+ else None
54
+ )
55
+ self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
56
+ super().__init__(**kwargs)
57
+
58
+ def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
59
+ super().on_task_evaluation_begin(classifier, task_name)
60
+ if self.layer_wise_routing_weights_save_path is not None:
61
+ # setup hooks for saving layer-wise routing weights
62
+ assert isinstance(
63
+ classifier.clip_model.vision_model,
64
+ (CLIPVisionTransformer, CLIPVisionModel),
65
+ ), "Vision model is expected to be a CLIPVisionTransformer"
66
+ vision_model = classifier.clip_model.vision_model
67
+ if isinstance(vision_model, CLIPVisionModel):
68
+ vision_model = vision_model.vision_model
69
+ # assign forward hooks for each layer
70
+
71
+ for i, layer in enumerate(vision_model.encoder.layers):
72
+ for linear_module_name in self.linear_module_names:
73
+ linear_module = layer.get_submodule(linear_module_name)
74
+ assert isinstance(
75
+ linear_module,
76
+ (SmileMoELinear),
77
+ ), f"Linear module is expected to be a SmileMoELinear, but got {type(linear_module)}"
78
+ # layer-wise routing weights
79
+ hook = LayerWiseRoutingWeightSaver(
80
+ self.layer_wise_routing_weights_save_path
81
+ / task_name
82
+ / f"layer_{i}_{linear_module_name}.pt",
83
+ max_num=self.layer_wise_routing_weights_max_num,
84
+ )
85
+ self._layer_wise_routing_weights_save_hooks[
86
+ (i, linear_module_name)
87
+ ] = hook
88
+ self._layer_wise_routing_weights_save_hook_handles[
89
+ (i, linear_module_name)
90
+ ] = linear_module.gate.register_forward_hook(hook)
91
+
92
+ def on_task_evaluation_end(self):
93
+ super().on_task_evaluation_end()
94
+ if self.layer_wise_routing_weights_save_path is not None:
95
+ # remove hooks for saving layer-wise routing weights
96
+ for (
97
+ key,
98
+ handle,
99
+ ) in self._layer_wise_routing_weights_save_hook_handles.items():
100
+ self._layer_wise_routing_weights_save_hooks[key].save_routing_weights()
101
+ self._layer_wise_routing_weights_save_hook_handles.pop(key)
102
+ handle.remove()
@@ -15,36 +15,7 @@ from fusion_bench.models.sparse_we_moe import (
15
15
  )
16
16
 
17
17
  from .taskpool import CLIPVisionModelTaskPool
18
-
19
-
20
- class LayerWiseRoutingWeightSaver:
21
- def __init__(self, save_path: Path, max_num: Optional[int] = None):
22
- self.save_path = save_path
23
- self.max_num = max_num
24
- self.routing_weights = []
25
-
26
- def __call__(self, module, input: Tuple[Tensor], output: Tensor):
27
- assert isinstance(output, Tensor), "Output is expected to be a Tensor"
28
- # (batch_size, num_tokens, num_experts)
29
- routing_weights = output.detach().cpu()
30
- if self.max_num is not None and self.max_num > 0:
31
- if len(self.routing_weights) > self.max_num:
32
- return
33
- elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
34
- self.routing_weights.append(
35
- routing_weights[: self.max_num - len(self.routing_weights)]
36
- )
37
- else:
38
- self.routing_weights.append(routing_weights)
39
- else:
40
- self.routing_weights.append(routing_weights)
41
-
42
- def save_routing_weights(self):
43
- routing_weights = torch.cat(self.routing_weights, dim=0)
44
- if self.save_path is not None:
45
- self.save_path.parent.mkdir(parents=True, exist_ok=True)
46
- print(f"Saving routing weights to {self.save_path}")
47
- torch.save(routing_weights, self.save_path)
18
+ from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
48
19
 
49
20
 
50
21
  class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
@@ -117,4 +88,5 @@ class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
117
88
  # remove hooks for saving layer-wise routing weights
118
89
  for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
119
90
  self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
91
+ self._layer_wise_routing_weights_save_hook_handles.pop(i)
120
92
  handle.remove()
@@ -32,8 +32,7 @@ from fusion_bench.mixins import LightningFabricMixin
32
32
  from fusion_bench.models.hf_clip import HFCLIPClassifier
33
33
  from fusion_bench.taskpool import BaseTaskPool
34
34
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
35
- from fusion_bench.utils import instantiate
36
- from fusion_bench.utils.parameters import count_parameters
35
+ from fusion_bench.utils import count_parameters, instantiate
37
36
 
38
37
  if TYPE_CHECKING:
39
38
  from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
File without changes
@@ -0,0 +1,65 @@
1
+ from pathlib import Path
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def _number_of_samples(routing_weights: List[Tensor]):
9
+ count = 0
10
+ for routing_weight in routing_weights:
11
+ count += routing_weight.size(0)
12
+ return count
13
+
14
+
15
+ class LayerWiseRoutingWeightSaver:
16
+ """
17
+ A hook for saving layer-wise routing weights.
18
+ """
19
+
20
+ save_path: Path
21
+ "The path to save the layer-wise routing weights."
22
+ max_num: Optional[int]
23
+ "The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved."
24
+ routing_weights: List[Tensor]
25
+ "The list of layer-wise routing weights."
26
+
27
+ def __init__(self, save_path: Path, max_num: Optional[int] = None):
28
+ """
29
+ Args:
30
+ save_path (Path): The path to save the layer-wise routing weights.
31
+ max_num (Optional[int]): The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved.
32
+ """
33
+ self.save_path = save_path
34
+ self.max_num = max_num
35
+ self.routing_weights = []
36
+
37
+ def __call__(self, module, input: Tuple[Tensor], output: Tensor):
38
+ assert isinstance(output, Tensor), "Output is expected to be a Tensor"
39
+ # (batch_size, num_tokens, num_experts)
40
+ routing_weights = output.detach().cpu()
41
+ if self.max_num is not None and self.max_num > 0:
42
+ if _number_of_samples(self.routing_weights) > self.max_num:
43
+ return
44
+ elif (
45
+ routing_weights.size(0) + _number_of_samples(self.routing_weights)
46
+ > self.max_num
47
+ ):
48
+ self.routing_weights.append(
49
+ routing_weights[
50
+ : self.max_num - _number_of_samples(self.routing_weights)
51
+ ]
52
+ )
53
+ else:
54
+ self.routing_weights.append(routing_weights)
55
+ else:
56
+ self.routing_weights.append(routing_weights)
57
+
58
+ def save_routing_weights(self):
59
+ routing_weights = torch.cat(self.routing_weights, dim=0)
60
+ if self.save_path is not None:
61
+ self.save_path.parent.mkdir(parents=True, exist_ok=True)
62
+ print(
63
+ f"Saving routing weights to {self.save_path}. Size: {routing_weights.size()}"
64
+ )
65
+ torch.save(routing_weights, self.save_path)
@@ -139,11 +139,40 @@ class GPT2TextClassificationTaskPool(BaseTaskPool, LightningFabricMixin):
139
139
  return dataloader
140
140
 
141
141
  @override
142
- def evaluate(self, model: GPT2Model):
142
+ def evaluate(self, model: GPT2Model, name: str = None):
143
+ """Evaluate the model on the test datasets.
144
+
145
+ Args:
146
+ model (GPT2Model): The model to evaluate.
147
+ name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.
148
+
149
+ Returns:
150
+ dict: A dictionary containing the evaluation results for each task.
151
+ """
143
152
  report = {}
153
+ if name is not None:
154
+ report["name"] = name
144
155
  for task_name in (pbar := tqdm(self._test_datasets, desc="Evaluating tasks")):
145
156
  pbar.set_description(f"Evaluating task {task_name}")
146
157
  dataloader = self.get_test_dataloader(task_name)
147
158
  result = self.evaluate_single_task(task_name, model, dataloader)
148
159
  report[task_name] = result
160
+
161
+ # calculate the average accuracy and loss
162
+ if "average" not in report:
163
+ report["average"] = {}
164
+ accuracies = [
165
+ value["accuracy"]
166
+ for key, value in report.items()
167
+ if isinstance(value, dict) and "accuracy" in value
168
+ ]
169
+ if len(accuracies) > 0:
170
+ average_accuracy = sum(accuracies) / len(accuracies)
171
+ report["average"]["accuracy"] = average_accuracy
172
+ losses = [value["loss"] for key, value in report.items() if "loss" in value]
173
+ if len(losses) > 0:
174
+ average_loss = sum(losses) / len(losses)
175
+ report["average"]["loss"] = average_loss
176
+
177
+ log.info(f"Evaluation Result: {report}")
149
178
  return report
@@ -0,0 +1,3 @@
1
+ from .taskpool import LMEvalHarnessTaskPool
2
+
3
+ __all__ = ["LMEvalHarnessTaskPool"]
@@ -0,0 +1,87 @@
1
+ import logging
2
+ import os
3
+ from typing import List, Literal, Optional, Union, TYPE_CHECKING
4
+
5
+ import lightning.fabric
6
+ import lm_eval
7
+ import lm_eval.models
8
+ from lm_eval.__main__ import check_argument_types, cli_evaluate, setup_parser
9
+ from omegaconf import DictConfig, ListConfig
10
+
11
+ from fusion_bench import BaseTaskPool
12
+ from fusion_bench.mixins import LightningFabricMixin
13
+ from fusion_bench.utils.strenum import _version
14
+
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class LMEvalHarnessTaskPool(BaseTaskPool, LightningFabricMixin):
20
+ def __init__(
21
+ self,
22
+ tasks: Union[str, List[str]],
23
+ apply_chat_template: bool = False,
24
+ include_path: Optional[str] = None,
25
+ batch_size: int = 1,
26
+ metadata: Optional[DictConfig] = None,
27
+ verbosity: Optional[
28
+ Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
29
+ ] = None,
30
+ output_path: Optional[str] = None,
31
+ log_samples: bool = False,
32
+ _usage_: Optional[str] = None,
33
+ _version_: Optional[str] = None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(_usage_=_usage_, _version_=_version_)
37
+ self.tasks = tasks
38
+ self.include_path = include_path
39
+ self.batch_size = batch_size
40
+ self.metadata = metadata
41
+ self.apply_chat_template = apply_chat_template
42
+ self.verbosity = verbosity
43
+ self.kwargs = kwargs
44
+ self.output_path = output_path
45
+ self.log_samples = log_samples
46
+
47
+ def evaluate(self, model, *command_line_args, **kwargs):
48
+ command_line_args = []
49
+ if self.include_path is not None:
50
+ command_line_args.extend(["--include_path", self.include_path])
51
+ if isinstance(self.tasks, (list, ListConfig)):
52
+ command_line_args.extend(["--tasks", ",".join(self.tasks)])
53
+ elif isinstance(self.tasks, str):
54
+ command_line_args.extend(["--tasks", self.tasks])
55
+ if self.apply_chat_template:
56
+ command_line_args.extend(
57
+ ["--apply_chat_template", str(self.apply_chat_template)]
58
+ )
59
+ if self.batch_size is not None:
60
+ command_line_args.extend(["--batch_size", str(self.batch_size)])
61
+ if self.verbosity is not None:
62
+ command_line_args.extend(["--verbosity", str(self.verbosity)])
63
+ if self.metadata is not None:
64
+ command_line_args.extend(["--metadata", str(self.metadata)])
65
+ if self.output_path is None:
66
+ command_line_args.extend(
67
+ [
68
+ "--output_path",
69
+ os.path.join(self.log_dir, "lm_eval_results"),
70
+ ]
71
+ )
72
+ else:
73
+ command_line_args.extend(["--output_path", self.output_path])
74
+ if self.log_samples:
75
+ command_line_args.extend(["--log_samples"])
76
+ for key, value in kwargs.items():
77
+ command_line_args.extend([f"--{key}", str(value)])
78
+
79
+ parser = setup_parser()
80
+ check_argument_types(parser)
81
+ args = parser.parse_args(args=command_line_args)
82
+ log.info("LM-Eval Harness arguments:\n%s", args)
83
+
84
+ if not lightning.fabric.is_wrapped(model):
85
+ model = self.fabric.setup(model)
86
+ args.model = lm_eval.models.huggingface.HFLM(pretrained=model)
87
+ cli_evaluate(args)
@@ -0,0 +1 @@
1
+ from .openclip_taskpool import OpenCLIPVisionModelTaskPool