fusion-bench 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (727) hide show
  1. fusion_bench/__init__.py +20 -0
  2. fusion_bench/__main__.py +4 -0
  3. fusion_bench/compat/__init__.py +0 -0
  4. fusion_bench/compat/method/__init__.py +109 -0
  5. fusion_bench/compat/method/base_algorithm.py +58 -0
  6. fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
  7. fusion_bench/compat/modelpool/__init__.py +116 -0
  8. fusion_bench/compat/modelpool/base_pool.py +328 -0
  9. fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
  10. fusion_bench/compat/taskpool/__init__.py +95 -0
  11. fusion_bench/compat/taskpool/base_pool.py +111 -0
  12. fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
  13. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
  14. fusion_bench/constants/__init__.py +2 -0
  15. fusion_bench/constants/paths.py +18 -0
  16. fusion_bench/dataset/__init__.py +29 -0
  17. fusion_bench/dataset/arc_agi/__init__.py +6 -0
  18. fusion_bench/dataset/arc_agi/arc.py +308 -0
  19. fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
  20. fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
  21. fusion_bench/dataset/arc_agi/messagers.py +1355 -0
  22. fusion_bench/dataset/arc_agi/np_cache.py +168 -0
  23. fusion_bench/dataset/arc_agi/preprocess.py +298 -0
  24. fusion_bench/dataset/arc_agi/representers.py +1019 -0
  25. fusion_bench/dataset/clip_dataset.py +71 -0
  26. fusion_bench/dataset/fer2013.py +12 -0
  27. fusion_bench/dataset/gpt2_glue.py +300 -0
  28. fusion_bench/dataset/gsm8k.py +60 -0
  29. fusion_bench/dataset/image_dataset.py +55 -0
  30. fusion_bench/dataset/imdb.py +11 -0
  31. fusion_bench/dataset/llama/__init__.py +1 -0
  32. fusion_bench/dataset/llama/alpaca.py +232 -0
  33. fusion_bench/dataset/llama/collate.py +120 -0
  34. fusion_bench/dataset/llama/metamathqa.py +50 -0
  35. fusion_bench/dataset/llama/openai.py +160 -0
  36. fusion_bench/dataset/llama/preference_700k.py +70 -0
  37. fusion_bench/dataset/llama/sharegpt.py +141 -0
  38. fusion_bench/dataset/llama/squad.py +125 -0
  39. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  40. fusion_bench/dataset/llama/ultrachat.py +58 -0
  41. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  42. fusion_bench/dataset/llama/wikitext.py +89 -0
  43. fusion_bench/dataset/nyuv2.py +119 -0
  44. fusion_bench/method/__init__.py +177 -0
  45. fusion_bench/method/ada_svd/__init__.py +2 -0
  46. fusion_bench/method/ada_svd/clip_vision.py +319 -0
  47. fusion_bench/method/adamerging/__init__.py +6 -0
  48. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
  49. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
  50. fusion_bench/method/adamerging/entropy_loss.py +25 -0
  51. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
  52. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
  53. fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
  54. fusion_bench/method/adamerging/llama_adamerging.py +335 -0
  55. fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
  56. fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
  57. fusion_bench/method/adamerging/utils.py +15 -0
  58. fusion_bench/method/analysis/__init__.py +2 -0
  59. fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
  60. fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
  61. fusion_bench/method/base_algorithm.py +44 -0
  62. fusion_bench/method/classification/__init__.py +3 -0
  63. fusion_bench/method/classification/clip_finetune.py +444 -0
  64. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  65. fusion_bench/method/concrete_subspace/__init__.py +6 -0
  66. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
  67. fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
  68. fusion_bench/method/dare/__init__.py +4 -0
  69. fusion_bench/method/dare/simple_average.py +31 -0
  70. fusion_bench/method/dare/task_arithmetic.py +82 -0
  71. fusion_bench/method/dare/ties_merging.py +100 -0
  72. fusion_bench/method/dare/utils.py +87 -0
  73. fusion_bench/method/dawe/__init__.py +2 -0
  74. fusion_bench/method/dawe/dawe_for_clip.py +274 -0
  75. fusion_bench/method/dawe/warppers/__init__.py +13 -0
  76. fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
  77. fusion_bench/method/depth_upscaling/__init__.py +3 -0
  78. fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
  79. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
  80. fusion_bench/method/dummy.py +35 -0
  81. fusion_bench/method/ensemble.py +98 -0
  82. fusion_bench/method/fisher_merging/__init__.py +4 -0
  83. fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
  84. fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
  85. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
  86. fusion_bench/method/linear/__init__.py +6 -0
  87. fusion_bench/method/linear/expo.py +118 -0
  88. fusion_bench/method/linear/linear_interpolation.py +60 -0
  89. fusion_bench/method/linear/llama_expo.py +229 -0
  90. fusion_bench/method/linear/simple_average_for_llama.py +54 -0
  91. fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
  92. fusion_bench/method/lm_finetune/__init__.py +3 -0
  93. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  94. fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
  95. fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
  96. fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
  97. fusion_bench/method/mixture_of_experts/__init__.py +7 -0
  98. fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
  99. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
  100. fusion_bench/method/model_recombination.py +121 -0
  101. fusion_bench/method/opcm/__init__.py +4 -0
  102. fusion_bench/method/opcm/opcm.py +277 -0
  103. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  104. fusion_bench/method/opcm/ties_merging.py +156 -0
  105. fusion_bench/method/opcm/utils.py +73 -0
  106. fusion_bench/method/opcm/weight_average.py +120 -0
  107. fusion_bench/method/pruning/__init__.py +5 -0
  108. fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
  109. fusion_bench/method/pruning/llama_random_prune.py +143 -0
  110. fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
  111. fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
  112. fusion_bench/method/pruning/prune_utils.py +165 -0
  113. fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
  114. fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
  115. fusion_bench/method/pruning/wanda_utils/data.py +135 -0
  116. fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
  117. fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
  118. fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
  119. fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
  120. fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
  121. fusion_bench/method/pwe_moe/__init__.py +5 -0
  122. fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
  123. fusion_bench/method/pwe_moe/module.py +316 -0
  124. fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
  125. fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
  126. fusion_bench/method/pwe_moe/utils.py +43 -0
  127. fusion_bench/method/rankone_moe/__init__.py +3 -0
  128. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  129. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  130. fusion_bench/method/regmean/__init__.py +4 -0
  131. fusion_bench/method/regmean/clip_regmean.py +131 -0
  132. fusion_bench/method/regmean/gpt2_regmean.py +147 -0
  133. fusion_bench/method/regmean/regmean.py +375 -0
  134. fusion_bench/method/simple_average.py +112 -0
  135. fusion_bench/method/slerp/__init__.py +2 -0
  136. fusion_bench/method/slerp/slerp.py +101 -0
  137. fusion_bench/method/slerp/slerp_utils.py +107 -0
  138. fusion_bench/method/smile_upscaling/__init__.py +3 -0
  139. fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
  140. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
  141. fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
  142. fusion_bench/method/sparse_we_moe/__init__.py +2 -0
  143. fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
  144. fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
  145. fusion_bench/method/sparselo/__init__.py +2 -0
  146. fusion_bench/method/sparselo/sparselo.py +955 -0
  147. fusion_bench/method/surgery/__init__.py +1 -0
  148. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  149. fusion_bench/method/tall_mask/__init__.py +0 -0
  150. fusion_bench/method/tall_mask/utils.py +234 -0
  151. fusion_bench/method/task_arithmetic/__init__.py +2 -0
  152. fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
  153. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  154. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  155. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  156. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  157. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
  158. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  159. fusion_bench/method/ties_merging/__init__.py +2 -0
  160. fusion_bench/method/ties_merging/ties_merging.py +117 -0
  161. fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
  162. fusion_bench/method/trust_region/__init__.py +2 -0
  163. fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
  164. fusion_bench/method/trust_region/utils.py +58 -0
  165. fusion_bench/method/we_moe/__init__.py +2 -0
  166. fusion_bench/method/we_moe/clip_we_moe.py +161 -0
  167. fusion_bench/method/we_moe/we_moe.py +247 -0
  168. fusion_bench/method/weighted_average/__init__.py +3 -0
  169. fusion_bench/method/weighted_average/llama.py +113 -0
  170. fusion_bench/method/weighted_average/weighted_average.py +102 -0
  171. fusion_bench/metrics/__init__.py +0 -0
  172. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  173. fusion_bench/metrics/nyuv2/__init__.py +11 -0
  174. fusion_bench/metrics/nyuv2/depth.py +45 -0
  175. fusion_bench/metrics/nyuv2/loss.py +31 -0
  176. fusion_bench/metrics/nyuv2/noise.py +16 -0
  177. fusion_bench/metrics/nyuv2/normal.py +48 -0
  178. fusion_bench/metrics/nyuv2/segmentation.py +43 -0
  179. fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
  180. fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
  181. fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
  182. fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
  183. fusion_bench/mixins/__init__.py +28 -0
  184. fusion_bench/mixins/clip_classification.py +252 -0
  185. fusion_bench/mixins/fabric_training.py +320 -0
  186. fusion_bench/mixins/lightning_fabric.py +174 -0
  187. fusion_bench/mixins/optim/__init__.py +0 -0
  188. fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
  189. fusion_bench/mixins/rich_live.py +21 -0
  190. fusion_bench/mixins/serialization.py +132 -0
  191. fusion_bench/mixins/simple_profiler.py +79 -0
  192. fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
  193. fusion_bench/modelpool/__init__.py +42 -0
  194. fusion_bench/modelpool/base_pool.py +268 -0
  195. fusion_bench/modelpool/causal_lm/__init__.py +2 -0
  196. fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
  197. fusion_bench/modelpool/clip_vision/__init__.py +1 -0
  198. fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
  199. fusion_bench/modelpool/huggingface_automodel.py +20 -0
  200. fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
  201. fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
  202. fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
  203. fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
  204. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  205. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  206. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  207. fusion_bench/models/__init__.py +3 -0
  208. fusion_bench/models/chat_templates/__init__.py +1 -0
  209. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  210. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  211. fusion_bench/models/hf_clip.py +199 -0
  212. fusion_bench/models/linearized/__init__.py +0 -0
  213. fusion_bench/models/linearized/linearized_model_utils.py +91 -0
  214. fusion_bench/models/linearized/vision_model.py +122 -0
  215. fusion_bench/models/llama/__init__.py +16 -0
  216. fusion_bench/models/llama/model_utils/__init__.py +0 -0
  217. fusion_bench/models/llama/model_utils/embedding.py +87 -0
  218. fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
  219. fusion_bench/models/llama/model_utils/misc.py +112 -0
  220. fusion_bench/models/llama/model_utils/mod.py +52 -0
  221. fusion_bench/models/llama/model_utils/visual.py +241 -0
  222. fusion_bench/models/llama/patcher.py +78 -0
  223. fusion_bench/models/llama/tokenizer_loader.py +153 -0
  224. fusion_bench/models/masks/__init__.py +2 -0
  225. fusion_bench/models/masks/mask_model.py +160 -0
  226. fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
  227. fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
  228. fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
  229. fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
  230. fusion_bench/models/modeling_losparse_llama/register.py +8 -0
  231. fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
  232. fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
  233. fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
  234. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
  235. fusion_bench/models/modeling_smile_mistral/register.py +8 -0
  236. fusion_bench/models/nyuv2/__init__.py +0 -0
  237. fusion_bench/models/nyuv2/aspp.py +82 -0
  238. fusion_bench/models/nyuv2/lightning_module.py +176 -0
  239. fusion_bench/models/nyuv2/resnet.py +405 -0
  240. fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
  241. fusion_bench/models/parameter_dict.py +75 -0
  242. fusion_bench/models/rankone_moe.py +410 -0
  243. fusion_bench/models/separate_io.py +105 -0
  244. fusion_bench/models/smile_moe/__init__.py +0 -0
  245. fusion_bench/models/smile_moe/linear.py +256 -0
  246. fusion_bench/models/sparse_we_moe.py +459 -0
  247. fusion_bench/models/surgery/__init__.py +1 -0
  248. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  249. fusion_bench/models/utils.py +80 -0
  250. fusion_bench/models/we_moe.py +247 -0
  251. fusion_bench/models/wrappers/__init__.py +0 -0
  252. fusion_bench/models/wrappers/ensemble.py +183 -0
  253. fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
  254. fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
  255. fusion_bench/optim/__init__.py +2 -0
  256. fusion_bench/optim/exception.py +47 -0
  257. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  258. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  259. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  260. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  261. fusion_bench/optim/mezo.py +118 -0
  262. fusion_bench/programs/__init__.py +20 -0
  263. fusion_bench/programs/base_program.py +9 -0
  264. fusion_bench/programs/fabric_fusion_program.py +299 -0
  265. fusion_bench/scripts/__init__.py +0 -0
  266. fusion_bench/scripts/cli.py +43 -0
  267. fusion_bench/scripts/clip/__init__.py +0 -0
  268. fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
  269. fusion_bench/scripts/imgui.py +218 -0
  270. fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
  271. fusion_bench/scripts/webui.py +405 -0
  272. fusion_bench/taskpool/__init__.py +39 -0
  273. fusion_bench/taskpool/base_pool.py +35 -0
  274. fusion_bench/taskpool/clip_vision/__init__.py +4 -0
  275. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  276. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
  277. fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
  278. fusion_bench/taskpool/dummy.py +58 -0
  279. fusion_bench/taskpool/gpt2_text_classification.py +149 -0
  280. fusion_bench/taskpool/llama/__init__.py +1 -0
  281. fusion_bench/taskpool/llama/reward_model.py +157 -0
  282. fusion_bench/taskpool/llama/test_generation.py +185 -0
  283. fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
  284. fusion_bench/tasks/__init__.py +2 -0
  285. fusion_bench/tasks/base_task.py +18 -0
  286. fusion_bench/tasks/classification.py +75 -0
  287. fusion_bench/tasks/clip_classification/__init__.py +183 -0
  288. fusion_bench/tasks/clip_classification/cifar10.py +33 -0
  289. fusion_bench/tasks/clip_classification/cifar100.py +146 -0
  290. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
  291. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  292. fusion_bench/tasks/clip_classification/dtd.py +60 -0
  293. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  294. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  295. fusion_bench/tasks/clip_classification/eurosat.py +18 -0
  296. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  297. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  298. fusion_bench/tasks/clip_classification/flower102.py +106 -0
  299. fusion_bench/tasks/clip_classification/food101.py +105 -0
  300. fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
  301. fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
  302. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  303. fusion_bench/tasks/clip_classification/mnist.py +5 -0
  304. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  305. fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
  306. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  307. fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
  308. fusion_bench/tasks/clip_classification/resisc45.py +68 -0
  309. fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
  310. fusion_bench/tasks/clip_classification/stl10.py +17 -0
  311. fusion_bench/tasks/clip_classification/sun397.py +404 -0
  312. fusion_bench/tasks/clip_classification/svhn.py +5 -0
  313. fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
  314. fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
  315. fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
  316. fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
  317. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
  318. fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
  319. fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
  320. fusion_bench/utils/__init__.py +14 -0
  321. fusion_bench/utils/auto.py +31 -0
  322. fusion_bench/utils/cache_utils.py +58 -0
  323. fusion_bench/utils/data.py +165 -0
  324. fusion_bench/utils/devices.py +231 -0
  325. fusion_bench/utils/dict.py +43 -0
  326. fusion_bench/utils/dtype.py +146 -0
  327. fusion_bench/utils/expr.py +90 -0
  328. fusion_bench/utils/fabric.py +17 -0
  329. fusion_bench/utils/functools.py +37 -0
  330. fusion_bench/utils/hydra_utils.py +28 -0
  331. fusion_bench/utils/instantiate.py +450 -0
  332. fusion_bench/utils/json.py +93 -0
  333. fusion_bench/utils/lazy_imports.py +74 -0
  334. fusion_bench/utils/misc.py +18 -0
  335. fusion_bench/utils/packages.py +84 -0
  336. fusion_bench/utils/parameters.py +323 -0
  337. fusion_bench/utils/path.py +22 -0
  338. fusion_bench/utils/plot/__init__.py +0 -0
  339. fusion_bench/utils/plot/color_data.py +1726 -0
  340. fusion_bench/utils/plot/token.py +52 -0
  341. fusion_bench/utils/plot/token_notebook.py +127 -0
  342. fusion_bench/utils/pylogger.py +55 -0
  343. fusion_bench/utils/rich_utils.py +201 -0
  344. fusion_bench/utils/set.py +8 -0
  345. fusion_bench/utils/state_dict_arithmetic.py +297 -0
  346. fusion_bench/utils/strenum/__init__.py +326 -0
  347. fusion_bench/utils/strenum/_name_mangler.py +127 -0
  348. fusion_bench/utils/strenum/_version.py +556 -0
  349. fusion_bench/utils/tensorboard.py +51 -0
  350. fusion_bench/utils/timer.py +49 -0
  351. fusion_bench/utils/type.py +34 -0
  352. fusion_bench-0.2.9.dist-info/LICENSE +21 -0
  353. fusion_bench-0.2.9.dist-info/METADATA +258 -0
  354. fusion_bench-0.2.9.dist-info/RECORD +727 -0
  355. fusion_bench-0.2.9.dist-info/WHEEL +5 -0
  356. fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
  357. fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
  358. fusion_bench_config/README.md +12 -0
  359. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
  360. fusion_bench_config/dataset/image_classification/README.md +6 -0
  361. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  362. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  363. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
  364. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
  365. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  366. fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
  367. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  368. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  369. fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
  370. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  371. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  372. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  373. fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
  374. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  375. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  376. fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
  377. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  378. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  379. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  380. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  381. fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
  382. fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
  383. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  384. fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
  385. fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
  386. fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
  387. fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
  388. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  389. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  390. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
  391. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
  392. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  393. fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
  394. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  395. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  396. fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
  397. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  398. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  399. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  400. fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
  401. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  402. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  403. fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
  404. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  405. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  406. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  407. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  408. fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
  409. fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
  410. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  411. fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
  412. fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
  413. fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
  414. fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
  415. fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
  416. fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
  417. fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
  418. fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
  419. fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
  420. fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
  421. fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
  422. fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
  423. fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
  424. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  425. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  426. fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
  427. fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
  428. fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
  429. fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
  430. fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
  431. fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
  432. fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
  433. fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
  434. fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
  435. fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
  436. fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
  437. fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
  438. fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
  439. fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
  440. fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
  441. fusion_bench_config/fabric/auto.yaml +16 -0
  442. fusion_bench_config/fabric/llama_ddp.yaml +18 -0
  443. fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
  444. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  445. fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
  446. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
  447. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  448. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  449. fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
  450. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  451. fusion_bench_config/fabric_model_fusion.yaml +20 -0
  452. fusion_bench_config/hydra/default.yaml +8 -0
  453. fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
  454. fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
  455. fusion_bench_config/llama_full_finetune.yaml +19 -0
  456. fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
  457. fusion_bench_config/llama_model_fusion.yaml +17 -0
  458. fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
  459. fusion_bench_config/method/adamerging/clip.yaml +23 -0
  460. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
  461. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
  462. fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
  463. fusion_bench_config/method/adamerging.yaml +23 -0
  464. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
  465. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
  466. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  467. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  468. fusion_bench_config/method/clip_finetune.yaml +26 -0
  469. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
  470. fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
  471. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
  472. fusion_bench_config/method/dare/simple_average.yaml +5 -0
  473. fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
  474. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  475. fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
  476. fusion_bench_config/method/depth_upscaling.yaml +5 -0
  477. fusion_bench_config/method/dummy.yaml +1 -0
  478. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
  479. fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
  480. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
  481. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
  482. fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
  483. fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
  484. fusion_bench_config/method/linear/expo.yaml +8 -0
  485. fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
  486. fusion_bench_config/method/linear/llama_expo.yaml +19 -0
  487. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
  488. fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
  489. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
  490. fusion_bench_config/method/linear/weighted_average.yaml +6 -0
  491. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
  492. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  493. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
  494. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
  495. fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
  496. fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
  497. fusion_bench_config/method/model_recombination.yaml +4 -0
  498. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  499. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  500. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  501. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  502. fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
  503. fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
  504. fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
  505. fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
  506. fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
  507. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  508. fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
  509. fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
  510. fusion_bench_config/method/regmean/regmean.yaml +4 -0
  511. fusion_bench_config/method/simple_average.yaml +1 -0
  512. fusion_bench_config/method/slerp/slerp.yaml +6 -0
  513. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
  514. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
  515. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
  516. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
  517. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
  518. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
  519. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  520. fusion_bench_config/method/task_arithmetic.yaml +2 -0
  521. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  522. fusion_bench_config/method/ties_merging.yaml +8 -0
  523. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
  524. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
  525. fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
  526. fusion_bench_config/model/clip-vit/README.md +38 -0
  527. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
  528. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  529. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  530. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  531. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  532. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
  533. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
  534. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  535. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
  536. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  537. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  538. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  539. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
  540. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  541. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
  542. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  543. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  544. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  545. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  546. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
  547. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
  548. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  549. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
  550. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
  551. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
  552. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  553. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  554. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  555. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  556. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
  557. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
  558. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  559. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
  560. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  561. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  562. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  563. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
  564. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  565. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
  566. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  567. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  568. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  569. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  570. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
  571. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
  572. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  573. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
  574. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
  575. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
  576. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  577. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  578. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  579. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  580. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
  581. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
  582. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  583. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
  584. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  585. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  586. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  587. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
  588. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  589. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
  590. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  591. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  592. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  593. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  594. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
  595. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
  596. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  597. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
  598. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
  599. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  600. fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
  601. fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
  602. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
  603. fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
  604. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
  605. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
  606. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
  607. fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
  608. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
  609. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
  610. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
  611. fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
  612. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
  613. fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
  614. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
  615. fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
  616. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
  617. fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
  618. fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
  619. fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
  620. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
  621. fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
  622. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
  623. fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
  624. fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
  625. fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
  626. fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
  627. fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
  628. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
  629. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
  630. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
  631. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  632. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  633. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  634. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  635. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  636. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
  637. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
  638. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
  639. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
  640. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
  641. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  642. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  643. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  644. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  645. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
  646. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
  647. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
  648. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
  649. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
  650. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
  651. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
  652. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  653. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
  654. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  655. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
  656. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
  657. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  658. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  659. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  660. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  661. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
  662. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  663. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  664. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
  665. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  666. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  667. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
  668. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
  669. fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
  670. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
  671. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
  672. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
  673. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
  674. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
  675. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  676. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  677. fusion_bench_config/modelpool/automodelpool.yaml +12 -0
  678. fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
  679. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
  680. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
  681. fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
  682. fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
  683. fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
  684. fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
  685. fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
  686. fusion_bench_config/nyuv2_config.yaml +17 -0
  687. fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
  688. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
  689. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  690. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
  691. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
  692. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
  693. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
  694. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
  695. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  696. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  697. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  698. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  699. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  700. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  701. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  702. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  703. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  704. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  705. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  706. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  707. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  708. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  709. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  710. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  711. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  712. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  713. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  714. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  715. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  716. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  717. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  718. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  719. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
  720. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
  721. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  722. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
  723. fusion_bench_config/taskpool/dummy.yaml +2 -0
  724. fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
  725. fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
  726. fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
  727. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
