fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -60,4 +60,6 @@ class NYUv2TaskPool(TaskPool):
60
60
  num_workers=self.config.num_workers,
61
61
  )
62
62
  report = self.trainer.validate(model, val_loader)
63
+ if isinstance(report, list) and len(report) == 1:
64
+ report = report[0]
63
65
  return report
@@ -58,11 +58,24 @@ class CLIPTemplateFactory:
58
58
  "templates": "templates",
59
59
  },
60
60
  "nateraw/rendered-sst2": ".rendered_sst2",
61
+ "rendered-sst2": ".rendered_sst2",
61
62
  "tanganke/stl10": ".stl10",
63
+ "stl10": ".stl10",
62
64
  "dpdl-benchmark/oxford_flowers102": ".flower102",
65
+ "oxford_flowers102": ".flower102",
63
66
  "timm/oxford-iiit-pet": ".oxford_iiit_pet",
67
+ "oxford-iiit-pet": ".oxford_iiit_pet",
64
68
  "imagenet": ".imagenet",
65
69
  "tiny-imagenet": ".tiny_imagenet",
70
+ "pcam": ".pcam",
71
+ "fer2013": ".fer2013",
72
+ "emnist_mnist": ".emnist_mnist",
73
+ "emnist_letters": ".emnist_letters",
74
+ "kmnist": ".kmnist",
75
+ "food101": ".food101",
76
+ "fashion_mnist": ".fashion_mnist",
77
+ "cub-200-2011": ".cub_200_2011",
78
+ "mango-leaf-disease": ".mango_leaf_disease",
66
79
  }
67
80
 
68
81
  @staticmethod
@@ -168,48 +181,3 @@ class CLIPTemplateFactory:
168
181
 
169
182
  def get_classnames_and_templates(dataset_name: str):
170
183
  return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
