fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ _target_: fusion_bench.programs.FabricModelFusionProgram
11
11
  _recursive_: false
12
12
  fast_dev_run: false # Run a single batch of data to test the model or method
13
13
  # Run the script without actually running the experiment, use with `print_config=true`.
14
- # You can also use `--cfg` or `-c` to show the configuration instead of runing.
14
+ # You can also use `--cfg` or `-c` to show the configuration instead of running.
15
15
  dry_run: false
16
16
  print_config: true # Print the configuration to the console
17
17
  merged_model_save_path: null # path to save the merged model, use "{log_dir}" to refer to the logger directory, for example `merged_model_save_path=\{log_dir\}/merged_model`
@@ -0,0 +1,6 @@
1
+ # Image Classification Dataset Configurations
2
+
3
+ This folder contains the dataset configuration for image classification tasks.
4
+
5
+ - Each dataset should have 'image' and 'label' columns.
6
+ - If a dataset has no test split, we will use the validation split as the test split and create the validation set from the training set.
@@ -0,0 +1,20 @@
1
+ # The 14 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # eight tasks in the task arithmetic paper
6
+ - sun397
7
+ - stanford-cars
8
+ - resisc45
9
+ - eurosat
10
+ - svhn
11
+ - gtsrb
12
+ - mnist
13
+ - dtd
14
+ # additional 6 tasks in the TALL mask paper
15
+ - oxford_flowers102
16
+ - pcam
17
+ - fer2013
18
+ - oxford-iiit-pet
19
+ - stl10
20
+ - cifar100
@@ -0,0 +1,28 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # eight tasks in the task arithmetic paper
6
+ - sun397
7
+ - stanford-cars
8
+ - resisc45
9
+ - eurosat
10
+ - svhn
11
+ - gtsrb
12
+ - mnist
13
+ - dtd
14
+ # additional 6 tasks in the TALL mask paper (TALL 14)
15
+ - oxford_flowers102
16
+ - pcam
17
+ - fer2013
18
+ - oxford-iiit-pet
19
+ - stl10
20
+ - cifar100
21
+ # additional 6 tasks in the TALL mask paper (TALL 20)
22
+ - cifar10
23
+ - food101
24
+ - fashion_mnist
25
+ - emnist_letters
26
+ - kmnist
27
+ - rendered-sst2
28
+
@@ -1,4 +1,4 @@
1
- dtd:
1
+ cifar10:
2
2
  _target_: datasets.load_dataset
3
3
  path: tanganke/cifar10
4
4
  split: test
@@ -1,4 +1,4 @@
1
- dtd:
1
+ cifar100:
2
2
  _target_: datasets.load_dataset
3
3
  path: tanganke/cifar100
4
4
  split: test