@@ -0,0 +1,640 @@
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+ from fusion_bench.utils.type import StateDictType
7
+
8
+
9
+ def compute_svd_dict(task_vectors, config):
10
+ """
11
+ Computes the Singular Value Decomposition (SVD) for each task vector in the provided datasets and stores the results in a dictionary.
12
+
13
+ Args:
14
+ task_vectors (list): A list of task vector objects, where each task vector contains a dictionary of matrices to be decomposed.
15
+ config (object): Configuration object containing the list of datasets under the attribute `DATASETS`.
16
+
17
+ Returns:
18
+ dict: A dictionary where each key is a dataset name and the value is another dictionary containing the SVD components ('u', 's', 'v') for each matrix in the task vector.
19
+ If a matrix is not 2-dimensional or contains 'text_projection' in its key, it is stored under the key 'dim1' without decomposition.
20
+ """
21
+ sv_reduction = 1 / len(config.DATASETS)
22
+ with torch.no_grad():
23
+ svd_dict = {}
24
+ for i, (task_vector, dataset) in enumerate(zip(task_vectors, config.DATASETS)):
25
+ svd_dict[dataset] = {}
26
+ print(f"Computing SVD for {dataset}...")
27
+ for key in task_vector.vector:
28
+ svd_dict[dataset][key] = {}
29
+ if (
30
+ len(task_vector.vector[key].shape) == 2
31
+ and "text_projection" not in key
32
+ ):
33
+ u, s, v = torch.linalg.svd(
34
+ task_vector.vector[key], full_matrices=False
35
+ )
36
+ reduced_index_s = int(s.shape[0] * sv_reduction)
37
+
38
+ temp_u = torch.zeros_like(u)
39
+ # select only the first reduced_index_s columns of u and place them
40
+ temp_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
41
+ :, :reduced_index_s
42
+ ]
43
+ svd_dict[dataset][key]["u"] = temp_u
44
+
45
+ temp_s = torch.zeros_like(s)
46
+ temp_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
47
+ :reduced_index_s
48
+ ]
49
+
50
+ svd_dict[dataset][key]["s"] = temp_s # s_reduced
51
+
52
+ # select only the first reduced_index_s rows of v and place them
53
+ temp_v = torch.zeros_like(v)
54
+ temp_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
55
+ :reduced_index_s, :
56
+ ]
57
+
58
+ svd_dict[dataset][key]["v"] = temp_v
59
+
60
+ # temp_mat = temp_u @ torch.diag_embed(temp_s) @ temp_v
61
+ else:
62
+ svd_dict[dataset][key]["dim1"] = task_vector.vector[key]
63
+ return svd_dict
64
+
65
+
66
+ def sum_svd_dict(svd_dict, config):
67
+ """
68
+ Sums the Singular Value Decomposition (SVD) components from multiple datasets and computes a new vector.
69
+
70
+ Args:
71
+ svd_dict (dict): A dictionary containing SVD components for multiple datasets. The structure of the dictionary is expected to be:
72
+ {
73
+ dataset_name: {
74
+ key: {
75
+ "u": tensor,
76
+ "s": tensor,
77
+ "v": tensor,
78
+ "dim1": tensor (optional)
79
+ }
80
+ }
81
+ }
82
+ config (object): A configuration object that contains a list of dataset names under the attribute `DATASETS`.
83
+
84
+ Returns:
85
+ dict: A dictionary containing the merged SVD components or averaged "dim1" values for each key.
86
+ """
87
+ print("Summing SVD...")
88
+ new_vector = {}
89
+ for key in svd_dict[config.DATASETS[0]]:
90
+ if "u" in svd_dict[config.DATASETS[0]][key].keys():
91
+ sum_u = sum([svd_dict[dataset][key]["u"] for dataset in config.DATASETS])
92
+ sum_s = sum([svd_dict[dataset][key]["s"] for dataset in config.DATASETS])
93
+ sum_v = sum([svd_dict[dataset][key]["v"] for dataset in config.DATASETS])
94
+
95
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
96
+ u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
97
+ new_vector[key] = torch.linalg.multi_dot(
98
+ (
99
+ u_u,
100
+ v_u,
101
+ torch.diag(sum_s),
102
+ u_v,
103
+ v_v,
104
+ )
105
+ )
106
+ else:
107
+ for i, dataset in enumerate(config.DATASETS, start=1):
108
+ if i == 1:
109
+ new_vector[key] = svd_dict[dataset][key]["dim1"]
110
+ else:
111
+ new_vector[key] += (
112
+ svd_dict[dataset][key]["dim1"] - new_vector[key]
113
+ ) / i
114
+ return new_vector
115
+
116
+
117
+ ###############
118
+ ##### LOSSLESS Orthogonalization
119
+ def compute_and_sum_svd_mem_reduction_lossless(
120
+ task_vectors: List[StateDictType],
121
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
122
+ ):
123
+ """
124
+ Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
125
+
126
+ This function performs the following steps:
127
+ 1. Iterates over each layer in the task vectors.
128
+ 2. For each layer, it computes the SVD of the corresponding matrix if it is a 2D tensor excluding "text_projection".
129
+ 3. Concatenate the U_i, S_i, and V_i matrices from the SVD across all tasks.
130
+ 4. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
131
+ 5. After concatenating the SVD components, recomputes the SVD of the summed U and V matrices and constructs the merged layer.
132
+
133
+ Args:
134
+ task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
135
+ accelerator (torch.device): The device to use for the computation.
136
+ Returns:
137
+ dict: A dictionary containing the new vectors after summing the SVD components.
138
+ """
139
+ # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
140
+ print("Computing SVD...")
141
+ with torch.no_grad():
142
+ new_vector = {}
143
+ for key in task_vectors[0]:
144
+ original_device = task_vectors[0][key].device
145
+ new_vector[key] = {}
146
+ for i, task_vector in enumerate(task_vectors):
147
+ vec = task_vector[key].to(accelerator)
148
+
149
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
150
+
151
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
152
+
153
+ if i == 0:
154
+ print(f"Computed SVD for {key}...")
155
+ sum_u = torch.zeros(
156
+ u.shape[0],
157
+ u.shape[1] * len(task_vectors),
158
+ device=accelerator,
159
+ )
160
+ sum_s = torch.zeros(
161
+ s.shape[0] * len(task_vectors), device=accelerator
162
+ )
163
+ sum_v = torch.zeros(
164
+ v.shape[0] * len(task_vectors),
165
+ v.shape[1],
166
+ device=accelerator,
167
+ )
168
+ reduced_index_s = s.shape[0]
169
+
170
+ # select only the first reduced_index_s columns of u and place them
171
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
172
+ :, :reduced_index_s
173
+ ]
174
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
175
+ :reduced_index_s
176
+ ]
177
+ # select only the first reduced_index_s rows of v and place them
178
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
179
+ :reduced_index_s, :
180
+ ]
181
+
182
+ else:
183
+ if i == 0:
184
+ new_vector[key] = vec.clone()
185
+ else:
186
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
187
+
188
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
189
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
190
+ u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
191
+
192
+ new_vector[key] = torch.linalg.multi_dot(
193
+ (
194
+ u_u,
195
+ v_u,
196
+ torch.diag(sum_s),
197
+ u_v,
198
+ v_v,
199
+ )
200
+ )
201
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
202
+ return new_vector
203
+
204
+
205
+ ###############
206
+ ##### LOSSLESS EIGENDECOMP
207
+ def compute_and_sum_svd_mem_reduction_lossless_eigen(
208
+ task_vectors: List[StateDictType],
209
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
210
+ ):
211
+ """
212
+ Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
213
+
214
+ This function performs the following steps:
215
+ 1. Iterates over each layer in the task vectors.
216
+ 2. For each layer, it computes the SVD of the corresponding matrix if it is a 2D tensor excluding "text_projection".
217
+ 3. Concatenate the U_i, S_i, and V_i matrices from the SVD across all tasks.
218
+ 4. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
219
+ 5. After concatenating the SVD components, recomputes the eigendecomposition of the summed U and V matrices and constructs the merged layer.
220
+
221
+ Args:
222
+ task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
223
+ accelerator (torch.device): The device to use for the computation.
224
+
225
+ Returns:
226
+ dict: A dictionary containing the new vectors after merging the SVD components.
227
+ """
228
+ # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
229
+ print("Computing SVD...")
230
+ with torch.no_grad():
231
+ new_vector = {}
232
+ for key in task_vectors[0]:
233
+ original_device = task_vectors[0][key].device
234
+ new_vector[key] = {}
235
+ for i, task_vector in enumerate(task_vectors):
236
+ vec = task_vector[key].to(accelerator)
237
+
238
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
239
+
240
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
241
+
242
+ if i == 0:
243
+ print(f"Computed SVD for {key}...")
244
+ sum_u = torch.zeros(
245
+ u.shape[0],
246
+ u.shape[1] * len(task_vectors),
247
+ device=accelerator,
248
+ )
249
+ sum_s = torch.zeros(
250
+ s.shape[0] * len(task_vectors), device=accelerator
251
+ )
252
+ sum_v = torch.zeros(
253
+ v.shape[0] * len(task_vectors),
254
+ v.shape[1],
255
+ device=accelerator,
256
+ )
257
+ reduced_index_s = s.shape[0]
258
+
259
+ # select only the first reduced_index_s columns of u and place them
260
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
261
+ :, :reduced_index_s
262
+ ]
263
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
264
+ :reduced_index_s
265
+ ]
266
+ # select only the first reduced_index_s rows of v and place them
267
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
268
+ :reduced_index_s, :
269
+ ]
270
+
271
+ else:
272
+ if i == 0:
273
+ new_vector[key] = vec.clone()
274
+ else:
275
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
276
+
277
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
278
+ sum_s, indices = torch.sort(sum_s, stable=True)
279
+
280
+ sum_u = torch.index_select(sum_u, 1, indices)
281
+ l_u, q_u = torch.linalg.eigh(sum_u.mT @ sum_u)
282
+ u_orth = (
283
+ q_u
284
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_u)) + 1e-12))
285
+ @ q_u.mT
286
+ )
287
+
288
+ sum_v = torch.index_select(sum_v, 0, indices)
289
+
290
+ l_v, q_v = torch.linalg.eigh(sum_v @ sum_v.mT)
291
+ v_orth = (
292
+ q_v
293
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_v)) + 1e-12))
294
+ @ q_v.mT
295
+ )
296
+
297
+ new_vector[key] = torch.linalg.multi_dot( # bool_mask *
298
+ (
299
+ sum_u,
300
+ u_orth,
301
+ torch.diag(sum_s),
302
+ v_orth,
303
+ sum_v,
304
+ )
305
+ )
306
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
307
+ return new_vector
308
+
309
+
310
+ ###############
311
+ #### TSV Merge Orthogonalization
312
+ @torch.no_grad()
313
+ def compute_and_sum_svd_mem_reduction(
314
+ task_vectors: List[StateDictType],
315
+ exclude_keys: Optional[List[str]] = None,
316
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
317
+ ) -> StateDictType:
318
+ """
319
+ Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
320
+ reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
321
+ the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
322
+ Computation of the SVD is performed also for the second operation.
323
+
324
+ Args:
325
+ task_vectors (list): A list of task vector objects, where each object contains a
326
+ dictionary of vectors.
327
+ exclude_keys (list): A list of keys to exclude from the TSVM.
328
+ accelerator (torch.device): The device to use for the computation.
329
+
330
+ Returns:
331
+ dict: A dictionary containing the new vectors after SVD computation and merging.
332
+ """
333
+ if exclude_keys is None:
334
+ exclude_keys = []
335
+ sv_reduction = 1 / len(task_vectors)
336
+
337
+ new_vector = {}
338
+ for key in task_vectors[0]:
339
+ original_device = task_vectors[0][key].device
340
+ original_dtype = task_vectors[0][key].dtype
341
+
342
+ new_vector[key] = {}
343
+ for i, task_vector in enumerate(task_vectors):
344
+ vec = task_vector[key].to(accelerator)
345
+
346
+ if len(task_vector[key].shape) == 2 and key not in exclude_keys:
347
+ # at current, the SVD is not supported for half precision, so we need to convert to float32
348
+ if not (
349
+ original_dtype == torch.float32 or original_dtype == torch.float64
350
+ ):
351
+ vec = vec.to(dtype=torch.float32)
352
+
353
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
354
+
355
+ if i == 0:
356
+ print(f"Computed SVD for {key}...")
357
+ sum_u = torch.zeros_like(u, device=accelerator)
358
+ sum_s = torch.zeros_like(s, device=accelerator)
359
+ sum_v = torch.zeros_like(v, device=accelerator)
360
+ reduced_index_s = int(s.shape[0] * sv_reduction)
361
+
362
+ # select only the first reduced_index_s columns of u and place them
363
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
364
+ :, :reduced_index_s
365
+ ]
366
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
367
+ :reduced_index_s
368
+ ]
369
+ # select only the first reduced_index_s rows of v and place them
370
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
371
+ :reduced_index_s, :
372
+ ]
373
+
374
+ else:
375
+ # if the vector is not a 2D tensor or is in exclude_keys, compute the mean
376
+ if i == 0:
377
+ new_vector[key] = vec.clone()
378
+ else:
379
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
380
+
381
+ if len(task_vector[key].shape) == 2 and key not in exclude_keys:
382
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
383
+ u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
384
+
385
+ new_vector[key] = torch.linalg.multi_dot(
386
+ (
387
+ u_u,
388
+ v_u,
389
+ torch.diag(sum_s),
390
+ u_v,
391
+ v_v,
392
+ )
393
+ )
394
+ new_vector[key] = new_vector[key].to(
395
+ device=original_device, dtype=original_dtype, non_blocking=True
396
+ )
397
+ return new_vector
398
+
399
+
400
+ ###############
401
+ #### TSV Merge Eigendecomp
402
+ def compute_and_sum_svd_mem_reduction_2(
403
+ task_vectors: List[StateDictType],
404
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
405
+ ):
406
+ """
407
+ Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
408
+ reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
409
+ the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
410
+ Computation of the eigendecomposition is performed instead of the SVD for the second operation.
411
+
412
+ Args:
413
+ task_vectors (list): A list of task vector objects, where each object contains a
414
+ dictionary of vectors.
415
+ accelerator (torch.device): The device to use for the computation.
416
+
417
+ Returns:
418
+ dict: A dictionary containing the new vectors after SVD computation and merging.
419
+ """
420
+ sv_reduction = 1 / len(task_vectors)
421
+
422
+ print("Computing SVD...")
423
+ with torch.no_grad():
424
+ new_vector = {}
425
+ for key in task_vectors[0]:
426
+ original_device = task_vectors[0][key].device
427
+ new_vector[key] = {}
428
+ for i, task_vector in enumerate(task_vectors):
429
+ vec = task_vector[key].to(accelerator)
430
+
431
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
432
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
433
+
434
+ if i == 0:
435
+ print(f"Computed SVD for {key}...")
436
+ sum_u = torch.zeros_like(u, device=accelerator)
437
+ sum_s = torch.zeros_like(s, device=accelerator)
438
+ sum_v = torch.zeros_like(v, device=accelerator)
439
+ reduced_index_s = int(s.shape[0] * sv_reduction)
440
+
441
+ # select only the first reduced_index_s columns of u and place them
442
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
443
+ :, :reduced_index_s
444
+ ]
445
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
446
+ :reduced_index_s
447
+ ]
448
+ # select only the first reduced_index_s rows of v and place them
449
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
450
+ :reduced_index_s, :
451
+ ]
452
+
453
+ else:
454
+ if i == 0:
455
+ new_vector[key] = vec.clone()
456
+ else:
457
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
458
+
459
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
460
+ sum_s, indices = torch.sort(sum_s, stable=True)
461
+
462
+ sum_u = torch.index_select(sum_u, 1, indices)
463
+ l_u, q_u = torch.linalg.eigh(sum_u.mT @ sum_u)
464
+ u_orth = (
465
+ q_u
466
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_u)) + 1e-12))
467
+ @ q_u.mT
468
+ )
469
+
470
+ sum_v = torch.index_select(sum_v, 0, indices)
471
+
472
+ l_v, q_v = torch.linalg.eigh(sum_v @ sum_v.mT)
473
+ v_orth = (
474
+ q_v
475
+ @ torch.diag(1.0 / (torch.sqrt(torch.abs(l_v)) + 1e-12))
476
+ @ q_v.mT
477
+ )
478
+
479
+ new_vector[key] = torch.linalg.multi_dot( # bool_mask *
480
+ (
481
+ sum_u,
482
+ u_orth,
483
+ torch.diag(sum_s),
484
+ v_orth,
485
+ sum_v,
486
+ )
487
+ )
488
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
489
+
490
+ return new_vector
491
+
492
+
493
+ ###############
494
+ #### Rank Reduction TV
495
+ def compute_and_sum_svd_mem_reduction_rank_reduction(
496
+ task_vectors: List[StateDictType],
497
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
498
+ ):
499
+ """
500
+ Compute and sum the Singular Value Decomposition (SVD) of task vectors with rank reduction.
501
+
502
+ This function performs SVD on the vectors in `task_vectors` and reduces their rank based on the
503
+ number of tasks specified in `config.DATASETS`. The reduced vectors are then summed together.
504
+
505
+ Args:
506
+ task_vectors (list): A list of task vector objects. Each object should have a `vector` attribute
507
+ which is a dictionary where keys are vector names and values are tensors.
508
+ accelerator (torch.device): The device to use for the computation.
509
+
510
+ Returns:
511
+ dict: A dictionary containing the new vectors after SVD computation and summation.
512
+ """
513
+ sv_reduction = 1 / len(task_vectors)
514
+ print("Computing SVD...")
515
+ with torch.no_grad():
516
+ new_vector = {}
517
+ for key in task_vectors[0]:
518
+ original_device = task_vectors[0][key].device
519
+ new_vector[key] = {}
520
+ for i, task_vector in enumerate(task_vectors):
521
+ vec = task_vector[key].to(accelerator)
522
+
523
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
524
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
525
+
526
+ if i == 0:
527
+ print(f"Computed SVD for {key}...")
528
+ sum_u = torch.zeros_like(u, device=accelerator)
529
+ sum_s = torch.zeros_like(s, device=accelerator)
530
+ sum_v = torch.zeros_like(v, device=accelerator)
531
+ reduced_index_s = int(s.shape[0] * sv_reduction)
532
+
533
+ # select only the first reduced_index_s columns of u and place them
534
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
535
+ :, :reduced_index_s
536
+ ]
537
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
538
+ :reduced_index_s
539
+ ]
540
+ # select only the first reduced_index_s rows of v and place them
541
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
542
+ :reduced_index_s, :
543
+ ]
544
+
545
+ else:
546
+ if i == 0:
547
+ new_vector[key] = vec.clone()
548
+ else:
549
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
550
+
551
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
552
+ new_vector[key] = torch.linalg.multi_dot(
553
+ (
554
+ sum_u,
555
+ torch.diag(sum_s),
556
+ sum_v,
557
+ )
558
+ )
559
+
560
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
561
+ return new_vector
562
+
563
+
564
+ def compute_and_sum_svd_mem_reduction_dummy(
565
+ task_vectors: List[StateDictType],
566
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
567
+ ):
568
+ """To perform dummy operations."""
569
+ sv_reduction = 1 / len(task_vectors)
570
+ print("Computing SVD...")
571
+ with torch.no_grad():
572
+ new_vector = {}
573
+ for key in task_vectors[0]:
574
+ original_device = task_vectors[0][key].device
575
+ new_vector[key] = {}
576
+ for i, task_vector in enumerate(task_vectors):
577
+ vec = task_vector[key].to(accelerator)
578
+
579
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
580
+ if i == 0:
581
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
582
+ reduced_index_s = int(s.shape[0] * sv_reduction)
583
+
584
+ print(f"Computed SVD for {key}...")
585
+ sum_u = torch.zeros_like(u)
586
+ sum_s = torch.zeros_like(s)
587
+ sum_v = torch.zeros_like(v)
588
+
589
+ # select only the first reduced_index_s columns of u and place them
590
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
591
+ :, :reduced_index_s
592
+ ]
593
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
594
+ :reduced_index_s
595
+ ]
596
+ # select only the first reduced_index_s rows of v and place them
597
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
598
+ :reduced_index_s, :
599
+ ]
600
+ else:
601
+ # generate u vectors orthogonal to the previous ones
602
+ # generate v vectors orthogonal to the previous ones
603
+ print("dummy")
604
+ u = torch.nn.functional.normalize(
605
+ torch.randn_like(sum_u), p=2, dim=-2
606
+ )
607
+ v = torch.nn.functional.normalize(
608
+ torch.randn_like(sum_v), p=2, dim=-1
609
+ )
610
+
611
+ # select only the first reduced_index_s columns of u and place them
612
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
613
+ :, :reduced_index_s
614
+ ]
615
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
616
+ :reduced_index_s
617
+ ]
618
+ # select only the first reduced_index_s rows of v and place them
619
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
620
+ :reduced_index_s, :
621
+ ]
622
+
623
+ else:
624
+ if i == 0:
625
+ new_vector[key] = vec.clone()
626
+ else:
627
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
628
+
629
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
630
+
631
+ new_vector[key] = torch.linalg.multi_dot(
632
+ (
633
+ sum_u,
634
+ torch.diag(sum_s),
635
+ sum_v,
636
+ )
637
+ )
638
+
639
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
640
+ return new_vector
@@ -0,0 +1,7 @@
1
+ from fusion_bench.method.ties_merging.ties_merging_utils import (
2
+ check_parameterNamesMatch,
3
+ check_state_dicts_equal,
4
+ )
5
+ from fusion_bench.utils import state_dict_to_vector, vector_to_state_dict
6
+
7
+ from . import TSVC_utils, TSVM_utils
@@ -0,0 +1,2 @@
1
+ # flake8: noqa F401
2
+ from .ties_merging import TiesMergingAlgorithm