fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+
3
+ pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: transformers.AutoModelForCausalLM.from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+
11
+ tokenizer:
12
+ _target_: transformers.AutoTokenizer.from_pretrained
13
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
+
15
+ train_datasets:
16
+ alpaca-cleaned:
17
+ _target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset
18
+ tokenizer: ${...tokenizer}
19
+ path: "yahma/alpaca-cleaned"
20
+ split: train
21
+ cache_path: null
@@ -0,0 +1,21 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+
3
+ pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: transformers.AutoModelForCausalLM.from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+
11
+ tokenizer:
12
+ _target_: transformers.AutoTokenizer.from_pretrained
13
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
+
15
+ train_datasets:
16
+ codealpaca:
17
+ _target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset
18
+ tokenizer: ${...tokenizer}
19
+ path: sahil2801/CodeAlpaca-20k
20
+ split: train
21
+ cache_path: null
@@ -0,0 +1,19 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+
3
+ pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: transformers.AutoModelForCausalLM.from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+
11
+ tokenizer:
12
+ _target_: transformers.AutoTokenizer.from_pretrained
13
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
+
15
+ train_datasets:
16
+ metamathqa:
17
+ _target_: fusion_bench.dataset.llama.metamathqa.load_tokenized_metamathqa
18
+ tokenizer: ${...tokenizer}
19
+ cache_path: null
@@ -0,0 +1,18 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+
3
+ pretrained_model_name_or_path: meta-llama/Llama-3-1B-Instruct
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: transformers.AutoModelForCausalLM.from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+
11
+ tokenizer:
12
+ _target_: transformers.AutoTokenizer.from_pretrained
13
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
+
15
+ train_datasets:
16
+ ultrachat-200k:
17
+ _target_: fusion_bench.dataset.llama.ultrachat.load_tokenized_ultrachat_200k
18
+ tokenizer: ${...tokenizer}
@@ -0,0 +1,23 @@
1
+ _target_: fusion_bench.modelpool.SeqenceClassificationModelPool
2
+
3
+ pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: fusion_bench.modelpool.seq_classification_lm.create_reward_model_from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+ use_flash_attention_2: true
11
+
12
+ tokenizer:
13
+ _target_: transformers.AutoTokenizer.from_pretrained
14
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
15
+ pad_token: <|end_of_text|> # do not use eos token (<|eos_id|>) as padding token because it is used as the end of each content
16
+
17
+ train_datasets:
18
+ preference_700k:
19
+ _target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_rlhf
20
+ tokenizer: ${...tokenizer}
21
+ path: hendrydong/preference_700K
22
+ split: train
23
+ cache_path: null
@@ -0,0 +1,14 @@
1
+ _target_: fusion_bench.modelpool.SeqenceClassificationModelPool
2
+
3
+ pretrained_model_name_or_path: fusion-bench/Llama-3.2-1B-Instruct_Bradly-Terry-RM_Preference-700k
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: transformers.AutoModelForSequenceClassification.from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+
11
+ tokenizer:
12
+ _target_: transformers.AutoTokenizer.from_pretrained
13
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
+ pad_token: <|end_of_text|> # do not use eos token (<|eos_id|>) as padding token because it is used as the end of each content
@@ -1,13 +1,17 @@
1
1
  defaults:
2
2
  - hydra: default
3
+ - fabric: auto
3
4
  - modelpool: nyuv2_modelpool
4
5
  - method: simple_average
5
6
  - taskpool: nyuv2_taskpool
6
7
  - _self_
8
+
9
+ _target_: fusion_bench.programs.FabricModelFusionProgram
10
+ _recursive_: false
11
+
7
12
  fast_dev_run: false # Run a single batch of data to test the model or method
8
13
  use_lightning: true # Use the fabric to run the experiment
9
14
  print_config: true # Print the configuration to the console
10
15
  save_report: false # path to save the result report
11
- fabric: null
12
16
  trainer:
13
17
  devices: 1