@@ -0,0 +1,4 @@
1
+ cub-200-2011:
2
+ _target_: datasets.load_dataset
3
+ path: Donghyun99/CUB-200-2011
4
+ split: test
@@ -0,0 +1,5 @@
1
+ emnist_letters:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/emnist_letters
4
+ split: test
5
+
@@ -0,0 +1,4 @@
1
+ emnist_mnist:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/emnist_mnist
4
+ split: test
@@ -0,0 +1,4 @@
1
+ fashion_mnist:
2
+ _target_: datasets.load_dataset
3
+ path: zalando-datasets/fashion_mnist
4
+ split: test
@@ -0,0 +1,3 @@
1
+ fer2013:
2
+ _target_: fusion_bench.dataset.fer2013.load_fer2013
3
+ split: test
@@ -0,0 +1,4 @@
1
+ food101:
2
+ _target_: datasets.load_dataset
3
+ path: ethz/food101
4
+ split: validation
@@ -0,0 +1,4 @@
1
+ kmnist:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/kmnist
4
+ split: test
@@ -0,0 +1,4 @@
1
+ mango-leaf-disease:
2
+ _target_: datasets.load_dataset
3
+ path: AfiqN/mango-leaf-disease
4
+ split: test
@@ -0,0 +1,4 @@
1
+ oxford-iiit-pet:
2
+ _target_: datasets.load_dataset
3
+ path: timm/oxford-iiit-pet
4
+ split: test
@@ -0,0 +1,4 @@
1
+ oxford_flowers102:
2
+ _target_: datasets.load_dataset
3
+ path: dpdl-benchmark/oxford_flowers102
4
+ split: test
@@ -0,0 +1,4 @@
1
+ pcam:
2
+ _target_: datasets.load_dataset
3
+ path: 1aurent/PatchCamelyon
4
+ split: test
@@ -0,0 +1,4 @@
1
+ rendered-sst2:
2
+ _target_: datasets.load_dataset
3
+ path: nateraw/rendered-sst2
4
+ split: test
@@ -0,0 +1,4 @@
1
+ stl10:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/stl10
4
+ split: test
@@ -0,0 +1,20 @@
1
+ # The 14 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # eight tasks in the task arithmetic paper
6
+ - sun397
7
+ - stanford-cars
8
+ - resisc45
9
+ - eurosat
10
+ - svhn
11
+ - gtsrb
12
+ - mnist
13
+ - dtd
14
+ # additional 6 tasks in the TALL mask paper
15
+ - oxford_flowers102
16
+ - pcam
17
+ - fer2013
18
+ - oxford-iiit-pet
19
+ - stl10
20
+ - cifar100
@@ -0,0 +1,28 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # eight tasks in the task arithmetic paper
6
+ - sun397
7
+ - stanford-cars
8
+ - resisc45
9
+ - eurosat
10
+ - svhn
11
+ - gtsrb
12
+ - mnist
13
+ - dtd
14
+ # additional 6 tasks in the TALL mask paper (TALL 14)
15
+ - oxford_flowers102
16
+ - pcam
17
+ - fer2013
18
+ - oxford-iiit-pet
19
+ - stl10
20
+ - cifar100
21
+ # additional 6 tasks in the TALL mask paper (TALL 20)
22
+ - cifar10
23
+ - food101
24
+ - fashion_mnist
25
+ - emnist_letters
26
+ - kmnist
27
+ - rendered-sst2
28
+
@@ -1,4 +1,4 @@
1
- dtd:
1
+ cifar10:
2
2
  _target_: datasets.load_dataset
3
3
  path: tanganke/cifar10
4
4
  split: train
@@ -1,4 +1,4 @@
1
- dtd:
1
+ cifar100:
2
2
  _target_: datasets.load_dataset
3
3
  path: tanganke/cifar100
4
4
  split: train
