fusion-bench 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (727) hide show
  1. fusion_bench/__init__.py +20 -0
  2. fusion_bench/__main__.py +4 -0
  3. fusion_bench/compat/__init__.py +0 -0
  4. fusion_bench/compat/method/__init__.py +109 -0
  5. fusion_bench/compat/method/base_algorithm.py +58 -0
  6. fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
  7. fusion_bench/compat/modelpool/__init__.py +116 -0
  8. fusion_bench/compat/modelpool/base_pool.py +328 -0
  9. fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
  10. fusion_bench/compat/taskpool/__init__.py +95 -0
  11. fusion_bench/compat/taskpool/base_pool.py +111 -0
  12. fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
  13. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
  14. fusion_bench/constants/__init__.py +2 -0
  15. fusion_bench/constants/paths.py +18 -0
  16. fusion_bench/dataset/__init__.py +29 -0
  17. fusion_bench/dataset/arc_agi/__init__.py +6 -0
  18. fusion_bench/dataset/arc_agi/arc.py +308 -0
  19. fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
  20. fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
  21. fusion_bench/dataset/arc_agi/messagers.py +1355 -0
  22. fusion_bench/dataset/arc_agi/np_cache.py +168 -0
  23. fusion_bench/dataset/arc_agi/preprocess.py +298 -0
  24. fusion_bench/dataset/arc_agi/representers.py +1019 -0
  25. fusion_bench/dataset/clip_dataset.py +71 -0
  26. fusion_bench/dataset/fer2013.py +12 -0
  27. fusion_bench/dataset/gpt2_glue.py +300 -0
  28. fusion_bench/dataset/gsm8k.py +60 -0
  29. fusion_bench/dataset/image_dataset.py +55 -0
  30. fusion_bench/dataset/imdb.py +11 -0
  31. fusion_bench/dataset/llama/__init__.py +1 -0
  32. fusion_bench/dataset/llama/alpaca.py +232 -0
  33. fusion_bench/dataset/llama/collate.py +120 -0
  34. fusion_bench/dataset/llama/metamathqa.py +50 -0
  35. fusion_bench/dataset/llama/openai.py +160 -0
  36. fusion_bench/dataset/llama/preference_700k.py +70 -0
  37. fusion_bench/dataset/llama/sharegpt.py +141 -0
  38. fusion_bench/dataset/llama/squad.py +125 -0
  39. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  40. fusion_bench/dataset/llama/ultrachat.py +58 -0
  41. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  42. fusion_bench/dataset/llama/wikitext.py +89 -0
  43. fusion_bench/dataset/nyuv2.py +119 -0
  44. fusion_bench/method/__init__.py +177 -0
  45. fusion_bench/method/ada_svd/__init__.py +2 -0
  46. fusion_bench/method/ada_svd/clip_vision.py +319 -0
  47. fusion_bench/method/adamerging/__init__.py +6 -0
  48. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
  49. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
  50. fusion_bench/method/adamerging/entropy_loss.py +25 -0
  51. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
  52. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
  53. fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
  54. fusion_bench/method/adamerging/llama_adamerging.py +335 -0
  55. fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
  56. fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
  57. fusion_bench/method/adamerging/utils.py +15 -0
  58. fusion_bench/method/analysis/__init__.py +2 -0
  59. fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
  60. fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
  61. fusion_bench/method/base_algorithm.py +44 -0
  62. fusion_bench/method/classification/__init__.py +3 -0
  63. fusion_bench/method/classification/clip_finetune.py +444 -0
  64. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  65. fusion_bench/method/concrete_subspace/__init__.py +6 -0
  66. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
  67. fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
  68. fusion_bench/method/dare/__init__.py +4 -0
  69. fusion_bench/method/dare/simple_average.py +31 -0
  70. fusion_bench/method/dare/task_arithmetic.py +82 -0
  71. fusion_bench/method/dare/ties_merging.py +100 -0
  72. fusion_bench/method/dare/utils.py +87 -0
  73. fusion_bench/method/dawe/__init__.py +2 -0
  74. fusion_bench/method/dawe/dawe_for_clip.py +274 -0
  75. fusion_bench/method/dawe/warppers/__init__.py +13 -0
  76. fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
  77. fusion_bench/method/depth_upscaling/__init__.py +3 -0
  78. fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
  79. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
  80. fusion_bench/method/dummy.py +35 -0
  81. fusion_bench/method/ensemble.py +98 -0
  82. fusion_bench/method/fisher_merging/__init__.py +4 -0
  83. fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
  84. fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
  85. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
  86. fusion_bench/method/linear/__init__.py +6 -0
  87. fusion_bench/method/linear/expo.py +118 -0
  88. fusion_bench/method/linear/linear_interpolation.py +60 -0
  89. fusion_bench/method/linear/llama_expo.py +229 -0
  90. fusion_bench/method/linear/simple_average_for_llama.py +54 -0
  91. fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
  92. fusion_bench/method/lm_finetune/__init__.py +3 -0
  93. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  94. fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
  95. fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
  96. fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
  97. fusion_bench/method/mixture_of_experts/__init__.py +7 -0
  98. fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
  99. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
  100. fusion_bench/method/model_recombination.py +121 -0
  101. fusion_bench/method/opcm/__init__.py +4 -0
  102. fusion_bench/method/opcm/opcm.py +277 -0
  103. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  104. fusion_bench/method/opcm/ties_merging.py +156 -0
  105. fusion_bench/method/opcm/utils.py +73 -0
  106. fusion_bench/method/opcm/weight_average.py +120 -0
  107. fusion_bench/method/pruning/__init__.py +5 -0
  108. fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
  109. fusion_bench/method/pruning/llama_random_prune.py +143 -0
  110. fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
  111. fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
  112. fusion_bench/method/pruning/prune_utils.py +165 -0
  113. fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
  114. fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
  115. fusion_bench/method/pruning/wanda_utils/data.py +135 -0
  116. fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
  117. fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
  118. fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
  119. fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
  120. fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
  121. fusion_bench/method/pwe_moe/__init__.py +5 -0
  122. fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
  123. fusion_bench/method/pwe_moe/module.py +316 -0
  124. fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
  125. fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
  126. fusion_bench/method/pwe_moe/utils.py +43 -0
  127. fusion_bench/method/rankone_moe/__init__.py +3 -0
  128. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  129. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  130. fusion_bench/method/regmean/__init__.py +4 -0
  131. fusion_bench/method/regmean/clip_regmean.py +131 -0
  132. fusion_bench/method/regmean/gpt2_regmean.py +147 -0
  133. fusion_bench/method/regmean/regmean.py +375 -0
  134. fusion_bench/method/simple_average.py +112 -0
  135. fusion_bench/method/slerp/__init__.py +2 -0
  136. fusion_bench/method/slerp/slerp.py +101 -0
  137. fusion_bench/method/slerp/slerp_utils.py +107 -0
  138. fusion_bench/method/smile_upscaling/__init__.py +3 -0
  139. fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
  140. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
  141. fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
  142. fusion_bench/method/sparse_we_moe/__init__.py +2 -0
  143. fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
  144. fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
  145. fusion_bench/method/sparselo/__init__.py +2 -0
  146. fusion_bench/method/sparselo/sparselo.py +955 -0
  147. fusion_bench/method/surgery/__init__.py +1 -0
  148. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  149. fusion_bench/method/tall_mask/__init__.py +0 -0
  150. fusion_bench/method/tall_mask/utils.py +234 -0
  151. fusion_bench/method/task_arithmetic/__init__.py +2 -0
  152. fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
  153. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  154. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  155. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  156. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  157. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
  158. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  159. fusion_bench/method/ties_merging/__init__.py +2 -0
  160. fusion_bench/method/ties_merging/ties_merging.py +117 -0
  161. fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
  162. fusion_bench/method/trust_region/__init__.py +2 -0
  163. fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
  164. fusion_bench/method/trust_region/utils.py +58 -0
  165. fusion_bench/method/we_moe/__init__.py +2 -0
  166. fusion_bench/method/we_moe/clip_we_moe.py +161 -0
  167. fusion_bench/method/we_moe/we_moe.py +247 -0
  168. fusion_bench/method/weighted_average/__init__.py +3 -0
  169. fusion_bench/method/weighted_average/llama.py +113 -0
  170. fusion_bench/method/weighted_average/weighted_average.py +102 -0
  171. fusion_bench/metrics/__init__.py +0 -0
  172. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  173. fusion_bench/metrics/nyuv2/__init__.py +11 -0
  174. fusion_bench/metrics/nyuv2/depth.py +45 -0
  175. fusion_bench/metrics/nyuv2/loss.py +31 -0
  176. fusion_bench/metrics/nyuv2/noise.py +16 -0
  177. fusion_bench/metrics/nyuv2/normal.py +48 -0
  178. fusion_bench/metrics/nyuv2/segmentation.py +43 -0
  179. fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
  180. fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
  181. fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
  182. fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
  183. fusion_bench/mixins/__init__.py +28 -0
  184. fusion_bench/mixins/clip_classification.py +252 -0
  185. fusion_bench/mixins/fabric_training.py +320 -0
  186. fusion_bench/mixins/lightning_fabric.py +174 -0
  187. fusion_bench/mixins/optim/__init__.py +0 -0
  188. fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
  189. fusion_bench/mixins/rich_live.py +21 -0
  190. fusion_bench/mixins/serialization.py +132 -0
  191. fusion_bench/mixins/simple_profiler.py +79 -0
  192. fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
  193. fusion_bench/modelpool/__init__.py +42 -0
  194. fusion_bench/modelpool/base_pool.py +268 -0
  195. fusion_bench/modelpool/causal_lm/__init__.py +2 -0
  196. fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
  197. fusion_bench/modelpool/clip_vision/__init__.py +1 -0
  198. fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
  199. fusion_bench/modelpool/huggingface_automodel.py +20 -0
  200. fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
  201. fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
  202. fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
  203. fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
  204. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  205. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  206. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  207. fusion_bench/models/__init__.py +3 -0
  208. fusion_bench/models/chat_templates/__init__.py +1 -0
  209. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  210. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  211. fusion_bench/models/hf_clip.py +199 -0
  212. fusion_bench/models/linearized/__init__.py +0 -0
  213. fusion_bench/models/linearized/linearized_model_utils.py +91 -0
  214. fusion_bench/models/linearized/vision_model.py +122 -0
  215. fusion_bench/models/llama/__init__.py +16 -0
  216. fusion_bench/models/llama/model_utils/__init__.py +0 -0
  217. fusion_bench/models/llama/model_utils/embedding.py +87 -0
  218. fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
  219. fusion_bench/models/llama/model_utils/misc.py +112 -0
  220. fusion_bench/models/llama/model_utils/mod.py +52 -0
  221. fusion_bench/models/llama/model_utils/visual.py +241 -0
  222. fusion_bench/models/llama/patcher.py +78 -0
  223. fusion_bench/models/llama/tokenizer_loader.py +153 -0
  224. fusion_bench/models/masks/__init__.py +2 -0
  225. fusion_bench/models/masks/mask_model.py +160 -0
  226. fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
  227. fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
  228. fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
  229. fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
  230. fusion_bench/models/modeling_losparse_llama/register.py +8 -0
  231. fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
  232. fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
  233. fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
  234. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
  235. fusion_bench/models/modeling_smile_mistral/register.py +8 -0
  236. fusion_bench/models/nyuv2/__init__.py +0 -0
  237. fusion_bench/models/nyuv2/aspp.py +82 -0
  238. fusion_bench/models/nyuv2/lightning_module.py +176 -0
  239. fusion_bench/models/nyuv2/resnet.py +405 -0
  240. fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
  241. fusion_bench/models/parameter_dict.py +75 -0
  242. fusion_bench/models/rankone_moe.py +410 -0
  243. fusion_bench/models/separate_io.py +105 -0
  244. fusion_bench/models/smile_moe/__init__.py +0 -0
  245. fusion_bench/models/smile_moe/linear.py +256 -0
  246. fusion_bench/models/sparse_we_moe.py +459 -0
  247. fusion_bench/models/surgery/__init__.py +1 -0
  248. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  249. fusion_bench/models/utils.py +80 -0
  250. fusion_bench/models/we_moe.py +247 -0
  251. fusion_bench/models/wrappers/__init__.py +0 -0
  252. fusion_bench/models/wrappers/ensemble.py +183 -0
  253. fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
  254. fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
  255. fusion_bench/optim/__init__.py +2 -0
  256. fusion_bench/optim/exception.py +47 -0
  257. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  258. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  259. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  260. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  261. fusion_bench/optim/mezo.py +118 -0
  262. fusion_bench/programs/__init__.py +20 -0
  263. fusion_bench/programs/base_program.py +9 -0
  264. fusion_bench/programs/fabric_fusion_program.py +299 -0
  265. fusion_bench/scripts/__init__.py +0 -0
  266. fusion_bench/scripts/cli.py +43 -0
  267. fusion_bench/scripts/clip/__init__.py +0 -0
  268. fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
  269. fusion_bench/scripts/imgui.py +218 -0
  270. fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
  271. fusion_bench/scripts/webui.py +405 -0
  272. fusion_bench/taskpool/__init__.py +39 -0
  273. fusion_bench/taskpool/base_pool.py +35 -0
  274. fusion_bench/taskpool/clip_vision/__init__.py +4 -0
  275. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  276. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
  277. fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
  278. fusion_bench/taskpool/dummy.py +58 -0
  279. fusion_bench/taskpool/gpt2_text_classification.py +149 -0
  280. fusion_bench/taskpool/llama/__init__.py +1 -0
  281. fusion_bench/taskpool/llama/reward_model.py +157 -0
  282. fusion_bench/taskpool/llama/test_generation.py +185 -0
  283. fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
  284. fusion_bench/tasks/__init__.py +2 -0
  285. fusion_bench/tasks/base_task.py +18 -0
  286. fusion_bench/tasks/classification.py +75 -0
  287. fusion_bench/tasks/clip_classification/__init__.py +183 -0
  288. fusion_bench/tasks/clip_classification/cifar10.py +33 -0
  289. fusion_bench/tasks/clip_classification/cifar100.py +146 -0
  290. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
  291. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  292. fusion_bench/tasks/clip_classification/dtd.py +60 -0
  293. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  294. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  295. fusion_bench/tasks/clip_classification/eurosat.py +18 -0
  296. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  297. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  298. fusion_bench/tasks/clip_classification/flower102.py +106 -0
  299. fusion_bench/tasks/clip_classification/food101.py +105 -0
  300. fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
  301. fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
  302. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  303. fusion_bench/tasks/clip_classification/mnist.py +5 -0
  304. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  305. fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
  306. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  307. fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
  308. fusion_bench/tasks/clip_classification/resisc45.py +68 -0
  309. fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
  310. fusion_bench/tasks/clip_classification/stl10.py +17 -0
  311. fusion_bench/tasks/clip_classification/sun397.py +404 -0
  312. fusion_bench/tasks/clip_classification/svhn.py +5 -0
  313. fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
  314. fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
  315. fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
  316. fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
  317. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
  318. fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
  319. fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
  320. fusion_bench/utils/__init__.py +14 -0
  321. fusion_bench/utils/auto.py +31 -0
  322. fusion_bench/utils/cache_utils.py +58 -0
  323. fusion_bench/utils/data.py +165 -0
  324. fusion_bench/utils/devices.py +231 -0
  325. fusion_bench/utils/dict.py +43 -0
  326. fusion_bench/utils/dtype.py +146 -0
  327. fusion_bench/utils/expr.py +90 -0
  328. fusion_bench/utils/fabric.py +17 -0
  329. fusion_bench/utils/functools.py +37 -0
  330. fusion_bench/utils/hydra_utils.py +28 -0
  331. fusion_bench/utils/instantiate.py +450 -0
  332. fusion_bench/utils/json.py +93 -0
  333. fusion_bench/utils/lazy_imports.py +74 -0
  334. fusion_bench/utils/misc.py +18 -0
  335. fusion_bench/utils/packages.py +84 -0
  336. fusion_bench/utils/parameters.py +323 -0
  337. fusion_bench/utils/path.py +22 -0
  338. fusion_bench/utils/plot/__init__.py +0 -0
  339. fusion_bench/utils/plot/color_data.py +1726 -0
  340. fusion_bench/utils/plot/token.py +52 -0
  341. fusion_bench/utils/plot/token_notebook.py +127 -0
  342. fusion_bench/utils/pylogger.py +55 -0
  343. fusion_bench/utils/rich_utils.py +201 -0
  344. fusion_bench/utils/set.py +8 -0
  345. fusion_bench/utils/state_dict_arithmetic.py +297 -0
  346. fusion_bench/utils/strenum/__init__.py +326 -0
  347. fusion_bench/utils/strenum/_name_mangler.py +127 -0
  348. fusion_bench/utils/strenum/_version.py +556 -0
  349. fusion_bench/utils/tensorboard.py +51 -0
  350. fusion_bench/utils/timer.py +49 -0
  351. fusion_bench/utils/type.py +34 -0
  352. fusion_bench-0.2.9.dist-info/LICENSE +21 -0
  353. fusion_bench-0.2.9.dist-info/METADATA +258 -0
  354. fusion_bench-0.2.9.dist-info/RECORD +727 -0
  355. fusion_bench-0.2.9.dist-info/WHEEL +5 -0
  356. fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
  357. fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
  358. fusion_bench_config/README.md +12 -0
  359. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
  360. fusion_bench_config/dataset/image_classification/README.md +6 -0
  361. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  362. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  363. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
  364. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
  365. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  366. fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
  367. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  368. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  369. fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
  370. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  371. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  372. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  373. fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
  374. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  375. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  376. fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
  377. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  378. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  379. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  380. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  381. fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
  382. fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
  383. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  384. fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
  385. fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
  386. fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
  387. fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
  388. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  389. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  390. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
  391. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
  392. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  393. fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
  394. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  395. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  396. fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
  397. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  398. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  399. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  400. fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
  401. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  402. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  403. fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
  404. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  405. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  406. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  407. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  408. fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
  409. fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
  410. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  411. fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
  412. fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
  413. fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
  414. fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
  415. fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
  416. fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
  417. fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
  418. fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
  419. fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
  420. fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
  421. fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
  422. fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
  423. fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
  424. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  425. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  426. fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
  427. fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
  428. fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
  429. fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
  430. fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
  431. fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
  432. fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
  433. fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
  434. fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
  435. fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
  436. fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
  437. fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
  438. fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
  439. fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
  440. fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
  441. fusion_bench_config/fabric/auto.yaml +16 -0
  442. fusion_bench_config/fabric/llama_ddp.yaml +18 -0
  443. fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
  444. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  445. fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
  446. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
  447. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  448. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  449. fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
  450. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  451. fusion_bench_config/fabric_model_fusion.yaml +20 -0
  452. fusion_bench_config/hydra/default.yaml +8 -0
  453. fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
  454. fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
  455. fusion_bench_config/llama_full_finetune.yaml +19 -0
  456. fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
  457. fusion_bench_config/llama_model_fusion.yaml +17 -0
  458. fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
  459. fusion_bench_config/method/adamerging/clip.yaml +23 -0
  460. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
  461. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
  462. fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
  463. fusion_bench_config/method/adamerging.yaml +23 -0
  464. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
  465. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
  466. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  467. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  468. fusion_bench_config/method/clip_finetune.yaml +26 -0
  469. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
  470. fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
  471. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
  472. fusion_bench_config/method/dare/simple_average.yaml +5 -0
  473. fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
  474. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  475. fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
  476. fusion_bench_config/method/depth_upscaling.yaml +5 -0
  477. fusion_bench_config/method/dummy.yaml +1 -0
  478. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
  479. fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
  480. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
  481. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
  482. fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
  483. fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
  484. fusion_bench_config/method/linear/expo.yaml +8 -0
  485. fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
  486. fusion_bench_config/method/linear/llama_expo.yaml +19 -0
  487. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
  488. fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
  489. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
  490. fusion_bench_config/method/linear/weighted_average.yaml +6 -0
  491. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
  492. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  493. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
  494. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
  495. fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
  496. fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
  497. fusion_bench_config/method/model_recombination.yaml +4 -0
  498. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  499. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  500. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  501. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  502. fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
  503. fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
  504. fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
  505. fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
  506. fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
  507. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  508. fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
  509. fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
  510. fusion_bench_config/method/regmean/regmean.yaml +4 -0
  511. fusion_bench_config/method/simple_average.yaml +1 -0
  512. fusion_bench_config/method/slerp/slerp.yaml +6 -0
  513. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
  514. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
  515. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
  516. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
  517. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
  518. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
  519. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  520. fusion_bench_config/method/task_arithmetic.yaml +2 -0
  521. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  522. fusion_bench_config/method/ties_merging.yaml +8 -0
  523. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
  524. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
  525. fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
  526. fusion_bench_config/model/clip-vit/README.md +38 -0
  527. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
  528. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  529. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  530. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  531. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  532. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
  533. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
  534. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  535. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
  536. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  537. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  538. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  539. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
  540. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  541. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
  542. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  543. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  544. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  545. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  546. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
  547. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
  548. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  549. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
  550. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
  551. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
  552. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  553. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  554. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  555. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  556. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
  557. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
  558. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  559. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
  560. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  561. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  562. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  563. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
  564. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  565. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
  566. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  567. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  568. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  569. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  570. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
  571. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
  572. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  573. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
  574. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
  575. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
  576. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  577. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  578. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  579. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  580. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
  581. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
  582. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  583. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
  584. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  585. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  586. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  587. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
  588. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  589. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
  590. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  591. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  592. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  593. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  594. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
  595. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
  596. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  597. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
  598. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
  599. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  600. fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
  601. fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
  602. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
  603. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
  604. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
  605. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
  606. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
  607. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
  608. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
  609. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
  610. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
  611. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
  612. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
  613. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
  614. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
  615. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
  616. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
  617. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
  618. fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
  619. fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
  620. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
  621. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
  622. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
  623. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
  624. fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
  625. fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
  626. fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
  627. fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
  628. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
  629. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
  630. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
  631. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  632. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  633. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  634. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  635. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  636. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
  637. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
  638. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
  639. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
  640. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
  641. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  642. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  643. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  644. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  645. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
  646. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
  647. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
  648. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
  649. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
  650. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
  651. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
  652. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  653. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
  654. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  655. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
  656. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
  657. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  658. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  659. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  660. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  661. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
  662. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  663. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  664. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
  665. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  666. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  667. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
  668. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
  669. fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
  670. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
  671. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
  672. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
  673. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
  674. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
  675. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  676. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  677. fusion_bench_config/modelpool/automodelpool.yaml +12 -0
  678. fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
  679. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
  680. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
  681. fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
  682. fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
  683. fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
  684. fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
  685. fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
  686. fusion_bench_config/nyuv2_config.yaml +17 -0
  687. fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
  688. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
  689. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  690. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
  691. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
  692. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
  693. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
  694. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
  695. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  696. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  697. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  698. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  699. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  700. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  701. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  702. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  703. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  704. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  705. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  706. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  707. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  708. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  709. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  710. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  711. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  712. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  713. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  714. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  715. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  716. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  717. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  718. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  719. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
  720. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
  721. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  722. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
  723. fusion_bench_config/taskpool/dummy.yaml +2 -0
  724. fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
  725. fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
  726. fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
  727. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