171
-
172
-
173
- def _load_hf_dataset(dataset_name: str):
174
- """
175
- Load a dataset from the Hugging Face datasets library based on the specified dataset name.
176
-
177
- This function handles specific preprocessing steps for certain datasets to ensure consistency in dataset format.
178
- For example, it renames columns, removes unnecessary columns, and specifies subsets for certain datasets.
179
-
180
- Expected dataset format:
181
- - The dataset should have an "image" column containing the image data.
182
- - The dataset should have a "label" column containing the class labels.
183
-
184
- Args:
185
- dataset_name (str): The name of the dataset to load. Can be one of "svhn", "cifar10", "cifar100", "timm/oxford-iiit-pet", or any other dataset name supported by the Hugging Face datasets library. By default, the datasets have two columns: "image" and "label".
186
-
187
- Returns:
188
- A dataset object loaded from the Hugging Face datasets library, with any necessary preprocessing applied.
189
- """
190
- if dataset_name == "svhn":
191
- return load_dataset(dataset_name, "cropped_digits")
192
- elif dataset_name == "cifar10":
193
- dataset = load_dataset(dataset_name)
194
- dataset = dataset.rename_columns({"img": "image"})
195
- return dataset
196
- elif dataset_name == "cifar100":
197
- dataset = load_dataset(dataset_name)
198
- dataset = dataset.remove_columns(["coarse_label"]).rename_columns(
199
- {"img": "image", "fine_label": "label"}
200
- )
201
- return dataset
202
- elif dataset_name == "timm/oxford-iiit-pet":
203
- dataset = load_dataset(dataset_name)
204
- dataset = dataset.remove_columns(["image_id", "label_cat_dog"])
205
- return dataset
206
- else:
207
- return load_dataset(dataset_name)
208
-
209
-
210
- def load_clip_dataset(dataset: str, processor):
211
- hf_dataset = _load_hf_dataset(dataset)
212
- return (
213
- CLIPDataset(hf_dataset["train"], processor),
214
- CLIPDataset(hf_dataset["test"], processor),
215
- )
@@ -1,16 +1 @@
1
- import torch
2
-
3
-
4
- class CLIPDataset(torch.utils.data.Dataset):
5
- def __init__(self, dataset, processor):
6
- self.dataset = dataset
7
- self.processor = processor
8
-
9
- def __len__(self):
10
- return len(self.dataset)
11
-
12
- def __getitem__(self, idx):
13
- item = self.dataset[idx]
14
- image = item["image"]
15
- inputs = self.processor(images=[image], return_tensors="pt")["pixel_values"][0]
16
- return inputs, item["label"]
1
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
@@ -0,0 +1,208 @@
1
+ classname_mapping = {
2
+ "0": "Black_footed_Albatross",
3
+ "1": "Laysan_Albatross",
4
+ "2": "Sooty_Albatross",
5
+ "3": "Groove_billed_Ani",
6
+ "4": "Crested_Auklet",
7
+ "5": "Least_Auklet",
8
+ "6": "Parakeet_Auklet",
9
+ "7": "Rhinoceros_Auklet",
10
+ "8": "Brewer_Blackbird",
11
+ "9": "Red_winged_Blackbird",
12
+ "10": "Rusty_Blackbird",
13
+ "11": "Yellow_headed_Blackbird",
14
+ "12": "Bobolink",
15
+ "13": "Indigo_Bunting",
16
+ "14": "Lazuli_Bunting",
17
+ "15": "Painted_Bunting",
18
+ "16": "Cardinal",
19
+ "17": "Spotted_Catbird",
20
+ "18": "Gray_Catbird",
21
+ "19": "Yellow_breasted_Chat",
22
+ "20": "Eastern_Towhee",
23
+ "21": "Chuck_will_Widow",
24
+ "22": "Brandt_Cormorant",
25
+ "23": "Red_faced_Cormorant",
26
+ "24": "Pelagic_Cormorant",
27
+ "25": "Bronzed_Cowbird",
28
+ "26": "Shiny_Cowbird",
29
+ "27": "Brown_Creeper",
30
+ "28": "American_Crow",
31
+ "29": "Fish_Crow",
32
+ "30": "Black_billed_Cuckoo",
33
+ "31": "Mangrove_Cuckoo",
34
+ "32": "Yellow_billed_Cuckoo",
35
+ "33": "Gray_crowned_Rosy_Finch",
36
+ "34": "Purple_Finch",
37
+ "35": "Northern_Flicker",
38
+ "36": "Acadian_Flycatcher",
39
+ "37": "Great_Crested_Flycatcher",
40
+ "38": "Least_Flycatcher",
41
+ "39": "Olive_sided_Flycatcher",
42
+ "40": "Scissor_tailed_Flycatcher",
43
+ "41": "Vermilion_Flycatcher",
44
+ "42": "Yellow_bellied_Flycatcher",
45
+ "43": "Frigatebird",
46
+ "44": "Northern_Fulmar",
47
+ "45": "Gadwall",
48
+ "46": "American_Goldfinch",
49
+ "47": "European_Goldfinch",
50
+ "48": "Boat_tailed_Grackle",
51
+ "49": "Eared_Grebe",
52
+ "50": "Horned_Grebe",
53
+ "51": "Pied_billed_Grebe",
54
+ "52": "Western_Grebe",
55
+ "53": "Blue_Grosbeak",
56
+ "54": "Evening_Grosbeak",
57
+ "55": "Pine_Grosbeak",
58
+ "56": "Rose_breasted_Grosbeak",
59
+ "57": "Pigeon_Guillemot",
60
+ "58": "California_Gull",
61
+ "59": "Glaucous_winged_Gull",
62
+ "60": "Heermann_Gull",
63
+ "61": "Herring_Gull",
64
+ "62": "Ivory_Gull",
65
+ "63": "Ring_billed_Gull",
66
+ "64": "Slaty_backed_Gull",
67
+ "65": "Western_Gull",
68
+ "66": "Anna_Hummingbird",
69
+ "67": "Ruby_throated_Hummingbird",
70
+ "68": "Rufous_Hummingbird",
71
+ "69": "Green_Violetear",
72
+ "70": "Long_tailed_Jaeger",
73
+ "71": "Pomarine_Jaeger",
74
+ "72": "Blue_Jay",
75
+ "73": "Florida_Jay",
76
+ "74": "Green_Jay",
77
+ "75": "Dark_eyed_Junco",
78
+ "76": "Tropical_Kingbird",
79
+ "77": "Gray_Kingbird",
80
+ "78": "Belted_Kingfisher",
81
+ "79": "Green_Kingfisher",
82
+ "80": "Pied_Kingfisher",
83
+ "81": "Ringed_Kingfisher",
84
+ "82": "White_breasted_Kingfisher",
85
+ "83": "Red_legged_Kittiwake",
86
+ "84": "Horned_Lark",
87
+ "85": "Pacific_Loon",
88
+ "86": "Mallard",
89
+ "87": "Western_Meadowlark",
90
+ "88": "Hooded_Merganser",
91
+ "89": "Red_breasted_Merganser",
92
+ "90": "Mockingbird",
93
+ "91": "Nighthawk",
94
+ "92": "Clark_Nutcracker",
95
+ "93": "White_breasted_Nuthatch",
96
+ "94": "Baltimore_Oriole",
97
+ "95": "Hooded_Oriole",
98
+ "96": "Orchard_Oriole",
99
+ "97": "Scott_Oriole",
100
+ "98": "Ovenbird",
101
+ "99": "Brown_Pelican",
102
+ "100": "White_Pelican",
103
+ "101": "Western_Wood_Pewee",
104
+ "102": "Sayornis",
105
+ "103": "American_Pipit",
106
+ "104": "Whip_poor_Will",
107
+ "105": "Horned_Puffin",
108
+ "106": "Common_Raven",
109
+ "107": "White_necked_Raven",
110
+ "108": "American_Redstart",
111
+ "109": "Geococcyx",
112
+ "110": "Loggerhead_Shrike",
113
+ "111": "Great_Grey_Shrike",
114
+ "112": "Baird_Sparrow",
115
+ "113": "Black_throated_Sparrow",
116
+ "114": "Brewer_Sparrow",
117
+ "115": "Chipping_Sparrow",
118
+ "116": "Clay_colored_Sparrow",
119
+ "117": "House_Sparrow",
120
+ "118": "Field_Sparrow",
121
+ "119": "Fox_Sparrow",
122
+ "120": "Grasshopper_Sparrow",
123
+ "121": "Harris_Sparrow",
124
+ "122": "Henslow_Sparrow",
125
+ "123": "Le_Conte_Sparrow",
126
+ "124": "Lincoln_Sparrow",
127
+ "125": "Nelson_Sharp_tailed_Sparrow",
128
+ "126": "Savannah_Sparrow",
129
+ "127": "Seaside_Sparrow",
130
+ "128": "Song_Sparrow",
131
+ "129": "Tree_Sparrow",
132
+ "130": "Vesper_Sparrow",
133
+ "131": "White_crowned_Sparrow",
134
+ "132": "White_throated_Sparrow",
135
+ "133": "Cape_Glossy_Starling",
136
+ "134": "Bank_Swallow",
137
+ "135": "Barn_Swallow",
138
+ "136": "Cliff_Swallow",
139
+ "137": "Tree_Swallow",
140
+ "138": "Scarlet_Tanager",
141
+ "139": "Summer_Tanager",
142
+ "140": "Artic_Tern",
143
+ "141": "Black_Tern",
144
+ "142": "Caspian_Tern",
145
+ "143": "Common_Tern",
146
+ "144": "Elegant_Tern",
147
+ "145": "Forsters_Tern",
148
+ "146": "Least_Tern",
149
+ "147": "Green_tailed_Towhee",
150
+ "148": "Brown_Thrasher",
151
+ "149": "Sage_Thrasher",
152
+ "150": "Black_capped_Vireo",
153
+ "151": "Blue_headed_Vireo",
154
+ "152": "Philadelphia_Vireo",
155
+ "153": "Red_eyed_Vireo",
156
+ "154": "Warbling_Vireo",
157
+ "155": "White_eyed_Vireo",
158
+ "156": "Yellow_throated_Vireo",
159
+ "157": "Bay_breasted_Warbler",
160
+ "158": "Black_and_white_Warbler",
161
+ "159": "Black_throated_Blue_Warbler",
162
+ "160": "Blue_winged_Warbler",
163
+ "161": "Canada_Warbler",
164
+ "162": "Cape_May_Warbler",
165
+ "163": "Cerulean_Warbler",
166
+ "164": "Chestnut_sided_Warbler",
167
+ "165": "Golden_winged_Warbler",
168
+ "166": "Hooded_Warbler",
169
+ "167": "Kentucky_Warbler",
170
+ "168": "Magnolia_Warbler",
171
+ "169": "Mourning_Warbler",
172
+ "170": "Myrtle_Warbler",
173
+ "171": "Nashville_Warbler",
174
+ "172": "Orange_crowned_Warbler",
175
+ "173": "Palm_Warbler",
176
+ "174": "Pine_Warbler",
177
+ "175": "Prairie_Warbler",
178
+ "176": "Prothonotary_Warbler",
179
+ "177": "Swainson_Warbler",
180
+ "178": "Tennessee_Warbler",
181
+ "179": "Wilson_Warbler",
182
+ "180": "Worm_eating_Warbler",
183
+ "181": "Yellow_Warbler",
184
+ "182": "Northern_Waterthrush",
185
+ "183": "Louisiana_Waterthrush",
186
+ "184": "Bohemian_Waxwing",
187
+ "185": "Cedar_Waxwing",
188
+ "186": "American_Three_toed_Woodpecker",
189
+ "187": "Pileated_Woodpecker",
190
+ "188": "Red_bellied_Woodpecker",
191
+ "189": "Red_cockaded_Woodpecker",
192
+ "190": "Red_headed_Woodpecker",
193
+ "191": "Downy_Woodpecker",
194
+ "192": "Bewick_Wren",
195
+ "193": "Cactus_Wren",
196
+ "194": "Carolina_Wren",
197
+ "195": "House_Wren",
198
+ "196": "Marsh_Wren",
199
+ "197": "Rock_Wren",
200
+ "198": "Winter_Wren",
201
+ "199": "Common_Yellowthroat",
202
+ }
203
+
204
+ classnames = [classname_mapping[str(i)] for i in range(200)]
205
+ templates = [
206
+ lambda c: f"a photo of a {c}.",
207
+ lambda c: f"a photo of the {c}.",
208
+ ]
@@ -0,0 +1,31 @@
1
+ classnames_mapping = {
2
+ "0": "A",
3
+ "1": "B",
4
+ "2": "C",
5
+ "3": "D",
6
+ "4": "E",
7
+ "5": "F",
8
+ "6": "G",
9
+ "7": "H",
10
+ "8": "I",
11
+ "9": "J",
12
+ "10": "K",
13
+ "11": "L",
14
+ "12": "M",
15
+ "13": "N",
16
+ "14": "O",
17
+ "15": "P",
18
+ "16": "Q",
19
+ "17": "R",
20
+ "18": "S",
21
+ "19": "T",
22
+ "20": "U",
23
+ "21": "V",
24
+ "22": "W",
25
+ "23": "X",
26
+ "24": "Y",
27
+ "25": "Z",
28
+ }
29
+
30
+ classnames = [classnames_mapping[str(i)] for i in range(26)]
31
+ templates = [lambda c: f'a photo of the digit character: "{c}".']
@@ -0,0 +1,5 @@
1
+ # https://huggingface.co/datasets/tanganke/emnist_mnist
2
+ classnames = [str(i) for i in range(10)]
3
+ templates = [
4
+ lambda c: f'a photo of the number: "{c}".',
5
+ ]
@@ -0,0 +1,18 @@
1
+ classname_mapping = {
2
+ "0": "T - shirt / top",
3
+ "1": "Trouser",
4
+ "2": "Pullover",
5
+ "3": "Dress",
6
+ "4": "Coat",
7
+ "5": "Sandal",
8
+ "6": "Shirt",
9
+ "7": "Sneaker",
10
+ "8": "Bag",
11
+ "9": "Ankle boot",
12
+ }
13
+ classnames = [classname_mapping[str(i)] for i in range(10)]
14
+
15
+ templates = [
16
+ lambda c: f"a photo of a {c}.",
17
+ lambda c: f"a photo of the {c}.",
18
+ ]
@@ -0,0 +1,18 @@
1
+ classnames = [
2
+ "angry",
3
+ "disgusted",
4
+ "fearful",
5
+ "happy",
6
+ "neutral",
7
+ "sad",
8
+ "surprised",
9
+ ]
10
+
11
+ templates = [
12
+ lambda c: f"a photo of a {c} looking face.",
13
+ lambda c: f"a photo of a face showing the emotion: {c}.",
14
+ lambda c: f"a photo of a face looking {c}.",
15
+ lambda c: f"a face that looks {c}.",
16
+ lambda c: f"they look {c}.",
17
+ lambda c: f"look at how {c} they are.",
18
+ ]
@@ -0,0 +1,105 @@
1
+ classnames = [
2
+ "apple pie",
3
+ "baby back ribs",
4
+ "baklava",
5
+ "beef carpaccio",
6
+ "beef tartare",
7
+ "beet salad",
8
+ "beignets",
9
+ "bibimbap",
10
+ "bread pudding",
11
+ "breakfast burrito",
12
+ "bruschetta",
13
+ "caesar salad",
14
+ "cannoli",
15
+ "caprese salad",
16
+ "carrot cake",
17
+ "ceviche",
18
+ "cheese plate",
19
+ "cheesecake",
20
+ "chicken curry",
21
+ "chicken quesadilla",
22
+ "chicken wings",
23
+ "chocolate cake",
24
+ "chocolate mousse",
25
+ "churros",
26
+ "clam chowder",
27
+ "club sandwich",
28
+ "crab cakes",
29
+ "creme brulee",
30
+ "croque madame",
31
+ "cup cakes",
32
+ "deviled eggs",
33
+ "donuts",
34
+ "dumplings",
35
+ "edamame",
36
+ "eggs benedict",
37
+ "escargots",
38
+ "falafel",
39
+ "filet mignon",
40
+ "fish and chips",
41
+ "foie gras",
42
+ "french fries",
43
+ "french onion soup",
44
+ "french toast",
45
+ "fried calamari",
46
+ "fried rice",
47
+ "frozen yogurt",
48
+ "garlic bread",
49
+ "gnocchi",
50
+ "greek salad",
51
+ "grilled cheese sandwich",
52
+ "grilled salmon",
53
+ "guacamole",
54
+ "gyoza",
55
+ "hamburger",
56
+ "hot and sour soup",
57
+ "hot dog",
58
+ "huevos rancheros",
59
+ "hummus",
60
+ "ice cream",
61
+ "lasagna",
62
+ "lobster bisque",
63
+ "lobster roll sandwich",
64
+ "macaroni and cheese",
65
+ "macarons",
66
+ "miso soup",
67
+ "mussels",
68
+ "nachos",
69
+ "omelette",
70
+ "onion rings",
71
+ "oysters",
72
+ "pad thai",
73
+ "paella",
74
+ "pancakes",
75
+ "panna cotta",
76
+ "peking duck",
77
+ "pho",
78
+ "pizza",
79
+ "pork chop",
80
+ "poutine",
81
+ "prime rib",
82
+ "pulled pork sandwich",
83
+ "ramen",
84
+ "ravioli",
85
+ "red velvet cake",
86
+ "risotto",
87
+ "samosa",
88
+ "sashimi",
89
+ "scallops",
90
+ "seaweed salad",
91
+ "shrimp and grits",
92
+ "spaghetti bolognese",
93
+ "spaghetti carbonara",
94
+ "spring rolls",
95
+ "steak",
96
+ "strawberry shortcake",
97
+ "sushi",
98
+ "tacos",
99
+ "takoyaki",
100
+ "tiramisu",
101
+ "tuna tartare",
102
+ "waffles",
103
+ ]
104
+
105
+ templates = [lambda c: f"a photo of {c}, a type of food."]
@@ -0,0 +1,17 @@
1
+ classnames_mapping = {
2
+ "0": "お",
3
+ "1": "き",
4
+ "2": "す",
5
+ "3": "つ",
6
+ "4": "な",
7
+ "5": "は",
8
+ "6": "ま",
9
+ "7": "や",
10
+ "8": "れ",
11
+ "9": "を",
12
+ }
13
+ classnames = [classnames_mapping[str(c)] for c in range(10)]
14
+
15
+ templates = [
16
+ lambda c: f"a photo of the character {c}.",
17
+ ]
@@ -0,0 +1,19 @@
1
+ classnames = [
2
+ "Anthracnose",
3
+ "Bacterial Canker",
4
+ "Cutting Weevil",
5
+ "Die Back",
6
+ "Gall Midge",
7
+ "Healthy",
8
+ "Powdery Mildew",
9
+ "Sooty Mould",
10
+ ]
11
+
12
+ templates = [
13
+ lambda c: f"a photo of a mango leaf with {c}.",
14
+ lambda c: f"a mango leaf showing symptoms of {c}.",
15
+ lambda c: f"a close-up photo of {c} on a mango leaf.",
16
+ lambda c: f"this mango leaf is affected by {c}.",
17
+ lambda c: f"a mango leaf disease identified as {c}.",
18
+ lambda c: f"a {c} infection on a mango leaf.",
19
+ ]
@@ -0,0 +1,5 @@
1
+ classnames = ["lymph node", "lymph node containing metastatic tumor tissue"]
2
+
3
+ templates = [
4
+ lambda c: f"this is a photo of {c}",
5
+ ]
@@ -4,3 +4,25 @@ import hydra.core.hydra_config
4
4
  def get_hydra_output_dir():
