fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,17 @@ import functools
2
2
  import logging
3
3
  import os
4
4
  from copy import deepcopy
5
- from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast # noqa: F401
5
+ from typing import ( # noqa: F401
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Dict,
9
+ List,
10
+ Optional,
11
+ Tuple,
12
+ TypeVar,
13
+ Union,
14
+ cast,
15
+ )
6
16
 
7
17
  import torch
8
18
  from omegaconf import DictConfig
@@ -18,10 +28,12 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
18
28
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
19
29
  from fusion_bench.utils.data import InfiniteDataLoader
20
30
 
21
- log = logging.getLogger(__name__)
31
+ if TYPE_CHECKING:
32
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
22
33
 
23
- TensorOrModule = TypeVar("TensorOrModule", torch.Tensor, torch.nn.Module, Any)
34
+ log = logging.getLogger(__name__)
24
35
 
36
+ # disable tokenizers parallelism by default to avoid deadlocks
25
37
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
38
 
27
39
 
@@ -120,13 +132,13 @@ class CLIPClassificationMixin(LightningFabricMixin):
120
132
 
121
133
  # get cache directory
122
134
  if self.modelpool.has_pretrained:
123
- model_name = self.modelpool.get_model_config(
124
- "_pretrained_"
125
- ).pretrained_model_name_or_path
135
+ model_name = self.modelpool.get_model_config("_pretrained_")
136
+ if not isinstance(model_name, str):
137
+ model_name = model_name.pretrained_model_name_or_path
126
138
  else:
