fusion-bench 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (727) hide show
  1. fusion_bench/__init__.py +20 -0
  2. fusion_bench/__main__.py +4 -0
  3. fusion_bench/compat/__init__.py +0 -0
  4. fusion_bench/compat/method/__init__.py +109 -0
  5. fusion_bench/compat/method/base_algorithm.py +58 -0
  6. fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
  7. fusion_bench/compat/modelpool/__init__.py +116 -0
  8. fusion_bench/compat/modelpool/base_pool.py +328 -0
  9. fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
  10. fusion_bench/compat/taskpool/__init__.py +95 -0
  11. fusion_bench/compat/taskpool/base_pool.py +111 -0
  12. fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
  13. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
  14. fusion_bench/constants/__init__.py +2 -0
  15. fusion_bench/constants/paths.py +18 -0
  16. fusion_bench/dataset/__init__.py +29 -0
  17. fusion_bench/dataset/arc_agi/__init__.py +6 -0
  18. fusion_bench/dataset/arc_agi/arc.py +308 -0
  19. fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
  20. fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
  21. fusion_bench/dataset/arc_agi/messagers.py +1355 -0
  22. fusion_bench/dataset/arc_agi/np_cache.py +168 -0
  23. fusion_bench/dataset/arc_agi/preprocess.py +298 -0
  24. fusion_bench/dataset/arc_agi/representers.py +1019 -0
  25. fusion_bench/dataset/clip_dataset.py +71 -0
  26. fusion_bench/dataset/fer2013.py +12 -0
  27. fusion_bench/dataset/gpt2_glue.py +300 -0
  28. fusion_bench/dataset/gsm8k.py +60 -0
  29. fusion_bench/dataset/image_dataset.py +55 -0
  30. fusion_bench/dataset/imdb.py +11 -0
  31. fusion_bench/dataset/llama/__init__.py +1 -0
  32. fusion_bench/dataset/llama/alpaca.py +232 -0
  33. fusion_bench/dataset/llama/collate.py +120 -0
  34. fusion_bench/dataset/llama/metamathqa.py +50 -0
  35. fusion_bench/dataset/llama/openai.py +160 -0
  36. fusion_bench/dataset/llama/preference_700k.py +70 -0
  37. fusion_bench/dataset/llama/sharegpt.py +141 -0
  38. fusion_bench/dataset/llama/squad.py +125 -0
  39. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  40. fusion_bench/dataset/llama/ultrachat.py +58 -0
  41. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  42. fusion_bench/dataset/llama/wikitext.py +89 -0
  43. fusion_bench/dataset/nyuv2.py +119 -0
  44. fusion_bench/method/__init__.py +177 -0
  45. fusion_bench/method/ada_svd/__init__.py +2 -0
  46. fusion_bench/method/ada_svd/clip_vision.py +319 -0
  47. fusion_bench/method/adamerging/__init__.py +6 -0
  48. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
  49. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
  50. fusion_bench/method/adamerging/entropy_loss.py +25 -0
  51. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
  52. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
  53. fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
  54. fusion_bench/method/adamerging/llama_adamerging.py +335 -0
  55. fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
  56. fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
  57. fusion_bench/method/adamerging/utils.py +15 -0
  58. fusion_bench/method/analysis/__init__.py +2 -0
  59. fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
  60. fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
  61. fusion_bench/method/base_algorithm.py +44 -0
  62. fusion_bench/method/classification/__init__.py +3 -0
  63. fusion_bench/method/classification/clip_finetune.py +444 -0
  64. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  65. fusion_bench/method/concrete_subspace/__init__.py +6 -0
  66. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
  67. fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
  68. fusion_bench/method/dare/__init__.py +4 -0
  69. fusion_bench/method/dare/simple_average.py +31 -0
  70. fusion_bench/method/dare/task_arithmetic.py +82 -0
  71. fusion_bench/method/dare/ties_merging.py +100 -0
  72. fusion_bench/method/dare/utils.py +87 -0
  73. fusion_bench/method/dawe/__init__.py +2 -0
  74. fusion_bench/method/dawe/dawe_for_clip.py +274 -0
  75. fusion_bench/method/dawe/warppers/__init__.py +13 -0
  76. fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
  77. fusion_bench/method/depth_upscaling/__init__.py +3 -0
  78. fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
  79. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
  80. fusion_bench/method/dummy.py +35 -0
  81. fusion_bench/method/ensemble.py +98 -0
  82. fusion_bench/method/fisher_merging/__init__.py +4 -0
  83. fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
  84. fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
  85. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
  86. fusion_bench/method/linear/__init__.py +6 -0
  87. fusion_bench/method/linear/expo.py +118 -0
  88. fusion_bench/method/linear/linear_interpolation.py +60 -0
  89. fusion_bench/method/linear/llama_expo.py +229 -0
  90. fusion_bench/method/linear/simple_average_for_llama.py +54 -0
  91. fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
  92. fusion_bench/method/lm_finetune/__init__.py +3 -0
  93. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  94. fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
  95. fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
  96. fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
  97. fusion_bench/method/mixture_of_experts/__init__.py +7 -0
  98. fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
  99. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
  100. fusion_bench/method/model_recombination.py +121 -0
  101. fusion_bench/method/opcm/__init__.py +4 -0
  102. fusion_bench/method/opcm/opcm.py +277 -0
  103. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  104. fusion_bench/method/opcm/ties_merging.py +156 -0
  105. fusion_bench/method/opcm/utils.py +73 -0
  106. fusion_bench/method/opcm/weight_average.py +120 -0
  107. fusion_bench/method/pruning/__init__.py +5 -0
  108. fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
  109. fusion_bench/method/pruning/llama_random_prune.py +143 -0
  110. fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
  111. fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
  112. fusion_bench/method/pruning/prune_utils.py +165 -0
  113. fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
  114. fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
  115. fusion_bench/method/pruning/wanda_utils/data.py +135 -0
  116. fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
  117. fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
  118. fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
  119. fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
  120. fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
  121. fusion_bench/method/pwe_moe/__init__.py +5 -0
  122. fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
  123. fusion_bench/method/pwe_moe/module.py +316 -0
  124. fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
  125. fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
  126. fusion_bench/method/pwe_moe/utils.py +43 -0
  127. fusion_bench/method/rankone_moe/__init__.py +3 -0
  128. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  129. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  130. fusion_bench/method/regmean/__init__.py +4 -0
  131. fusion_bench/method/regmean/clip_regmean.py +131 -0
  132. fusion_bench/method/regmean/gpt2_regmean.py +147 -0
  133. fusion_bench/method/regmean/regmean.py +375 -0
  134. fusion_bench/method/simple_average.py +112 -0
  135. fusion_bench/method/slerp/__init__.py +2 -0
  136. fusion_bench/method/slerp/slerp.py +101 -0
  137. fusion_bench/method/slerp/slerp_utils.py +107 -0
  138. fusion_bench/method/smile_upscaling/__init__.py +3 -0
  139. fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
  140. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
  141. fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
  142. fusion_bench/method/sparse_we_moe/__init__.py +2 -0
  143. fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
  144. fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
  145. fusion_bench/method/sparselo/__init__.py +2 -0
  146. fusion_bench/method/sparselo/sparselo.py +955 -0
  147. fusion_bench/method/surgery/__init__.py +1 -0
  148. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  149. fusion_bench/method/tall_mask/__init__.py +0 -0
  150. fusion_bench/method/tall_mask/utils.py +234 -0
  151. fusion_bench/method/task_arithmetic/__init__.py +2 -0
  152. fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
  153. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  154. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  155. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  156. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  157. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
  158. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  159. fusion_bench/method/ties_merging/__init__.py +2 -0
  160. fusion_bench/method/ties_merging/ties_merging.py +117 -0
  161. fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
  162. fusion_bench/method/trust_region/__init__.py +2 -0
  163. fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
  164. fusion_bench/method/trust_region/utils.py +58 -0
  165. fusion_bench/method/we_moe/__init__.py +2 -0
  166. fusion_bench/method/we_moe/clip_we_moe.py +161 -0
  167. fusion_bench/method/we_moe/we_moe.py +247 -0
  168. fusion_bench/method/weighted_average/__init__.py +3 -0
  169. fusion_bench/method/weighted_average/llama.py +113 -0
  170. fusion_bench/method/weighted_average/weighted_average.py +102 -0
  171. fusion_bench/metrics/__init__.py +0 -0
  172. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  173. fusion_bench/metrics/nyuv2/__init__.py +11 -0
  174. fusion_bench/metrics/nyuv2/depth.py +45 -0
  175. fusion_bench/metrics/nyuv2/loss.py +31 -0
  176. fusion_bench/metrics/nyuv2/noise.py +16 -0
  177. fusion_bench/metrics/nyuv2/normal.py +48 -0
  178. fusion_bench/metrics/nyuv2/segmentation.py +43 -0
  179. fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
  180. fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
  181. fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
  182. fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
  183. fusion_bench/mixins/__init__.py +28 -0
  184. fusion_bench/mixins/clip_classification.py +252 -0
  185. fusion_bench/mixins/fabric_training.py +320 -0
  186. fusion_bench/mixins/lightning_fabric.py +174 -0
  187. fusion_bench/mixins/optim/__init__.py +0 -0
  188. fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
  189. fusion_bench/mixins/rich_live.py +21 -0
  190. fusion_bench/mixins/serialization.py +132 -0
  191. fusion_bench/mixins/simple_profiler.py +79 -0
  192. fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
  193. fusion_bench/modelpool/__init__.py +42 -0
  194. fusion_bench/modelpool/base_pool.py +268 -0
  195. fusion_bench/modelpool/causal_lm/__init__.py +2 -0
  196. fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
  197. fusion_bench/modelpool/clip_vision/__init__.py +1 -0
  198. fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
  199. fusion_bench/modelpool/huggingface_automodel.py +20 -0
  200. fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
  201. fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
  202. fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
  203. fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
  204. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  205. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  206. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  207. fusion_bench/models/__init__.py +3 -0
  208. fusion_bench/models/chat_templates/__init__.py +1 -0
  209. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  210. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  211. fusion_bench/models/hf_clip.py +199 -0
  212. fusion_bench/models/linearized/__init__.py +0 -0
  213. fusion_bench/models/linearized/linearized_model_utils.py +91 -0
  214. fusion_bench/models/linearized/vision_model.py +122 -0
  215. fusion_bench/models/llama/__init__.py +16 -0
  216. fusion_bench/models/llama/model_utils/__init__.py +0 -0
  217. fusion_bench/models/llama/model_utils/embedding.py +87 -0
  218. fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
  219. fusion_bench/models/llama/model_utils/misc.py +112 -0
  220. fusion_bench/models/llama/model_utils/mod.py +52 -0
  221. fusion_bench/models/llama/model_utils/visual.py +241 -0
  222. fusion_bench/models/llama/patcher.py +78 -0
  223. fusion_bench/models/llama/tokenizer_loader.py +153 -0
  224. fusion_bench/models/masks/__init__.py +2 -0
  225. fusion_bench/models/masks/mask_model.py +160 -0
  226. fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
  227. fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
  228. fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
  229. fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
  230. fusion_bench/models/modeling_losparse_llama/register.py +8 -0
  231. fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
  232. fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
  233. fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
  234. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
  235. fusion_bench/models/modeling_smile_mistral/register.py +8 -0
  236. fusion_bench/models/nyuv2/__init__.py +0 -0
  237. fusion_bench/models/nyuv2/aspp.py +82 -0
  238. fusion_bench/models/nyuv2/lightning_module.py +176 -0
  239. fusion_bench/models/nyuv2/resnet.py +405 -0
  240. fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
  241. fusion_bench/models/parameter_dict.py +75 -0
  242. fusion_bench/models/rankone_moe.py +410 -0
  243. fusion_bench/models/separate_io.py +105 -0
  244. fusion_bench/models/smile_moe/__init__.py +0 -0
  245. fusion_bench/models/smile_moe/linear.py +256 -0
  246. fusion_bench/models/sparse_we_moe.py +459 -0
  247. fusion_bench/models/surgery/__init__.py +1 -0
  248. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  249. fusion_bench/models/utils.py +80 -0
  250. fusion_bench/models/we_moe.py +247 -0
  251. fusion_bench/models/wrappers/__init__.py +0 -0
  252. fusion_bench/models/wrappers/ensemble.py +183 -0
  253. fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
  254. fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
  255. fusion_bench/optim/__init__.py +2 -0
  256. fusion_bench/optim/exception.py +47 -0
  257. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  258. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  259. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  260. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  261. fusion_bench/optim/mezo.py +118 -0
  262. fusion_bench/programs/__init__.py +20 -0
  263. fusion_bench/programs/base_program.py +9 -0
  264. fusion_bench/programs/fabric_fusion_program.py +299 -0
  265. fusion_bench/scripts/__init__.py +0 -0
  266. fusion_bench/scripts/cli.py +43 -0
  267. fusion_bench/scripts/clip/__init__.py +0 -0
  268. fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
  269. fusion_bench/scripts/imgui.py +218 -0
  270. fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
  271. fusion_bench/scripts/webui.py +405 -0
  272. fusion_bench/taskpool/__init__.py +39 -0
  273. fusion_bench/taskpool/base_pool.py +35 -0
  274. fusion_bench/taskpool/clip_vision/__init__.py +4 -0
  275. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  276. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
  277. fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
  278. fusion_bench/taskpool/dummy.py +58 -0
  279. fusion_bench/taskpool/gpt2_text_classification.py +149 -0
  280. fusion_bench/taskpool/llama/__init__.py +1 -0
  281. fusion_bench/taskpool/llama/reward_model.py +157 -0
  282. fusion_bench/taskpool/llama/test_generation.py +185 -0
  283. fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
  284. fusion_bench/tasks/__init__.py +2 -0
  285. fusion_bench/tasks/base_task.py +18 -0
  286. fusion_bench/tasks/classification.py +75 -0
  287. fusion_bench/tasks/clip_classification/__init__.py +183 -0
  288. fusion_bench/tasks/clip_classification/cifar10.py +33 -0
  289. fusion_bench/tasks/clip_classification/cifar100.py +146 -0
  290. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
  291. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  292. fusion_bench/tasks/clip_classification/dtd.py +60 -0
  293. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  294. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  295. fusion_bench/tasks/clip_classification/eurosat.py +18 -0
  296. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  297. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  298. fusion_bench/tasks/clip_classification/flower102.py +106 -0
  299. fusion_bench/tasks/clip_classification/food101.py +105 -0
  300. fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
  301. fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
  302. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  303. fusion_bench/tasks/clip_classification/mnist.py +5 -0
  304. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  305. fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
  306. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  307. fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
  308. fusion_bench/tasks/clip_classification/resisc45.py +68 -0
  309. fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
  310. fusion_bench/tasks/clip_classification/stl10.py +17 -0
  311. fusion_bench/tasks/clip_classification/sun397.py +404 -0
  312. fusion_bench/tasks/clip_classification/svhn.py +5 -0
  313. fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
  314. fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
  315. fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
  316. fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
  317. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
  318. fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
  319. fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
  320. fusion_bench/utils/__init__.py +14 -0
  321. fusion_bench/utils/auto.py +31 -0
  322. fusion_bench/utils/cache_utils.py +58 -0
  323. fusion_bench/utils/data.py +165 -0
  324. fusion_bench/utils/devices.py +231 -0
  325. fusion_bench/utils/dict.py +43 -0
  326. fusion_bench/utils/dtype.py +146 -0
  327. fusion_bench/utils/expr.py +90 -0
  328. fusion_bench/utils/fabric.py +17 -0
  329. fusion_bench/utils/functools.py +37 -0
  330. fusion_bench/utils/hydra_utils.py +28 -0
  331. fusion_bench/utils/instantiate.py +450 -0
  332. fusion_bench/utils/json.py +93 -0
  333. fusion_bench/utils/lazy_imports.py +74 -0
  334. fusion_bench/utils/misc.py +18 -0
  335. fusion_bench/utils/packages.py +84 -0
  336. fusion_bench/utils/parameters.py +323 -0
  337. fusion_bench/utils/path.py +22 -0
  338. fusion_bench/utils/plot/__init__.py +0 -0
  339. fusion_bench/utils/plot/color_data.py +1726 -0
  340. fusion_bench/utils/plot/token.py +52 -0
  341. fusion_bench/utils/plot/token_notebook.py +127 -0
  342. fusion_bench/utils/pylogger.py +55 -0
  343. fusion_bench/utils/rich_utils.py +201 -0
  344. fusion_bench/utils/set.py +8 -0
  345. fusion_bench/utils/state_dict_arithmetic.py +297 -0
  346. fusion_bench/utils/strenum/__init__.py +326 -0
  347. fusion_bench/utils/strenum/_name_mangler.py +127 -0
  348. fusion_bench/utils/strenum/_version.py +556 -0
  349. fusion_bench/utils/tensorboard.py +51 -0
  350. fusion_bench/utils/timer.py +49 -0
  351. fusion_bench/utils/type.py +34 -0
  352. fusion_bench-0.2.9.dist-info/LICENSE +21 -0
  353. fusion_bench-0.2.9.dist-info/METADATA +258 -0
  354. fusion_bench-0.2.9.dist-info/RECORD +727 -0
  355. fusion_bench-0.2.9.dist-info/WHEEL +5 -0
  356. fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
  357. fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
  358. fusion_bench_config/README.md +12 -0
  359. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
  360. fusion_bench_config/dataset/image_classification/README.md +6 -0
  361. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  362. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  363. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
  364. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
  365. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  366. fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
  367. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  368. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  369. fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
  370. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  371. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  372. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  373. fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
  374. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  375. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  376. fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
  377. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  378. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  379. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  380. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  381. fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
  382. fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
  383. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  384. fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
  385. fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
  386. fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
  387. fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
  388. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  389. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  390. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
  391. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
  392. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  393. fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
  394. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  395. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  396. fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
  397. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  398. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  399. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  400. fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
  401. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  402. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  403. fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
  404. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  405. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  406. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  407. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  408. fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
  409. fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
  410. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  411. fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
  412. fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
  413. fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
  414. fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
  415. fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
  416. fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
  417. fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
  418. fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
  419. fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
  420. fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
  421. fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
  422. fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
  423. fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
  424. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  425. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  426. fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
  427. fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
  428. fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
  429. fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
  430. fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
  431. fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
  432. fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
  433. fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
  434. fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
  435. fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
  436. fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
  437. fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
  438. fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
  439. fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
  440. fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
  441. fusion_bench_config/fabric/auto.yaml +16 -0
  442. fusion_bench_config/fabric/llama_ddp.yaml +18 -0
  443. fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
  444. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  445. fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
  446. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
  447. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  448. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  449. fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
  450. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  451. fusion_bench_config/fabric_model_fusion.yaml +20 -0
  452. fusion_bench_config/hydra/default.yaml +8 -0
  453. fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
  454. fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
  455. fusion_bench_config/llama_full_finetune.yaml +19 -0
  456. fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
  457. fusion_bench_config/llama_model_fusion.yaml +17 -0
  458. fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
  459. fusion_bench_config/method/adamerging/clip.yaml +23 -0
  460. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
  461. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
  462. fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
  463. fusion_bench_config/method/adamerging.yaml +23 -0
  464. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
  465. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
  466. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  467. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  468. fusion_bench_config/method/clip_finetune.yaml +26 -0
  469. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
  470. fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
  471. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
  472. fusion_bench_config/method/dare/simple_average.yaml +5 -0
  473. fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
  474. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  475. fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
  476. fusion_bench_config/method/depth_upscaling.yaml +5 -0
  477. fusion_bench_config/method/dummy.yaml +1 -0
  478. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
  479. fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
  480. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
  481. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
  482. fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
  483. fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
  484. fusion_bench_config/method/linear/expo.yaml +8 -0
  485. fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
  486. fusion_bench_config/method/linear/llama_expo.yaml +19 -0
  487. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
  488. fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
  489. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
  490. fusion_bench_config/method/linear/weighted_average.yaml +6 -0
  491. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
  492. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  493. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
  494. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
  495. fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
  496. fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
  497. fusion_bench_config/method/model_recombination.yaml +4 -0
  498. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  499. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  500. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  501. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  502. fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
  503. fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
  504. fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
  505. fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
  506. fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
  507. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  508. fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
  509. fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
  510. fusion_bench_config/method/regmean/regmean.yaml +4 -0
  511. fusion_bench_config/method/simple_average.yaml +1 -0
  512. fusion_bench_config/method/slerp/slerp.yaml +6 -0
  513. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
  514. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
  515. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
  516. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
  517. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
  518. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
  519. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  520. fusion_bench_config/method/task_arithmetic.yaml +2 -0
  521. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  522. fusion_bench_config/method/ties_merging.yaml +8 -0
  523. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
  524. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
  525. fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
  526. fusion_bench_config/model/clip-vit/README.md +38 -0
  527. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
  528. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  529. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  530. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  531. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  532. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
  533. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
  534. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  535. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
  536. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  537. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  538. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  539. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
  540. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  541. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
  542. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  543. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  544. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  545. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  546. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
  547. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
  548. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  549. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
  550. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
  551. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
  552. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  553. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  554. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  555. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  556. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
  557. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
  558. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  559. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
  560. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  561. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  562. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  563. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
  564. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  565. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
  566. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  567. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  568. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  569. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  570. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
  571. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
  572. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  573. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
  574. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
  575. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
  576. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  577. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  578. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  579. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  580. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
  581. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
  582. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  583. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
  584. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  585. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  586. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  587. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
  588. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  589. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
  590. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  591. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  592. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  593. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  594. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
  595. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
  596. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  597. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
  598. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
  599. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  600. fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
  601. fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
  602. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
  603. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
  604. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
  605. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
  606. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
  607. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
  608. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
  609. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
  610. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
  611. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
  612. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
  613. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
  614. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
  615. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
  616. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
  617. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
  618. fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
  619. fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
  620. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
  621. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
  622. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
  623. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
  624. fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
  625. fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
  626. fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
  627. fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
  628. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
  629. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
  630. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
  631. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  632. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  633. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  634. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  635. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  636. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
  637. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
  638. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
  639. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
  640. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
  641. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  642. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  643. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  644. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  645. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
  646. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
  647. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
  648. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
  649. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
  650. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
  651. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
  652. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  653. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
  654. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  655. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
  656. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
  657. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  658. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  659. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  660. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  661. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
  662. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  663. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  664. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
  665. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  666. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  667. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
  668. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
  669. fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
  670. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
  671. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
  672. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
  673. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
  674. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
  675. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  676. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  677. fusion_bench_config/modelpool/automodelpool.yaml +12 -0
  678. fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
  679. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
  680. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
  681. fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
  682. fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
  683. fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
  684. fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
  685. fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
  686. fusion_bench_config/nyuv2_config.yaml +17 -0
  687. fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
  688. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
  689. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  690. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
  691. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
  692. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
  693. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
  694. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
  695. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  696. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  697. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  698. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  699. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  700. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  701. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  702. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  703. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  704. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  705. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  706. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  707. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  708. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  709. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  710. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  711. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  712. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  713. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  714. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  715. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  716. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  717. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  718. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  719. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
  720. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
  721. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  722. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
  723. fusion_bench_config/taskpool/dummy.yaml +2 -0
  724. fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
  725. fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
  726. fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
  727. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
