fusion-bench 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (727) hide show
  1. fusion_bench/__init__.py +20 -0
  2. fusion_bench/__main__.py +4 -0
  3. fusion_bench/compat/__init__.py +0 -0
  4. fusion_bench/compat/method/__init__.py +109 -0
  5. fusion_bench/compat/method/base_algorithm.py +58 -0
  6. fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
  7. fusion_bench/compat/modelpool/__init__.py +116 -0
  8. fusion_bench/compat/modelpool/base_pool.py +328 -0
  9. fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
  10. fusion_bench/compat/taskpool/__init__.py +95 -0
  11. fusion_bench/compat/taskpool/base_pool.py +111 -0
  12. fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
  13. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
  14. fusion_bench/constants/__init__.py +2 -0
  15. fusion_bench/constants/paths.py +18 -0
  16. fusion_bench/dataset/__init__.py +29 -0
  17. fusion_bench/dataset/arc_agi/__init__.py +6 -0
  18. fusion_bench/dataset/arc_agi/arc.py +308 -0
  19. fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
  20. fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
  21. fusion_bench/dataset/arc_agi/messagers.py +1355 -0
  22. fusion_bench/dataset/arc_agi/np_cache.py +168 -0
  23. fusion_bench/dataset/arc_agi/preprocess.py +298 -0
  24. fusion_bench/dataset/arc_agi/representers.py +1019 -0
  25. fusion_bench/dataset/clip_dataset.py +71 -0
  26. fusion_bench/dataset/fer2013.py +12 -0
  27. fusion_bench/dataset/gpt2_glue.py +300 -0
  28. fusion_bench/dataset/gsm8k.py +60 -0
  29. fusion_bench/dataset/image_dataset.py +55 -0
  30. fusion_bench/dataset/imdb.py +11 -0
  31. fusion_bench/dataset/llama/__init__.py +1 -0
  32. fusion_bench/dataset/llama/alpaca.py +232 -0
  33. fusion_bench/dataset/llama/collate.py +120 -0
  34. fusion_bench/dataset/llama/metamathqa.py +50 -0
  35. fusion_bench/dataset/llama/openai.py +160 -0
  36. fusion_bench/dataset/llama/preference_700k.py +70 -0
  37. fusion_bench/dataset/llama/sharegpt.py +141 -0
  38. fusion_bench/dataset/llama/squad.py +125 -0
  39. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  40. fusion_bench/dataset/llama/ultrachat.py +58 -0
  41. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  42. fusion_bench/dataset/llama/wikitext.py +89 -0
  43. fusion_bench/dataset/nyuv2.py +119 -0
  44. fusion_bench/method/__init__.py +177 -0
  45. fusion_bench/method/ada_svd/__init__.py +2 -0
  46. fusion_bench/method/ada_svd/clip_vision.py +319 -0
  47. fusion_bench/method/adamerging/__init__.py +6 -0
  48. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
  49. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
  50. fusion_bench/method/adamerging/entropy_loss.py +25 -0
  51. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
  52. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
  53. fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
  54. fusion_bench/method/adamerging/llama_adamerging.py +335 -0
  55. fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
  56. fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
  57. fusion_bench/method/adamerging/utils.py +15 -0
  58. fusion_bench/method/analysis/__init__.py +2 -0
  59. fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
  60. fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
  61. fusion_bench/method/base_algorithm.py +44 -0
  62. fusion_bench/method/classification/__init__.py +3 -0
  63. fusion_bench/method/classification/clip_finetune.py +444 -0
  64. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  65. fusion_bench/method/concrete_subspace/__init__.py +6 -0
  66. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
  67. fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
  68. fusion_bench/method/dare/__init__.py +4 -0
  69. fusion_bench/method/dare/simple_average.py +31 -0
  70. fusion_bench/method/dare/task_arithmetic.py +82 -0
  71. fusion_bench/method/dare/ties_merging.py +100 -0
  72. fusion_bench/method/dare/utils.py +87 -0
  73. fusion_bench/method/dawe/__init__.py +2 -0
  74. fusion_bench/method/dawe/dawe_for_clip.py +274 -0
  75. fusion_bench/method/dawe/warppers/__init__.py +13 -0
  76. fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
  77. fusion_bench/method/depth_upscaling/__init__.py +3 -0
  78. fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
  79. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
  80. fusion_bench/method/dummy.py +35 -0
  81. fusion_bench/method/ensemble.py +98 -0
  82. fusion_bench/method/fisher_merging/__init__.py +4 -0
  83. fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
  84. fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
  85. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
  86. fusion_bench/method/linear/__init__.py +6 -0
  87. fusion_bench/method/linear/expo.py +118 -0
  88. fusion_bench/method/linear/linear_interpolation.py +60 -0
  89. fusion_bench/method/linear/llama_expo.py +229 -0
  90. fusion_bench/method/linear/simple_average_for_llama.py +54 -0
  91. fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
  92. fusion_bench/method/lm_finetune/__init__.py +3 -0
  93. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  94. fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
  95. fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
  96. fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
  97. fusion_bench/method/mixture_of_experts/__init__.py +7 -0
  98. fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
  99. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
  100. fusion_bench/method/model_recombination.py +121 -0
  101. fusion_bench/method/opcm/__init__.py +4 -0
  102. fusion_bench/method/opcm/opcm.py +277 -0
  103. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  104. fusion_bench/method/opcm/ties_merging.py +156 -0
  105. fusion_bench/method/opcm/utils.py +73 -0
  106. fusion_bench/method/opcm/weight_average.py +120 -0
  107. fusion_bench/method/pruning/__init__.py +5 -0
  108. fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
  109. fusion_bench/method/pruning/llama_random_prune.py +143 -0
  110. fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
  111. fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
  112. fusion_bench/method/pruning/prune_utils.py +165 -0
  113. fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
  114. fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
  115. fusion_bench/method/pruning/wanda_utils/data.py +135 -0
  116. fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
  117. fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
  118. fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
  119. fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
  120. fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
  121. fusion_bench/method/pwe_moe/__init__.py +5 -0
  122. fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
  123. fusion_bench/method/pwe_moe/module.py +316 -0
  124. fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
  125. fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
  126. fusion_bench/method/pwe_moe/utils.py +43 -0
  127. fusion_bench/method/rankone_moe/__init__.py +3 -0
  128. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  129. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  130. fusion_bench/method/regmean/__init__.py +4 -0
  131. fusion_bench/method/regmean/clip_regmean.py +131 -0
  132. fusion_bench/method/regmean/gpt2_regmean.py +147 -0
  133. fusion_bench/method/regmean/regmean.py +375 -0
  134. fusion_bench/method/simple_average.py +112 -0
  135. fusion_bench/method/slerp/__init__.py +2 -0
  136. fusion_bench/method/slerp/slerp.py +101 -0
  137. fusion_bench/method/slerp/slerp_utils.py +107 -0
  138. fusion_bench/method/smile_upscaling/__init__.py +3 -0
  139. fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
  140. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
  141. fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
  142. fusion_bench/method/sparse_we_moe/__init__.py +2 -0
  143. fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
  144. fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
  145. fusion_bench/method/sparselo/__init__.py +2 -0
  146. fusion_bench/method/sparselo/sparselo.py +955 -0
  147. fusion_bench/method/surgery/__init__.py +1 -0
  148. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  149. fusion_bench/method/tall_mask/__init__.py +0 -0
  150. fusion_bench/method/tall_mask/utils.py +234 -0
  151. fusion_bench/method/task_arithmetic/__init__.py +2 -0
  152. fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
  153. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  154. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  155. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  156. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  157. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
  158. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  159. fusion_bench/method/ties_merging/__init__.py +2 -0
  160. fusion_bench/method/ties_merging/ties_merging.py +117 -0
  161. fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
  162. fusion_bench/method/trust_region/__init__.py +2 -0
  163. fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
  164. fusion_bench/method/trust_region/utils.py +58 -0
  165. fusion_bench/method/we_moe/__init__.py +2 -0
  166. fusion_bench/method/we_moe/clip_we_moe.py +161 -0
  167. fusion_bench/method/we_moe/we_moe.py +247 -0
  168. fusion_bench/method/weighted_average/__init__.py +3 -0
  169. fusion_bench/method/weighted_average/llama.py +113 -0
  170. fusion_bench/method/weighted_average/weighted_average.py +102 -0
  171. fusion_bench/metrics/__init__.py +0 -0
  172. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  173. fusion_bench/metrics/nyuv2/__init__.py +11 -0
  174. fusion_bench/metrics/nyuv2/depth.py +45 -0
  175. fusion_bench/metrics/nyuv2/loss.py +31 -0
  176. fusion_bench/metrics/nyuv2/noise.py +16 -0
  177. fusion_bench/metrics/nyuv2/normal.py +48 -0
  178. fusion_bench/metrics/nyuv2/segmentation.py +43 -0
  179. fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
  180. fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
  181. fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
  182. fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
  183. fusion_bench/mixins/__init__.py +28 -0
  184. fusion_bench/mixins/clip_classification.py +252 -0
  185. fusion_bench/mixins/fabric_training.py +320 -0
  186. fusion_bench/mixins/lightning_fabric.py +174 -0
  187. fusion_bench/mixins/optim/__init__.py +0 -0
  188. fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
  189. fusion_bench/mixins/rich_live.py +21 -0
  190. fusion_bench/mixins/serialization.py +132 -0
  191. fusion_bench/mixins/simple_profiler.py +79 -0
  192. fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
  193. fusion_bench/modelpool/__init__.py +42 -0
  194. fusion_bench/modelpool/base_pool.py +268 -0
  195. fusion_bench/modelpool/causal_lm/__init__.py +2 -0
  196. fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
  197. fusion_bench/modelpool/clip_vision/__init__.py +1 -0
  198. fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
  199. fusion_bench/modelpool/huggingface_automodel.py +20 -0
  200. fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
  201. fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
  202. fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
  203. fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
  204. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  205. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  206. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  207. fusion_bench/models/__init__.py +3 -0
  208. fusion_bench/models/chat_templates/__init__.py +1 -0
  209. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  210. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  211. fusion_bench/models/hf_clip.py +199 -0
  212. fusion_bench/models/linearized/__init__.py +0 -0
  213. fusion_bench/models/linearized/linearized_model_utils.py +91 -0
  214. fusion_bench/models/linearized/vision_model.py +122 -0
  215. fusion_bench/models/llama/__init__.py +16 -0
  216. fusion_bench/models/llama/model_utils/__init__.py +0 -0
  217. fusion_bench/models/llama/model_utils/embedding.py +87 -0
  218. fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
  219. fusion_bench/models/llama/model_utils/misc.py +112 -0
  220. fusion_bench/models/llama/model_utils/mod.py +52 -0
  221. fusion_bench/models/llama/model_utils/visual.py +241 -0
  222. fusion_bench/models/llama/patcher.py +78 -0
  223. fusion_bench/models/llama/tokenizer_loader.py +153 -0
  224. fusion_bench/models/masks/__init__.py +2 -0
  225. fusion_bench/models/masks/mask_model.py +160 -0
  226. fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
  227. fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
  228. fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
  229. fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
  230. fusion_bench/models/modeling_losparse_llama/register.py +8 -0
  231. fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
  232. fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
  233. fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
  234. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
  235. fusion_bench/models/modeling_smile_mistral/register.py +8 -0
  236. fusion_bench/models/nyuv2/__init__.py +0 -0
  237. fusion_bench/models/nyuv2/aspp.py +82 -0
  238. fusion_bench/models/nyuv2/lightning_module.py +176 -0
  239. fusion_bench/models/nyuv2/resnet.py +405 -0
  240. fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
  241. fusion_bench/models/parameter_dict.py +75 -0
  242. fusion_bench/models/rankone_moe.py +410 -0
  243. fusion_bench/models/separate_io.py +105 -0
  244. fusion_bench/models/smile_moe/__init__.py +0 -0
  245. fusion_bench/models/smile_moe/linear.py +256 -0
  246. fusion_bench/models/sparse_we_moe.py +459 -0
  247. fusion_bench/models/surgery/__init__.py +1 -0
  248. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  249. fusion_bench/models/utils.py +80 -0
  250. fusion_bench/models/we_moe.py +247 -0
  251. fusion_bench/models/wrappers/__init__.py +0 -0
  252. fusion_bench/models/wrappers/ensemble.py +183 -0
  253. fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
  254. fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
  255. fusion_bench/optim/__init__.py +2 -0
  256. fusion_bench/optim/exception.py +47 -0
  257. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  258. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  259. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  260. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  261. fusion_bench/optim/mezo.py +118 -0
  262. fusion_bench/programs/__init__.py +20 -0
  263. fusion_bench/programs/base_program.py +9 -0
  264. fusion_bench/programs/fabric_fusion_program.py +299 -0
  265. fusion_bench/scripts/__init__.py +0 -0
  266. fusion_bench/scripts/cli.py +43 -0
  267. fusion_bench/scripts/clip/__init__.py +0 -0
  268. fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
  269. fusion_bench/scripts/imgui.py +218 -0
  270. fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
  271. fusion_bench/scripts/webui.py +405 -0
  272. fusion_bench/taskpool/__init__.py +39 -0
  273. fusion_bench/taskpool/base_pool.py +35 -0
  274. fusion_bench/taskpool/clip_vision/__init__.py +4 -0
  275. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  276. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
  277. fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
  278. fusion_bench/taskpool/dummy.py +58 -0
  279. fusion_bench/taskpool/gpt2_text_classification.py +149 -0
  280. fusion_bench/taskpool/llama/__init__.py +1 -0
  281. fusion_bench/taskpool/llama/reward_model.py +157 -0
  282. fusion_bench/taskpool/llama/test_generation.py +185 -0
  283. fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
  284. fusion_bench/tasks/__init__.py +2 -0
  285. fusion_bench/tasks/base_task.py +18 -0
  286. fusion_bench/tasks/classification.py +75 -0
  287. fusion_bench/tasks/clip_classification/__init__.py +183 -0
  288. fusion_bench/tasks/clip_classification/cifar10.py +33 -0
  289. fusion_bench/tasks/clip_classification/cifar100.py +146 -0
  290. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
  291. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  292. fusion_bench/tasks/clip_classification/dtd.py +60 -0
  293. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  294. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  295. fusion_bench/tasks/clip_classification/eurosat.py +18 -0
  296. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  297. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  298. fusion_bench/tasks/clip_classification/flower102.py +106 -0
  299. fusion_bench/tasks/clip_classification/food101.py +105 -0
  300. fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
  301. fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
  302. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  303. fusion_bench/tasks/clip_classification/mnist.py +5 -0
  304. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  305. fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
  306. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  307. fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
  308. fusion_bench/tasks/clip_classification/resisc45.py +68 -0
  309. fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
  310. fusion_bench/tasks/clip_classification/stl10.py +17 -0
  311. fusion_bench/tasks/clip_classification/sun397.py +404 -0
  312. fusion_bench/tasks/clip_classification/svhn.py +5 -0
  313. fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
  314. fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
  315. fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
  316. fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
  317. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
  318. fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
  319. fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
  320. fusion_bench/utils/__init__.py +14 -0
  321. fusion_bench/utils/auto.py +31 -0
  322. fusion_bench/utils/cache_utils.py +58 -0
  323. fusion_bench/utils/data.py +165 -0
  324. fusion_bench/utils/devices.py +231 -0
  325. fusion_bench/utils/dict.py +43 -0
  326. fusion_bench/utils/dtype.py +146 -0
  327. fusion_bench/utils/expr.py +90 -0
  328. fusion_bench/utils/fabric.py +17 -0
  329. fusion_bench/utils/functools.py +37 -0
  330. fusion_bench/utils/hydra_utils.py +28 -0
  331. fusion_bench/utils/instantiate.py +450 -0
  332. fusion_bench/utils/json.py +93 -0
  333. fusion_bench/utils/lazy_imports.py +74 -0
  334. fusion_bench/utils/misc.py +18 -0
  335. fusion_bench/utils/packages.py +84 -0
  336. fusion_bench/utils/parameters.py +323 -0
  337. fusion_bench/utils/path.py +22 -0
  338. fusion_bench/utils/plot/__init__.py +0 -0
  339. fusion_bench/utils/plot/color_data.py +1726 -0
  340. fusion_bench/utils/plot/token.py +52 -0
  341. fusion_bench/utils/plot/token_notebook.py +127 -0
  342. fusion_bench/utils/pylogger.py +55 -0
  343. fusion_bench/utils/rich_utils.py +201 -0
  344. fusion_bench/utils/set.py +8 -0
  345. fusion_bench/utils/state_dict_arithmetic.py +297 -0
  346. fusion_bench/utils/strenum/__init__.py +326 -0
  347. fusion_bench/utils/strenum/_name_mangler.py +127 -0
  348. fusion_bench/utils/strenum/_version.py +556 -0
  349. fusion_bench/utils/tensorboard.py +51 -0
  350. fusion_bench/utils/timer.py +49 -0
  351. fusion_bench/utils/type.py +34 -0
  352. fusion_bench-0.2.9.dist-info/LICENSE +21 -0
  353. fusion_bench-0.2.9.dist-info/METADATA +258 -0
  354. fusion_bench-0.2.9.dist-info/RECORD +727 -0
  355. fusion_bench-0.2.9.dist-info/WHEEL +5 -0
  356. fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
  357. fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
  358. fusion_bench_config/README.md +12 -0
  359. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
  360. fusion_bench_config/dataset/image_classification/README.md +6 -0
  361. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  362. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  363. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
  364. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
  365. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  366. fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
  367. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  368. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  369. fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
  370. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  371. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  372. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  373. fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
  374. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  375. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  376. fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
  377. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  378. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  379. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  380. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  381. fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
  382. fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
  383. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  384. fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
  385. fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
  386. fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
  387. fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
  388. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  389. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  390. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
  391. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
  392. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  393. fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
  394. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  395. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  396. fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
  397. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  398. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  399. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  400. fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
  401. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  402. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  403. fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
  404. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  405. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  406. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  407. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  408. fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
  409. fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
  410. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  411. fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
  412. fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
  413. fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
  414. fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
  415. fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
  416. fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
  417. fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
  418. fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
  419. fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
  420. fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
  421. fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
  422. fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
  423. fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
  424. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  425. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  426. fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
  427. fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
  428. fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
  429. fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
  430. fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
  431. fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
  432. fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
  433. fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
  434. fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
  435. fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
  436. fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
  437. fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
  438. fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
  439. fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
  440. fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
  441. fusion_bench_config/fabric/auto.yaml +16 -0
  442. fusion_bench_config/fabric/llama_ddp.yaml +18 -0
  443. fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
  444. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  445. fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
  446. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
  447. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  448. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  449. fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
  450. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  451. fusion_bench_config/fabric_model_fusion.yaml +20 -0
  452. fusion_bench_config/hydra/default.yaml +8 -0
  453. fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
  454. fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
  455. fusion_bench_config/llama_full_finetune.yaml +19 -0
  456. fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
  457. fusion_bench_config/llama_model_fusion.yaml +17 -0
  458. fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
  459. fusion_bench_config/method/adamerging/clip.yaml +23 -0
  460. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
  461. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
  462. fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
  463. fusion_bench_config/method/adamerging.yaml +23 -0
  464. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
  465. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
  466. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  467. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  468. fusion_bench_config/method/clip_finetune.yaml +26 -0
  469. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
  470. fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
  471. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
  472. fusion_bench_config/method/dare/simple_average.yaml +5 -0
  473. fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
  474. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  475. fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
  476. fusion_bench_config/method/depth_upscaling.yaml +5 -0
  477. fusion_bench_config/method/dummy.yaml +1 -0
  478. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
  479. fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
  480. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
  481. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
  482. fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
  483. fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
  484. fusion_bench_config/method/linear/expo.yaml +8 -0
  485. fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
  486. fusion_bench_config/method/linear/llama_expo.yaml +19 -0
  487. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
  488. fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
  489. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
  490. fusion_bench_config/method/linear/weighted_average.yaml +6 -0
  491. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
  492. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  493. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
  494. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
  495. fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
  496. fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
  497. fusion_bench_config/method/model_recombination.yaml +4 -0
  498. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  499. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  500. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  501. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  502. fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
  503. fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
  504. fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
  505. fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
  506. fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
  507. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  508. fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
  509. fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
  510. fusion_bench_config/method/regmean/regmean.yaml +4 -0
  511. fusion_bench_config/method/simple_average.yaml +1 -0
  512. fusion_bench_config/method/slerp/slerp.yaml +6 -0
  513. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
  514. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
  515. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
  516. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
  517. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
  518. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
  519. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  520. fusion_bench_config/method/task_arithmetic.yaml +2 -0
  521. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  522. fusion_bench_config/method/ties_merging.yaml +8 -0
  523. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
  524. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
  525. fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
  526. fusion_bench_config/model/clip-vit/README.md +38 -0
  527. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
  528. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  529. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  530. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  531. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  532. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
  533. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
  534. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  535. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
  536. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  537. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  538. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  539. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
  540. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  541. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
  542. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  543. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  544. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  545. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  546. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
  547. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
  548. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  549. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
  550. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
  551. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
  552. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  553. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  554. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  555. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  556. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
  557. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
  558. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  559. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
  560. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  561. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  562. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  563. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
  564. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  565. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
  566. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  567. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  568. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  569. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  570. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
  571. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
  572. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  573. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
  574. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
  575. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
  576. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  577. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  578. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  579. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  580. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
  581. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
  582. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  583. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
  584. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  585. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  586. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  587. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
  588. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  589. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
  590. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  591. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  592. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  593. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  594. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
  595. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
  596. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  597. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
  598. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
  599. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  600. fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
  601. fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
  602. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
  603. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
  604. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
  605. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
  606. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
  607. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
  608. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
  609. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
  610. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
  611. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
  612. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
  613. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
  614. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
  615. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
  616. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
  617. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
  618. fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
  619. fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
  620. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
  621. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
  622. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
  623. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
  624. fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
  625. fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
  626. fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
  627. fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
  628. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
  629. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
  630. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
  631. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  632. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  633. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  634. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  635. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  636. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
  637. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
  638. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
  639. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
  640. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
  641. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  642. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  643. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  644. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  645. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
  646. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
  647. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
  648. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
  649. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
  650. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
  651. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
  652. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  653. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
  654. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  655. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
  656. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
  657. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  658. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  659. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  660. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  661. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
  662. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  663. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  664. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
  665. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  666. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  667. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
  668. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
  669. fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
  670. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
  671. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
  672. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
  673. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
  674. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
  675. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  676. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  677. fusion_bench_config/modelpool/automodelpool.yaml +12 -0
  678. fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
  679. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
  680. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
  681. fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
  682. fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
  683. fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
  684. fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
  685. fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
  686. fusion_bench_config/nyuv2_config.yaml +17 -0
  687. fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
  688. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
  689. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  690. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
  691. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
  692. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
  693. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
  694. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
  695. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  696. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  697. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  698. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  699. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  700. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  701. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  702. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  703. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  704. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  705. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  706. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  707. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  708. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  709. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  710. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  711. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  712. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  713. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  714. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  715. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  716. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  717. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  718. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  719. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
  720. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
  721. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  722. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
  723. fusion_bench_config/taskpool/dummy.yaml +2 -0
  724. fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
  725. fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
  726. fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
  727. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
