fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -49,7 +49,7 @@ class MinNormSolver:
49
49
  return gamma, cost
50
50
 
51
51
  def _min_norm_2d(vecs, dps):
52
- """
52
+ R"""
53
53
  Find the minimum norm solution as combination of two points
54
54
  This is correct only in 2D
55
55
  ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
@@ -85,7 +85,7 @@ class MinNormSolver:
85
85
  return sol, dps
86
86
 
87
87
  def _projection2simplex(y):
88
- """
88
+ R"""
89
89
  Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
90
90
  """
91
91
  m = len(y)
@@ -117,7 +117,7 @@ class MinNormSolver:
117
117
  return next_point
118
118
 
119
119
  def find_min_norm_element(vecs):
120
- """
120
+ R"""
121
121
  Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
122
122
  as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
123
123
  It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
@@ -163,7 +163,7 @@ class MinNormSolver:
163
163
  sol_vec = new_sol_vec
164
164
 
165
165
  def find_min_norm_element_FW(vecs):
166
- """
166
+ R"""
167
167
  Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
168
168
  as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
169
169
  It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
@@ -41,11 +41,10 @@ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
41
41
  from fusion_bench import print_parameters
42
42
  from fusion_bench.compat.method import ModelFusionAlgorithm
43
43
  from fusion_bench.compat.modelpool import to_modelpool
44
- from fusion_bench.compat.modelpool.huggingface_clip_vision import (
45
- HuggingFaceClipVisionPool,
46
- )
44
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
47
45
  from fusion_bench.mixins import CLIPClassificationMixin
48
46
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
47
+ from fusion_bench.modelpool import CLIPVisionModelPool
49
48
  from fusion_bench.models.hf_clip import HFCLIPClassifier
50
49
  from fusion_bench.models.linearized.linearized_model_utils import LinearizedModelWraper
51
50
  from fusion_bench.utils.data import InfiniteDataLoader
@@ -92,12 +91,12 @@ class ImageClassificationFineTuningForCLIP(
92
91
  A class for fine-tuning CLIP models for image classification tasks.
93
92
  """
94
93
 
95
- def run(self, modelpool: HuggingFaceClipVisionPool):
94
+ def run(self, modelpool: CLIPVisionModelPool):
96
95
  """
97
96
  Executes the fine-tuning process.
98
97
 
99
98
  Args:
100
- modelpool (HuggingFaceClipVisionPool): The modelpool is responsible for loading the pre-trained model and training datasets.
99
+ modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.
101
100
 
102
101
  Returns:
103
102
  VisionModel: The fine-tuned vision model.
@@ -109,9 +108,7 @@ class ImageClassificationFineTuningForCLIP(
109
108
 
110
109
  L.seed_everything(config.seed)
111
110
 
112
- task_names = [
113
- dataset_config["name"] for dataset_config in modelpool.config.train_datasets
114
- ]
111
+ task_names = modelpool.train_dataset_names
115
112
  with self.profile("setup model and optimizer"):
116
113
  processor, classifier, optimizer, lr_scheduler = self.setup_model()
117
114
 
@@ -133,7 +130,7 @@ class ImageClassificationFineTuningForCLIP(
133
130
 
134
131
  with self.profile("setup data"):
135
132
  train_datasets = [
136
- modelpool.get_train_dataset(task_name, processor)
133
+ CLIPDataset(modelpool.load_train_dataset(task_name), processor)
137
134
  for task_name in task_names
138
135
  ]
139
136
  train_dataloaders = [
@@ -157,6 +154,7 @@ class ImageClassificationFineTuningForCLIP(
157
154
  range(config.num_steps),
158
155
  desc=self.finetune_method,
159
156
  disable=not self.fabric.is_global_zero,
157
+ dynamic_ncols=True,
160
158
  ):
161
159
  optimizer.zero_grad()
162
160
  loss = 0
@@ -183,7 +181,7 @@ class ImageClassificationFineTuningForCLIP(
183
181
  save_path = os.path.join(
184
182
  self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
185
183
  )
186
- self.save_model(classifier, save_path, trainable_only=True)
184
+ self.save_model(classifier, save_path)
187
185
 
188
186
  if config.state_dict_save_path is not None:
189
187
  self.save_model(
@@ -232,9 +230,8 @@ class ImageClassificationFineTuningForCLIP(
232
230
  config = self.config
233
231
  modelpool = self.modelpool
234
232
 
235
- pretrained_model_config = modelpool.get_model_config("_pretrained_")
236
- clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_model_config.path)
237
- processor = CLIPProcessor.from_pretrained(pretrained_model_config.path)
233
+ clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
234
+ processor = modelpool.load_processor()
238
235
 
239
236
  self.finetune_method = "full fine-tune"
240
237
  if config.use_lora or config.use_l_lora:
@@ -6,6 +6,10 @@ Reference:
6
6
  """