@@ -0,0 +1,117 @@
1
+ R"""
2
+ Overview of Ties-Merging:
3
+
4
+ 1. Trim: For each task t, we trim the redundant parameters from the task vector $\tau_t$ to create $\hat{\tau}_t$ by keeping the top-k% values according to their magnitude and trimming the bottom $(100 - k)\%$ of the redundant parameters by resetting them to 0. This can be decomposed further as $\hat{\tau}_t = \hat{\gamma}_t \odot \hat{\mu}_t$.
5
+
6
+ 2. Elect: Next, we create an aggregate elected sign vector $\gamma_m$ for the merged model that resolves the disagreements in the sign for each parameter p across different models. To create the elected sign vector, we choose the sign with the highest total magnitude across all relevant models. For each parameter $p \in \{1, 2, \ldots, d\}$, we separate the values $\{\hat{\tau}_t^p\}_{t=1}^n$ based on their sign $(+1$ or $-1)$ and take their sum to calculate the total mass (i.e., total magnitude) in the positive and the negative direction. We then assign $\gamma_m^p$ as the sign with greater total movement. This can be efficiently computed using $\gamma_m^p = \text{sgn}(\sum_{t=1}^n \hat{\tau}_t^p)$.
7
+
8
+ 3. Disjoint Merge: Then, for each parameter p, we compute a disjoint mean by only keeping the parameter values from the models whose signs are the same as the aggregated elected sign and calculate their mean. Formally, let $A_p = \{t \in [n] \mid \hat{\gamma}_t^p = \gamma_m^p\}$, then $\tau_m^p = \frac{1}{|A_p|}\sum_{t\in A_p} \hat{\tau}_t^p$. Note that the disjoint mean always ignores the zero values.
9
+ """
10
+
11
+ import logging
12
+ from typing import Dict, List, Literal, Mapping, Union # noqa: F401
13
+
14
+ import torch
15
+ from torch import Tensor, nn
16
+
17
+ from fusion_bench.compat.modelpool import to_modelpool
18
+ from fusion_bench.method import BaseAlgorithm
19
+ from fusion_bench.modelpool import BaseModelPool
20
+ from fusion_bench.utils.type import StateDictType
21
+
22
+ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_state_dict
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ class TiesMergingAlgorithm(BaseAlgorithm):
28
+ """
29
+ TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
30
+
31
+ Attributes:
32
+ scaling_factor (float): The scaling factor to apply to the merged task vector.
33
+ threshold (float): The threshold for resetting values in the task vector.
34
+ remove_keys (List[str]): List of keys to remove from the state dictionary.
35
+ merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
36
+ """
37
+
38
+ _config_mapping = BaseAlgorithm._config_mapping | {
39
+ "scaling_factor": "scaling_factor",
40
+ "threshold": "threshold",
41
+ "remove_keys": "remove_keys",
42
+ "merge_func": "merge_func",
43
+ }
44
+
45
+ def __init__(
46
+ self,
47
+ scaling_factor: float,
48
+ threshold: float,
49
+ remove_keys: List[str],
50
+ merge_func: Literal["sum", "mean", "max"],
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Initialize the TiesMergingAlgorithm with the given parameters.
55
+
56
+ Args:
57
+ scaling_factor (float): The scaling factor to apply to the merged task vector.
58
+ threshold (float): The threshold for resetting values in the task vector.
59
+ remove_keys (List[str]): List of keys to remove from the state dictionary.
60
+ merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
61
+ **kwargs: Additional keyword arguments for the base class.
62
+ """
63
+ self.scaling_factor = scaling_factor
64
+ self.threshold = threshold
65
+ self.remove_keys = remove_keys
66
+ self.merge_func = merge_func
67
+ super().__init__(**kwargs)
68
+
69
+ @torch.no_grad()
70
+ def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
71
+ """
72
+ Run the TIES merging algorithm to fuse models in the model pool.
73
+
74
+ Args:
75
+ modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.
76
+
77
+ Returns:
78
+ nn.Module: The fused model.
79
+ """
80
+ log.info("Fusing models using ties merging.")
81
+ modelpool = to_modelpool(modelpool)
82
+ remove_keys = self.config.get("remove_keys", [])
83
+ merge_func = self.config.get("merge_func", "sum")
84
+ scaling_factor = self.scaling_factor
85
+ threshold = self.threshold
86
+
87
+ # Load the pretrained model
88
+ pretrained_model = modelpool.load_model("_pretrained_")
89
+
90
+ # Load the state dicts of the models
91
+ ft_checks: List[StateDictType] = [
92
+ modelpool.load_model(model_name).state_dict(keep_vars=True)
93
+ for model_name in modelpool.model_names
94
+ ]
95
+ ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
96
+
97
+ # Compute the task vectors
98
+ flat_ft: Tensor = torch.vstack(
99
+ [state_dict_to_vector(check, remove_keys) for check in ft_checks]
100
+ )
101
+ flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
102
+ tv_flat_checks = flat_ft - flat_ptm
103
+
104
+ # Perform TIES Merging
105
+ merged_tv = ties_merging(
106
+ tv_flat_checks,
107
+ reset_thresh=threshold,
108
+ merge_func=merge_func,
109
+ )
110
+ merged_check = flat_ptm + scaling_factor * merged_tv
111
+ merged_state_dict = vector_to_state_dict(
112
+ merged_check, ptm_check, remove_keys=remove_keys
113
+ )
114
+
115
+ # Load the merged state dict into the pretrained model
116
+ pretrained_model.load_state_dict(merged_state_dict)
117
+ return pretrained_model
@@ -0,0 +1,331 @@
1
+ """
2
+ This is modified based on https://github.com/EnnengYang/AdaMerging/blob/main/src/ties_merging_utils.py
3
+ """
4
+
5
+ import copy
6
+ from collections import OrderedDict
7
+ from typing import List
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from fusion_bench.utils.type import StateDictType
13
+
14
+
15
+ # Model conversion utils
16
+ def state_dict_to_vector(state_dict, remove_keys=[]):
17
+ """
18
+ Convert a state dictionary to a vector, removing specified keys.
19
+
20
+ Args:
21
+ state_dict (dict): The state dictionary to convert.
22
+ remove_keys (list): List of keys to remove from the state dictionary.
23
+
24
+ Returns:
25
+ Tensor: A vector representation of the state dictionary.
26
+ """
27
+ shared_state_dict = copy.deepcopy(state_dict)
28
+ for key in remove_keys:
29
+ if key in shared_state_dict:
30
+ del shared_state_dict[key]
31
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
32
+ return nn.utils.parameters_to_vector(
33
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
34
+ )
35
+
36
+
37
+ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
38
+ """
39
+ Convert a vector back to a state dictionary, removing specified keys.
40
+
41
+ Args:
42
+ vector (Tensor): The vector to convert.
43
+ state_dict (dict): The reference state dictionary.
44
+ remove_keys (list): List of keys to remove from the state dictionary.
45
+
46
+ Returns:
47
+ dict: A state dictionary representation of the vector.
48
+ """
49
+ # create a reference dict to define the order of the vector
50
+ reference_dict = copy.deepcopy(state_dict)
51
+ for key in remove_keys:
52
+ if key in reference_dict:
53
+ del reference_dict[key]
54
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
55
+
56
+ # create a shared state dict using the reference dict
57
+ nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
58
+
59
+ # add back the encoder and decoder embedding weights.
60
+ if "transformer.shared.weight" in sorted_reference_dict:
61
+ for key in remove_keys:
62
+ sorted_reference_dict[key] = sorted_reference_dict[
63
+ "transformer.shared.weight"
64
+ ]
65
+ return sorted_reference_dict
66
+
67
+
68
+ def add_ptm_to_tv(tv_dict, ptm_dict):
69
+ """
70
+ Add the values of one state dictionary to another.
71
+
72
+ Args:
73
+ tv_dict (dict): The target state dictionary.
74
+ ptm_dict (dict): The state dictionary to add.
75
+
76
+ Returns:
77
+ dict: The resulting state dictionary after addition.
78
+ """
79
+ assert set(tv_dict.keys()) == set(
80
+ ptm_dict.keys()
81
+ ), "Differing parameter names in models."
82
+ final_dict = copy.deepcopy(tv_dict)
83
+ for k, v in ptm_dict.items():
84
+ final_dict[k] = tv_dict[k] + v
85
+ return final_dict
86
+
87
+
88
+ def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
89
+ """
90
+ Check if the parameter names match across multiple checkpoints.
91
+
92
+ Args:
93
+ checkpoints (list): List of state dictionaries to check.
94
+
95
+ Raises:
96
+ ValueError: If the parameter names do not match.
97
+ """
98
+ parameter_names = set(checkpoints[0].keys())
99
+
100
+ if len(checkpoints) >= 2:
101
+ # raise ValueError("Number of models is less than 2.")
102
+ for checkpoint in checkpoints[1:]:
103
+ current_parameterNames = set(checkpoint.keys())
104
+ if current_parameterNames != parameter_names:
105
+ raise ValueError(
106
+ "Differing parameter names in models. "
107
+ f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
108
+ )
109
+
110
+
111
+ def check_state_dicts_equal(
112
+ state_dict1: StateDictType, state_dict2: StateDictType
113
+ ) -> bool:
114
+ """
115
+ Check if two state dictionaries are equal.
116
+
117
+ Args:
118
+ state_dict1 (dict): The first state dictionary.
119
+ state_dict2 (dict): The second state dictionary.
120
+
121
+ Returns:
122
+ bool: True if the state dictionaries are equal, False otherwise.
123
+ """
124
+ if set(state_dict1.keys()) != set(state_dict2.keys()):
125
+ return False
126
+
127
+ for key in state_dict1.keys():
128
+ if not torch.equal(state_dict1[key], state_dict2[key]):
129
+ return False
130
+
131
+ return True
132
+
133
+
134
+ # TIES MERGING UTILS
135
+
136
+
137
+ def topk_values_mask(M, K=0.7, return_mask=False):
138
+ """
139
+ Mask the top K values in a tensor.
140
+
141
+ Args:
142
+ M (Tensor): The input tensor.
143
+ K (float): The proportion of top values to keep.
144
+ return_mask (bool): Whether to return the mask tensor.
145
+
146
+ Returns:
147
+ tuple: The masked tensor, the mean of the mask, and optionally the mask tensor.
148
+ """
149
+ if K > 1:
150
+ K /= 100
151
+
152
+ original_shape = M.shape
153
+ if M.dim() == 1:
154
+ M = M.unsqueeze(0)
155
+
156
+ n, d = M.shape
157
+ k = int(d * K)
158
+ k = d - k # Keep top k elements instead of bottom k elements
159
+
160
+ # Find the k-th smallest element by magnitude for each row
161
+ kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
162
+ # Create a mask tensor with True for the top k elements in each row
163
+ mask = M.abs() >= kth_values
164
+ final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
165
+
166
+ if return_mask:
167
+ return M * final_mask, final_mask.float().mean(dim=1), final_mask
168
+ return M * final_mask, final_mask.float().mean(dim=1)
169
+
170
+
171
+ def resolve_zero_signs(sign_to_mult, method="majority"):
172
+ """
173
+ Resolve zero signs in a tensor by majority or minority rule.
174
+
175
+ Args:
176
+ sign_to_mult (Tensor): The tensor with signs to resolve.
177
+ method (str): The method to use for resolving zero signs ("majority" or "minority").
178
+
179
+ Returns:
180
+ Tensor: The tensor with resolved signs.
181
+ """
182
+ majority_sign = torch.sign(sign_to_mult.sum())
183
+
184
+ if method == "majority":
185
+ sign_to_mult[sign_to_mult == 0] = majority_sign
186
+ elif method == "minority":
187
+ sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
188
+ return sign_to_mult
189
+
190
+
191
+ def resolve_sign(v: Tensor):
192
+ """
193
+ Resolve the sign of a tensor by majority rule.
194
+
195
+ Args:
196
+ v (Tensor): The input tensor.
197
+
198
+ Returns:
199
+ Tensor: The tensor with resolved signs.
200
+ """
201
+ sign_to_mult = torch.sign(v.sum(dim=0))
202
+ sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
203
+ return sign_to_mult
204
+
205
+
206
+ def disjoint_merge(v: Tensor, merge_func: str, sign_to_mult):
207
+ """
208
+ Perform disjoint merging of a tensor using a specified merge function.
209
+
210
+ Args:
211
+ v (Tensor): The input tensor.
212
+ merge_func (str): The merge function to use ("mean", "sum", or "max").
213
+ sign_to_mult (Tensor): The tensor with signs to use for merging.
214
+
215
+ Returns:
216
+ Tensor: The merged tensor.
217
+ """
218
+ merge_func = merge_func.split("-")[-1]
219
+
220
+ # If sign is provided then we select the corresponding entries and aggregate.
221
+ if sign_to_mult is not None:
222
+ rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
223
+ selected_entries = v * rows_to_keep
224
+ # Else we select all non-zero entries and aggregate.
225
+ else:
226
+ rows_to_keep = v != 0
227
+ selected_entries = v * rows_to_keep
228
+
229
+ if merge_func == "mean":
230
+ non_zero_counts = (selected_entries != 0).sum(dim=0).float()
231
+ disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
232
+ non_zero_counts, min=1
233
+ )
234
+ elif merge_func == "sum":
235
+ disjoint_aggs = torch.sum(selected_entries, dim=0)
236
+ elif merge_func == "max":
237
+ disjoint_aggs = selected_entries.abs().max(dim=0)[0]
238
+ disjoint_aggs *= sign_to_mult
239
+ else:
240
+ raise ValueError(f"Merge method {merge_func} is not defined.")
241
+
242
+ return disjoint_aggs
243
+
244
+
245
+ def ties_merging(
246
+ flat_task_checks,
247
+ reset_thresh=None,
248
+ merge_func="",
249
+ ):
250
+ """
251
+ Perform TIES merging on a tensor.
252
+
253
+ Args:
254
+ flat_task_checks (Tensor): The input tensor.
255
+ reset_thresh (float): The threshold for resetting values.
256
+ merge_func (str): The merge function to use.
257
+
258
+ Returns:
259
+ Tensor: The merged tensor.
260
+ """
261
+ all_checks = flat_task_checks.clone()
262
+ updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
263
+ print("RESOLVING SIGN")
264
+ final_signs = resolve_sign(updated_checks)
265
+ assert final_signs is not None
266
+
267
+ print(f"Disjoint AGGREGATION: {merge_func}")
268
+ merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
269
+
270
+ return merged_tv
271
+
272
+
273
+ def disjoint_merge_split(v: Tensor, merge_func: str, sign_to_mult):
274
+ """
275
+ Perform disjoint merging of a tensor using a specified merge function and return selected entries.
276
+
277
+ Args:
278
+ v (Tensor): The input tensor.
279
+ merge_func (str): The merge function to use ("sum").
280
+ sign_to_mult (Tensor): The tensor with signs to use for merging.
281
+
282
+ Returns:
283
+ tuple: The selected entries and the merged tensor.
284
+ """
285
+ merge_func = merge_func.split("-")[-1]
286
+
287
+ # If sign is provided then we select the corresponding entries and aggregate.
288
+ if sign_to_mult is not None:
289
+ rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
290
+ selected_entries = v * rows_to_keep
291
+ # Else we select all non-zero entries and aggregate.
292
+ else:
293
+ rows_to_keep = v != 0
294
+ selected_entries = v * rows_to_keep
295
+
296
+ if merge_func == "sum":
297
+ disjoint_aggs = torch.sum(selected_entries, dim=0)
298
+ else:
299
+ raise ValueError(f"Merge method {merge_func} is not defined.")
300
+
301
+ return selected_entries, disjoint_aggs
302
+
303
+
304
+ def ties_merging_split(
305
+ flat_task_checks,
306
+ reset_thresh=None,
307
+ merge_func: str = "",
308
+ ):
309
+ """
310
+ Perform TIES merging on a tensor and return selected entries.
311
+
312
+ Args:
313
+ flat_task_checks (Tensor): The input tensor.
314
+ reset_thresh (float): The threshold for resetting values.
315
+ merge_func (str): The merge function to use.
316
+
317
+ Returns:
318
+ tuple: The selected entries and the merged tensor.
319
+ """
320
+ all_checks = flat_task_checks.clone()
321
+ updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
322
+ print("RESOLVING SIGN")
323
+ final_signs = resolve_sign(updated_checks)
324
+ assert final_signs is not None
325
+
326
+ print(f"Disjoint AGGREGATION: {merge_func}")
327
+ selected_entries, merged_tv = disjoint_merge_split(
328
+ updated_checks, merge_func, final_signs
329
+ )
330
+
331
+ return selected_entries, merged_tv
@@ -0,0 +1,2 @@
1
+ # flake8: noqa F401
2
+ from .clip_task_arithmetic import TaskArithmeticWithTrustRegionForCLIP
@@ -0,0 +1,205 @@
1
+ """
2
+ Implementation of Task Arithmetic in Trust Region: A Training-Free Model Merging Approach to Navigate Knowledge Conflicts
3
+ https://openreview.net/forum?id=q3ztjJRQuJ
4
+ """
5
+
6
+ import logging
7
+ from collections import defaultdict
8
+ from copy import deepcopy
9
+ from typing import Dict, Iterable, List, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
+ from torch.utils.data import DataLoader
15
+ from tqdm.auto import tqdm
16
+ from typing_extensions import override
17
+
18
+ from fusion_bench import BaseAlgorithm, BaseModelPool
19
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
20
+ from fusion_bench.mixins import CLIPClassificationMixin, SimpleProfilerMixin
21
+ from fusion_bench.utils import first
22
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
23
+ from fusion_bench.utils.type import StateDictType
24
+
25
+ from .utils import state_dict_to_vector, vector_to_state_dict
26
+
27
+ log = logging.getLogger(__name__)
28
+
29
+
30
+ def trainable_state_dict(module: nn.Module) -> StateDictType:
31
+ """
32
+ Returns the state dictionary of the module containing only the trainable parameters.
33
+
34
+ Args:
35
+ module (nn.Module): The neural network module.
36
+
37
+ Returns:
38
+ Dict[str, Tensor]: A dictionary containing the names and values of the trainable parameters.
39
+ """
40
+ return {
41
+ name: param for name, param in module.named_parameters() if param.requires_grad
42
+ }
43
+
44
+
45
+ class TaskArithmeticWithTrustRegionForCLIP(
46
+ BaseAlgorithm,
47
+ SimpleProfilerMixin,
48
+ CLIPClassificationMixin,
49
+ ):
50
+ def __init__(
51
+ self,
52
+ scaling_factor: Union[float, List[float]],
53
+ threshold_quantile: float,
54
+ max_samples: int,
55
+ batch_size: int,
56
+ zero_shot: bool,
57
+ **kwargs,
58
+ ):
59
+ self.scaling_factor = scaling_factor
60
+ self.threshold_quantile = threshold_quantile
61
+ self.max_samples = max_samples
62
+ self.batch_size = batch_size
63
+ self.zero_shot = zero_shot
64
+ super().__init__(**kwargs)
65
+
66
+ @override
67
+ def run(self, modelpool: BaseModelPool):
68
+ self.modelpool = modelpool
69
+
70
+ # compute the task vectors
71
+ pretrained_model, task_vectors = self.compute_vanilla_task_vectors()
72
+ task_vectors = {
73
+ name: state_dict_to_vector(task_vector)
74
+ for name, task_vector in task_vectors.items()
75
+ }
76
+
77
+ if not self.zero_shot:
78
+ all_avg_abs_grads = self.compute_avg_abs_grads(pretrained_model)
79
+ all_avg_abs_grads = {
80
+ n: state_dict_to_vector(grad) for n, grad in all_avg_abs_grads.items()
81
+ }
82
+ else:
83
+ # the task vector is used to estimate the gradient
84
+ all_avg_abs_grads = {name: tv.abs() for name, tv in task_vectors.items()}
85
+
86
+ # compute the trust region
87
+ Omega = torch.zeros_like(first(all_avg_abs_grads.values()))
88
+
89
+ for i in all_avg_abs_grads:
90
+ for j in all_avg_abs_grads:
91
+ if i != j:
92
+ vector1 = all_avg_abs_grads[i]
93
+ vector2 = torch.abs(task_vectors[j])
94
+ Omega += vector1 * vector2
95
+
96
+ values, indices = Omega.sort(descending=False)
97
+ threshold = values[
98
+ max(0, min(int(Omega.numel() * self.threshold_quantile), Omega.numel() - 1))
99
+ ]
100
+
101
+ mask = (Omega < threshold).bool()
102
+
103
+ # compute the task vectors
104
+ for task in task_vectors:
105
+ task_vectors[task] = task_vectors[task] * mask
106
+
107
+ task_vector_sum = sum(task_vectors.values())
108
+ task_vector_sum = vector_to_state_dict(
109
+ task_vector_sum, trainable_state_dict(pretrained_model)
110
+ )
111
+
112
+ if isinstance(self.scaling_factor, (int, float)):
113
+ model = pretrained_model
114
+ for name, param in model.named_parameters():
115
+ param.data += task_vector_sum[name] * self.scaling_factor
116
+ return model
117
+ elif isinstance(self.scaling_factor, Iterable):
118
+ models = {}
119
+ for scaling_factor in self.scaling_factor:
120
+ model = deepcopy(pretrained_model)
121
+ for name, param in pretrained_model.named_parameters():
122
+ param.data += task_vector_sum[name] * scaling_factor
123
+ models[scaling_factor] = model
124
+ return models
125
+ else:
126
+ raise ValueError(
127
+ f"Incorrect type of `scaling_factor`: {type(self.scaling_factor)}. "
128
+ "It should be a single real number or a list of real numbers."
129
+ )
130
+
131
+ def compute_avg_abs_grads(self, pretrained_model):
132
+ modelpool = self.modelpool
133
+
134
+ self.setup_zero_shot_classification_head()
135
+
136
+ pretrained_model = (
137
+ deepcopy(pretrained_model)
138
+ if pretrained_model is not None
139
+ else modelpool.load_pretrained_model()
140
+ )
141
+ pretrained_model = self.fabric.setup_module(pretrained_model)
142
+ pretrained_model.train()
143
+
144
+ all_avg_abs_grads: Dict[str, StateDictType] = {}
145
+ for train_dataset_name in (
146
+ pbar := tqdm(
147
+ modelpool.train_dataset_names, desc="Train datasets", dynamic_ncols=True
148
+ )
149
+ ):
150
+ pbar.set_description(f"Train dataset: {train_dataset_name}")
151
+ dataset = modelpool.load_train_dataset(train_dataset_name)
152
+ dataset = CLIPDataset(dataset, self.clip_processor)
153
+ dataloader = DataLoader(dataset, shuffle=True, batch_size=self.batch_size)
154
+ dataloader = self.fabric.setup_dataloaders(dataloader)
155
+
156
+ grad: StateDictType = defaultdict(float)
157
+ num_samples = 0
158
+ for batch in dataloader:
159
+ images, labels = batch
160
+ batch_size = images.size(0)
161
+
162
+ if num_samples + batch_size > self.max_samples:
163
+ batch_size = self.max_samples - num_samples
164
+ images = images[:batch_size]
165
+ labels = labels[:batch_size]
166
+
167
+ logits = self.compute_logits(
168
+ pretrained_model, images, task=train_dataset_name
169
+ )
170
+ for i in range(batch_size):
171
+ pretrained_model.zero_grad()
172
+ loss = F.cross_entropy(logits[i], labels[i])
173
+ self.fabric.backward(
174
+ loss, retain_graph=True if i != batch_size - 1 else False
175
+ )
176
+ for name, param in pretrained_model.module.named_parameters():
177
+ if param.requires_grad:
178
+ grad[name] += torch.abs(param.grad).detach()
179
+
180
+ num_samples += batch_size
181
+ if num_samples >= self.max_samples:
182
+ break
183
+
184
+ for name in grad:
185
+ grad[name] = (grad[name] / num_samples).cpu()
186
+
187
+ all_avg_abs_grads[name] = grad
188
+ return all_avg_abs_grads
189
+
190
+ @torch.no_grad()
191
+ def compute_vanilla_task_vectors(self):
192
+ modelpool = self.modelpool
193
+
194
+ pretrained_model = modelpool.load_pretrained_model()
195
+ pretrained_sd = trainable_state_dict(pretrained_model)
196
+ finetuned_sds = {
197
+ name: trainable_state_dict(model)
198
+ for name, model in modelpool.named_models()
199
+ }
200
+
201
+ task_vectors = {
202
+ name: state_dict_sub(finetuned, pretrained_sd)
203
+ for name, finetuned in finetuned_sds.items()
204
+ }
205
+ return pretrained_model, task_vectors