5
5
  hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
6
6
  return hydra_cfg.runtime.output_dir
7
+
8
+
9
+ def config_priority_get(priority_config, general_config, key, default):
10
+ """
11
+ Retrieve a configuration value with priority.
12
+
13
+ This function retrieves the value associated with `key` from `priority_config` if it exists.
14
+ If the key is not found in `priority_config`, it retrieves the value from `general_config`.
15
+ If the key is not found in either configuration, it returns the provided `default` value.
16
+
17
+ Args:
18
+ priority_config (dict): The configuration dictionary with higher priority.
19
+ general_config (dict): The general configuration dictionary.
20
+ key (str): The key to look up in the configuration dictionaries.
21
+ default: The default value to return if the key is not found in either configuration.
22
+
23
+ Returns:
24
+ The value associated with `key` from `priority_config` or `general_config`, or the `default` value if the key is not found.
25
+ """
26
+ if key in priority_config:
27
+ return priority_config[key]
28
+ return general_config.get(key, default)
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  from collections import OrderedDict
3
- from typing import List, Mapping, Union
3
+ from typing import List, Mapping, Optional, Union
4
4
 
5
5
  import torch
6
6
  from torch import nn
@@ -43,7 +43,10 @@ def trainable_state_dict(
43
43
  return state_dict
44
44
 
45
45
 
46
- def state_dict_to_vector(state_dict, remove_keys=[]):
46
+ def state_dict_to_vector(
47
+ state_dict: StateDictType,
48
+ remove_keys: Optional[List[str]] = None,
49
+ ):
47
50
  """
48
51
  Convert a state dictionary to a vector.
49
52
 
@@ -54,6 +57,7 @@ def state_dict_to_vector(state_dict, remove_keys=[]):
54
57
  Returns:
55
58
  torch.Tensor: The converted vector.
56
59
  """
60
+ remove_keys = remove_keys if remove_keys is not None else []
57
61
  shared_state_dict = copy.deepcopy(state_dict)
58
62
  for key in remove_keys:
59
63
  if key in shared_state_dict:
@@ -64,7 +68,11 @@ def state_dict_to_vector(state_dict, remove_keys=[]):
64
68
  )
