fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,11 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Optional
3
+ from typing import Optional, Union
4
4
 
5
+ from datasets import load_dataset
5
6
  from omegaconf import DictConfig, open_dict
7
+ from torch import nn
8
+ from torch.utils.data import Dataset
6
9
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
7
10
  from typing_extensions import override
8
11
 
@@ -36,17 +39,29 @@ class CLIPVisionModelPool(BaseModelPool):
36
39
 
37
40
  def load_processor(self, *args, **kwargs) -> CLIPProcessor:
38
41
  assert self._processor is not None, "Processor is not defined in the config"
39
- processor = instantiate(self._processor, *args, **kwargs)
42
+ if isinstance(self._processor, str):
43
+ log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
44
+ processor = CLIPProcessor.from_pretrained(self._processor)
45
+ else:
46
+ processor = instantiate(self._processor, *args, **kwargs)
40
47
  return processor
41
48
 
42
49
  def load_clip_model(self, model_name: str, *args, **kwargs) -> CLIPModel:
43
50
  model_config = self._models[model_name]
44
- assert isinstance(model_config, DictConfig), "Model config must be a DictConfig"
45
- model_config = deepcopy(model_config)
46
- with open_dict(model_config):
47
- model_config._target_ = "transformers.CLIPModel.from_pretrained"
48
- clip_model = instantiate(model_config, *args, **kwargs)
49
- return clip_model
51
+
52
+ if isinstance(model_config, str):
53
+ log.info(f"Loading `transformers.CLIPModel`: {model_config}")
54
+ clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
55
+ return clip_model
56
+ else:
57
+ assert isinstance(
58
+ model_config, DictConfig
59
+ ), "Model config must be a DictConfig"
60
+ model_config = deepcopy(model_config)
61
+ with open_dict(model_config):
62
+ model_config._target_ = "transformers.CLIPModel.from_pretrained"
63
+ clip_model = instantiate(model_config, *args, **kwargs)
64
+ return clip_model
50
65
 
51
66
  @override
52
67
  def save_model(self, model: CLIPVisionModel, path: str):
