fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
1
+ class NoSparseGradientError(Exception):
2
+ """Raised when the gradient is sparse gradient.
3
+
4
+ :param optimizer_name: str. optimizer name.
5
+ :param note: str. special conditions to note (default '').
6
+ """
7
+
8
+ def __init__(self, optimizer_name: str, note: str = ""):
9
+ self.note: str = " " if not note else f" w/ {note} "
10
+ self.message: str = (
11
+ f"[-] {optimizer_name}{self.note}does not support sparse gradient."
12
+ )
13
+ super().__init__(self.message)
14
+
15
+
16
+ class ZeroParameterSizeError(Exception):
17
+ """Raised when the parameter size is 0."""
18
+
19
+ def __init__(self):
20
+ self.message: str = "[-] parameter size is 0"
21
+ super().__init__(self.message)
22
+
23
+
24
+ class NoClosureError(Exception):
25
+ """Raised when there's no closure function."""
26
+
27
+ def __init__(self, optimizer_name: str, note: str = ""):
28
+ self.message: str = f"[-] {optimizer_name} requires closure.{note}"
29
+ super().__init__(self.message)
30
+
31
+
32
+ class NegativeLRError(Exception):
33
+ """Raised when learning rate is negative."""
34
+
35
+ def __init__(self, lr: float, lr_type: str = ""):
36
+ self.note: str = lr_type if lr_type else "learning rate"
37
+ self.message: str = f"[-] {self.note} must be positive. ({lr} > 0)"
38
+ super().__init__(self.message)
39
+
40
+
41
+ class NegativeStepError(Exception):
42
+ """Raised when step is negative."""
43
+
44
+ def __init__(self, num_steps: int, step_type: str = ""):
45
+ self.note: str = step_type if step_type else "step"
46
+ self.message: str = f"[-] {self.note} must be positive. ({num_steps} > 0)"
47
+ super().__init__(self.message)
@@ -0,0 +1 @@
1
+ from .linear_warmup import *
@@ -0,0 +1,222 @@
1
+ """
2
+ Modified from pytorch_optimizer: https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/lr_scheduler/linear_warmup.py
3
+ """
4
+
5
+ import math
6
+ from abc import ABC, abstractmethod
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from fusion_bench.optim.exception import NegativeLRError, NegativeStepError
13
+
14
+ __all__ = [
15
+ "BaseLinearWarmupScheduler",
16
+ "LinearWarmupScheduler",
17
+ "CosineDecayWithWarmup",
18
+ "PolySchedulerWithWarmup",
19
+ ]
20
+
21
+
22
+ class BaseLinearWarmupScheduler(ABC):
23
+ r"""BaseLinearWarmupScheduler class.
24
+
25
+ The LR Scheduler class based on this class has linear warmup strategy.
26
+
27
+ Args:
28
+ optimizer (torch.optim.Optimizer): Optimizer. It will set learning rate to all trainable parameters in optimizer.
29
+ T_max (int): Total steps to train.
30
+ max_lr (float): Maximum learning rate.
31
+ min_lr (float): Minimum learning rate.
32
+ init_lr (float): Initial learning rate.
33
+ warmup_steps (int): Steps to warm-up.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ optimizer: torch.optim.Optimizer,
39
+ T_max: int,
40
+ max_lr: float,
41
+ min_lr: float = 0.0,
42
+ init_lr: float = 0.0,
43
+ warmup_steps: int = 0,
44
+ ):
45
+ """
46
+ Initialize the BaseLinearWarmupScheduler.
47
+
48
+ Args:
49
+ optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
50
+ T_max (int): Total number of training steps.
51
+ max_lr (float): Maximum learning rate.
52
+ min_lr (float): Minimum learning rate.
53
+ init_lr (float): Initial learning rate.
54
+ warmup_steps (int): Number of steps for the warm-up phase.
55
+ """
56
+ self.optimizer = optimizer
57
+ self.total_steps = T_max
58
+ self.max_lr = max_lr
59
+ self.min_lr = min_lr
60
+ self.init_lr = init_lr
61
+ self.warmup_steps = warmup_steps
62
+
63
+ self.step_t: int = 0
64
+ self.base_lrs: List[float] = []
65
+
66
+ # record current value in self._last_lr to match API from torch.optim.lr_scheduler
67
+ self.last_lr: List[float] = [init_lr]
68
+
69
+ self.validate_parameters()
70
+
71
+ self._init_lr()
72
+
73
+ def validate_parameters(self):
74
+ """
75
+ Validate the parameters to ensure they are non-negative.
76
+
77
+ Raises:
78
+ NegativeLRError: If any of the learning rates are negative.
79
+ NegativeStepError: If any of the step values are negative.
80
+ """
81
+ if self.min_lr < 0:
82
+ raise NegativeLRError(self.min_lr, "min_lr")
83
+
84
+ if self.max_lr < 0:
85
+ raise NegativeLRError(self.max_lr, "max_lr")
86
+
87
+ if self.init_lr < 0:
88
+ raise NegativeLRError(self.init_lr, "init_lr")
89
+
90
+ if self.total_steps < 0:
91
+ raise NegativeStepError(self.total_steps, "T_max")
92
+
93
+ if self.warmup_steps < 0:
94
+ raise NegativeStepError(self.warmup_steps, "warmup_steps")
95
+
96
+ def _init_lr(self):
97
+ """
98
+ Initialize the learning rate for each parameter group in the optimizer.
99
+ """
100
+ self.base_lrs = []
101
+ for param_group in self.optimizer.param_groups:
102
+ param_group["lr"] = self.min_lr
103
+ self.base_lrs.append(self.min_lr)
104
+
105
+ def step(self):
106
+ """
107
+ Update the learning rate for the current step.
108
+
109
+ Returns:
110
+ float: The updated learning rate.
111
+ """
112
+ if self.step_t < self.warmup_steps:
113
+ value = (
114
+ self.init_lr
115
+ + (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps
116
+ )
117
+ elif self.step_t == self.warmup_steps:
118
+ value = self.max_lr
119
+ else:
120
+ value = self._step()
121
+
122
+ self.step_t += 1
123
+
124
+ if self.optimizer is not None:
125
+ for param_group in self.optimizer.param_groups:
126
+ param_group["lr"] = value
127
+
128
+ self.last_lr = [value]
129
+
130
+ return value
131
+
132
+ @abstractmethod
133
+ def _step(self) -> float: # pragma: no cover
134
+ """
135
+ Abstract method to calculate the learning rate for the current step.
136
+
137
+ Returns:
138
+ float: The calculated learning rate.
139
+ """
140
+ raise NotImplementedError
141
+
142
+ def get_lr(self) -> float:
143
+ """
144
+ Get the current learning rate.
145
+
146
+ Returns:
147
+ float: The current learning rate.
148
+ """
149
+ return self.last_lr[0]
150
+
151
+
152
+ class LinearWarmupScheduler(BaseLinearWarmupScheduler):
153
+ r"""Linear LR Scheduler w/ linear warmup."""
154
+
155
+ def _step(self) -> float:
156
+ """
157
+ Calculate the learning rate for the current step using a linear decay.
158
+
159
+ Returns:
160
+ float: The calculated learning rate.
161
+ """
162
+ return self.max_lr + (self.min_lr - self.max_lr) * (
163
+ self.step_t - self.warmup_steps
164
+ ) / (self.total_steps - self.warmup_steps)
165
+
166
+
167
+ class CosineDecayWithWarmup(BaseLinearWarmupScheduler):
168
+ r"""Cosine LR Scheduler w/ linear warmup."""
169
+
170
+ def _step(self) -> float:
171
+ """
172
+ Calculate the learning rate for the current step using a cosine decay.
173
+
174
+ Returns:
175
+ float: The calculated learning rate.
176
+ """
177
+ phase: float = (
178
+ (self.step_t - self.warmup_steps)
179
+ / (self.total_steps - self.warmup_steps)
180
+ * math.pi
181
+ )
182
+ return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0
183
+
184
+
185
+ class PolySchedulerWithWarmup(BaseLinearWarmupScheduler):
186
+ r"""Poly LR Scheduler.
187
+
188
+ Args:
189
+ poly_order (float): LR scheduler decreases with steps.
190
+ """
191
+
192
+ def __init__(self, optimizer, poly_order: float = 0.5, **kwargs):
193
+ """
194
+ Initialize the PolySchedulerWithWarmup.
195
+
196
+ Args:
197
+ optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
198
+ poly_order (float): Order of the polynomial for the learning rate decay.
199
+ kwargs: Additional arguments for the base class.
200
+
201
+ Raises:
202
+ ValueError: If poly_order is not positive.
203
+ """
204
+ self.poly_order = poly_order
205
+
206
+ if poly_order <= 0:
207
+ raise ValueError(f"[-] poly_order must be positive. {poly_order}")
208
+
209
+ super().__init__(optimizer, **kwargs)
210
+
211
+ def _step(self) -> float:
212
+ """
213
+ Calculate the learning rate for the current step using a polynomial decay.
214
+
215
+ Returns:
216
+ float: The calculated learning rate.
217
+ """
218
+ return (
219
+ self.min_lr
220
+ + (self.max_lr - self.min_lr)
221
+ * (self.step_t - self.warmup_steps) ** self.poly_order
222
+ )
@@ -0,0 +1 @@
1
+ from .visualization import *
@@ -0,0 +1,119 @@
1
+ """
2
+ This module provides utilities for visualizing learning rate schedulers.
3
+
4
+ Functions:
5
+ simulate_scheduler(lr_scheduler, steps): Simulates the learning rate scheduler for a given number of steps.
6
+ plot_lr_schedulers(lr_schedulers, steps, titles): Plots the learning rates of one or more schedulers over a number of steps.
7
+ """
8
+
9
+ from typing import TYPE_CHECKING, List, Union
10
+
11
+ import matplotlib.pyplot as plt
12
+ import torch
13
+
14
+ if TYPE_CHECKING:
15
+ from torch.optim.lr_scheduler import LRScheduler
16
+
17
+ __all__ = ["simulate_scheduler", "plot_lr_schedulers"]
18
+
19
+
20
+ def simulate_scheduler(lr_scheduler, steps: int):
21
+ """
22
+ Simulates the learning rate scheduler for a given number of steps.
23
+
24
+ Args:
25
+ lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler object.
26
+ steps (int): The number of steps to simulate.
27
+
28
+ Returns:
29
+ List[float]: A list of learning rates for each step.
30
+ """
31
+ lrs = []
32
+ for _ in range(steps):
33
+ lr = lr_scheduler.step()
34
+ lrs.append(lr)
35
+ return lrs
36
+
37
+
38
+ def plot_lr_schedulers(
39
+ lr_schedulers: Union["LRScheduler", List["LRScheduler"]],
40
+ steps: int,
41
+ titles: Union[str, List[str]],
42
+ show_plot: bool = True,
43
+ ):
44
+ """
45
+ Plots the learning rates of one or more schedulers over a number of steps.
46
+
47
+ Args:
48
+ lr_schedulers (Union[LRScheduler, List[LRScheduler]]): One or more learning rate scheduler objects.
49
+ steps (int): The number of steps to simulate.
50
+ titles (Union[str, List[str]]): Titles for the plots.
51
+
52
+ Returns:
53
+ fig, axes: The matplotlib figure and axes objects.
54
+ """
55
+ # Handle single scheduler
56
+ if isinstance(lr_schedulers, torch.optim.lr_scheduler.LRScheduler):
57
+ lr_schedulers = [lr_schedulers]
58
+ if isinstance(titles, str):
59
+ titles = [titles]
60
+
61
+ fig, axs = plt.subplots(len(lr_schedulers), 1, figsize=(5, 3 * len(lr_schedulers)))
62
+ if len(lr_schedulers) == 1:
63
+ axs = [axs]
64
+
65
+ for i, (scheduler, title) in enumerate(zip(lr_schedulers, titles)):
66
+ lrs = simulate_scheduler(scheduler, steps)
67
+ axs[i].plot(lrs, label=title)
68
+ axs[i].set_title(title)
69
+ axs[i].set_xlabel("Steps")
70
+ axs[i].set_ylabel("Learning Rate")
71
+ axs[i].legend()
72
+ axs[i].grid(True)
73
+
74
+ plt.tight_layout()
75
+ if show_plot:
76
+ plt.show()
77
+ return fig, axs
78
+
79
+
80
+ # Example usage
81
+ if __name__ == "__main__":
82
+ from fusion_bench.optim.lr_scheduler.linear_warmup import (
83
+ CosineDecayWithWarmup,
84
+ LinearWarmupScheduler,
85
+ PolySchedulerWithWarmup,
86
+ )
87
+
88
+ # Dummy optimizer
89
+ optimizer = torch.optim.SGD(
90
+ [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))], lr=0.1
91
+ )
92
+
93
+ # Define the schedulers
94
+ linear_scheduler = LinearWarmupScheduler(
95
+ optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
96
+ )
97
+ cosine_scheduler = CosineDecayWithWarmup(
98
+ optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
99
+ )
100
+ poly_scheduler = PolySchedulerWithWarmup(
101
+ optimizer,
102
+ t_max=100,
103
+ max_lr=0.1,
104
+ min_lr=0.01,
105
+ init_lr=0.0,
106
+ warmup_steps=40,
107
+ poly_order=2.0,
108
+ )
109
+
110
+ # Plot the learning rates
111
+ plot_lr_schedulers(
112
+ [linear_scheduler, cosine_scheduler, poly_scheduler],
113
+ steps=100,
114
+ titles=[
115
+ "Linear Warmup",
116
+ "Cosine Decay with Warmup",
117
+ "Poly Scheduler with Warmup",
118
+ ],
119
+ )
@@ -5,8 +5,6 @@ import numpy as np
5
5
  import torch
6
6
  from torch.optim.optimizer import Optimizer
7
7
 
8
- from fusion_bench.utils import timeit_context
9
-
10
8
  log = logging.getLogger(__name__)
11
9
 
12
10
 
@@ -185,10 +185,13 @@ class FabricModelFusionProgram(
185
185
  report = taskpool.evaluate(merged_model)
186
186
  return report
187
187
  elif isinstance(merged_model, Dict):
188
- model = merged_model.pop("model")
189
- report: dict = taskpool.evaluate(model)
190
- report.update(merged_model)
191
- print(report)
188
+ report = {}
189
+ for key, item in merged_model.items():
190
+ if isinstance(item, nn.Module):
191
+ report[key] = taskpool.evaluate(item)
192
+ else:
193
+ # metadata
194
+ report[key] = item
192
195
  return report
193
196
  elif isinstance(merged_model, Iterable):
194
197
  return [
@@ -236,7 +239,11 @@ class FabricModelFusionProgram(
236
239
  self.save_merged_model(merged_model)
237
240
  if self.taskpool is not None:
238
241
  report = self.evaluate_merged_model(self.taskpool, merged_model)
239
- print_json(report, print_type=False)
242
+ try:
243
+ print_json(report, print_type=False)
244
+ except Exception as e:
245
+ log.warning(f"Failed to pretty print the report: {e}")
246
+ print(report)
240
247
  if self.report_save_path is not None:
241
248
  # save report (Dict) to a file
242
249
  # if the directory of `save_report` does not exists, create it
@@ -3,7 +3,17 @@ import json
3
3
  import logging
4
4
  import os
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast # noqa: F401
6
+ from typing import ( # noqa: F401
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Callable,
10
+ Dict,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Union,
15
+ cast,
16
+ )
7
17
 
8
18
  import torch
9
19
  from omegaconf import DictConfig
@@ -25,6 +35,10 @@ from fusion_bench.tasks.clip_classification import get_classnames_and_templates
25
35
  from fusion_bench.utils import instantiate
26
36
  from fusion_bench.utils.parameters import count_parameters
27
37
 
38
+ if TYPE_CHECKING:
39
+ from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
40
+
41
+ # disable tokenizers parallelism by default to avoid deadlocks
28
42
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
29
43
 
30
44
  log = logging.getLogger(__name__)
@@ -198,14 +212,16 @@ class CLIPVisionModelTaskPool(
198
212
  classifier: HFCLIPClassifier,
199
213
  test_loader: DataLoader,
200
214
  num_classes: int,
215
+ task_name: str = None,
201
216
  ):
202
217
  """