@@ -0,0 +1,4 @@
1
+ cub-200-2011:
2
+ _target_: datasets.load_dataset
3
+ path: Donghyun99/CUB-200-2011
4
+ split: train
@@ -0,0 +1,4 @@
1
+ emnist_letters:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/emnist_letters
4
+ split: train
@@ -0,0 +1,4 @@
1
+ emnist_mnist:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/emnist_mnist
4
+ split: train
@@ -0,0 +1,4 @@
1
+ fashion_mnist:
2
+ _target_: datasets.load_dataset
3
+ path: zalando-datasets/fashion_mnist
4
+ split: train
@@ -0,0 +1,3 @@
1
+ fer2013:
2
+ _target_: fusion_bench.dataset.fer2013.load_fer2013
3
+ split: train
@@ -0,0 +1,4 @@
1
+ food101:
2
+ _target_: datasets.load_dataset
3
+ path: ethz/food101
4
+ split: train
@@ -0,0 +1,4 @@
1
+ kmnist:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/kmnist
4
+ split: train
@@ -0,0 +1,4 @@
1
+ mango-leaf-disease:
2
+ _target_: datasets.load_dataset
3
+ path: AfiqN/mango-leaf-disease
4
+ split: train
@@ -0,0 +1,4 @@
1
+ oxford-iiit-pet:
2
+ _target_: datasets.load_dataset
3
+ path: timm/oxford-iiit-pet
4
+ split: train
@@ -0,0 +1,4 @@
1
+ oxford_flowers102:
2
+ _target_: datasets.load_dataset
3
+ path: dpdl-benchmark/oxford_flowers102
4
+ split: train
@@ -0,0 +1,4 @@
1
+ pcam:
2
+ _target_: datasets.load_dataset
3
+ path: 1aurent/PatchCamelyon
4
+ split: train
@@ -0,0 +1,4 @@
1
+ rendered-sst2:
2
+ _target_: datasets.load_dataset
3
+ path: nateraw/rendered-sst2
4
+ split: train
@@ -0,0 +1,4 @@
1
+ stl10:
2
+ _target_: datasets.load_dataset
3
+ path: tanganke/stl10
4
+ split: train
@@ -0,0 +1,6 @@
1
+ alpaca-cleaned:
2
+ _target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset
3
+ tokenizer: ???
4
+ path: "yahma/alpaca-cleaned"
5
+ split: train
6
+ cache_path: null
@@ -0,0 +1,3 @@
1
+ ultrachat-200k:
2
+ _target_: fusion_bench.dataset.ultrachat.load_tokenized_ultrachat_200k
3
+ tokenizer: ???
@@ -0,0 +1,16 @@
1
+ defaults:
2
+ - loggers: tensorboard_logger
3
+ - strategy: llama_peft_fsdp
4
+ - _self_
5
+
6
+ _target_: lightning.Fabric
7
+ _recursive_: true
8
+ # Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
9
+ # The value applies per node.
10
+ devices: auto
11
+ # The hardware to run on. Possible choices are:
12
+ # ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
13
+ # for example: fabric.accelerator=cpu
14
+ accelerator: auto
15
+ # reference to the precision policy: https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
16
+ precision: bf16-true
@@ -0,0 +1,2 @@
1
+ # https://lightning.ai/docs/fabric/2.4.0/guide/loggers/wandb.html#weights-and-biases
2
+ _target_: wandb.integration.lightning.fabric.WandbLogger
@@ -0,0 +1,10 @@
1
+ # https://lightning.ai/docs/fabric/2.4.0/api/generated/lightning.fabric.strategies.DeepSpeedStrategy.html#deepspeedstrategy
2
+ _target_: lightning.fabric.strategies.DeepSpeedStrategy
3
+
4
+ accelerator: null
5
+ zero_optimization: true
6
+ stage: 2
7
+ offload_optimizer: false
8
+ offload_parameters: false
9
+ offload_params_device: "cpu"
10
+ offload_optimizer_device: "cpu"
@@ -0,0 +1,9 @@
1
+ _target_: lightning.fabric.strategies.FSDPStrategy
2
+ sharding_strategy: FULL_SHARD
3
+ state_dict_type: full # Save a single, consolidated checkpoint file
4
+ cpu_offload: false
5
+ auto_wrap_policy:
6
+ _target_: fusion_bench.mixins.lightning_fabric.get_size_based_auto_wrap_policy
7
+ activation_checkpointing_policy: ${.auto_wrap_policy}
8
+ # limit_all_gathers: true
9
+
@@ -11,7 +11,7 @@ _target_: fusion_bench.programs.FabricModelFusionProgram
11
11
  _recursive_: false
12
12
  fast_dev_run: false # Run a single batch of data to test the model or method
13
13
  # Run the script without actually running the experiment, use with `print_config=true`.
14
- # You can also use `--cfg` or `-c` to show the configuration instead of runing.
14
+ # You can also use `--cfg` or `-c` to show the configuration instead of running.
15
15
  dry_run: false
16
16
  print_config: true # Print the configuration to the console
17
17
  merged_model_save_path: null # path to save the merged model, use "{log_dir}" to refer to the logger directory, for example `merged_model_save_path=\{log_dir\}/merged_model`