7
7
 
8
8
  import logging
9
+ from copy import deepcopy
10
+
11
+ import torch
12
+ from torch import nn
9
13
 
10
14
  from fusion_bench import BaseAlgorithm, BaseModelPool
11
15
  from fusion_bench.method import SimpleAverageAlgorithm
@@ -18,6 +22,41 @@ from fusion_bench.utils.state_dict_arithmetic import (
18
22
  log = logging.getLogger(__name__)
19
23
 
20
24
 
25
+ def expo_merge(
26
+ sft_model: nn.Module,
27
+ rlhf_model: nn.Module,
28
+ extrapolation_factor: float,
29
+ inplace: bool = True,
30
+ enable_grad: bool = False,
31
+ ):
32
+ """
33
+ Minimal implementation of ExPO merge.
34
+
35
+ Args:
36
+ sft_model (nn.Module): The pretrained model (base model).
37
+ rlhf_model (nn.Module): The finetuned model (medium-aligned model).
38
+ extrapolation_factor (float): The extrapolation factor.
39
+ inplace (bool): Whether to perform the merge in-place. If not, the rlhf_model will be copied before merging.
40
+ enable_grad (bool): Whether to enable gradient computation during the merge.
41
+
42
+ Returns:
43
+ nn.Module: The merged model.
44
+ """
45
+
46
+ if not inplace:
47
+ rlhf_model = deepcopy(rlhf_model)
48
+
49
+ with torch.set_grad_enabled(enable_grad):
50
+ for (sft_name, sft_param), (rlhf_name, rlhf_param) in zip(
51
+ sft_model.named_parameters(), rlhf_model.named_parameters()
52
+ ):
53
+ assert sft_name == rlhf_name, f"Model mismatch: {sft_name} != {rlhf_name}"
54
+ rlhf_param.data = rlhf_param.data + extrapolation_factor * (
55
+ rlhf_param.data - sft_param.data
56
+ )
57
+ return rlhf_model
58
+
59
+
21
60
  class ExPOAlgorithm(BaseAlgorithm):
22
61
  R"""
23
62
  ExPO merge algorithm.
@@ -1,2 +1,3 @@
1
+ from .bradley_terry_rm import BradleyTerryRewardModeling
1
2
  from .fullfinetune_sft import FullFinetuneSFT
2
3
  from .peftfinetune_sft import PeftFinetuneSFT
@@ -0,0 +1,432 @@
1
+ R"""
2
+ This is basically the same as fullfinetune_sft.py, but with a different loss function.
3
+
4
+ The dataset contains the following fields:
5
+
6
+ - chosen_input_ids: The input token ids for the winner.
7
+ - chosen_attention_mask: The attention mask for the winner.
8
+ - rejected_input_ids: The input token ids for the loser.
9
+ - rejected_attention_mask: The attention mask for the loser.
10
+
11
+ """
12
+
13
+ import functools
14
+ import itertools
15
+ import logging
16
+ import os
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, override
19
+
20
+ import lightning as L
21
+ import omegaconf
22
+ import torch
23
+ from lightning.fabric.strategies.fsdp import FSDPStrategy
24
+ from lightning.fabric.utilities import rank_zero_only
25
+ from omegaconf import DictConfig
26
+ from torch import Tensor, nn
27
+ from torch.utils.data import ConcatDataset, DataLoader
28
+ from tqdm.auto import tqdm
29
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
30
+
31
+ from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
32
+ from fusion_bench.method import BaseAlgorithm
33
+ from fusion_bench.mixins import FabricTrainingMixin
34
+ from fusion_bench.modelpool import SeqenceClassificationModelPool
35
+ from fusion_bench.utils import instantiate
36
+ from fusion_bench.utils.dtype import get_dtype
37
+
38
+ if TYPE_CHECKING:
39
+ from lightning.fabric.wrappers import (
40
+ _FabricDataLoader,
41
+ _FabricModule,
42
+ _FabricOptimizer,
43
+ )
44
+ from transformers.models.llama.modeling_llama import LlamaForSequenceClassification
45
+
46
+ log = logging.getLogger(__name__)
47
+
48
+
49
+ class BradleyTerryRewardModeling(BaseAlgorithm, FabricTrainingMixin):
50
+
51
+ model: Union[nn.Module, "_FabricModule", "LlamaForSequenceClassification"]
52
+ optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
53
+ train_dataloader: Union[DataLoader, "_FabricDataLoader"]
54
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler
55
+
56
+ def __init__(
57
+ self,
58
+ optimizer: DictConfig,
59
+ lr_scheduler: Optional[DictConfig],
60
+ dataloader_kwargs: DictConfig,
61
+ max_epochs: int,
62
+ max_steps: int = -1,
63
+ max_steps_per_epoch: int = -1,
64
+ lr_scheduler_interval: Literal["epoch", "step"] = "step",
65
+ lr_scheduler_frequency: int = 1,
66
+ checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
67
+ checkpoint_save_frequency: int = 1,
68
+ accumulate_grad_batches: int = 1,
69
+ gradient_clip_val: Optional[float] = None,
70
+ gradient_clip_algorithm: Literal["value", "norm"] = "norm",
71
+ save_optimizer_state: bool = False,
72
+ save_full_model: bool = False,
73
+ save_ckpt_type: Literal["lightning", "hf"] = "lightning",
74
+ ckpt_path: Optional[str] = None,
75
+ max_length: int = 6144,
76
+ fix_token_embedding: bool = True,
77
+ **kwargs,
78
+ ):
79
+ """
80
+ Class for reward modeling using Bradley-Terry model.
81
+
82
+ Args:
83
+ optimizer(DictConfig): Configuration for the optimizer.
84
+ lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
85
+ dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
86
+ max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
87
+ max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
88
+ max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
89
+ lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
90
+ lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
91
+ checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
92
+ checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
93
+ accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
94
+ gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
95
+ gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
96
+ save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
97
+ save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
98
+ save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
99
+ ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
100
+ max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
101
+ fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
102
+ """
103
+ self._optimizer = optimizer
104
+ self._lr_scheduler = lr_scheduler
105
+ self.dataloader_kwargs = dataloader_kwargs
106
+ self.max_epochs = max_epochs
107
+ self.max_steps = max_steps
108
+ self.max_steps_per_epoch = max_steps_per_epoch
109
+ self.lr_scheduler_interval = lr_scheduler_interval
110
+ self.lr_scheduler_frequency = lr_scheduler_frequency
111
+ self.checkpoint_save_interval = checkpoint_save_interval
112
+ self.checkpoint_save_frequency = checkpoint_save_frequency
113
+ self.accumulate_grad_batches = accumulate_grad_batches
114
+ self.gradient_clip_val = gradient_clip_val
115
+ self.gradient_clip_algorithm = gradient_clip_algorithm
116
+ self.save_optimizer_state = save_optimizer_state
117
+ self.save_full_model = save_full_model
118
+ self.save_ckpt_type = save_ckpt_type
119
+ self.ckpt_path = ckpt_path
120
+ self.max_length = max_length
121
+ self.fix_token_embedding = fix_token_embedding
122
+ super().__init__(**kwargs)
123
+
124
+ def run(self, modelpool: SeqenceClassificationModelPool):
125
+ self.modelpool = modelpool
126
+ self.setup()
127
+ self.train(self.model, self.optimizer, self.lr_scheduler)
128
+ return self.model
129
+
130
+ def setup_model(self):
131
+ self.tokenizer = self.modelpool.load_tokenizer()
132
+ if self.tokenizer.pad_token_id is None:
133
+ self.tokenizer.pad_token_id = (
134
+ self.tokenizer.eos_token_id
135
+ ) #! make sure eos_token_id only show up at the end of the sequence
136
+
137
+ model = self.modelpool.load_pretrained_model()
138
+ self.model: "LlamaForSequenceClassification" = model
139
+
140
+ if model.config.pad_token_id is None:
141
+ model.config.pad_token_id = self.tokenizer.pad_token_id
142
+
143
+ if self.fix_token_embedding:
144
+ self.model.model.embed_tokens.requires_grad_(False)
145
+
146
+ if self.fabric.strategy == "fsdp" or isinstance(
147
+ self.fabric.strategy, FSDPStrategy
148
+ ):
149
+ # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
150
+ self.model.gradient_checkpointing_enable(
151
+ gradient_checkpointing_kwargs={"use_reentrant": True}
152
+ )
153
+ self.use_cache = False
154
+ else:
155
+ self.use_cache = True
156
+ self.model_dtype = get_dtype(self.model)
157
+
158
+ def setup_data(self):
159
+ fabric = self.fabric
160
+ modelpool = self.modelpool
161
+ assert (
162
+ len(modelpool.train_dataset_names) > 0
163
+ ), "No training datasets found in modelpool."
164
+
165
+ train_datasets = [
166
+ modelpool.load_train_dataset(dataset_name)
167
+ for dataset_name in modelpool.train_dataset_names
168
+ ]
169
+ if len(train_datasets) > 1:
170
+ train_dataset = ConcatDataset(train_datasets)
171
+ else:
172
+ train_dataset = train_datasets[0]
173
+
174
+ self.train_dataset = train_dataset
175
+ self.train_dataloader = DataLoader(
176
+ train_dataset,
177
+ **self.dataloader_kwargs,
178
+ shuffle=True,
179
+ collate_fn=functools.partial(
180
+ bradley_terry_rm_collate,
181
+ pad_token_id=self.tokenizer.pad_token_id,
182
+ ), # NOTE: different from SFT, uses bradley_terry_rm_collate
183
+ )
184
+ self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
185
+
186
+ def configure_optimizer(self):
187
+ # compute expected total steps
188
+ self.compute_expected_total_steps(self.train_dataloader)
189
+
190
+ optimizer = instantiate(self._optimizer, self.model.parameters())
191
+ if self._lr_scheduler is not None:
192
+ for key, arg in self._lr_scheduler.items():
193
+ if arg == "_T_max_":
194
+ log.info(
195
+ f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
196
+ )
197
+ self._lr_scheduler[key] = self.expected_total_steps
198
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
199
+ self._lr_scheduler,
200
+ optimizer=optimizer,
201
+ )
202
+ else:
203
+ lr_scheduler = None
204
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
205
+
206
+ def setup(self):
207
+ fabric = self.fabric
208
+
209
+ self.setup_model()
210
+ self.setup_data()
211
+
212
+ optimizer = self.configure_optimizer()
213
+ optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]
214
+
215
+ self.model, self.optimizer = fabric.setup(self.model, optimizer)
216
+ self.lr_scheduler = lr_scheduler
217
+
218
+ def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor]:
219
+ """
220
+ Maximize the likelihood of the winner over the loser using the Bradley-Terry model.
221
+
222
+ Args:
223
+ batch (Dict[str, Union[Tensor, Any]]): A dictionary containing the input token ids and attention masks for the winner and loser.
224
+ """
225
+ batch_size = batch["input_ids"].size(0)
226
+ assert batch_size % 2 == 0, "Batch size must be even."
227
+
228
+ outputs = self.model(
229
+ input_ids=batch["input_ids"],
230
+ attention_mask=batch["attention_mask"],
231
+ use_cache=self.use_cache,
232
+ )
233
+
234
+ rewards = outputs[0]
235
+ chosen_reward = rewards[: batch_size // 2]
236
+ rejected_rewards = rewards[batch_size // 2 :]
237
+ loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()
238
+
239
+ return {
240
+ "chosen_reward": chosen_reward,
241
+ "rejected_reward": rejected_rewards,
242
+ "loss": loss,
243
+ }
244
+
245
+ @override
246
+ def train_epoch(self, *args, **kwargs):
247
+ fabric = self.fabric
248
+
249
+ accumulated_loss = 0
250
+ accumulated_chosen_reward = 0
251
+ accumulated_rejected_reward = 0
252
+ for step_idx, batch in enumerate(
253
+ pbar := tqdm(
254
+ self.train_dataloader,
255
+ desc="Training Batches",
256
+ dynamic_ncols=True,
257
+ leave=False,
258
+ disable=not fabric.is_global_zero,
259
+ )
260
+ ):
261
+ is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
262
+
263
+ if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
264
+ log.warning(
265
+ f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
266
+ )
267
+ batch["input_ids"] = batch["input_ids"][:, -self.max_length :]
268
+ batch["attention_mask"] = batch["attention_mask"][:, -self.max_length :]
269
+
270
+ # disable gradient synchronization if accumulating gradients across steps for improved performance
271
+ with fabric.no_backward_sync(self.model, enabled=is_accumulating):
272
+ # use_cache=True is not compatible with gradient checkpointing, so we disable it here
273
+ output = self.compute_loss(batch)
274
+ loss = output["loss"] / self.accumulate_grad_batches
275
+
276
+ fabric.backward(loss)
277
+
278
+ accumulated_loss += loss.item()
279
+ accumulated_chosen_reward += output["chosen_reward"].mean().item()
280
+ accumulated_rejected_reward += output["rejected_reward"].mean().item()
281
+
282
+ # 1. update the model parameters if not accumulating gradients
283
+ # 2. step the lr_scheduler if interval is set to "step" and frequency is met
284
+ # 3. save the model if interval is set to "step" and frequency is met
285
+ # 4. log metrics
286
+ # 5. increase the global step index
287
+ if not is_accumulating:
288
+ self.clip_gradients_if_needed(self.model, self.optimizer)
289
+
290
+ # run lr_scheduler at the end of the step if interval is set to "step"
291
+ if (
292
+ self.lr_scheduler_interval == "step"
293
+ and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
294
+ ):
295
+ self.lr_scheduler.step()
296
+
297
+ # update the model parameters and zero the gradients
298
+ self.optimizer.step()
299
+ self.optimizer.zero_grad()
300
+
301
+ metrics = {
302
+ "train/loss": accumulated_loss,
303
+ "train/chosen_reward": accumulated_chosen_reward
304
+ / self.accumulate_grad_batches,
305
+ "train/rejected_reward": accumulated_rejected_reward
306
+ / self.accumulate_grad_batches,
307
+ "train/epoch_idx": self.epoch_idx,
308
+ "train/lr": self.optimizer.param_groups[0]["lr"],
309
+ }
310
+ metrics["train/chosen_reward-rejected_reward"] = (
311
+ metrics["train/chosen_reward"] - metrics["train/rejected_reward"]
312
+ )
313
+
314
+ fabric.log_dict(metrics, step=self.global_step_idx)
315
+ pbar.set_postfix(metrics)
316
+
317
+ # save the model at the end of the step if interval is set to "step" and frequency is met
318
+ self.conditional_checkpoint_save(stage="end_of_step")
319
+
320
+ # break if max_steps_per_epoch is set, and exit epoch
321
+ if (
322
+ self.max_steps_per_epoch > 0
323
+ and step_idx + 1 >= self.max_steps_per_epoch
324
+ ):
325
+ break
326
+ # break if max_steps is set, and exit training
327
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
328
+ self.is_training = False
329
+ break
330
+
331
+ self.global_step_idx += 1
332
+ accumulated_loss = 0
333
+ accumulated_chosen_reward = 0
334
+ accumulated_rejected_reward = 0
335
+
336
+ def save_checkpoint(
337
+ self,
338
+ path: Union[str, Path],
339
+ save_optimizer_state: Optional[bool] = None,
340
+ overwrite: bool = False,
341
+ ):
342
+ if not overwrite and os.path.exists(path):
343
+ return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
344
+
345
+ fabric = self.fabric
346
+
347
+ if self.save_ckpt_type == "lightning":
348
+ state = {"model": self.model}
349
+
350
+ # save the optimizer and lr_scheduler state if needed
351
+ if self.save_optimizer_state and save_optimizer_state is not False:
352
+ state.update(
353
+ {
354
+ "optimizer": self.optimizer,
355
+ "lr_scheduler": self.lr_scheduler,
356
+ "global_step_idx": self.global_step_idx,
357
+ "epoch_idx": self.epoch_idx,
358
+ }
359
+ )
360
+
361
+ trainable_param_names = set(
362
+ name
363
+ for name, param in self.model.state_dict(keep_vars=True).items()
364
+ if param.requires_grad
365
+ )
366
+ filter = (
367
+ None
368
+ if self.save_full_model
369
+ else {"model": lambda k, p: k in trainable_param_names}
370
+ )
371
+
372
+ fabric.save(path, state=state, filter=filter)
373
+ else:
374
+ self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
375
+
376
+ self._latest_saved_checkpoint_global_step = self.global_step_idx
377
+
378
+ def load_checkpoint(self, path: Union[str, Path]):
379
+ fabric = self.fabric
380
+
381
+ state = {"model": self.model}
382
+
383
+ # save the optimizer and lr_scheduler state if needed
384
+ if self.save_optimizer_state:
385
+ state.update(
386
+ {
387
+ "optimizer": self.optimizer,
388
+ "lr_scheduler": self.lr_scheduler,
389
+ }
390
+ )
391
+
392
+ fabric.load(path, state)
393
+
394
+
395
+ def load_checkpoint(
396
+ fabric: L.Fabric,
397
+ ckpt_path: Union[str, Path],
398
+ model: Union[nn.Module, "LlamaForSequenceClassification"],
399
+ strict: bool = True,
400
+ **state_components,
401
+ ):
402
+ """
403
+ Load a checkpoint into a model.
404
+ """
405
+ state = {"model": model}
406
+ state.update(state_components)
407
+ fabric.load(ckpt_path, state=state, strict=strict)
408
+
409
+
410
+ if __name__ == "__main__":
411
+ # convert a checkpoint to hf format
412
+ import argparse
413
+
414
+ parser = argparse.ArgumentParser()
415
+ parser.add_argument("--base-model-path", type=str)
416
+ parser.add_argument("--ckpt-path", type=str)
417
+ parser.add_argument("--output-path", type=str)
418
+
419
+ args = parser.parse_args()
420
+
421
+ fabric = L.Fabric(devices=1, strategy="fsdp")
422
+ fabric.launch()
423
+
424
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
425
+ tokenizer.save_pretrained(args.output_path)
426
+
427
+ model = AutoModelForSequenceClassification.from_pretrained(
428
+ args.base_model_path, num_labels=1, torch_dtype=torch.bfloat16
429
+ )
430
+ model = fabric.setup_module(model)
431
+ load_checkpoint(fabric, args.ckpt_path, model=model, strict=True)
432
+ model.save_pretrained(args.output_path)