@@ -59,3 +74,72 @@ class CLIPVisionModelPool(BaseModelPool):
59
74
  """
60
75
  with timeit_context(f'Saving clip vision model to "{path}"'):
61
76
  model.save_pretrained(path)
77
+
78
+ def load_model(
79
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
80
+ ) -> CLIPVisionModel:
81
+ """
82
+ This method is used to load a CLIPVisionModel from the model pool.
83
+
84
+ Example configuration could be:
85
+
86
+ ```yaml
87
+ models:
88
+ cifar10: tanganke/clip-vit-base-patch32_cifar10
89
+ sun397: tanganke/clip-vit-base-patch32_sun397
90
+ stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
91
+ ```
92
+
93
+ Args:
94
+ model_name_or_config (Union[str, DictConfig]): The name of the model or the model configuration.
95
+
96
+ Returns:
97
+ CLIPVisionModel: The loaded CLIPVisionModel.
98
+ """
99
+ if (
100
+ isinstance(model_name_or_config, str)
101
+ and model_name_or_config in self._models
102
+ ):
103
+ model = self._models[model_name_or_config]
104
+ if isinstance(model, str):
105
+ log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
106
+ return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
107
+ if isinstance(model, nn.Module):
108
+ log.info(f"Returning existing model: {model}")
109
+ return model
110
+
111
+ # If the model is not a string, we use the default load_model method
112
+ return super().load_model(model_name_or_config, *args, **kwargs)
113
+
114
+ def load_train_dataset(self, dataset_name: str, *args, **kwargs):
115
+ dataset_config = self._train_datasets[dataset_name]
116
+ if isinstance(dataset_config, str):
117
+ log.info(
118
+ f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
119
+ )
120
+ dataset = load_dataset(dataset_config, split="train")
121
+ else:
122
+ dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
123
+ return dataset
124
+
125
+ def load_val_dataset(self, dataset_name: str, *args, **kwargs):
126
+ dataset_config = self._val_datasets[dataset_name]
127
+ if isinstance(dataset_config, str):
128
+ log.info(
129
+ f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
130
+ )
131
+ dataset = load_dataset(dataset_config, split="validation")
132
+ else:
133
+ dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
134
+ return dataset
135
+
136
+ def load_test_dataset(self, dataset_name: str, *args, **kwargs):
137
+ dataset_config = self._test_datasets[dataset_name]
138
+ if isinstance(dataset_config, str):
139
+ log.info(
140
+ f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
141
+ )
142
+ dataset = load_dataset(dataset_config, split="test")
143
+ else:
144
+ dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
145
+ return dataset
@@ -0,0 +1,2 @@
1
+ from .reward_model import create_reward_model_from_pretrained
2
+ from .seq_classification_lm import SeqenceClassificationModelPool
@@ -0,0 +1,15 @@
1
+ from transformers import AutoModelForSequenceClassification
2
+
3
+
4
+ def create_reward_model_from_pretrained(pretrained_model_name_or_path: str, **kwargs):
5
+ """
6
+ Create a reward model for reward modeling (RLHF).
7
+
8
+ Args:
9
+ pretrained_model_name_or_path (str): The name or path of the pretrained model.
10
+ **kwargs: Additional keyword arguments passed to the model class.
11
+ """
12
+ model = AutoModelForSequenceClassification.from_pretrained(
13
+ pretrained_model_name_or_path, num_labels=1, **kwargs
14
+ )
15
+ return model
@@ -0,0 +1,98 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union, cast # noqa: F401
5
+
6
+ from omegaconf import DictConfig, flag_override
7
+ from transformers import PreTrainedModel, PreTrainedTokenizer
8
+ from typing_extensions import override
9
+
10
+ from fusion_bench.modelpool import BaseModelPool
11
+ from fusion_bench.utils import instantiate
12
+ from fusion_bench.utils.dtype import parse_dtype
13
+
14
+ if TYPE_CHECKING:
15
+ from transformers import LlamaForSequenceClassification
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ class SeqenceClassificationModelPool(BaseModelPool):
21
+
22
+ def __init__(
23
+ self,
24
+ models,
25
+ *,
26
+ tokenizer: Optional[DictConfig],
27
+ model_kwargs: Optional[DictConfig] = None,
28
+ **kwargs,
29
+ ):
30
+ super().__init__(models, **kwargs)
31
+ # process `model_kwargs`
32
+ self._tokenizer = tokenizer
33
+ self._model_kwargs = model_kwargs
34
+ if self._model_kwargs is None:
35
+ self._model_kwargs = DictConfig({})
36
+ with flag_override(self._model_kwargs, "allow_objects", True):
37
+ if hasattr(self._model_kwargs, "torch_dtype"):
38
+ self._model_kwargs.torch_dtype = parse_dtype(
39
+ self._model_kwargs.torch_dtype
40
+ )
41
+
42
+ @override
43
+ def load_model(
44
+ self,
45
+ model_name_or_config: str | DictConfig,
46
+ *args,
47
+ **kwargs,
48
+ ) -> Union[PreTrainedModel, "LlamaForSequenceClassification"]:
49
+ model_kwargs = deepcopy(self._model_kwargs)
50
+ model_kwargs.update(kwargs)
51
+ if isinstance(model_name_or_config, str):
52
+ log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
53
+ return super().load_model(model_name_or_config, *args, **model_kwargs)
54
+
55
+ def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
56
+ assert self._tokenizer is not None, "Tokenizer is not defined in the config"
57
+ log.info("Loading tokenizer.", stacklevel=2)
58
+ tokenizer = instantiate(self._tokenizer, *args, **kwargs)
59
+ return tokenizer
60
+
61
+ @override
62
+ def save_model(
63
+ self,
64
+ model: PreTrainedModel,
65
+ path: str,
66
+ push_to_hub: bool = False,
67
+ model_dtype: Optional[str] = None,
68
+ save_tokenizer: bool = False,
69
+ tokenizer_kwargs=None,
70
+ **kwargs,
71
+ ):
72
+ """
73
+ Save the model to the specified path.
74
+
75
+ Args:
76
+ model (PreTrainedModel): The model to be saved.
77
+ path (str): The path where the model will be saved.
78
+ push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
79
+ save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
80
+ **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
81
+ """
82
+ path = os.path.expanduser(path)
83
+ if save_tokenizer:
84
+ if tokenizer_kwargs is None:
85
+ tokenizer_kwargs = {}
86
+ # load the tokenizer
87
+ tokenizer = self.load_tokenizer(**tokenizer_kwargs)
88
+ tokenizer.save_pretrained(
89
+ path,
90
+ push_to_hub=push_to_hub,
91
+ )
92
+ if model_dtype is not None:
93
+ model.to(dtype=parse_dtype(model_dtype))
94
+ model.save_pretrained(
95
+ path,
96
+ push_to_hub=push_to_hub,
97
+ **kwargs,
98
+ )
@@ -0,0 +1 @@
1
+ from .load_tokenizer import chat_template_mapping, load_tokenizer_with_chat_template
@@ -0,0 +1 @@
1
+ CHAT_TEMPLATE = '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\'+ message[\'content\'] | trim + \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n'
@@ -0,0 +1,43 @@
1
+ import logging
2
+
3
+ from transformers import AutoTokenizer
4
+
5
+ from .llama_3_Instruct import CHAT_TEMPLATE as LLAMA_3_INSTRUCT_CHAT_TEMPLATE
6
+
7
+ chat_template_mapping = {"llama_3_instruct": LLAMA_3_INSTRUCT_CHAT_TEMPLATE}
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ def load_tokenizer_with_chat_template(
13
+ pretrained_model_name_or_path: str,
14
+ model_family: str,
15
+ overwrite_chat_template: bool = True,
16
+ **kwargs,
17
+ ):
18
+ """
19
+ Load the tokenizer for Llama 3 model.
20
+
21
+ Args:
22
+ pretrained_model_name_or_path (str): The name or path of the pretrained model.
23
+ model_family (str): The model family.
24
+ **kwargs: Additional keyword arguments passed to the tokenizer class.
25
+ """
26
+ assert (
27
+ model_family in chat_template_mapping
28
+ ), f"Model family {model_family} not found. Available model families: {chat_template_mapping.keys()}"
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ pretrained_model_name_or_path,
32
+ **kwargs,
33
+ )
34
+
35
+ if tokenizer.chat_template is None:
36
+ tokenizer.chat_template = chat_template_mapping[model_family]
37
+ else:
38
+ if overwrite_chat_template:
39
+ log.warning("Overwriting the chat template with the default chat template.")
40
+ tokenizer.chat_template = chat_template_mapping[model_family]
41
+ else:
42
+ log.warning("Chat template already exists. Skipping overwriting.")
43
+ return tokenizer
@@ -1,4 +1,5 @@
1
- from typing import Callable, Iterable, List # noqa: F401
1
+ import logging
2
+ from typing import TYPE_CHECKING, Callable, Iterable, List # noqa: F401
2
3
 
3
4
  import torch
4
5
  from torch import Tensor, nn
@@ -7,6 +8,11 @@ from transformers.models.clip.modeling_clip import BaseModelOutputWithPooling
7
8
 
8
9
  from fusion_bench.utils.devices import get_device
9
10
 
11
+ if TYPE_CHECKING:
12
+ from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
13
+
14
+ log = logging.getLogger(__name__)
15
+
10
16
  default_templates = [
11
17
  lambda c: f"a photo of a {c}",
12
18
  ]
@@ -33,6 +39,7 @@ class HFCLIPClassifier(nn.Module):
33
39
  self,
34
40
  clip_model: CLIPModel,
35
41
  processor: CLIPProcessor,
42
+ extra_module=None,
36
43
  ):
37
44
  """