@@ -0,0 +1,20 @@
1
+ # flake8: noqa: F401
2
+ from . import (
3
+ constants,
4
+ dataset,
5
+ method,
6
+ metrics,
7
+ mixins,
8
+ modelpool,
9
+ models,
10
+ optim,
11
+ programs,
12
+ taskpool,
13
+ tasks,
14
+ utils,
15
+ )
16
+ from .method import BaseAlgorithm, BaseModelFusionAlgorithm
17
+ from .modelpool import BaseModelPool
18
+ from .models import separate_io
19
+ from .taskpool import BaseTaskPool
20
+ from .utils import parse_dtype, print_parameters, timeit_context
@@ -0,0 +1,4 @@
1
+ from fusion_bench.scripts.cli import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
File without changes
@@ -0,0 +1,109 @@
1
+ import warnings
2
+
3
+ from omegaconf import DictConfig
4
+
5
+ from .base_algorithm import ModelFusionAlgorithm
6
+
7
+
8
+ class AlgorithmFactory:
9
+ """
10
+ Factory class to create and manage different model fusion algorithms.
11
+
12
+ This class provides methods to create algorithms based on a given configuration,
13
+ register new algorithms, and list available algorithms.
14
+ """
15
+
16
+ _aglorithms = {
17
+ # single task learning (fine-tuning)
18
+ "clip_finetune": ".classification.clip_finetune.ImageClassificationFineTuningForCLIP",
19
+ # analysis
20
+ # model merging methods
21
+ "clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
22
+ "clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
23
+ "singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
24
+ "clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
25
+ # plug-and-play model merging methods
26
+ "clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
27
+ "clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
28
+ "clip_concrete_layer_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteLayerWiseAdaMergingForCLIP",
29
+ # model mixing methods
30
+ "clip_weight_ensembling_moe": ".we_moe.clip_we_moe.CLIPWeightEnsemblingMoEAlgorithm",
31
+ "sparse_clip_weight_ensembling_moe": "fusion_bench.method.SparseCLIPWeightEnsemblingMoEAlgorithm",
32
+ "smile_mistral_upscaling": ".smile_upscaling.smile_mistral_upscaling.SmileMistralUpscalingAlgorithm",
33
+ "rankone_moe": ".rankone_moe.clip_rankone_moe.CLIPRankOneMoEAlgorithm",
34
+ }
35
+
36
+ @staticmethod
37
+ def create_algorithm(method_config: DictConfig) -> ModelFusionAlgorithm:
38
+ """
39
+ Create an instance of a model fusion algorithm based on the provided configuration.
40
+
41
+ Args:
42
+ method_config (DictConfig): The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm.
43
+
44
+ Returns:
45
+ ModelFusionAlgorithm: An instance of the specified algorithm.
46
+
47
+ Raises:
48
+ ValueError: If 'name' attribute is not found in the configuration or does not match any known algorithm names.
49
+ """
50
+ warnings.warn(
51
+ "AlgorithmFactory.create_algorithm() is deprecated and will be removed in future versions. "
52
+ "Please implement new model fusion algorithm using `fusion_bench.method.BaseModelFusionAlgorithm` instead.",
53
+ DeprecationWarning,
54
+ )
55
+
56
+ from fusion_bench.utils import import_object
57
+
58
+ algorithm_name = method_config.name
59
+ if algorithm_name not in AlgorithmFactory._aglorithms:
60
+ raise ValueError(
61
+ f"Unknown algorithm: {algorithm_name}, available algorithms: {AlgorithmFactory._aglorithms.keys()}."
62
+ "You can register a new algorithm using `AlgorithmFactory.register_algorithm()` method."
63
+ )
64
+ algorithm_cls = AlgorithmFactory._aglorithms[algorithm_name]
65
+ if isinstance(algorithm_cls, str):
66
+ if algorithm_cls.startswith("."):
67
+ algorithm_cls = f"fusion_bench.method.{algorithm_cls[1:]}"
68
+ algorithm_cls = import_object(algorithm_cls)
69
+ return algorithm_cls(method_config)
70
+
71
+ @staticmethod
72
+ def register_algorithm(name: str, algorithm_cls):
73
+ """
74
+ Register a new algorithm with the factory.
75
+
76
+ Args:
77
+ name (str): The name of the algorithm.
78
+ algorithm_cls: The class of the algorithm to register.
79
+ """
80
+ AlgorithmFactory._aglorithms[name] = algorithm_cls
81
+
82
+ @classmethod
83
+ def available_algorithms(cls):
84
+ """
85
+ Get a list of available algorithms.
86
+
87
+ Returns:
88
+ list: A list of available algorithm names.
89
+ """
90
+ return list(cls._aglorithms.keys())
91
+
92
+
93
+ def load_algorithm_from_config(method_config: DictConfig):
94
+ """
95
+ Loads an algorithm based on the provided configuration.
96
+
97
+ The function checks the 'name' attribute of the configuration and returns an instance of the corresponding algorithm.
98
+ If the 'name' attribute is not found or does not match any known algorithm names, a ValueError is raised.
99
+
100
+ Args:
101
+ method_config (DictConfig): The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm.
102
+
103
+ Returns:
104
+ An instance of the specified algorithm.
105
+
106
+ Raises:
107
+ ValueError: If 'name' attribute is not found in the configuration or does not match any known algorithm names.
108
+ """
109
+ return AlgorithmFactory.create_algorithm(method_config)
@@ -0,0 +1,58 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ if TYPE_CHECKING:
7
+ from fusion_bench.programs.base_program import BaseHydraProgram
8
+
9
+ __all__ = ["ModelFusionAlgorithm"]
10
+
11
+
12
+ class ModelFusionAlgorithm(ABC):
13
+ """
14
+ Abstract base class for model fusion algorithms (for v0.1.x versions, deprecated).
15
+ For implementing new method, use `fusion_bench.method.BaseModelFusionAlgorithm` instead.
16
+
17
+ This class provides a template for implementing model fusion algorithms.
18
+ Subclasses must implement the `run` method to define the fusion logic.
19
+
20
+ Attributes:
21
+ config (DictConfig): Configuration for the algorithm.
22
+ """
23
+
24
+ _program: "BaseHydraProgram" = None
25
+ """A reference to the program that is running the algorithm."""
26
+
27
+ def __init__(self, algorithm_config: Optional[DictConfig] = None):
28
+ """
29
+ Initialize the model fusion algorithm with the given configuration.
30
+
31
+ Args:
32
+ algorithm_config (Optional[DictConfig]): Configuration for the algorithm. Defaults to an empty configuration if not provided.
33
+ Get access to the configuration using `self.config`.
34
+ """
35
+ if algorithm_config is None:
36
+ algorithm_config = DictConfig({})
37
+ self.config = algorithm_config
38
+
39
+ @abstractmethod
40
+ def run(self, modelpool):
41
+ """
42
+ Fuse the models in the given model pool.
43
+
44
+ This method must be implemented by subclasses to define the fusion logic.
45
+ `modelpool` is an object responsible for managing the models to be fused and optional datasets to be used for fusion.
46
+
47
+ Args:
48
+ modelpool: The pool of models to fuse.
49
+
50
+ Returns:
51
+ The fused model.
52
+
53
+ Examples:
54
+ >>> algorithm = SimpleAverageAlgorithm()
55
+ >>> modelpool = ModelPool()
56
+ >>> merged_model = algorithm.fuse(modelpool)
57
+ """
58
+ pass
@@ -0,0 +1,34 @@
1
+ import logging
2
+
3
+ from omegaconf import DictConfig
4
+ from transformers import AutoModelForSeq2SeqLM
5
+
6
+ from fusion_bench.utils import timeit_context
7
+
8
+ from .base_pool import ModelPool
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class AutoModelForSeq2SeqLMPool(ModelPool):
14
+ def load_model(self, model_config: str | DictConfig):
15
+ """
16
+ Load a Seq2Seq language model based on the provided configuration.
17
+
18
+ The configuration options are:
19
+
20
+ - name: The name of the model.
21
+ - path: The path where the model is stored.
22
+
23
+ Args:
24
+ model_config (str | DictConfig): The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.
25
+
26
+ Returns:
27
+ model: The loaded Seq2Seq language model.
28
+ """
29
+
30
+ if isinstance(model_config, str):
31
+ model_config = self.get_model_config(model_config)
32
+ with timeit_context(f"Loading model {model_config['name']}"):
33
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_config["path"])
34
+ return model
@@ -0,0 +1,116 @@
1
+ # flake8: noqa F401
2
+ import warnings
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ from fusion_bench.modelpool.huggingface_gpt2_classification import (
7
+ GPT2ForSequenceClassificationPool,
8
+ )
9
+ from fusion_bench.modelpool.PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
10
+
11
+ from .AutoModelForSeq2SeqLM import AutoModelForSeq2SeqLMPool
12
+ from .base_pool import DictModelPool, ListModelPool, ModelPool, to_modelpool
13
+ from .huggingface_clip_vision import HuggingFaceClipVisionPool
14
+
15
+
16
+ class ModelPoolFactory:
17
+ """
18
+ Factory class to create and manage different model pools.
19
+
20
+ This class provides methods to create model pools based on a given configuration,
21
+ register new model pools, and list available model pools.
22
+ """
23
+
24
+ _modelpool = {
25
+ "NYUv2ModelPool": "fusion_bench.modelpool.nyuv2_modelpool.NYUv2ModelPool",
26
+ "huggingface_clip_vision": HuggingFaceClipVisionPool,
27
+ "HF_GPT2ForSequenceClassification": GPT2ForSequenceClassificationPool,
28
+ "AutoModelPool": ".huggingface_automodel.AutoModelPool",
29
+ # CausualLM
30
+ "AutoModelForCausalLMPool": ".huggingface_llm.AutoModelForCausalLMPool",
31
+ "LLamaForCausalLMPool": ".huggingface_llm.LLamaForCausalLMPool",
32
+ "MistralForCausalLMPool": ".huggingface_llm.MistralForCausalLMPool",
33
+ # Seq2SeqLM
34
+ "AutoModelForSeq2SeqLMPool": AutoModelForSeq2SeqLMPool,
35
+ "PeftModelForSeq2SeqLMPool": PeftModelForSeq2SeqLMPool,
36
+ }
37
+
38
+ @staticmethod
39
+ def create_modelpool(modelpool_config: DictConfig) -> ModelPool:
40
+ """
41
+ Create an instance of a model pool based on the provided configuration.
42
+ This is for v0.1.x versions, deprecated.
43
+ For implementing new model pool, use `fusion_bench.modelpool.BaseModelPool` instead.
44
+
45
+ Args:
46
+ modelpool_config (DictConfig): The configuration for the model pool.
47
+ Must contain a 'type' attribute that specifies the type of the model pool.
48
+
49
+ Returns:
50
+ ModelPool: An instance of the specified model pool.
51
+
52
+ Raises:
53
+ ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types.
54
+ """
55
+ warnings.warn(
56
+ "ModelPoolFactory.create_modelpool() is deprecated and will be removed in future versions. "
57
+ "Please implement new model pool using `fusion_bench.modelpool.BaseModelPool` instead.",
58
+ DeprecationWarning,
59
+ )
60
+
61
+ from fusion_bench.utils import import_object
62
+
63
+ modelpool_type = modelpool_config.get("type")
64
+ if modelpool_type is None:
65
+ raise ValueError("Model pool type not specified")
66
+
67
+ if modelpool_type not in ModelPoolFactory._modelpool:
68
+ raise ValueError(
69
+ f"Unknown model pool: {modelpool_type}, available model pools: {ModelPoolFactory._modelpool.keys()}. You can register a new model pool using `ModelPoolFactory.register_modelpool()` method."
70
+ )
71
+ modelpool_cls = ModelPoolFactory._modelpool[modelpool_type]
72
+ if isinstance(modelpool_cls, str):
73
+ if modelpool_cls.startswith("."):
74
+ modelpool_cls = f"fusion_bench.compat.modelpool.{modelpool_cls[1:]}"
75
+ modelpool_cls = import_object(modelpool_cls)
76
+ return modelpool_cls(modelpool_config)
77
+
78
+ @staticmethod
79
+ def register_modelpool(name: str, modelpool_cls):
80
+ """
81
+ Register a new model pool with the factory.
82
+
83
+ Args:
84
+ name (str): The name of the model pool.
85
+ modelpool_cls: The class of the model pool to register.
86
+ """
87
+ ModelPoolFactory._modelpool[name] = modelpool_cls
88
+
89
+ @classmethod
90
+ def available_modelpools(cls):
91
+ """
92
+ Get a list of available model pools.
93
+
94
+ Returns:
95
+ list: A list of available model pool names.
96
+ """
97
+ return list(cls._modelpool.keys())
98
+
99
+
100
+ def load_modelpool_from_config(modelpool_config: DictConfig):
101
+ """
102
+ Loads a model pool based on the provided configuration.
103
+
104
+ The function checks the 'type' attribute of the configuration and returns an instance of the corresponding model pool.
105
+ If the 'type' attribute is not found or does not match any known model pool types, a ValueError is raised.
106
+
107
+ Args:
108
+ modelpool_config (DictConfig): The configuration for the model pool. Must contain a 'type' attribute that specifies the type of the model pool.
109
+
110
+ Returns:
111
+ An instance of the specified model pool.
112
+
113
+ Raises:
114
+ ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types.
115
+ """
116
+ return ModelPoolFactory.create_modelpool(modelpool_config)
@@ -0,0 +1,328 @@
1
+ import logging
2
+ from abc import ABC
3
+ from copy import deepcopy
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from torch import nn
9
+
10
+ from fusion_bench.modelpool import BaseModelPool
11
+ from fusion_bench.utils import timeit_context
12
+
13
+ __all__ = ["ModelPool", "DictModelPool", "ListModelPool", "to_modelpool"]
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class ModelPool(ABC):
19
+ """
20
+ This is the base class for all modelpools.
21
+ For verison v0.1.x, deprecated.
22
+ Please implemente new algorithms use `fusion_bench.modelpool.BaseModelPool`.
23
+ """
24
+
25
+ _model_names = None
26
+
27
+ def __init__(self, modelpool_config: Optional[DictConfig] = None):
28
+ """
29
+ Initialize the ModelPool with the given configuration.
30
+
31
+ Args:
32
+ modelpool_config (Optional[DictConfig]): The configuration for the model pool.
33
+ """
34
+ super().__init__()
35
+ self.config = modelpool_config
36
+
37
+ # check for duplicate model names
38
+ if self.config is not None and self.config.get("models", None) is not None:
39
+ model_names = [model["name"] for model in self.config["models"]]
40
+ assert len(model_names) == len(
41
+ set(model_names)
42
+ ), "Duplicate model names found in model pool"
43
+ self._model_names = model_names
44
+
45
+ def __len__(self):
46
+ """
47
+ Return the number of models in the model pool, exclude special models such as `_pretrained_`.
48
+
49
+ Returns:
50
+ int: The number of models in the model pool.
51
+ """
52
+ return len(self.model_names)
53
+
54
+ @property
55
+ def model_names(self) -> List[str]:
56
+ """
57
+ This property returns a list of model names from the configuration, excluding any names that start or end with an underscore.
58
+ To obtain all model names, including those starting or ending with an underscore, use the `_model_names` attribute.
59
+
60
+ Returns:
61
+ list: A list of model names.
62
+ """
63
+ names = [
64
+ name for name in self._model_names if name[0] != "_" and name[-1] != "_"
65
+ ]
66
+ return names
67
+
68
+ @property
69
+ def has_pretrained(self):
70
+ """
71
+ Check if the pretrained model is available in the model pool.
72
+
73
+ Returns:
74
+ bool: True if the pretrained model is available, False otherwise.
75
+ """
76
+ for model_config in self.config["models"]:
77
+ if model_config.get("name", None) == "_pretrained_":
78
+ return True
79
+ return False
80
+
81
+ def get_model_config(self, model_name: str):
82
+ """
83
+ Retrieves the configuration for a specific model from the model pool.
84
+
85
+ Args:
86
+ model_name (str): The name of the model for which to retrieve the configuration.
87
+
88
+ Returns:
89
+ dict: The configuration dictionary for the specified model.
90
+
91
+ Raises:
92
+ ValueError: If the specified model is not found in the model pool.
93
+ """
94
+ for model in self.config["models"]:
95
+ if model["name"] == model_name:
96
+ return model
97
+ raise ValueError(f"Model {model_name} not found in model pool")
98
+
99
+ def load_model(self, model_config: Union[str, DictConfig]) -> nn.Module:
100
+ """
101
+ The models are load lazily, so this method should be implemented to load the model from the model pool.
102
+
103
+ Load the model from the model pool.
104
+
105
+ Args:
106
+ model_config (Union[str, DictConfig]): The configuration dictionary for the model to load.
107
+
108
+ Returns:
109
+ Any: The loaded model.
110
+ """
111
+ raise NotImplementedError
112
+
113
+ def load_pretrained_or_first_model(self, *args, **kwargs):
114
+ """
115
+ Load the pretrained model if available, otherwise load the first model in the list.
116
+
117
+ This method checks if a pretrained model is available. If it is, it loads the pretrained model.
118
+ If not, it loads the first model from the list of model names.
119
+
120
+ Returns:
121
+ nn.Module: The loaded model.
122
+ """
123
+ if self.has_pretrained:
124
+ model = self.load_model("_pretrained_", *args, **kwargs)
125
+ else:
126
+ model = self.load_model(self.model_names[0], *args, **kwargs)
127
+ return model
128
+
129
+ def save_model(self, model: nn.Module, path: str):
130
+ """
131
+ Save the state dictionary of the model to the specified path.
132
+
133
+ Args:
134
+ model (nn.Module): The model whose state dictionary is to be saved.
135
+ path (str): The path where the state dictionary will be saved.
136
+ """
137
+ with timeit_context(f"Saving the state dict of model to {path}"):
138
+ torch.save(model.state_dict(), path)
139
+
140
+ def models(self):
141
+ """
142
+ Generator that yields models from the model pool.
143
+
144
+ Yields:
145
+ nn.Module: The next model in the model pool.
146
+ """
147
+ for model_name in self.model_names:
148
+ yield self.load_model(model_name)
149
+
150
+ def named_models(self):
151
+ """
152
+ Generator that yields model names and models from the model pool.
153
+
154
+ Yields:
155
+ tuple: A tuple containing the model name and the model.
156
+ """
157
+ for model_name in self.model_names:
158
+ yield model_name, self.load_model(model_name)
159
+
160
+ def get_train_dataset(self, model_name: str):
161
+ """
162
+ Get the training dataset for the model.
163
+
164
+ Args:
165
+ model_name (str): The name of the model for which to get the training dataset.
166
+
167
+ Returns:
168
+ Any: The training dataset for the model.
169
+ """
170
+ raise NotImplementedError
171
+
172
+ def get_test_dataset(self, model_name: str):
173
+ """
174
+ Get the testing dataset for the model.
175
+
176
+ Args:
177
+ model_name (str): The name of the model for which to get the testing dataset.
178
+
179
+ Returns:
180
+ Any: The testing dataset for the model.
181
+ """
182
+ raise NotImplementedError
183
+
184
+ def setup_taskpool(self, taskpool):
185
+ """
186
+ Setup the taskpool before evaluation.
187
+ Such as setting the fabric, processor, tokenizer, etc.
188
+
189
+ Args:
190
+ taskpool (Any): The taskpool to setup.
191
+ """
192
+ pass
193
+
194
+ def to_modellist(self) -> List[nn.Module]:
195
+ """
196
+ Convert the model pool to a list of models.
197
+
198
+ Returns:
199
+ list: A list of models.
200
+ """
201
+ return [self.load_model(m) for m in self.model_names]
202
+
203
+ def to_modeldict(self) -> Dict[str, nn.Module]:
204
+ """
205
+ Convert the model pool to a dictionary of models.
206
+
207
+ Returns:
208
+ dict: A dictionary of models.
209
+ """
210
+ return {m: self.load_model(m) for m in self.model_names}
211
+
212
+
213
+ class ListModelPool(ModelPool):
214
+ """
215
+ ModelPool from a list of models.
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ models: List[nn.Module],
221
+ has_pretraned: bool = False,
222
+ ):
223
+ """
224
+ Initialize the ListModelPool with the given list of models.
225
+
226
+ Args:
227
+ models (List[nn.Module]): The list of models.
228
+ has_pretraned (bool): Whether the first model in the list is pretrained.
229
+ """
230
+ modelpool_config = {}
231
+ modelpool_config["models"] = []
232
+ model_dict = {}
233
+ if has_pretraned:
234
+ model_dict["_pretrained_"] = models[0]
235
+ modelpool_config["models"].append({"name": "_pretrained_"})
236
+ for i, model in enumerate(models[1:]):
237
+ model_dict[f"model_{i}"] = model
238
+ modelpool_config["models"].append({"name": f"model_{i}"})
239
+ else:
240
+ for i, model in enumerate(models):
241
+ model_dict[f"model_{i}"] = model
242
+ modelpool_config["models"].append({"name": f"model_{i}"})
243
+
244
+ self.model_dict = model_dict
245
+ super().__init__(DictConfig(modelpool_config))
246
+
247
+ def load_model(self, model_config: str | DictConfig, copy=True) -> nn.Module:
248
+ """
249
+ Load the model from the model pool.
250
+
251
+ Args:
252
+ model_config (str | DictConfig): The model name or the configuration dictionary for the model to load.
253
+ copy (bool): Whether to return a copy of the model, defaults to `True`.
254
+
255
+ Returns:
256
+ nn.Module: The loaded model.
257
+ """
258
+ if isinstance(model_config, str):
259
+ model_config = self.get_model_config(model_config)
260
+ model_name = model_config["name"]
261
+ model = self.model_dict[model_name]
262
+ if copy:
263
+ model = deepcopy(model)
264
+ return model
265
+
266
+
267
+ class DictModelPool(ModelPool):
268
+ """
269
+ ModelPool from a dictionary of models.
270
+ """
271
+
272
+ def __init__(self, model_dict: Dict[str, nn.Module]):
273
+ """
274
+ Initialize the DictModelPool with the given dictionary of models.
275
+
276
+ Args:
277
+ model_dict (Dict[str, nn.Module]): The dictionary of models.
278
+ """
279
+ modelpool_config = {}
280
+ modelpool_config["models"] = []
281
+ for model_name, model in model_dict.items():
282
+ modelpool_config["models"].append({"name": model_name})
283
+ self.model_dict = model_dict
284
+ super().__init__(DictConfig(modelpool_config))
285
+
286
+ def load_model(self, model_config: str | DictConfig, copy=True) -> nn.Module:
287
+ """
288
+ Load the model from the model pool.
289
+
290
+ Args:
291
+ model_config (str | DictConfig): The configuration dictionary for the model to load.
292
+ copy (bool): Whether to return a copy of the model.
293
+
294
+ Returns:
295
+ nn.Module: The loaded model.
296
+ """
297
+ if isinstance(model_config, str):
298
+ model_config = self.get_model_config(model_config)
299
+ model_name = model_config["name"]
300
+ model = self.model_dict[model_name]
301
+ if copy:
302
+ model = deepcopy(model)
303
+ return model
304
+
305
+
306
+ def to_modelpool(obj: List[nn.Module], **kwargs):
307
+ """
308
+ Convert the given object to a model pool.
309
+
310
+ Args:
311
+ obj (List[nn.Module]): The object to convert to a model pool.
312
+
313
+ Returns:
314
+ ModelPool: The converted model pool.
315
+
316
+ Raises:
317
+ ValueError: If the object cannot be converted to a model pool.
318
+ """
319
+ if isinstance(obj, (ModelPool, BaseModelPool)):
320
+ return obj
321
+ elif isinstance(obj, (list, tuple)) and all(isinstance(m, nn.Module) for m in obj):
322
+ return ListModelPool(models=obj, **kwargs)
323
+ elif isinstance(obj, Dict) and all(isinstance(m, nn.Module) for m in obj.values()):
324
+ return DictModelPool(model_dict=obj, **kwargs)
325
+ elif isinstance(obj, nn.Module):
326
+ return ListModelPool(models=[obj], **kwargs)
327
+ else:
328
+ raise ValueError(f"Invalid modelpool object: {obj}")