203
- Evaluate the classifier on the test dataset.
218
+ Evaluate the classifier on the test dataset (single-task evaluation).
204
219
 
205
220
  Args:
206
221
  classifier (HFCLIPClassifier): The classifier to evaluate.
207
222
  test_loader (DataLoader): The data loader for the test dataset.
208
223
  num_classes (int): The number of classes in the classification task.
224
+ task_name (str): The name of the task.
209
225
 
210
226
  Returns:
211
227
  Dict[str, float]: A dictionary containing the accuracy and loss of the classifier on the test dataset.
@@ -228,7 +244,12 @@ class CLIPVisionModelTaskPool(
228
244
  )
229
245
  ):
230
246
  inputs, targets = batch
231
- outputs = classifier(inputs, return_image_embeds=True, return_dict=True)
247
+ outputs = classifier(
248
+ inputs,
249
+ return_image_embeds=True,
250
+ return_dict=True,
251
+ task_name=task_name,
252
+ )
232
253
  logits: Tensor = outputs["logits"]
233
254
 
234
255
  loss = F.cross_entropy(logits, targets)
@@ -246,12 +267,18 @@ class CLIPVisionModelTaskPool(
246
267
  results = {"accuracy": acc, "loss": loss}
247
268
  return results
248
269
 
249
- def evaluate(self, model: Union[CLIPVisionModel, CLIPVisionTransformer], name=None):
270
+ def evaluate(
271
+ self,
272
+ model: Union[CLIPVisionModel, CLIPVisionTransformer],
273
+ name=None,
274
+ **kwargs,
275
+ ):
250
276
  """
251
277
  Evaluate the model on the image classification task.
252
278
 
253
279
  Args:
254
280
  model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
281
+ name (Optional[str]): The name of the model. This will be logged into the report if not None.
255
282
 
256
283
  Returns:
257
284
  Dict[str, Any]: A dictionary containing the evaluation results for each task.
@@ -261,8 +288,17 @@ class CLIPVisionModelTaskPool(
261
288
 
262
289
  report = {}
263
290
  # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
264
- self.clip_model.vision_model = model
265
- classifier = HFCLIPClassifier(self.clip_model, processor=self.processor)
291
+ if hasattr(model, "is_surgery_model") and model.is_surgery_model:
292
+ log.info("running evaluation on a surgery model.")
293
+ model: "SurgeryModelWrapper" = model
294
+ self.clip_model.vision_model = model
295
+ else:
296
+ # replace the vision encoder with the model
297
+ self.clip_model.vision_model = model
298
+ classifier = HFCLIPClassifier(
299
+ self.clip_model,
300
+ processor=self.processor,
301
+ )
266
302
  classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
267
303
  # collect basic model information
268
304
  training_params, all_params = count_parameters(model)
@@ -285,6 +321,7 @@ class CLIPVisionModelTaskPool(
285
321
  classifier,
286
322
  test_dataloader,
287
323
  num_classes=len(classnames),
324
+ task_name=task_name,
288
325
  )
289
326
  report[task_name] = result
290
327
  self.on_task_evaluation_end()
@@ -0,0 +1,157 @@
1
+ """
2
+ The dataset contains the following fields:
3
+
4
+ - chosen_input_ids: The input token ids for the winner.
5
+ - chosen_attention_mask: The attention mask for the winner.
6
+ - rejected_input_ids: The input token ids for the loser.
7
+ - rejected_attention_mask: The attention mask for the loser.
8
+ """
9
+
10
+ import functools
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
12
+
13
+ import lightning as L
14
+ import numpy as np
15
+ import torch
16
+ from omegaconf import DictConfig
17
+ from torch.utils.data import Subset
18
+ from tqdm.auto import tqdm
19
+
20
+ from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
21
+ from fusion_bench.mixins import LightningFabricMixin
22
+ from fusion_bench.taskpool import BaseTaskPool
23
+ from fusion_bench.utils import instantiate
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers import LlamaForSequenceClassification
27
+
28
+
29
+ def evaluate_batch(model: "LlamaForSequenceClassification", batch):
30
+ batch_size = batch["input_ids"].size(0)
31
+ assert batch_size % 2 == 0, "Batch size must be even."
32
+
33
+ outputs = model(
34
+ input_ids=batch["input_ids"],
35
+ attention_mask=batch["attention_mask"],
36
+ )
37
+
38
+ rewards = outputs[0]
39
+ chosen_reward = rewards[: batch_size // 2]
40
+ rejected_rewards = rewards[batch_size // 2 :]
41
+
42
+ loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()
43
+ correct = (chosen_reward > rejected_rewards).sum().item()
44
+ total = batch_size // 2
45
+
46
+ return {
47
+ "loss": loss.item(),
48
+ "correct": correct,
49
+ "total": total,
50
+ }
51
+
52
+
53
+ def evaluate_dataloader(model: "LlamaForSequenceClassification", dataloader):
54
+ """
55
+ Compute the accuracy of the reward model on the given dataloader.
56
+
57
+ Args:
58
+ model: The reward model
59
+ dataloader: The dataloader for the dataset
60
+
61
+ Returns:
62
+ float: The accuracy of the reward model on the dataset
63
+ """
64
+ metrics = {
65
+ "loss": 0.0,
66
+ "correct": 0,
67
+ "total": 0,
68
+ }
69
+ with torch.no_grad():
70
+ for batch in (pbar := tqdm(dataloader)):
71
+ batch_result = evaluate_batch(model, batch)
72
+ new_total = metrics["total"] + batch_result["total"]
73
+ metrics["loss"] = (
74
+ metrics["loss"] * metrics["total"] / new_total
75
+ + batch_result["loss"] * batch_result["total"] / new_total
76
+ )
77
+ metrics["correct"] += batch_result["correct"]
78
+ metrics["total"] += batch_result["total"]
79
+ pbar.set_postfix(metrics)
80
+
81
+ metrics["accuracy"] = metrics["correct"] / metrics["total"]
82
+ return metrics
83
+
84
+
85
+ class RewardModelEvaluationTaskPool(
86
+ BaseTaskPool,
87
+ LightningFabricMixin,
88
+ ):
89
+ def __init__(
90
+ self,
91
+ test_datasets: List[DictConfig],
92
+ dataloader_kwargs: DictConfig,
93
+ tokenizer: Optional[DictConfig],
94
+ max_num_samples: int = -1,
95
+ seed: int = 0,
96
+ **kwargs,
97
+ ):
98
+ self.seed = seed
99
+ L.seed_everything(seed)
100
+ self._test_datasets = test_datasets
101
+ self.dataloader_kwargs = dataloader_kwargs
102
+ self._tokenizer = tokenizer
103
+ self.max_num_samples = max_num_samples
104
+ super().__init__(**kwargs)
105
+
106
+ def setup(self):
107
+ if self._tokenizer is None:
108
+ # try to load the tokenizer from the model pool
109
+ tokenizer = self._program.modelpool.load_tokenizer()
110
+ else:
111
+ tokenizer = instantiate(self._tokenizer)
112
+ self.tokenizer = tokenizer
113
+
114
+ test_datasets = {
115
+ dataset_name: instantiate(self._test_datasets[dataset_name])
116
+ for dataset_name in self._test_datasets
117
+ }
118
+ if self.max_num_samples > 0:
119
+ test_datasets = {
120
+ dataset_name: Subset(
121
+ test_dataset,
122
+ np.random.permutation(len(test_dataset))[: self.max_num_samples],
123
+ )
124
+ for dataset_name, test_dataset in test_datasets.items()
125
+ }
126
+ test_dataloaders = {
127
+ dataset_name: torch.utils.data.DataLoader(
128
+ test_dataset,
129
+ collate_fn=functools.partial(
130
+ bradley_terry_rm_collate,
131
+ pad_token_id=tokenizer.pad_token_id,
132
+ ),
133
+ **self.dataloader_kwargs,
134
+ )
135
+ for dataset_name, test_dataset in test_datasets.items()
136
+ }
137
+
138
+ self.test_dataloaders = {
139
+ dataset_name: self.fabric.setup_dataloaders(test_dataloader)
140
+ for dataset_name, test_dataloader in test_dataloaders.items()
141
+ }
142
+
143
+ @torch.no_grad()
144
+ def evaluate(self, model: "LlamaForSequenceClassification"):
145
+ self.setup()
146
+
147
+ model = self.fabric.setup_module(model)
148
+ if model.config.pad_token_id is None:
149
+ model.config.pad_token_id = self.tokenizer.pad_token_id
150
+
151
+ model.eval()
152
+ report = {}
153
+ for dataset_name, test_dataloader in self.test_dataloaders.items():
154
+ report[dataset_name] = evaluate_dataloader(model, test_dataloader)
155
+
156
+ print(report)
157
+ return report