127
- model_name = self.modelpool.get_model_config(
128
- self.modelpool.model_names[0]
129
- ).pretrained_model_name_or_path
139
+ model_name = self.modelpool.get_model_config(self.modelpool.model_names[0])
140
+ if not isinstance(model_name, str):
141
+ model_name = model_name.pretrained_model_name_or_path
130
142
  cache_dir = os.path.join(
131
143
  self.zeroshot_weights_cache_dir,
132
144
  os.path.normpath(model_name.split("/")[-1]),
@@ -175,13 +187,30 @@ class CLIPClassificationMixin(LightningFabricMixin):
175
187
 
176
188
  def compute_logits(
177
189
  self,
178
- module: Union[nn.Module, CLIPVisionModel],
190
+ module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
179
191
  images: torch.Tensor,
180
192
  task: str,
193
+ image_embeds: Optional[torch.Tensor] = None,
181
194
  ) -> torch.Tensor:
195
+ """
196
+ Compute the logits of the images for a given task.
197
+
198
+ Args:
199
+ module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
200
+ images (torch.Tensor): The images to compute the logits.
201
+ task (str): The task to compute the logits.
202
+ image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.
203
+
204
+ Returns:
205
+ torch.Tensor: The logits of the images.
206
+ """
182
207
  text_embeds = self.zeroshot_weights[task]
183
208
 
184
- image_embeds = module(images)[1]
209
+ if image_embeds is None:
210
+ image_embeds = module(images)[1]
211
+ assert isinstance(
212
+ image_embeds, torch.Tensor
213
+ ), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
185
214
  image_embeds = self.visual_projection(image_embeds)
186
215
 
187
216
  # normalize embeddings
@@ -194,3 +223,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
194
223
  logits_per_image = logits_per_text.t()
195
224
 
196
225
  return logits_per_image
226
+
227
+ def compute_features(
228
+ self,
229
+ module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
230
+ images: torch.Tensor,
231
+ normalize: bool = True,
232
+ ) -> torch.Tensor:
233
+ """
234
+ Extracts image features using CLIP's vision encoder and visual projection.
235
+
236
+ Args:
237
+ module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
238
+ images (torch.Tensor): Input image batch to process.
239
+ normalize (bool): Whether to normalize the image embeddings.
240
+
241
+ Returns:
242
+ torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
243
+ """
244
+ image_embeds = module(images)[1]
245
+ image_embeds = self.visual_projection(image_embeds)
246
+
247
+ if normalize:
248
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
249
+ return image_embeds
@@ -0,0 +1,320 @@
1
+ import itertools
2
+ import logging
3
+ import os
4
+ from abc import abstractmethod
5
+ from typing import TYPE_CHECKING, Literal, Union
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+ from tqdm.auto import tqdm
10
+
11
+ from .lightning_fabric import LightningFabricMixin
12
+
13
+ if TYPE_CHECKING:
14
+ from lightning.fabric.wrappers import (
15
+ _FabricDataLoader,
16
+ _FabricModule,
17
+ _FabricOptimizer,
18
+ )
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class FabricTrainingMixin(LightningFabricMixin):
24
+ """
25
+ This is a general purpose mixin for training a model with PyTorch Lightning.
26
+ """
27
+
28
+ _latest_saved_checkpoint_global_step: int = -1
29
+ """The global step index of the latest saved checkpoint."""
30
+ _expected_total_steps: int = None
31
+ """The expected total number of steps of the entire training."""
32
+ is_training: bool
33
+ """Whether the training is in progress. If set to False, the training will stop."""
34
+ epoch_idx: int
35
+ """The epoch index, which is the number of epochs completed."""
36
+ global_step_idx: int
37
+ """The global step index, which is the number of parameter update steps."""
38
+ max_epochs: int
39
+ """Max number of epochs of the entire training."""
40
+ max_steps: int
41
+ """Max number of parameter update steps of the entire training."""
42
+ max_steps_per_epoch: int
43
+ """Max number of parameter update steps per epoch."""
44
+ gradient_clip_algorithm: Literal["value", "norm"]
45
+ """The algorithm to clip gradients. Available options: 'value', 'norm'."""
46
+ gradient_clip_val: float
47
+ """The value to clip gradients. If None, no clipping is applied."""
48
+ accumulate_grad_batches: int
49
+ """The number of gradient accumulation steps. The effective global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`."""
50
+ lr_scheduler_interval: Literal["step", "epoch"]
51
+ """The interval to run the learning rate scheduler. Available options: 'step', 'epoch'."""
52
+ lr_scheduler_frequency: int
53
+ """The frequency to run the learning rate scheduler."""
54
+ checkpoint_save_interval: Literal["step", "epoch"]
55
+ """The interval to save the model checkpoint. Available options: 'step', 'epoch'."""
56
+ checkpoint_save_frequency: int
57
+ """The frequency to save the model checkpoint."""
58
+
59
+ def clip_gradients_if_needed(self, model, optimizer):
60
+ """
61
+ Clips gradients if the gradient clipping value is set.
62
+
63
+ Args:
64
+ model (nn.Module): The model whose gradients need to be clipped.
65
+ optimizer (torch.optim.Optimizer): The optimizer used for training.
66
+ """
67
+ fabric = self.fabric
68
+
69
+ if self.gradient_clip_val is not None:
70
+ if self.gradient_clip_algorithm == "value":
71
+ fabric.clip_gradients(model, optimizer, clip_val=self.gradient_clip_val)
72
+ elif self.gradient_clip_algorithm == "norm":
73
+ fabric.clip_gradients(model, optimizer, max_norm=self.gradient_clip_val)
74
+ else:
75
+ raise ValueError(
76
+ f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
77
+ )
78
+
79
+ def compute_expected_total_steps(
80
+ self, train_dataloader: torch.utils.data.DataLoader
81
+ ):
82
+ """
83
+ Computes the expected total number of steps for the entire training.
84
+
85
+ Args:
86
+ train_dataloader (torch.utils.data.DataLoader): The dataloader for the training data.
87
+ """
88
+ # compute expected total steps
89
+ self._expected_total_steps = []
90
+ if self.max_steps > 0:
91
+ self._expected_total_steps.append(self.max_steps)
92
+ if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
93
+ self._expected_total_steps.append(
94
+ self.max_steps_per_epoch * self.max_epochs
95
+ )
96
+ if self.max_epochs > 0:
97
+ self._expected_total_steps.append(
98
+ len(train_dataloader) * self.max_epochs // self.accumulate_grad_batches
99
+ )
100
+ self._expected_total_steps = min(self._expected_total_steps)
101
+ log.info(f"Expected total steps: {self._expected_total_steps}")
102
+
103
+ @property
104
+ def expected_total_steps(self):
105
+ """
106
+ The expected total number of steps of the entire training. You need to run `compute_expected_total_steps` method to compute this value before accessing it.
107
+
108
+ Raises:
109
+ ValueError: If the expected total steps have not been computed.
110
+ """
111
+ if self._expected_total_steps is None:
112
+ raise ValueError(
113
+ "The expected total steps have not been computed. Run `compute_expected_total_steps` method."
114
+ )
115
+ else:
116
+ return self._expected_total_steps
117
+
118
+ def conditional_checkpoint_save(
119
+ self,
120
+ stage: Literal["end_of_step", "end_of_epoch", "end_of_training"],
121
+ *args,
122
+ **kwargs,
123
+ ):
124
+ """
125
+ Conditionally saves a checkpoint based on the current training stage.
126
+
127
+ Args:
128
+ stage (Literal["end_of_step", "end_of_epoch", "end_of_training"]): The current stage of training.
129
+ """
130
+ if stage == "end_of_step":
131
+ if (
132
+ self.checkpoint_save_interval == "step"
133
+ and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
134
+ ):
135
+ save_path = os.path.join(
136
+ self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
137
+ )
138
+ self.save_checkpoint(save_path, *args, **kwargs)
139
+ elif stage == "end_of_epoch":
140
+ if (
141
+ self.checkpoint_save_interval == "epoch"
142
+ and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
143
+ ):
144
+ save_path = os.path.join(
145
+ self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
146
+ )
147
+ self.save_checkpoint(save_path, *args, **kwargs)
148
+ elif stage == "end_of_training":
149
+ # if the checkpoint has not been saved yet, save it
150
+ if self.global_step_idx > self._latest_saved_checkpoint_global_step:
151
+ save_path = os.path.join(
152
+ self.log_dir,
153
+ "checkpoints",
154
+ f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
155
+ )
156
+ self.save_checkpoint(save_path, *args, **kwargs)
157
+ try:
158
+ os.symlink(
159
+ src=save_path,
160
+ dst=os.path.join(
161
+ self.log_dir, "checkpoints", "latest_model.ckpt"
162
+ ),
163
+ target_is_directory=os.path.isdir(save_path),
164
+ )
165
+ except Exception as e:
166
+ log.error(f"Failed to create symlink: {e}")
167
+ else:
168
+ raise ValueError(
169
+ f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
170
+ )
171
+
172
+ @abstractmethod
173
+ def save_checkpoint(self, path, **kwargs):
174
+ """
175
+ Saves a checkpoint of the model.
176
+
177
+ Args:
178
+ path (str): The path where the checkpoint will be saved.
179
+
180
+ Raises:
181
+ NotImplementedError: If the method is not implemented.
182
+ """
183
+ raise NotImplementedError("save_checkpoint method is not implemented")
184
+
185
+ def train(
186
+ self,
187
+ model: Union[nn.Module, "_FabricModule"],
188
+ optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
189
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
190
+ ):
191
+ """
192
+ Trains the model.
193
+
194
+ The global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`.
195
+
196
+ Args:
197
+ model (Union[nn.Module, "_FabricModule"]): The model to be trained.
198
+ optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
199
+ lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
200
+ """
201
+ fabric = self.fabric
202
+ self.is_training = True
203
+ # number of parameter update iterations, not the number of batches
204
+ self.global_step_idx = 0
205
+ model.train()
206
+ optimizer.zero_grad()
207
+ for epoch_idx in tqdm(
208
+ range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
209
+ "Training Epoch",
210
+ dynamic_ncols=True,
211
+ leave=False,
212
+ disable=not fabric.is_global_zero,
213
+ ):
214
+ self.epoch_idx = epoch_idx
215
+ self.train_epoch(model, optimizer, lr_scheduler)
216
+ # run lr_scheduler at the end of the epoch if interval is set to "epoch"
217
+ if (
218
+ self.lr_scheduler_interval == "epoch"
219
+ and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
220
+ ):
221
+ lr_scheduler.step()
222
+
223
+ # save the model at the end of the epoch if interval is set to "epoch" and frequency is met
224
+ self.conditional_checkpoint_save(stage="end_of_epoch")
225
+
226
+ if not self.is_training:
227
+ break
228
+
229
+ optimizer.zero_grad()
230
+ # save the model at the end of training
231
+ self.conditional_checkpoint_save(stage="end_of_training")
232
+ return model
233
+
234
+ @abstractmethod
235
+ def train_epoch(
236
+ self,
237
+ model: Union[nn.Module, "_FabricModule"],
238
+ optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
239
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
240
+ ):
241
+ """
242
+ Trains the model for one epoch.
243
+
244
+ Args:
245
+ model (Union[nn.Module, "_FabricModule"]): The model to be trained.
246
+ optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
247
+ lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
248
+
249
+ Raises:
250
+ NotImplementedError: If the method is not implemented.
251
+ """
252
+ raise NotImplementedError(
253
+ "Copy this as a template and implement your own train_epoch method"
254
+ )
255
+ fabric = self.fabric
256
+
257
+ accumulated_loss = 0
258
+ for step_idx, batch in enumerate(
259
+ pbar := tqdm(
260
+ self.train_dataloader,
261
+ desc="Training Batches",
262
+ dynamic_ncols=True,
263
+ leave=False,
264
+ disable=not fabric.is_global_zero,
265
+ )
266
+ ):
267
+ is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
268
+
269
+ # disable gradient synchronization if accumulating gradients across steps for improved performance
270
+ with fabric.no_backward_sync(self.model, enabled=is_accumulating):
271
+ # use_cache=True is not compatible with gradient checkpointing, so we disable it here
272
+ output = self.compute_loss(batch)
273
+ loss = output["loss"] / self.accumulate_grad_batches
274
+
275
+ fabric.backward(loss)
276
+ accumulated_loss += loss.item()
277
+
278
+ # 1. update the model parameters if not accumulating gradients
279
+ # 2. step the lr_scheduler if interval is set to "step" and frequency is met
280
+ # 3. save the model if interval is set to "step" and frequency is met
281
+ # 4. log metrics
282
+ # 5. increase the global step index and reset the accumulated metrics
283
+ if not is_accumulating:
284
+ self.clip_gradients_if_needed(model, optimizer)
285
+
286
+ # run lr_scheduler at the end of the step if interval is set to "step"
287
+ if (
288
+ self.lr_scheduler_interval == "step"
289
+ and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
290
+ ):
291
+ lr_scheduler.step()
292
+
293
+ # update the model parameters and zero the gradients
294
+ optimizer.step()
295
+ optimizer.zero_grad()
296
+
297
+ metrics = {
298
+ "train/loss": accumulated_loss,
299
+ "train/lr": optimizer.param_groups[0]["lr"],
300
+ }
301
+
302
+ fabric.log_dict(metrics, step=self.global_step_idx)
303
+ pbar.set_postfix(metrics)
304
+
305
+ # save the model at the end of the step if interval is set to "step" and frequency is met
306
+ self.conditional_checkpoint_save(stage="end_of_step")
307
+
308
+ # break if max_steps_per_epoch is set, and exit epoch
309
+ if (
310
+ self.max_steps_per_epoch > 0
311
+ and step_idx + 1 >= self.max_steps_per_epoch
312
+ ):
313
+ break
314
+ # break if max_steps is set, and exit training
315
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
316
+ self.is_training = False
317
+ break
318
+
319
+ self.global_step_idx += 1
320
+ accumulated_loss = 0
@@ -1,9 +1,11 @@
1
+ import functools
1
2
  import logging
2
3
  import os
3
4
  from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
4
5
 
5
6
  import lightning as L
6
7
  import torch
8
+ from lightning.fabric.connector import _is_using_cli
7
9
  from lightning.fabric.loggers import TensorBoardLogger
8
10
  from lightning.fabric.utilities.rank_zero import rank_zero_only
9
11
  from omegaconf import DictConfig, OmegaConf
@@ -13,6 +15,7 @@ from fusion_bench.utils.instantiate import instantiate
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  import lightning.fabric.loggers.tensorboard
18
+ from lightning.fabric.strategies import FSDPStrategy
16
19
 
17
20
  log = logging.getLogger(__name__)
18
21
 
@@ -32,6 +35,13 @@ def get_policy(*args: str) -> set:
32
35
  return {import_object(arg) for arg in args}
33
36
 
34
37
 
38
+ def get_size_based_auto_wrap_policy(*args, **kwargs):
39
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
40
+
41
+ policy = functools.partial(size_based_auto_wrap_policy, *args, **kwargs)
42
+ return policy
43
+
44
+
35
45
  class LightningFabricMixin:
36
46
  """
37
47
  A mixin class for integrating Lightning Fabric into a project.
@@ -70,7 +80,8 @@ class LightningFabricMixin:
70
80
  self._fabric_instance = L.Fabric()
71
81
  else:
72
82
  self._fabric_instance = instantiate(config.fabric)
73
- self._fabric_instance.launch()
83
+ if not _is_using_cli(): # if not using cli, launch the fabric
84
+ self._fabric_instance.launch()
74
85
  # Set the log directory in config if it is not already set
75
86
  if (
76
87
  self.log_dir is not None
@@ -16,6 +16,7 @@ _import_structure = {
16
16
  "HuggingFaceGPT2ClassificationPool",
17
17
  "GPT2ForSequenceClassificationPool",
18
18
  ],
19
+ "seq_classification_lm": ["SeqenceClassificationModelPool"],
19
20
  }
20
21
 
21
22
 
@@ -31,6 +32,7 @@ if TYPE_CHECKING:
31
32
  from .nyuv2_modelpool import NYUv2ModelPool
32
33
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
33
34
  from .seq2seq_lm import Seq2SeqLMPool
35
+ from .seq_classification_lm import SeqenceClassificationModelPool
34
36
 
35
37
  else:
36
38
  sys.modules[__name__] = LazyImporter(
@@ -147,7 +147,6 @@ class BaseModelPool(BaseYAMLSerializableModel):
147
147
  DictConfig: The configuration for the specified model.
148
148
  """
149
149
  model_config = self._models[model_name]
150
- assert isinstance(model_config, DictConfig), "Model config must be a DictConfig"
151
150
  if return_copy:
152
151
  model_config = deepcopy(model_config)
153
152
  return model_config
@@ -1,2 +1,2 @@
1
1
  # flake8: noqa F401
2
- from .causal_lm import CausalLMBackbonePool, CausalLMPool
2
+ from .causal_lm import CausalLMBackbonePool, CausalLMPool, load_peft_causal_lm
@@ -3,6 +3,7 @@ import os
3
3
  from copy import deepcopy
4
4
  from typing import Any, Optional, TypeAlias, Union, cast # noqa: F401
5
5
 
6
+ import peft
6
7
  from omegaconf import DictConfig, flag_override
7
8
  from torch import nn
8
9
  from torch.nn.modules import Module
@@ -23,28 +24,6 @@ log = logging.getLogger(__name__)
23
24
  CausalLM: TypeAlias = Union[LlamaForCausalLM, MistralForCausalLM, Any]
24
25
 
25
26
 
26
- def config_priority_get(priority_config, general_config, key, default):
27
- """
28
- Retrieve a configuration value with priority.
29
-
30
- This function retrieves the value associated with `key` from `priority_config` if it exists.
31
- If the key is not found in `priority_config`, it retrieves the value from `general_config`.
32
- If the key is not found in either configuration, it returns the provided `default` value.
33
-
34
- Args:
35
- priority_config (dict): The configuration dictionary with higher priority.
36
- general_config (dict): The general configuration dictionary.
37
- key (str): The key to look up in the configuration dictionaries.
38
- default: The default value to return if the key is not found in either configuration.
39
-
40
- Returns:
41
- The value associated with `key` from `priority_config` or `general_config`, or the `default` value if the key is not found.
42
- """
43
- if key in priority_config:
44
- return priority_config[key]
45
- return general_config.get(key, default)
46
-
47
-
48
27
  class CausalLMPool(BaseModelPool):
49
28
  _config_mapping = BaseModelPool._config_mapping | {
50
29
  "_tokenizer": "tokenizer",
@@ -138,3 +117,23 @@ class CausalLMBackbonePool(CausalLMPool):
138
117
  model_name_or_config, *args, **kwargs
139
118
  )
140
119
  return model.model.layers
120
+
121
+
122
+ def load_peft_causal_lm(
123
+ base_model_path: str,
124
+ peft_model_path: str,
125
+ torch_dtype: str = "bfloat16",
126
+ is_trainable: bool = True,
127
+ merge_and_unload: bool = False,
128
+ ):
129
+ base_model = LlamaForCausalLM.from_pretrained(
130
+ base_model_path, torch_dtype=torch_dtype
131
+ )
132
+ model = peft.PeftModel.from_pretrained(
133
+ base_model,
134
+ peft_model_path,
135
+ is_trainable=is_trainable,
136
+ )
137
+ if merge_and_unload:
138
+ model = model.merge_and_unload()
139
+ return model