@@ -0,0 +1,27 @@
1
+ type: clip_vit_classification
2
+ name: clip-vit-robustness_clean
3
+ # corrption can be one of:
4
+ # contrast, gaussian_noise, impulse_noise, jpeg_compression, motion_blur, pixelate, spatter
5
+ corruption: ${corruption}
6
+ dataset_type: huggingface_image_classification
7
+ tasks:
8
+ - name: stanford_cars
9
+ dataset:
10
+ name: tanganke/stanford_cars
11
+ split: ${taskpool.corruption}
12
+ - name: eurosat
13
+ dataset:
14
+ name: tanganke/eurosat
15
+ split: ${taskpool.corruption}
16
+ - name: resisc45
17
+ dataset:
18
+ name: tanganke/resisc45
19
+ split: ${taskpool.corruption}
20
+ - name: gtsrb
21
+ dataset:
22
+ name: tanganke/gtsrb
23
+ split: ${taskpool.corruption}
24
+ clip_model: openai/clip-vit-base-patch32
25
+ batch_size: 128
26
+ num_workers: 16
27
+ fast_dev_run: ${fast_dev_run}
@@ -0,0 +1,19 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets:
4
+ # eight tasks in the task arithmetic paper
5
+ - sun397
6
+ - stanford-cars
7
+ - resisc45
8
+ - eurosat
9
+ - svhn
10
+ - gtsrb
11
+ - mnist
12
+ - dtd
13
+ # additional 6 tasks in the TALL mask paper (TALL 14)
14
+ - oxford_flowers102
15
+ - pcam
16
+ - fer2013
17
+ - oxford-iiit-pet
18
+ - stl10
19
+ - cifar100
@@ -0,0 +1,26 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets:
4
+ # eight tasks in the task arithmetic paper
5
+ - sun397
6
+ - stanford-cars
7
+ - resisc45
8
+ - eurosat
9
+ - svhn
10
+ - gtsrb
11
+ - mnist
12
+ - dtd
13
+ # additional 6 tasks in the TALL mask paper (TALL 14)
14
+ - oxford_flowers102
15
+ - pcam
16
+ - fer2013
17
+ - oxford-iiit-pet
18
+ - stl10
19
+ - cifar100
20
+ # additional 6 tasks in the TALL mask paper (TALL 20)
21
+ - cifar10
22
+ - food101
23
+ - fashion_mnist
24
+ - emnist_letters
25
+ - kmnist
26
+ - rendered-sst2
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: cifar10
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: cifar100
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: dtd
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: emnist_letters
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: eurosat
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: fashion_mnist
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: fer2013
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: food101
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: gtsrb
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: kmnist
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: mnist
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: oxford-iiit-pet
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: oxford_flowers102
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/val@test_datasets: oxford_flowers102
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: pcam
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: rendered-sst2
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: resisc45
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: stanford-cars
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: stl10
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: sun397
@@ -0,0 +1,3 @@
1
+ defaults:
2
+ - CLIPVisionModelTaskPool@: _template
3
+ - /dataset/image_classification/test@test_datasets: svhn
@@ -0,0 +1,18 @@
1
+ _target_: fusion_bench.taskpool.llama.reward_model.RewardModelEvaluationTaskPool
2
+
3
+ test_datasets:
4
+ preference_700k:
5
+ _target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_rlhf
6
+ tokenizer: ${...tokenizer}
7
+ path: hendrydong/preference_700K
8
+ split: train
9
+ cache_path: null
10
+
11
+ dataloader_kwargs:
12
+ shuffle: False
13
+ batch_size: 16
14
+
15
+ tokenizer: ${..modelpool.tokenizer}
16
+
17
+ max_num_samples: 1000
18
+ seed: 42
@@ -1,26 +0,0 @@
1
- defaults:
2
- - example_config
3
- - override method: weighted_average_for_llama
4
- - override modelpool: llama_for_causallm
5
- - _self_
6
- modelpool:
7
- models:
8
- # the pre-trained model (base model) is optional
9
- # if not provided, the first model will be used as the base model
10
- - name: _pretrained_
11
- path: meta-llama/Meta-Llama-3-8B
12
- - name: expert_1
13
- path: meta-llama/Meta-Llama-3-8B
14
- - name: expert_2
15
- path: meta-llama/Meta-Llama-3-8B-Instruct
16
- method:
17
- normalize: true # if true, the weights will be normalized before merging
18
- weights: # List of weights for each model
19
- - 0.5
20
- - 0.5
21
- # if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
22
- # if false, the whole model will be merged
23
- backbone_only: true
24
- merged_model_save_path: null
25
- save_tokenizer: true
26
- push_to_hub: false