@@ -0,0 +1,19 @@
1
+ defaults:
2
+ - hydra: default
3
+ - fabric: llama_fsdp
4
+ # --- Model, Method, Task ---
5
+ - method: lm_finetune/fullfinetune_sft.yaml
6
+ - modelpool: CausalLMPool/llama_alpaca_cleaned.yaml
7
+ - taskpool: dummy
8
+ - _self_
9
+
10
+ _target_: fusion_bench.programs.FabricModelFusionProgram
11
+ _recursive_: false
12
+
13
+ fast_dev_run: false # Run a single batch of data to test the model or method
14
+ # Run the script without actually running the experiment, use with `print_config=true`.
15
+ # You can also use `--cfg` or `-c` to show the configuration instead of running.
16
+ dry_run: false
17
+ print_config: true # Print the configuration to the console
18
+ report_save_path: null # path to save the result report
19
+ print_function_call: true # set to false if you don't want to print the details of instantiate calls
@@ -0,0 +1,47 @@
1
+ _target_: fusion_bench.method.BradleyTerryRewardModeling
2
+ _recursive_: False
3
+
4
+ optimizer:
5
+ _target_: torch.optim.AdamW
6
+ lr: 1e-5
7
+ weight_decay: 0.01
8
+ fused: null
9
+
10
+ lr_scheduler:
11
+ _target_: fusion_bench.optim.lr_scheduler.CosineDecayWithWarmup
12
+ T_max: _T_max_ # this will be replaced by the expected number of training steps
13
+ init_lr: 0
14
+ warmup_steps: 100
15
+ max_lr: ${..optimizer.lr}
16
+ min_lr: 1e-6
17
+
18
+ dataloader_kwargs:
19
+ # per-gpu batch size
20
+ batch_size: 1
21
+ num_workers: 0
22
+ pin_memory: True
23
+
24
+ # Training hyperparameters
25
+ # if max_epochs=-1, max_steps will be used to determine the number of training steps
26
+ max_epochs: 3
27
+ max_steps: -1
28
+ max_steps_per_epoch: -1
29
+ accumulate_grad_batches: 1
30
+ lr_scheduler_interval: step
31
+ lr_scheduler_frequency: 1
32
+ # Checkpointing may be done by epoch or step, and at the end of training
33
+ # `checkpoint_save_interval` can be 'epoch' or 'step'
34
+ checkpoint_save_interval: epoch
35
+ checkpoint_save_frequency: 1
36
+ # Whether to use gradient clipping, and if so, the value and algorithm
37
+ gradient_clip_val: null
38
+ gradient_clip_algorithm: norm
39
+ save_optimizer_state: false
40
+ # save_full_model must be true when using shared FSDP
41
+ save_full_model: true
42
+ # save_ckpt_type can be 'hf' or 'lightning'
43
+ save_ckpt_type: lightning
44
+ # Path to checkpoint to load from, used for resuming training
45
+ ckpt_path: null
46
+ max_length: 4096
47
+ fix_token_embedding: true
@@ -3,14 +3,17 @@ _recursive_: False
3
3
 
4
4
  optimizer:
5
5
  _target_: torch.optim.AdamW
6
- fused: True
6
+ lr: 1e-5
7
7
  weight_decay: 0.01
8
- lr: 5e-5
8
+ fused: null
9
9
 
10
10
  lr_scheduler:
11
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
11
+ _target_: fusion_bench.optim.lr_scheduler.CosineDecayWithWarmup
12
12
  T_max: _T_max_ # this will be replaced by the expected number of training steps
13
- eta_min: 1e-6
13
+ init_lr: 0
14
+ warmup_steps: 100
15
+ max_lr: ${..optimizer.lr}
16
+ min_lr: 1e-6
14
17
 
15
18
  dataloader_kwargs:
16
19
  # per-gpu batch size
@@ -36,5 +39,9 @@ gradient_clip_algorithm: norm
36
39
  save_optimizer_state: false