65
69
 
66
70
 
67
- def vector_to_state_dict(vector, state_dict, remove_keys=[]):
71
+ def vector_to_state_dict(
72
+ vector: torch.Tensor,
73
+ state_dict: StateDictType,
74
+ remove_keys: Optional[List[str]] = None,
75
+ ):
68
76
  """
69
77
  Convert a vector to a state dictionary.
70
78
 
@@ -76,6 +84,7 @@ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
76
84
  Returns:
77
85
  dict: The converted state dictionary.
78
86
  """
87
+ remove_keys = remove_keys if remove_keys is not None else []
79
88
  # create a reference dict to define the order of the vector
80
89
  reference_dict = copy.deepcopy(state_dict)
81
90
  for key in remove_keys:
File without changes
@@ -0,0 +1,52 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import seaborn as sns
4
+
5
+
6
+ def visualize_model_inputs(input_ids, attention_mask, labels, tokenizer=None):
7
+ """
8
+ Visualize model inputs: attention mask, labels and input_ids
9
+
10
+ Parameters:
11
+ -----------
12
+ attention_mask: numpy array or tensor
13
+ The attention mask array
14
+ labels: numpy array or tensor
15
+ The labels array
16
+ input_ids: numpy array or tensor
17
+ The input ids array
18
+ tokenizer: optional
19
+ The tokenizer object to decode input_ids
20
+ """
21
+
22
+ # Convert inputs to numpy if they're tensors
23
+ attention_mask = np.array(attention_mask)
24
+ labels = np.array(labels)
25
+ input_ids = np.array(input_ids)
26
+
27
+ # Create figure with 3 subplots
28
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 10))
29
+
30
+ # Plot attention mask
31
+ sns.heatmap(attention_mask.reshape(1, -1), ax=ax1, cmap="Blues", cbar=True)
32
+ ax1.set_title("**Attention Mask**")
33
+ ax1.set_ylabel("Sequence")
34
+
35
+ # Plot labels
36
+ sns.heatmap(labels.reshape(1, -1), ax=ax2, cmap="Reds", cbar=True)
37
+ ax2.set_title("**Labels**")
38
+ ax2.set_ylabel("Sequence")
39
+
40
+ # Plot input_ids
41
+ sns.heatmap(input_ids.reshape(1, -1), ax=ax3, cmap="Greens", cbar=True)
42
+ ax3.set_title("**Input IDs**")
43
+ ax3.set_ylabel("Sequence")
44
+
45
+ # If tokenizer is provided, add decoded tokens as x-axis labels
46
+ if tokenizer:
47
+ decoded_tokens = [tokenizer.decode(token_id) for token_id in input_ids]
48
+ ax3.set_xticks(np.arange(len(decoded_tokens)) + 0.5)
49
+ ax3.set_xticklabels(decoded_tokens, rotation=45, ha="right")
50
+
51
+ plt.tight_layout()
52
+ return fig