@@ -0,0 +1,484 @@
1
+ """
2
+ This implementation is largely based on the implementation from https://github.com/yule-BUAA/MergeLM/
3
+ """
4
+
5
+ import logging
6
+ import re
7
+ from collections import defaultdict
8
+ from typing import Dict, List
9
+
10
+ import torch
11
+ from torch import Tensor, nn
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ from fusion_bench.method import BaseAlgorithm
15
+ from fusion_bench.modelpool import BaseModelPool
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ def get_param_names_to_merge(
21
+ input_param_names: List[str], exclude_param_names_regex: list
22
+ ) -> List[str]:
23
+ """
24
+ Get the names of parameters that need to be merged.
25
+
26
+ Args:
27
+ input_param_names (List[str]): List of input parameter names.
28
+ exclude_param_names_regex (list): List of regular expressions for parameter names to be excluded.
29
+
30
+ Returns:
31
+ List[str]: List of parameter names to be merged.
32
+ """
33
+ param_names_to_merge = []
34
+ for param_name in input_param_names:
35
+ exclude = any(
36
+ [
37
+ re.match(exclude_pattern, param_name)
38
+ for exclude_pattern in exclude_param_names_regex
39
+ ]
40
+ )
41
+ if not exclude:
42
+ param_names_to_merge.append(param_name)
43
+ return param_names_to_merge
44
+
45
+
46
+ def get_param_squared_gradients(
47
+ model: nn.Module, param_names_to_merge: List[str]
48
+ ) -> Dict[str, Tensor]:
49
+ """
50
+ Get the squared gradients of parameters.
51
+
52
+ Args:
53
+ model (nn.Module): The model.
54
+ param_names_to_merge (List[str]): List of parameter names to be merged.
55
+
56
+ Returns:
57
+ Dict[str, Tensor]: Dictionary of parameter names and their squared gradients.
58
+ """
59
+ param_squared_gradients = {
60
+ param_name: param_value.grad.detach() ** 2
61
+ for param_name, param_value in model.state_dict(keep_vars=True).items()
62
+ if param_name in param_names_to_merge
63
+ }
64
+ return param_squared_gradients
65
+
66
+
67
+ def get_models_fisher_norm(
68
+ models_to_merge_param_dict: dict, models_to_merge_fisher_weights_list: list
69
+ ) -> Tensor:
70
+ """
71
+ Get normalization of Fisher weights of all the models that need to be merged.
72
+
73
+ Args:
74
+ models_to_merge_param_dict (dict): Dictionary of list, where key is the parameter name,
75
+ value is a list of the corresponding parameters of all the models that need to be merged.
76
+ models_to_merge_fisher_weights_list (list): List of dictionaries with length len(models_to_merge),
77
+ each dictionary records the Fisher weights (matrix or vector) of parameters for each model that needs to be merged.
78
+
79
+ Returns:
80
+ Tensor: L2 norm over all the parameters of models that need to be merged.
81
+ """
82
+ # dict, key is parameter name, value is a Tensor with shape (num_models_to_merge, )
83
+ models_fisher_norm_dict = {}
84
+ # compute L2 norm over models for each parameter
85
+ for param_name, _ in models_to_merge_param_dict.items():
86
+ # Tensor, shape (num_models_to_merge, *fisher_weight_shape)
87
+ models_fisher = torch.stack(
88
+ [
89
+ model_to_merge_fisher_weights[param_name]
90
+ for model_to_merge_fisher_weights in models_to_merge_fisher_weights_list
91
+ ],
92
+ dim=0,
93
+ )
94
+ dims = [dim_idx for dim_idx in range(1, models_fisher.dim())]
95
+ # Tensor, shape (num_models_to_merge, ), compute L2 norm for each parameter
96
+ models_fisher_norm = torch.linalg.vector_norm(models_fisher, dim=dims)
97
+ models_fisher_norm_dict[param_name] = models_fisher_norm
98
+
99
+ # Tensor, shape (num_models_to_merge, num_parameters)
100
+ models_fisher_norm = torch.stack(
101
+ [models_fisher_norm for models_fisher_norm in models_fisher_norm_dict.values()],
102
+ dim=1,
103
+ )
104
+ # Tensor, shape (num_models_to_merge, ), compute L2 norm over all the parameters
105
+ models_fisher_norm = torch.norm(models_fisher_norm, dim=1)
106
+ return models_fisher_norm
107
+
108
+
109
+ def merging_with_fisher_weights(
110
+ models_to_merge_param_dict: Dict[str, List[Tensor]],
111
+ models_to_merge_fisher_weights_list: list,
112
+ fisher_scaling_coefficients: torch.Tensor,
113
+ normalize_fisher_weight: bool = True,
114
+ minimal_fisher_weight: float = 1e-6,
115
+ ) -> Dict[str, Tensor]:
116
+ """
117
+ Merge parameters of different models with computed Fisher weights.
118
+
119
+ Args:
120
+ models_to_merge_param_dict (Dict[str, List[Tensor]]): Dictionary of list, where key is the parameter name,
121
+ value is a list of the corresponding parameters of all the models that need to be merged.
122
+ models_to_merge_fisher_weights_list (list): List of dictionaries with length len(models_to_merge),
123
+ each dictionary records the Fisher weights (matrix or vector) of parameters for each model that needs to be merged.
124
+ fisher_scaling_coefficients (torch.Tensor): Scaling coefficients to merge Fisher weights.
125
+ normalize_fisher_weight (bool): Whether to normalize Fisher weights (L2 norm) or not.
126
+ minimal_fisher_weight (float): The minimal value in Fisher weights, used for tackling the potential numerical issues.
127
+
128
+ Returns:
129
+ Dict[str, Tensor]: Dictionary of merged parameters.
130
+ """
131
+ # dict, dictionary of model parameters
132
+ merged_params = {}
133
+
134
+ if normalize_fisher_weight:
135
+ # Tensor, shape (num_models_to_merge, ), L2 norm over all the parameters of models that need to be merged
136
+ models_fisher_norm = get_models_fisher_norm(
137
+ models_to_merge_param_dict=models_to_merge_param_dict,
138
+ models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
139
+ )
140
+
141
+ for param_name, param_value_list in models_to_merge_param_dict.items():
142
+ # shape (num_models_to_merge, *parameter_shape)
143
+ param_values = torch.stack(param_value_list, dim=0)
144
+ # Tensor, shape (num_models_to_merge, *fisher_weight_shape), use minimal_fisher_weight to solve the potential numerical issues
145
+ models_to_merge_fisher_weights = (
146
+ torch.stack(
147
+ [
148
+ model_to_merge_fisher_weights[param_name]
149
+ for model_to_merge_fisher_weights in models_to_merge_fisher_weights_list
150
+ ],
151
+ dim=0,
152
+ )
153
+ + minimal_fisher_weight
154
+ )
155
+
156
+ # Tensor, shape (num_models_to_merge, 1, 1, ...)
157
+ reshaped_scaling_coefficients = fisher_scaling_coefficients.reshape(
158
+ -1, *[1 for _ in range(param_values.dim() - 1)]
159
+ ).to(param_values.device)
160
+
161
+ if normalize_fisher_weight:
162
+ # Tensor, shape (num_models_to_merge, )
163
+ _models_fisher_norm = 1.0 / (models_fisher_norm + minimal_fisher_weight)
164
+ normalized_models_fisher_norm = (
165
+ _models_fisher_norm / _models_fisher_norm.sum()
166
+ )
167
+ normalized_models_fisher_norm = normalized_models_fisher_norm.reshape(
168
+ -1, *[1 for _ in range(param_values.dim() - 1)]
169
+ )
170
+ reshaped_scaling_coefficients = (
171
+ reshaped_scaling_coefficients * normalized_models_fisher_norm
172
+ )
173
+
174
+ # shape (*parameter_shape)
175
+ numerator = (
176
+ reshaped_scaling_coefficients
177
+ * models_to_merge_fisher_weights
178
+ * param_values
179
+ ).sum(dim=0)
180
+
181
+ # shape (*parameter_shape)
182
+ denominator = (
183
+ reshaped_scaling_coefficients * models_to_merge_fisher_weights
184
+ ).sum(dim=0)
185
+
186
+ merged_param = numerator / denominator
187
+ merged_params[param_name] = merged_param
188
+ return merged_params
189
+
190
+
191
+ def fisher_merging(
192
+ models_to_merge: List[nn.Module],
193
+ trainers: list,
194
+ exclude_param_names_regex: list,
195
+ nums_fisher_examples: List[int],
196
+ fisher_scaling_coefficients: list = None,
197
+ normalize_fisher_weight: bool = True,
198
+ minimal_fisher_weight: float = 1e-6,
199
+ ) -> Dict[str, Tensor]:
200
+ """
201
+ Fisher merging method.
202
+
203
+ Args:
204
+ models_to_merge (List[nn.Module]): List of individual models that need to be merged.
205
+ trainers (list): List of trainers of individual models.
206
+ exclude_param_names_regex (list): List of regular expressions for parameter names to be excluded.
207
+ nums_fisher_examples (List[int]): List of numbers of examples to compute Fisher weights.
208
+ fisher_scaling_coefficients (list, optional): Scaling coefficients to merge Fisher weights. Defaults to None.
209
+ normalize_fisher_weight (bool): Whether to normalize Fisher weights (L2 norm) or not. Defaults to True.
210
+ minimal_fisher_weight (float): The minimal value in Fisher weights, used for tackling the potential numerical issues. Defaults to 1e-6.
211
+
212
+ Returns:
213
+ Dict[str, Tensor]: Dictionary of merged parameters.
214
+ """
215
+ # dictionary of list, where key is the parameter name,
216
+ # value is a list of the corresponding parameters of all the models that need to be merged
217
+ models_to_merge_param_dict = defaultdict(list)
218
+
219
+ # list of dictionaries with length len(models_to_merge),
220
+ # each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
221
+ models_to_merge_fisher_weights_list = []
222
+
223
+ assert (
224
+ len(models_to_merge) == len(trainers) == len(nums_fisher_examples)
225
+ ), "sizes of lists are not identical!"
226
+
227
+ for model_idx, (model_to_merge, trainer, num_fisher_examples) in enumerate(
228
+ zip(models_to_merge, trainers, nums_fisher_examples)
229
+ ):
230
+ param_dict = {
231
+ param_name: param_value
232
+ for param_name, param_value in model_to_merge.named_parameters()
233
+ }
234
+ # exclude parameter whose name matches element in exclude_param_names_regex
235
+ param_names_to_merge = get_param_names_to_merge(
236
+ input_param_names=list(param_dict.keys()),
237
+ exclude_param_names_regex=exclude_param_names_regex,
238
+ )
239
+
240
+ for param_name in param_names_to_merge:
241
+ models_to_merge_param_dict[param_name].append(param_dict[param_name])
242
+
243
+ # list of dictionaries with length (num_fisher_examples // batch_size) or (num_fisher_examples // batch_size) + 1,
244
+ # each dictionary records the fisher weights of parameters for model_to_merge computed by examples in a batch
245
+ batches_fisher_weights_list = []
246
+
247
+ num_computed_examples = 0
248
+ train_dataloader = trainer.get_train_dataloader()
249
+ if num_fisher_examples % trainer._train_batch_size != 0:
250
+ print(
251
+ f"warning: the number of examples for computing fisher cannot be fully divided by the batch size for model {model_idx}, "
252
+ "which may lead to a slightly different number of the actually used examples."
253
+ )
254
+ for step, inputs in tqdm(
255
+ enumerate(train_dataloader),
256
+ desc=f"computing fisher weights for model {model_idx}",
257
+ ):
258
+ if num_computed_examples >= num_fisher_examples:
259
+ break
260
+ inputs = trainer._prepare_inputs(inputs)
261
+ outputs = model_to_merge(**inputs)
262
+ # Tensor, shape (batch_size, num_label_classes)
263
+ logits = outputs.logits
264
+ # compute fisher weights for regression task
265
+ if logits.shape[-1] == 1:
266
+ # use the label information to compute loss and obtain gradients
267
+ mse_loss = outputs.loss
268
+ model_to_merge.zero_grad()
269
+ mse_loss.backward()
270
+ # dict, fisher weights of a batch
271
+ batch_fisher_weights = get_param_squared_gradients(
272
+ model=model_to_merge, param_names_to_merge=param_names_to_merge
273
+ )
274
+ # compute fisher weights for classifxication task
275
+ else:
276
+ # use detach() to detach from the computation graph
277
+ # Tensor, shape (batch_size, num_label_classes)
278
+ labels_probabilities = torch.softmax(logits, dim=-1).detach()
279
+ labels_log_probabilities = torch.log_softmax(logits, dim=-1)
280
+ # sqrt labels_probabilities, since torch.sqrt(labels_probabilities) would be squared in the following squared gradients
281
+ labels_expectations = (
282
+ torch.sqrt(labels_probabilities) * labels_log_probabilities
283
+ )
284
+ # sum over label classes and batch dimension
285
+ sum_labels_expectations = labels_expectations.sum(dim=-1).sum(dim=0)
286
+ model_to_merge.zero_grad()
287
+ sum_labels_expectations.backward()
288
+ # dict, fisher weights of a batch
289
+ batch_fisher_weights = get_param_squared_gradients(
290
+ model=model_to_merge, param_names_to_merge=param_names_to_merge
291
+ )
292
+
293
+ batches_fisher_weights_list.append(batch_fisher_weights)
294
+ num_computed_examples += trainer._train_batch_size
295
+
296
+ model_to_merge_fisher_weights = {}
297
+ for batch_fisher_weights in batches_fisher_weights_list:
298
+ for key in batch_fisher_weights:
299
+ if key not in model_to_merge_fisher_weights:
300
+ model_to_merge_fisher_weights[key] = batch_fisher_weights[key]
301
+ else:
302
+ model_to_merge_fisher_weights[key] += batch_fisher_weights[key]
303
+
304
+ # mean over batches
305
+ for key in model_to_merge_fisher_weights:
306
+ model_to_merge_fisher_weights[key] /= num_computed_examples
307
+ models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
308
+
309
+ # merging with fisher weights
310
+ # if fisher_scaling_coefficients is None, then set the fisher weights of different models to contribute equally
311
+ if fisher_scaling_coefficients is None:
312
+ fisher_scaling_coefficients = torch.ones(len(models_to_merge)) / len(
313
+ models_to_merge
314
+ )
315
+ else:
316
+ assert isinstance(
317
+ fisher_scaling_coefficients, list
318
+ ), "wrong type of fisher_scaling_coefficients, should be list!"
319
+ assert len(fisher_scaling_coefficients) == len(
320
+ models_to_merge
321
+ ), "mismatched length of fisher_scaling_coefficients!"
322
+ fisher_scaling_coefficients = torch.Tensor(fisher_scaling_coefficients)
323
+ # merging with fisher weights
324
+ merged_params = merging_with_fisher_weights(
325
+ models_to_merge_param_dict=models_to_merge_param_dict,
326
+ models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
327
+ fisher_scaling_coefficients=fisher_scaling_coefficients,
328
+ normalize_fisher_weight=normalize_fisher_weight,
329
+ minimal_fisher_weight=minimal_fisher_weight,
330
+ )
331
+
332
+ return merged_params
333
+
334
+
335
+ def filter_state_dict(
336
+ state_dict: Dict[str, Tensor],
337
+ param_names: List[str],
338
+ ) -> Dict[str, Tensor]:
339
+ """
340
+ Filter the state dict with the param names.
341
+
342
+ Args:
343
+ state_dict (Dict[str, Tensor]): State dict of a model.
344
+ param_names (List[str]): List of parameter names to be filtered.
345
+
346
+ Returns:
347
+ Dict[str, Tensor]: Filtered state dict.
348
+ """
349
+ filtered_state_dict = {}
350
+ for key in param_names:
351
+ filtered_state_dict[key] = state_dict[key]
352
+ return filtered_state_dict
353
+
354
+
355
+ class FisherMergingAlgorithm(BaseAlgorithm):
356
+ """
357
+ Implements the Fisher Merging Algorithm.
358
+
359
+ This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights.
360
+ It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.
361
+
362
+ Methods:
363
+ run(modelpool: BaseModelPool) -> nn.Module:
364
+ Executes the Fisher merging process on the model pool and returns the merged model.
365
+ """
366
+
367
+ _config_mapping = BaseAlgorithm._config_mapping | {
368
+ "exclude_param_names_regex": "exclude_param_names_regex",
369
+ "normalize_fisher_weight": "normalize_fisher_weight",
370
+ "minimal_fisher_weight": "minimal_fisher_weight",
371
+ "num_fisher_examples": "num_fisher_examples",
372
+ }
373
+
374
+ def __init__(
375
+ self,
376
+ *,
377
+ exclude_param_names_regex: list,
378
+ normalize_fisher_weight: bool,
379
+ minimal_fisher_weight: float,
380
+ num_fisher_examples: int,
381
+ ):
382
+ super().__init__()
383
+ self.exclude_param_names_regex = exclude_param_names_regex
384
+ self.normalize_fisher_weight = normalize_fisher_weight
385
+ self.minimal_fisher_weight = minimal_fisher_weight
386
+ self.num_fisher_examples = num_fisher_examples
387
+
388
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
389
+ """
390
+ Run the Fisher Merging Algorithm.
391
+
392
+ This method constructs the wrapped model and performs test-time adaptation if necessary.
393
+
394
+ Args:
395
+ modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.
396
+
397
+ Returns:
398
+ nn.Module: The merged model after test-time adaptation.
399
+ """
400
+ log.info("Running Fisher Merging Algorithm")
401
+ if isinstance(modelpool, (dict, list, tuple)):
402
+ modelpool = BaseModelPool(modelpool)
403
+
404
+ assert len(modelpool) > 0, "model pool is empty"
405
+ assert (
406
+ modelpool.has_pretrained
407
+ ), "no pretrained model (base model) in the model pool"
408
+
409
+ self.modelpool = modelpool
410
+ self.on_fisher_merging_start()
411
+
412
+ # dictionary of list, where key is the parameter name,
413
+ # value is a list of the corresponding parameters of all the models that need to be merged
414
+ models_to_merge_param_dict = defaultdict(list)
415
+
416
+ # list of dictionaries with length len(models_to_merge),
417
+ # each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
418
+ models_to_merge_fisher_weights_list = []
419
+
420
+ param_names_to_merge = None
421
+
422
+ for name, model in modelpool.named_models():
423
+ param_dict = model.state_dict()
424
+ if param_names_to_merge is None:
425
+ param_names_to_merge = get_param_names_to_merge(
426
+ input_param_names=list(param_dict.keys()),
427
+ exclude_param_names_regex=self.config.get(
428
+ "exclude_param_names_regex", []
429
+ ),
430
+ )
431
+
432
+ for param_name in param_names_to_merge:
433
+ models_to_merge_param_dict[param_name].append(param_dict[param_name])
434
+
435
+ model_to_merge_fisher_weights = self.get_fisher_weights(
436
+ model_name=name,
437
+ model=model,
438
+ train_dataset=modelpool.load_train_dataset(name),
439
+ param_names_to_merge=param_names_to_merge,
440
+ )
441
+
442
+ models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
443
+
444
+ merged_params = merging_with_fisher_weights(
445
+ models_to_merge_param_dict=models_to_merge_param_dict,
446
+ models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
447
+ fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
448
+ normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
449
+ minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
450
+ )
451
+
452
+ merged_model = modelpool.load_model("_pretrained_")
453
+ merged_model.load_state_dict(merged_params, strict=False)
454
+ return merged_model
455
+
456
+ def get_fisher_weights(
457
+ self,
458
+ model_name: str,
459
+ model: nn.Module,
460
+ train_dataset,
461
+ param_names_to_merge: List[str],
462
+ ) -> Dict[str, Tensor]:
463
+ """
464
+ Compute the Fisher weights for the given model and training dataset.
465
+
466
+ Args:
467
+ model_name (str): The name of the model.
468
+ model (nn.Module): The model module.
469
+ train_dataset: The training dataset.
470
+ param_names_to_merge (List[str]): List of parameter names to merge.
471
+
472
+ Returns:
473
+ Dict[str, Tensor]: The computed Fisher weights for each parameter.
474
+ """
475
+ # this function is used to compute fisher weights for a model
476
+ # it should be implemented in the subclass
477
+ raise NotImplementedError
478
+
479
+ def on_fisher_merging_start(self):
480
+ """
481
+ Setup the zero-shot classification head before starting the Fisher merging process.
482
+ """
483
+ # this function is used to initialize some variables before running fisher merging
484
+ pass
@@ -0,0 +1,193 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from functools import cache
5
+ from typing import Dict, List, cast
6
+
7
+ import lightning as L
8
+ import torch
9
+ from omegaconf import DictConfig
10
+ from torch import Tensor, nn
11
+ from torch.nn.modules import Module
12
+ from torch.utils.data import DataLoader
13
+ from tqdm.autonotebook import tqdm
14
+ from transformers import GPT2ForSequenceClassification, GPT2Model
15
+ from transformers.data import default_data_collator
16
+ from transformers.models.gpt2.modeling_gpt2 import Conv1D
17
+
18
+ from fusion_bench.mixins import LightningFabricMixin
19
+ from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
20
+ from fusion_bench.utils import timeit_context
21
+
22
+ from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
23
+
24
+
25
+ class FisherMergingAlgorithmForGPT2(
26
+ FisherMergingAlgorithm,
27
+ LightningFabricMixin,
28
+ ):
29
+ """
30
+ Implements the Fisher Merging Algorithm for GPT-2 models on text classification tasks.
31
+
32
+ This class extends the FisherMergingAlgorithm to handle GPT-2 models specifically.
33
+ It supports caching, batch processing, and multi-worker data loading.
34
+
35
+ Attributes:
36
+ classifiers (dict): A dictionary to store classifiers for each model.
37
+ modelpool (HuggingFaceGPT2ClassificationPool): The model pool containing the GPT-2 models.
38
+ cache_dir (str): Directory to cache data.
39
+ batch_size (int): Batch size for data loading.
40
+ num_workers (int): Number of workers for data loading.
41
+ """
42
+
43
+ classifiers = {}
44
+ modelpool: GPT2ForSequenceClassificationPool = None
45
+ _config_mapping = FisherMergingAlgorithm._config_mapping | {
46
+ "cache_dir": "cache_dir",
47
+ "batch_size": "batch_size",
48
+ "num_workers": "num_workers",
49
+ }
50
+
51
+ def __init__(
52
+ self,
53
+ cache_dir: str,
54
+ batch_size: int,
55
+ num_workers: int,
56
+ **kwargs,
57
+ ):
58
+ """
59
+ Initialize the FisherMergingAlgorithmForGPT2 with the given configuration.
60
+
61
+ Args:
62
+ cache_dir (str): Directory to cache data.
63
+ batch_size (int): Batch size for data loading.
64
+ num_workers (int): Number of workers for data loading.
65
+ **kwargs: Additional keyword arguments.
66
+ """
67
+ self.cache_dir = cache_dir
68
+ self.batch_size = batch_size
69
+ self.num_workers = num_workers
70
+ super().__init__(**kwargs)
71
+
72
+ def on_fisher_merging_start(self):
73
+ """
74
+ Setup the classifiers for each model in the model pool before starting the Fisher merging process.
75
+ """
76
+ for model_name in self.modelpool.model_names:
77
+ classifier = cast(
78
+ GPT2ForSequenceClassification,
79
+ self.modelpool.load_classifier(model_name),
80
+ ).requires_grad_(False)
81
+ classifier.transformer = None
82
+ classifier = classifier.to(self.fabric.device)
83
+ self.classifiers[model_name] = classifier
84
+
85
+ def compute_logits(self, module: GPT2Model, batch, task: str) -> Tensor:
86
+ """
87
+ Compute the logits for the given batch and task.
88
+
89
+ Args:
90
+ module (GPT2Model): The GPT-2 model module.
91
+ batch (dict): The input batch.
92
+ task (str): The name of the task.
93
+
94
+ Returns:
95
+ Tensor: The computed logits.
96
+ """
97
+ self.classifiers[task].transformer = module
98
+ input_ids = batch["input_ids"]
99
+ attention_mask = batch["attention_mask"]
100
+
101
+ outputs = self.classifiers[task](input_ids, attention_mask=attention_mask)
102
+ logits = outputs.logits
103
+ assert logits.dim() == 2
104
+ return logits
105
+
106
+ def get_fisher_weights(
107
+ self,
108
+ model_name: str,
109
+ model: Module,
110
+ train_dataset,
111
+ param_names_to_merge: List[str],
112
+ ) -> Dict[str, Tensor]:
113
+ """
114
+ Compute the Fisher weights for the given model and training dataset.
115
+
116
+ Args:
117
+ model_name (str): The name of the model.
118
+ model (Module): The model module.
119
+ train_dataset: The training dataset.
120
+ param_names_to_merge (List[str]): List of parameter names to merge.
121
+
122
+ Returns:
123
+ Dict[str, Tensor]: The computed Fisher weights for each parameter.
124
+ """
125
+ # setup dataloader
126
+ train_dataloader = DataLoader(
127
+ train_dataset,
128
+ batch_size=self.config.batch_size,
129
+ shuffle=True,
130
+ collate_fn=default_data_collator,
131
+ num_workers=self.config.num_workers,
132
+ pin_memory=True,
133
+ )
134
+ train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
135
+ model = self.fabric.setup(model)
136
+ num_fisher_examples = self.config.num_fisher_examples
137
+ if num_fisher_examples % train_dataloader.batch_size != 0:
138
+ print(
139
+ f"warning: the number of examples for computing fisher cannot be fully divided by the batch size for model, "
140
+ "which may lead to a slightly different number of the actually used examples."
141
+ )
142
+ num_computed_examples = 0
143
+ batches_fisher_weights_list = []
144
+ for step, batch in tqdm(
145
+ enumerate(train_dataloader),
146
+ desc=f"computing fisher weights",
147
+ total=num_fisher_examples // train_dataloader.batch_size,
148
+ ):
149
+ if num_computed_examples >= num_fisher_examples:
150
+ break
151
+ logits = self.compute_logits(model, batch, model_name)
152
+ # Tensor, shape (batch_size, num_label_classes)
153
+
154
+ # compute fisher weights for classifxication task
155
+ # use detach() to detach from the computation graph
156
+ # Tensor, shape (batch_size, num_label_classes)
157
+ labels_probabilities = torch.softmax(logits, dim=-1).detach()
158
+ labels_log_probabilities = torch.log_softmax(logits, dim=-1)
159
+ # sqrt labels_probabilities, since torch.sqrt(labels_probabilities) would be squared in the following squared gradients
160
+ labels_expectations = (
161
+ torch.sqrt(labels_probabilities) * labels_log_probabilities
162
+ )
163
+ # sum over label classes and batch dimension
164
+ sum_labels_expectations = labels_expectations.sum(dim=-1).sum(dim=0)
165
+ model.zero_grad()
166
+ sum_labels_expectations.backward()
167
+ # dict, fisher weights of a batch
168
+ batch_fisher_weights = get_param_squared_gradients(
169
+ model=model, param_names_to_merge=param_names_to_merge
170
+ )
171
+
172
+ # move fisher weights to cpu to save GPU memory
173
+ for key, weights in batch_fisher_weights.items():
174
+ batch_fisher_weights[key] = weights.detach().cpu()
175
+
176
+ batches_fisher_weights_list.append(batch_fisher_weights)
177
+ num_computed_examples += batch["input_ids"].size(0)
178
+
179
+ model_to_merge_fisher_weights = {}
180
+ for batch_fisher_weights in batches_fisher_weights_list:
181
+ for key in batch_fisher_weights:
182
+ if key not in model_to_merge_fisher_weights:
183
+ model_to_merge_fisher_weights[key] = batch_fisher_weights[key]
184
+ else:
185
+ model_to_merge_fisher_weights[key] += batch_fisher_weights[key]
186
+
187
+ # mean over batches
188
+ for key in model_to_merge_fisher_weights:
189
+ model_to_merge_fisher_weights[key] /= num_computed_examples
190
+ model_to_merge_fisher_weights[key] = (
191
+ model_to_merge_fisher_weights[key].detach().cpu()
192
+ )
193
+ return model_to_merge_fisher_weights
@@ -0,0 +1,6 @@
1
+ # flake8: noqa F401
2
+ from .expo import ExPOAlgorithm
3
+ from .linear_interpolation import LinearInterpolationAlgorithm
4
+ from .llama_expo import ExPOAlgorithmForLlama
5
+ from .simple_average_for_llama import SimpleAverageForLlama
6
+ from .task_arithmetic_for_llama import TaskArithmeticForLlama