38
45
  Initialize the HFCLIPClassifier.
@@ -56,6 +63,8 @@ class HFCLIPClassifier(nn.Module):
56
63
  persistent=False,
57
64
  )
58
65
 
66
+ self.extra_module = extra_module
67
+
59
68
  @property
60
69
  def text_model(self):
61
70
  """Get the text model component of CLIP."""
@@ -111,7 +120,13 @@ class HFCLIPClassifier(nn.Module):
111
120
 
112
121
  self.zeroshot_weights = zeroshot_weights
113
122
 
114
- def forward(self, images, return_image_embeds=False, return_dict=False):
123
+ def forward(
124
+ self,
125
+ images: Tensor,
126
+ return_image_embeds=False,
127
+ return_dict=False,
128
+ task_name=None,
129
+ ):
115
130
  """
116
131
  Perform forward pass for zero-shot image classification.
117
132
 
@@ -120,6 +135,9 @@ class HFCLIPClassifier(nn.Module):
120
135
 
121
136
  Args:
122
137
  images (Tensor): Input images to classify.
138
+ return_image_embeds (bool): Whether to return the image embeddings.
139
+ return_dict (bool): Whether to return a dictionary with logits and image embeddings.
140
+ task_name (Optional[str]): The name of the task.
123
141
 
124
142
  Returns:
125
143
  Tensor: Classification logits for each input image.
@@ -131,16 +149,22 @@ class HFCLIPClassifier(nn.Module):
131
149
  raise ValueError("Must set classification task before forward pass")
132
150
  text_embeds = self.zeroshot_weights
133
151
 
134
- image_embeds = self.vision_model(images)
135
- if isinstance(image_embeds, Tensor):
136
- pass
137
- elif isinstance(image_embeds, BaseModelOutputWithPooling):
138
- image_embeds = image_embeds[1]
139
- image_embeds = self.clip_model.visual_projection(image_embeds)
140
-
152
+ image_embeds = self.get_image_features(images)
141
153
  # normalize embeddings
142
154
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
143
155
 
156
+ if (
157
+ hasattr(self.vision_model, "is_surgery_model")
158
+ and self.vision_model.is_surgery_model
159
+ ):
160
+ # Dealing with the surgery model, for more details, please refer to:
161
+ # (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging
162
+ # https://arxiv.org/abs/2402.02705
163
+ self.vision_model: "SurgeryModelWrapper" = self.vision_model
164
+ image_embeds, _, _ = self.vision_model.compute_surgery_features(
165
+ image_embeds, dataset_name=task_name
166
+ )
167
+
144
168
  # cosine similarity
145
169
  logit_scale = self.clip_model.logit_scale.exp()
146
170
  logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
@@ -156,3 +180,20 @@ class HFCLIPClassifier(nn.Module):
156
180
  return logits_per_image, image_embeds
157
181
  else:
158
182
  return logits_per_image
183
+
184
+ def get_image_features(self, images: Tensor) -> Tensor:
185
+ """
186
+ Compute the image embeddings.
187
+
188
+ Returns:
189
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
190
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
191
+ """
192
+
193
+ image_embeds = self.vision_model(images)
194
+ if isinstance(image_embeds, Tensor):
195
+ pass
196
+ elif isinstance(image_embeds, BaseModelOutputWithPooling):
197
+ image_embeds = image_embeds[1]
198
+ image_embeds = self.clip_model.visual_projection(image_embeds)
199
+ return image_embeds
@@ -0,0 +1 @@
1
+ from .surgerymodelwrapper import SurgeryModelWrapper
@@ -0,0 +1,158 @@
1
+ import math
2
+ from typing import TYPE_CHECKING, Callable, Generic, List, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers.models.clip.modeling_clip import (
7
+ CLIPVisionModel,
8
+ CLIPVisionTransformer,
9
+ )
10
+
11
+ from fusion_bench.utils.type import TorchModelType
12
+
13
+
14
+ def regularize_name(name: str):
15
+ name = name.replace("-", "_")
16
+ name = name.replace(".", "_")
17
+ return name
18
+
19
+
20
+ class SurgeryModelWrapper(torch.nn.Module, Generic[TorchModelType]):
21
+
22
+ is_surgery_model = True
23
+ """A flag to indicate that this is a surgery model."""
24
+
25
+ def __init__(
26
+ self,
27
+ model: TorchModelType,
28
+ test_datasets: List[str],
29
+ projection_dim: int = 512,
30
+ hidden_dim: int = 16,
31
+ ):
32
+ super(SurgeryModelWrapper, self).__init__()
33
+ self.model = model
34
+ self.model.requires_grad_(False)
35
+
36
+ self.test_datasets = test_datasets
37
+ self.non_linear_func = torch.nn.ReLU()
38
+
39
+ self.projection_dim = projection_dim
40
+ self.hidden_dim = hidden_dim
41
+
42
+ for dataset_name in test_datasets:
43
+ self.add_surgery_module(dataset_name)
44
+
45
+ def add_surgery_module(self, dataset_name: str):
46
+ """
47
+ Add a surgery module for a given dataset.
48
+
49
+ Args:
50
+ dataset_name (str): The name of the dataset.
51
+ """
52
+ dataset_name = regularize_name(dataset_name)
53
+
54
+ down_proj = torch.nn.Linear(self.projection_dim, self.hidden_dim, bias=False)
55
+ up_proj = torch.nn.Linear(self.hidden_dim, self.projection_dim, bias=False)
56
+
57
+ torch.nn.init.kaiming_uniform_(down_proj.weight, a=math.sqrt(5))
58
+ torch.nn.init.zeros_(up_proj.weight)
59
+
60
+ self.add_module(
61
+ "feature_mapping_to_head_down_proj_{}".format(dataset_name), down_proj
62
+ )
63
+ self.add_module(
64
+ "feature_mapping_to_head_up_proj_{}".format(dataset_name), up_proj
65
+ )
66
+
67
+ def collect_trainable_params(self):
68
+ trainable_params = []
69
+
70
+ # surgery parameter
71
+ for dataset_name in self.test_datasets:
72
+ dataset_name = regularize_name(dataset_name)
73
+ down_proj = getattr(
74
+ self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
75
+ )
76
+ up_proj = getattr(
77
+ self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
78
+ )
79
+ trainable_params.append(down_proj.weight)
80
+ trainable_params.append(up_proj.weight)
81
+ return trainable_params
82
+
83
+ def collect_surgery_module(self):
84
+ surgery_module = {}
85
+
86
+ # surgery parameter
87
+ for dataset_name in self.test_datasets:
88
+ dataset_name = regularize_name(dataset_name)
89
+ down_proj = getattr(
90
+ self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
91
+ )
92
+ up_proj = getattr(
93
+ self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
94
+ )
95
+ surgery_module[
96
+ "feature_mapping_to_head_down_proj_{}".format(dataset_name)
97
+ ] = down_proj
98
+ surgery_module[
99
+ "feature_mapping_to_head_up_proj_{}".format(dataset_name)
100
+ ] = up_proj
101
+
102
+ surgery_module["non_linear_func"] = self.non_linear_func
103
+
104
+ return surgery_module
105
+
106
+ def compute_surgery_features(
107
+ self,
108
+ compute_features_fn: Union[
109
+ torch.Tensor, Callable[[TorchModelType], torch.Tensor]
110
+ ],
111
+ dataset_name: str,
112
+ ):
113
+ """
114
+ Compute the surgery features.
115
+
116
+ Args:
117
+ compute_features_fn (Union[torch.Tensor, Callable[[nn.Module], torch.Tensor]]): A function that computes the features or a tensor that represents the features.
118
+ dataset_name (str): The name of the dataset.
119
+
120
+ Returns:
121
+ feature (torch.Tensor): The surgery features.
122
+ feature0 (torch.Tensor): The original features.
123
+ feature_sub (torch.Tensor): feature0 - feature.
124
+ """
125
+ dataset_name = regularize_name(dataset_name)
126
+
127
+ if isinstance(compute_features_fn, torch.Tensor):
128
+ feature = compute_features_fn
129
+ elif callable(compute_features_fn):
130
+ feature = compute_features_fn(self.model)
131
+ else:
132
+ raise ValueError(
133
+ "compute_features_fn must be a tensor or a callable, but got {}".format(
134
+ type(compute_features_fn)
135
+ )
136
+ )
137
+
138
+ feature0 = feature
139
+
140
+ # feature bias
141
+ down_proj = getattr(
142
+ self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
143
+ )
144
+ up_proj = getattr(
145
+ self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
146
+ )
147
+ feature_sub = down_proj(feature)
148
+ feature_sub = self.non_linear_func(feature_sub)
149
+ feature_sub = up_proj(feature_sub)
150
+
151
+ # surgery feature
152
+ feature = feature0 - feature_sub
153
+
154
+ return feature, feature0, feature_sub
155
+
156
+ def forward(self, *args, **kwargs):
157
+ """The wrappered model should just forward like normal."""
158
+ return self.model(*args, **kwargs)
@@ -1,5 +1,6 @@
1
1
  from typing import List
2
2
 
3
+ import torch
3
4
  from torch import nn
4
5
 
5
6
 
@@ -70,3 +71,10 @@ def find_layers_with_type(
70
71
  if isinstance(submodule, tuple(layer_types)):
71
72
  res[name] = submodule
72
73
  return res
74
+
75
+
76
+ def disable_dropout(model: torch.nn.Module):
77
+ """Disable dropout in a model."""
78
+ for module in model.modules():
79
+ if isinstance(module, torch.nn.Dropout):
80
+ module.p = 0
@@ -1,13 +1,22 @@
1
1
  import functools
2
2
  import logging
3
3
  from copy import deepcopy
4
- from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
4
+ from typing import ( # noqa: F401
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ Generic,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ TypeVar,
13
+ )
5
14
 
6
15
  import torch
7
16
  from torch import Tensor, nn
8
17
  from torch.func import functional_call
9
18
 
10
- from fusion_bench.utils.type import StateDictType
19
+ from fusion_bench.utils.type import StateDictType, TorchModelType
11
20
 
12
21
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
13
22
 
@@ -132,14 +141,14 @@ def fuse_weights(
132
141
  }
133
142
 
134
143
 
135
- class LayerWiseMergedModel(nn.Module):
144
+ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
136
145
  _merged_state_dict: StateDictType = None
137
146
 
138
147
  def __init__(
139
148
  self,
140
149
  layer_wise_weight: Tensor,
141
- pretrained_model: nn.Module,
142
- finetuned_models: List[nn.Module],
150
+ pretrained_model: TorchModelType,
151
+ finetuned_models: List[TorchModelType],
143
152
  clamp_weights: bool = True,
144
153
  tie_weights: bool = False,
145
154
  strict: bool = True,
@@ -16,13 +16,13 @@ outputs = merged_model(inputs)
16
16
 
17
17
  import functools
18
18
  import logging
19
- from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
19
+ from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
20
20
 
21
21
  import torch
22
22
  from torch import Tensor, nn
23
23
  from torch.func import functional_call
24
24
 
25
- from fusion_bench.utils.type import StateDictType
25
+ from fusion_bench.utils.type import StateDictType, TorchModelType
26
26
 
27
27
  log = logging.getLogger(__name__)
28
28
 
@@ -157,14 +157,14 @@ def fuse_weights(
157
157
  }
158
158
 
159
159
 
160
- class TaskWiseMergedModel(nn.Module):
160
+ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
161
161
  _merged_state_dict: StateDictType = None
162
162
 
163
163
  def __init__(
164
164
  self,
165
165
  task_wise_weight: Tensor,
166
- pretrained_model: nn.Module,
167
- finetuned_models: List[nn.Module],
166
+ pretrained_model: TorchModelType,
167
+ finetuned_models: List[TorchModelType],
168
168
  clamp_weights: bool = True,
169
169
  tie_weights: bool = False,
170
170
  strict: bool = True,
@@ -0,0 +1,2 @@
1
+ from . import exception, lr_scheduler
2
+ from .mezo import MeZO