37
40
  # save_full_model must be true when using shared FSDP
38
41
  save_full_model: true
42
+ # save_ckpt_type can be 'hf' or 'lightning'
43
+ save_ckpt_type: lightning
39
44
  # Path to checkpoint to load from, used for resuming training
40
45
  ckpt_path: null
46
+ max_length: 4096
47
+ fix_token_embedding: true
@@ -3,9 +3,9 @@ _recursive_: False
3
3
 
4
4
  optimizer:
5
5
  _target_: torch.optim.AdamW
6
- fused: True
6
+ lr: 1e-4
7
7
  weight_decay: 0.01
8
- lr: 5e-5
8
+ fused: null
9
9
 
10
10
  lr_scheduler:
11
11
  _target_: torch.optim.lr_scheduler.CosineAnnealingLR
@@ -56,6 +56,8 @@ gradient_clip_algorithm: norm
56
56
  save_optimizer_state: false
57
57
  # save_full_model must be true when using shared FSDP
58
58
  save_full_model: false
59
+ # save_ckpt_type can be 'peft' or 'lightning'
60
+ save_ckpt_type: lightning
59
61
  # Path to checkpoint to load from, used for resuming training
60
62
  ckpt_path: null
61
63
  max_length: 4096
@@ -0,0 +1,27 @@
1
+ # this option can be "clip_task_wise_adamerging"
2
+ name: clip_layer_wise_adamerging_surgery
3
+ # this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
4
+ # if weights is specified, skip the test-time adaptation training
5
+ weights: null
6
+ # learning rate
7
+ optimizer: adam
8
+ lr: 1e-3
9
+ init_values: 0.3
10
+ # if `clamp_weights` is true, the weights will be clamped to [0, 1]
11
+ clamp_weights: false
12
+ # arguments of `functional_call`
13
+ tie_weights: true
14
+ strict: false
15
+ # this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
16
+ devices: 1
17
+ batch_size: 16
18
+ num_workers: 8
19
+ max_steps: 1000
20
+ fast_dev_run: ${fast_dev_run}
21
+ # the path for saving the merging weights
22
+ save_merging_weights: 'merging_weights.pt'
23
+ cache_dir: outputs
24
+
25
+ # parameters of Surgery
26
+ eval_iterations: 200
27
+ surgery_steps: 1000
@@ -0,0 +1,2 @@
1
+ _target_: fusion_bench.method.TaskSingularVectorMerging
2
+ remove_keys: null
@@ -0,0 +1,38 @@
1
+ This folder contains the configuration for the CLIP-ViT models (managed by `fusion_bench.modelpool.CLIPVisionModelPool`).
2
+
3
+ ## Expected Configuration
4
+
5
+ ### Detailed Configuration
6
+
7
+
8
+ ```yaml
9
+ ${name_of_model}:
10
+ _target_: ${function_to_load_model}
11
+ ... # arguments to pass to the function
12
+ ```
13
+
14
+ For example, to load the pre-trained CLIP-ViT-B/16 model, you can use the following configuration:
15
+
16
+ ```yaml
17
+ _pretrained_: # `_pretrained_` is a special key in FusionBench that indicates the model is pre-trained
18
+ _target_: transformers.CLIPVisionModel.from_pretrained
19
+ pretrained_model_name_or_path: openai/clip-vit-base-patch16
20
+ ```
21
+
22
+ In this case, calling `modelpool.load_model("_pretrained_")` will return a `transformers.CLIPVisionModel` instance, which is equivalent to call `transformers.CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")`.
23
+
24
+ The detailed configuration is more flexible and can be used when you need to pass additional arguments to the `from_pretrained` function or call custom functions to load and preprocess the model.
25
+
26
+ ### Simplified Configuration
27
+
28
+ ```yaml
29
+ ${name_of_model}: ${pretrained_model_name_or_path}
30
+ ```
31
+
32
+ This is a simplified configuration that is equivalent to the detailed configuration.
33
+
34
+ For example, to load the pre-trained CLIP-ViT-B/16 model, you can use the following configuration:
35
+
36
+ ```yaml
37
+ _pretrained_: openai/clip-vit-base-patch16
38
+ ```
@@ -1,3 +1 @@
1
- _pretrained_:
2
- _target_: transformers.CLIPVisionModel.from_pretrained
3
- pretrained_model_name_or_path: openai/clip-vit-base-patch16
1
+ _pretrained_: openai/clip-vit-base-patch16
@@ -0,0 +1,22 @@
1
+ # The 14 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # pre-trained model
6
+ - clip-vit-base-patch16
7
+ # eight tasks in the task arithmetic paper
8
+ - clip-vit-base-patch16_sun397
9
+ - clip-vit-base-patch16_stanford-cars
10
+ - clip-vit-base-patch16_resisc45
11
+ - clip-vit-base-patch16_eurosat
12
+ - clip-vit-base-patch16_svhn
13
+ - clip-vit-base-patch16_gtsrb
14
+ - clip-vit-base-patch16_mnist
15
+ - clip-vit-base-patch16_dtd
16
+ # additional 6 tasks in the TALL mask paper
17
+ - clip-vit-base-patch16_oxford_flowers102
18
+ - clip-vit-base-patch16_pcam
19
+ - clip-vit-base-patch16_fer2013
20
+ - clip-vit-base-patch16_oxford-iiit-pet
21
+ - clip-vit-base-patch16_stl10
22
+ - clip-vit-base-patch16_cifar100
@@ -0,0 +1,29 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # pre-trained model
6
+ - clip-vit-base-patch16
7
+ # eight tasks in the task arithmetic paper
8
+ - clip-vit-base-patch16_sun397
9
+ - clip-vit-base-patch16_stanford-cars
10
+ - clip-vit-base-patch16_resisc45
11
+ - clip-vit-base-patch16_eurosat
12
+ - clip-vit-base-patch16_svhn
13
+ - clip-vit-base-patch16_gtsrb
14
+ - clip-vit-base-patch16_mnist
15
+ - clip-vit-base-patch16_dtd
16
+ # additional 6 tasks in the TALL mask paper (TALL 14)
17
+ - clip-vit-base-patch16_oxford_flowers102
18
+ - clip-vit-base-patch16_pcam
19
+ - clip-vit-base-patch16_fer2013
20
+ - clip-vit-base-patch16_oxford-iiit-pet
21
+ - clip-vit-base-patch16_stl10
22
+ - clip-vit-base-patch16_cifar100
23
+ # additional 6 tasks in the TALL mask paper (TALL 20)
24
+ - clip-vit-base-patch16_cifar10
25
+ - clip-vit-base-patch16_food101
26
+ - clip-vit-base-patch16_fashion_mnist
27
+ - clip-vit-base-patch16_emnist_letters
28
+ - clip-vit-base-patch16_kmnist
29
+ - clip-vit-base-patch16_rendered-sst2
@@ -0,0 +1 @@
1
+ cifar10: tanganke/clip-vit-base-patch16_cifar10
@@ -0,0 +1 @@
1
+ cifar100: tanganke/clip-vit-base-patch16_cifar100
@@ -1,3 +1 @@
1
- dtd:
2
- _target_: transformers.CLIPVisionModel.from_pretrained
3
- pretrained_model_name_or_path: tanganke/clip-vit-base-patch16_dtd
1
+ dtd: tanganke/clip-vit-base-patch16_dtd
@@ -0,0 +1 @@
1
+ emnist_letters: tanganke/clip-vit-base-patch16_emnist_letters
@@ -1,3 +1 @@
1
- eurosat:
2
- _target_: transformers.CLIPVisionModel.from_pretrained
3
- pretrained_model_name_or_path: tanganke/clip-vit-base-patch16_eurosat
1
+ eurosat: tanganke/clip-vit-base-patch16_eurosat