fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (264) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/clip_dataset.py +3 -0
  8. fusion_bench/dataset/fer2013.py +12 -0
  9. fusion_bench/dataset/llama/__init__.py +1 -0
  10. fusion_bench/dataset/llama/alpaca.py +93 -3
  11. fusion_bench/dataset/llama/collate.py +62 -2
  12. fusion_bench/dataset/llama/metamathqa.py +50 -0
  13. fusion_bench/dataset/llama/preference_700k.py +70 -0
  14. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  15. fusion_bench/dataset/llama/ultrachat.py +58 -0
  16. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  17. fusion_bench/method/__init__.py +3 -1
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  19. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  20. fusion_bench/method/classification/clip_finetune.py +10 -13
  21. fusion_bench/method/linear/expo.py +39 -0
  22. fusion_bench/method/lm_finetune/__init__.py +1 -0
  23. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  24. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  25. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  26. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  27. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  28. fusion_bench/method/surgery/__init__.py +1 -0
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  30. fusion_bench/method/tall_mask/__init__.py +0 -0
  31. fusion_bench/method/tall_mask/utils.py +234 -0
  32. fusion_bench/method/task_singular_vector/TSVC.py +16 -0
  33. fusion_bench/method/task_singular_vector/TSVM.py +63 -0
  34. fusion_bench/method/task_singular_vector/__init__.py +9 -0
  35. fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
  36. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
  37. fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
  38. fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
  39. fusion_bench/mixins/__init__.py +2 -0
  40. fusion_bench/mixins/clip_classification.py +64 -11
  41. fusion_bench/mixins/fabric_training.py +320 -0
  42. fusion_bench/mixins/lightning_fabric.py +12 -1
  43. fusion_bench/modelpool/__init__.py +2 -0
  44. fusion_bench/modelpool/base_pool.py +0 -1
  45. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  46. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  47. fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
  48. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  49. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  50. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  51. fusion_bench/models/chat_templates/__init__.py +1 -0
  52. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  53. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  54. fusion_bench/models/hf_clip.py +50 -9
  55. fusion_bench/models/surgery/__init__.py +1 -0
  56. fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
  57. fusion_bench/models/utils.py +8 -0
  58. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  59. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  60. fusion_bench/optim/__init__.py +2 -0
  61. fusion_bench/optim/exception.py +47 -0
  62. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  63. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  64. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  65. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  66. fusion_bench/optim/mezo.py +0 -2
  67. fusion_bench/programs/fabric_fusion_program.py +12 -5
  68. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  69. fusion_bench/taskpool/llama/reward_model.py +157 -0
  70. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  71. fusion_bench/tasks/clip_classification/__init__.py +13 -45
  72. fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
  73. fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
  74. fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
  75. fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
  76. fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
  77. fusion_bench/tasks/clip_classification/fer2013.py +18 -0
  78. fusion_bench/tasks/clip_classification/food101.py +105 -0
  79. fusion_bench/tasks/clip_classification/kmnist.py +17 -0
  80. fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
  81. fusion_bench/tasks/clip_classification/pcam.py +5 -0
  82. fusion_bench/utils/hydra_utils.py +22 -0
  83. fusion_bench/utils/parameters.py +12 -3
  84. fusion_bench/utils/plot/__init__.py +0 -0
  85. fusion_bench/utils/plot/token.py +52 -0
  86. fusion_bench/utils/plot/token_notebook.py +127 -0
  87. fusion_bench/utils/type.py +14 -3
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
  89. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
  90. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  91. fusion_bench_config/dataset/image_classification/README.md +6 -0
  92. fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
  93. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
  94. fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
  95. fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
  96. fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
  97. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
  98. fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
  99. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
  100. fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
  101. fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
  102. fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
  103. fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
  104. fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
  105. fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
  106. fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
  107. fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
  108. fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
  109. fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
  110. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
  111. fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
  112. fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
  113. fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
  114. fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
  115. fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
  116. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
  117. fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
  118. fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
  119. fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
  120. fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
  121. fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
  122. fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
  123. fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
  124. fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
  125. fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
  126. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  127. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  128. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  129. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  130. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  131. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  132. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  133. fusion_bench_config/llama_full_finetune.yaml +19 -0
  134. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  135. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  136. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  137. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
  139. fusion_bench_config/model/clip-vit/README.md +38 -0
  140. fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
  141. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
  143. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
  144. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
  145. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
  146. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
  147. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
  148. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
  149. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
  150. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
  151. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
  153. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
  154. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
  155. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
  156. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
  157. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
  158. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
  159. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
  160. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
  161. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
  162. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
  163. fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
  164. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
  165. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
  166. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
  167. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
  168. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
  169. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
  170. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
  171. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
  172. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
  173. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
  174. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
  175. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
  176. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
  177. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
  178. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
  179. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
  180. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
  181. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
  182. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
  183. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
  184. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
  185. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
  186. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
  187. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
  188. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
  189. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
  190. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
  191. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
  192. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
  193. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
  194. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
  195. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
  196. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
  197. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
  198. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
  199. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
  200. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
  201. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
  202. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
  203. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
  204. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
  205. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
  206. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
  207. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
  208. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
  209. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
  210. fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
  211. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
  212. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
  213. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
  214. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
  215. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
  216. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
  217. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
  218. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
  219. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
  220. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
  221. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
  222. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
  223. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
  224. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
  225. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
  226. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
  227. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
  228. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  229. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  230. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  231. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  232. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  233. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  234. fusion_bench_config/nyuv2_config.yaml +5 -1
  235. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
  236. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
  237. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
  238. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
  239. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
  240. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
  241. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
  242. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
  243. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
  244. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
  245. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
  246. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
  247. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
  248. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
  249. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
  250. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
  251. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
  252. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
  253. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
  254. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
  255. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
  256. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
  257. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
  258. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
  259. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  260. fusion_bench_config/llama_weighted_average.yaml +0 -26
  261. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
  262. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
  263. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
  264. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import itertools
2
3
  import logging
3
4
  import os
@@ -18,7 +19,7 @@ from typing_extensions import TYPE_CHECKING, override
18
19
 
19
20
  from fusion_bench import BaseAlgorithm, BaseModelPool
20
21
  from fusion_bench.dataset.llama.collate import padded_collate_sft
21
- from fusion_bench.mixins import LightningFabricMixin
22
+ from fusion_bench.mixins import FabricTrainingMixin
22
23
  from fusion_bench.modelpool import CausalLMPool
23
24
  from fusion_bench.utils import instantiate
24
25
  from fusion_bench.utils.dtype import get_dtype
@@ -34,7 +35,7 @@ if TYPE_CHECKING:
34
35
  log = logging.getLogger(__name__)
35
36
 
36
37
 
37
- class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
38
+ class FullFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
38
39
 
39
40
  model: Union[nn.Module, "_FabricModule", "LlamaForCausalLM"]
40
41
  optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
@@ -59,7 +60,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
59
60
  gradient_clip_algorithm: Literal["value", "norm"] = "norm",
60
61
  save_optimizer_state: bool = False,
61
62
  save_full_model: bool = False,
63
+ save_ckpt_type: Literal["lightning", "hf"] = "lightning",
62
64
  ckpt_path: Optional[str] = None,
65
+ max_length: int = 6144,
66
+ fix_token_embedding: bool = True,
63
67
  **kwargs,
64
68
  ):
65
69
  """
@@ -81,7 +85,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
81
85
  gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
82
86
  save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
83
87
  save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
88
+ save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
84
89
  ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
90
+ max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
91
+ fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
85
92
  """
86
93
  self._optimizer = optimizer
87
94
  self._lr_scheduler = lr_scheduler
@@ -98,18 +105,28 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
98
105
  self.gradient_clip_algorithm = gradient_clip_algorithm
99
106
  self.save_optimizer_state = save_optimizer_state
100
107
  self.save_full_model = save_full_model
108
+ self.save_ckpt_type = save_ckpt_type
101
109
  self.ckpt_path = ckpt_path
110
+ self.max_length = max_length
111
+ self.fix_token_embedding = fix_token_embedding
102
112
  super().__init__(**kwargs)
103
113
 
104
114
  def run(self, modelpool: CausalLMPool):
105
115
  self.modelpool = modelpool
106
116
  self.setup()
107
- self.train()
117
+ self.train(self.model, self.optimizer, self.lr_scheduler)
108
118
  return self.model
109
119
 
110
120
  def setup_model(self):
121
+ self.tokenizer = self.modelpool.load_tokenizer()
122
+ if self.tokenizer.pad_token_id is None:
123
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
124
+
111
125
  model = self.modelpool.load_pretrained_model()
112
- self.model = model
126
+ self.model: "LlamaForCausalLM" = model
127
+
128
+ if self.fix_token_embedding:
129
+ self.model.model.embed_tokens.requires_grad_(False)
113
130
 
114
131
  if self.fabric.strategy == "fsdp" or isinstance(
115
132
  self.fabric.strategy, FSDPStrategy
@@ -125,17 +142,7 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
125
142
 
126
143
  def configure_optimizer(self):
127
144
  # compute expected total steps
128
- self.expected_total_steps = []
129
- if self.max_steps > 0:
130
- self.expected_total_steps.append(self.max_steps)
131
- if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
132
- self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
133
- if self.max_epochs > 0:
134
- self.expected_total_steps.append(
135
- len(self.train_dataloader) * self.max_epochs
136
- )
137
- self.expected_total_steps = min(self.expected_total_steps)
138
- log.info(f"Expected total steps: {self.expected_total_steps}")
145
+ self.compute_expected_total_steps(self.train_dataloader)
139
146
 
140
147
  optimizer = instantiate(self._optimizer, self.model.parameters())
141
148
  if self._lr_scheduler is not None:
@@ -174,7 +181,9 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
174
181
  train_dataset,
175
182
  **self.dataloader_kwargs,
176
183
  shuffle=True,
177
- collate_fn=padded_collate_sft,
184
+ collate_fn=functools.partial(
185
+ padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
186
+ ),
178
187
  )
179
188
  self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
180
189
 
@@ -190,25 +199,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
190
199
  self.model, self.optimizer = fabric.setup(self.model, optimizer)
191
200
  self.lr_scheduler = lr_scheduler
192
201
 
193
- def _clip_gradients_if_needed(self):
202
+ @override
203
+ def train_epoch(self, *args, **kwargs):
194
204
  fabric = self.fabric
195
205
 
196
- if self.gradient_clip_val is not None:
197
- if self.gradient_clip_algorithm == "value":
198
- fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
199
- elif self.gradient_clip_algorithm == "norm":
200
- fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
201
- else:
202
- raise ValueError(
203
- f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
204
- )
205
-
206
- def train_epoch(self):
207
- fabric = self.fabric
206
+ accumulated_loss = 0
208
207
  for step_idx, batch in enumerate(
209
208
  pbar := tqdm(
210
209
  self.train_dataloader,
211
- desc="Training Steps",
210
+ desc="Training Batches",
212
211
  dynamic_ncols=True,
213
212
  leave=False,
214
213
  disable=not fabric.is_global_zero,
@@ -216,6 +215,14 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
216
215
  ):
217
216
  is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
218
217
 
218
+ if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
219
+ log.warning(
220
+ f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
221
+ )
222
+ batch["input_ids"] = batch["input_ids"][:, : self.max_length]
223
+ batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
224
+ batch["labels"] = batch["labels"][:, : self.max_length]
225
+
219
226
  # disable gradient synchronization if accumulating gradients across steps for improved performance
220
227
  with fabric.no_backward_sync(self.model, enabled=is_accumulating):
221
228
  # use_cache=True is not compatible with gradient checkpointing, so we disable it here
@@ -225,20 +232,13 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
225
232
  labels=batch["labels"],
226
233
  use_cache=self.use_cache,
227
234
  )
228
- loss = output["loss"]
235
+ loss = output["loss"] / self.accumulate_grad_batches
229
236
 
230
237
  fabric.backward(loss)
231
-
232
- metrics = {
233
- "train/loss": loss.item(),
234
- "train/epoch_idx": self.epoch_idx,
235
- "train/lr": self.optimizer.param_groups[0]["lr"],
236
- }
237
- fabric.log_dict(metrics, step=self.global_step_idx)
238
- pbar.set_postfix(metrics)
238
+ accumulated_loss += loss.item()
239
239
 
240
240
  if not is_accumulating:
241
- self._clip_gradients_if_needed()
241
+ self.clip_gradients_if_needed(self.model, self.optimizer)
242
242
 
243
243
  # run lr_scheduler at the end of the step if interval is set to "step"
244
244
  if (
@@ -251,105 +251,30 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
251
251
  self.optimizer.step()
252
252
  self.optimizer.zero_grad()
253
253
 
254
- # save the model at the end of the step if interval is set to "step" and frequency is met
255
- self._try_save_checkpoint(stage="end_of_step")
254
+ metrics = {
255
+ "train/loss": accumulated_loss,
256
+ "train/epoch_idx": self.epoch_idx,
257
+ "train/lr": self.optimizer.param_groups[0]["lr"],
258
+ }
259
+ fabric.log_dict(metrics, step=self.global_step_idx)
260
+ pbar.set_postfix(metrics)
256
261
 
257
- # break if max_steps_per_epoch is set, and exit epoch
258
- if (
259
- self.max_steps_per_epoch > 0
260
- and step_idx + 1 >= self.max_steps_per_epoch
261
- ):
262
- break
263
- # break if max_steps is set, and exit training
264
- if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
265
- self.is_training = False
266
- break
262
+ # save the model at the end of the step if interval is set to "step" and frequency is met
263
+ self.conditional_checkpoint_save(stage="end_of_step")
267
264
 
268
- self.global_step_idx += 1
265
+ # break if max_steps_per_epoch is set, and exit epoch
266
+ if (
267
+ self.max_steps_per_epoch > 0
268
+ and step_idx + 1 >= self.max_steps_per_epoch
269
+ ):
270
+ break
271
+ # break if max_steps is set, and exit training
272
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
273
+ self.is_training = False
274
+ break
269
275
 
270
- def train(self):
271
- fabric = self.fabric
272
- self.is_training = True
273
- self.global_step_idx = 0
274
- self.model.train()
275
- for epoch_idx in tqdm(
276
- range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
277
- "Training Epoch",
278
- dynamic_ncols=True,
279
- leave=False,
280
- disable=not fabric.is_global_zero,
281
- ):
282
- self.epoch_idx = epoch_idx
283
- self.train_epoch()
284
- # run lr_scheduler at the end of the epoch if interval is set to "epoch"
285
- if (
286
- self.lr_scheduler_interval == "epoch"
287
- and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
288
- ):
289
- self.lr_scheduler.step()
290
-
291
- # save the model at the end of the epoch if interval is set to "epoch" and frequency is met
292
- self._try_save_checkpoint(stage="end_of_epoch")
293
-
294
- if not self.is_training:
295
- break
296
-
297
- # save the model at the end of training
298
- self._try_save_checkpoint(stage="end_of_training")
299
-
300
- def _try_save_checkpoint(
301
- self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
302
- ):
303
- if stage == "end_of_step":
304
- if (
305
- self.checkpoint_save_interval == "step"
306
- and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
307
- ):
308
- self.save_checkpoint(
309
- os.path.join(
310
- self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
311
- )
312
- )
313
- elif stage == "end_of_epoch":
314
- if (
315
- self.checkpoint_save_interval == "epoch"
316
- and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
317
- ):
318
- self.save_checkpoint(
319
- os.path.join(
320
- self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
321
- )
322
- )
323
- elif stage == "end_of_training":
324
- # if the checkpoint has not been saved yet, save it
325
- if self.global_step_idx > self._latest_saved_checkpoint_global_step:
326
- self.save_checkpoint(
327
- os.path.join(
328
- self.log_dir,
329
- "checkpoints",
330
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
331
- )
332
- )
333
- try:
334
- os.symlink(
335
- os.path.join(
336
- self.log_dir,
337
- "checkpoints",
338
- "latest_model.ckpt",
339
- ),
340
- dst := os.path.join(
341
- self.log_dir,
342
- "checkpoints",
343
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
344
- ),
345
- target_is_directory=os.path.isdir(dst),
346
- )
347
- except Exception as e:
348
- log.error(f"Failed to create symlink: {e}")
349
- else:
350
- raise ValueError(
351
- f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
352
- )
276
+ self.global_step_idx += 1
277
+ accumulated_loss = 0
353
278
 
354
279
  def save_checkpoint(
355
280
  self,
@@ -361,31 +286,36 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
361
286
  return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
362
287
 
363
288
  fabric = self.fabric
364
- state = {"model": self.model}
365
289
 
366
- # save the optimizer and lr_scheduler state if needed
367
- if self.save_optimizer_state and save_optimizer_state is not False:
368
- state.update(
369
- {
370
- "optimizer": self.optimizer,
371
- "lr_scheduler": self.lr_scheduler,
372
- "global_step_idx": self.global_step_idx,
373
- "epoch_idx": self.epoch_idx,
374
- }
290
+ if self.save_ckpt_type == "lightning":
291
+ state = {"model": self.model}
292
+
293
+ # save the optimizer and lr_scheduler state if needed
294
+ if self.save_optimizer_state and save_optimizer_state is not False:
295
+ state.update(
296
+ {
297
+ "optimizer": self.optimizer,
298
+ "lr_scheduler": self.lr_scheduler,
299
+ "global_step_idx": self.global_step_idx,
300
+ "epoch_idx": self.epoch_idx,
301
+ }
302
+ )
303
+
304
+ trainable_param_names = set(
305
+ name
306
+ for name, param in self.model.state_dict(keep_vars=True).items()
307
+ if param.requires_grad
308
+ )
309
+ filter = (
310
+ None
311
+ if self.save_full_model
312
+ else {"model": lambda k, p: k in trainable_param_names}
375
313
  )
376
314
 
377
- trainable_param_names = set(
378
- name
379
- for name, param in self.model.state_dict(keep_vars=True).items()
380
- if param.requires_grad
381
- )
382
- filter = (
383
- None
384
- if self.save_full_model
385
- else {"model": lambda k, p: k in trainable_param_names}
386
- )
315
+ fabric.save(path, state=state, filter=filter)
316
+ else:
317
+ self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
387
318
 
388
- fabric.save(path, state=state, filter=filter)
389
319
  self._latest_saved_checkpoint_global_step = self.global_step_idx
390
320
 
391
321
  def load_checkpoint(self, path: Union[str, Path]):
@@ -425,9 +355,9 @@ if __name__ == "__main__":
425
355
  import argparse
426
356
 
427
357
  parser = argparse.ArgumentParser()
428
- parser.add_argument("--base_model_path", type=str)
429
- parser.add_argument("--ckpt_path", type=str)
430
- parser.add_argument("--output_path", type=str)
358
+ parser.add_argument("--base-model-path", type=str)
359
+ parser.add_argument("--ckpt-path", type=str)
360
+ parser.add_argument("--output-path", type=str)
431
361
 
432
362
  args = parser.parse_args()
433
363
 
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import itertools
2
3
  import logging
3
4
  import os
@@ -11,7 +12,7 @@ import torch
11
12
  from lightning.fabric.strategies.fsdp import FSDPStrategy
12
13
  from lightning.fabric.utilities import rank_zero_only
13
14
  from omegaconf import DictConfig, OmegaConf
14
- from peft import PeftModel, get_peft_config, get_peft_model
15
+ from peft import LoraConfig, PeftModel, get_peft_config, get_peft_model
15
16
  from torch import nn
16
17
  from torch.utils.data import ConcatDataset, DataLoader
17
18
  from tqdm.auto import tqdm
@@ -19,7 +20,7 @@ from typing_extensions import TYPE_CHECKING, override
19
20
 
20
21
  from fusion_bench import BaseAlgorithm, BaseModelPool
21
22
  from fusion_bench.dataset.llama.collate import padded_collate_sft
22
- from fusion_bench.mixins import LightningFabricMixin
23
+ from fusion_bench.mixins import FabricTrainingMixin
23
24
  from fusion_bench.modelpool import CausalLMPool
24
25
  from fusion_bench.utils import instantiate
25
26
  from fusion_bench.utils.dtype import get_dtype
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
35
36
  log = logging.getLogger(__name__)
36
37
 
37
38
 
38
- class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
39
+ class PeftFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
39
40
 
40
41
  model: Union[
41
42
  nn.Module, "_FabricModule", "LlamaForCausalLM", PeftModel, peft.LoraModel
@@ -67,7 +68,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
67
68
  save_full_model: bool = False,
68
69
  save_ckpt_type: Literal["lightning", "peft"] = "peft",
69
70
  ckpt_path: Optional[str] = None,
70
- max_length: int = 6150,
71
+ max_length: int = 6144,
71
72
  **kwargs,
72
73
  ):
73
74
  """
@@ -121,17 +122,23 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
121
122
  def run(self, modelpool: CausalLMPool):
122
123
  self.modelpool = modelpool
123
124
  self.setup()
124
- self.train()
125
+ self.train(self.model, self.optimizer, self.lr_scheduler)
125
126
 
126
127
  if self.merge_and_unload:
127
128
  self.model = self.model.merge_and_unload()
128
129
  return self.model
129
130
 
130
131
  def setup_model(self):
132
+ # https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py
133
+ self.tokenizer = self.modelpool.load_tokenizer()
134
+ if self.tokenizer.pad_token_id is None:
135
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
136
+
131
137
  model = self.modelpool.load_pretrained_model()
132
138
 
133
139
  # get the PEFT model
134
140
  peft_config = instantiate(self._peft_config, _convert_="all")
141
+ peft_config.save_pretrained(os.path.join(self.log_dir, "peft_config"))
135
142
  peft_model = get_peft_model(model, peft_config, self.adapter_name)
136
143
  peft_model.print_trainable_parameters()
137
144
 
@@ -149,20 +156,11 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
149
156
  self.use_cache = True
150
157
 
151
158
  self.model_dtype = get_dtype(self.model)
159
+ self.model = self.model.to(dtype=self.model_dtype)
152
160
 
153
161
  def configure_optimizer(self):
154
162
  # compute expected total steps
155
- self.expected_total_steps = []
156
- if self.max_steps > 0:
157
- self.expected_total_steps.append(self.max_steps)
158
- if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
159
- self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
160
- if self.max_epochs > 0:
161
- self.expected_total_steps.append(
162
- len(self.train_dataloader) * self.max_epochs
163
- )
164
- self.expected_total_steps = min(self.expected_total_steps)
165
- log.info(f"Expected total steps: {self.expected_total_steps}")
163
+ self.compute_expected_total_steps(self.train_dataloader)
166
164
 
167
165
  optimizer = instantiate(self._optimizer, self.model.parameters())
168
166
  if self._lr_scheduler is not None:
@@ -201,7 +199,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
201
199
  train_dataset,
202
200
  **self.dataloader_kwargs,
203
201
  shuffle=True,
204
- collate_fn=padded_collate_sft,
202
+ collate_fn=functools.partial(
203
+ padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
204
+ ),
205
205
  )
206
206
  self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
207
207
 
@@ -214,28 +214,19 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
214
214
  optimizer = self.configure_optimizer()
215
215
  optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]
216
216
 
217
- self.model, self.optimizer = fabric.setup(self.model, optimizer)
217
+ self.model = self.fabric.setup_module(self.model)
218
+ self.optimizer = self.fabric.setup_optimizers(optimizer)
218
219
  self.lr_scheduler = lr_scheduler
219
220
 
220
- def _clip_gradients_if_needed(self):
221
+ @override
222
+ def train_epoch(self, *args, **kwargs):
221
223
  fabric = self.fabric
222
224
 
223
- if self.gradient_clip_val is not None:
224
- if self.gradient_clip_algorithm == "value":
225
- fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
226
- elif self.gradient_clip_algorithm == "norm":
227
- fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
228
- else:
229
- raise ValueError(
230
- f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
231
- )
232
-
233
- def train_epoch(self):
234
- fabric = self.fabric
225
+ accumulated_loss = 0
235
226
  for step_idx, batch in enumerate(
236
227
  pbar := tqdm(
237
228
  self.train_dataloader,
238
- desc="Training Steps",
229
+ desc="Training Batches",
239
230
  dynamic_ncols=True,
240
231
  leave=False,
241
232
  disable=not fabric.is_global_zero,
@@ -250,6 +241,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
250
241
  batch["input_ids"] = batch["input_ids"][:, : self.max_length]
251
242
  batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
252
243
  batch["labels"] = batch["labels"][:, : self.max_length]
244
+
253
245
  # disable gradient synchronization if accumulating gradients across steps for improved performance
254
246
  with fabric.no_backward_sync(self.model, enabled=is_accumulating):
255
247
  # use_cache=True is not compatible with gradient checkpointing, so we disable it here
@@ -259,20 +251,13 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
259
251
  labels=batch["labels"],
260
252
  use_cache=self.use_cache,
261
253
  )
262
- loss = output["loss"]
254
+ loss = output["loss"] / self.accumulate_grad_batches
263
255
 
264
256
  fabric.backward(loss)
265
-
266
- metrics = {
267
- "train/loss": loss.item(),
268
- "train/epoch_idx": self.epoch_idx,
269
- "train/lr": self.optimizer.param_groups[0]["lr"],
270
- }
271
- fabric.log_dict(metrics, step=self.global_step_idx)
272
- pbar.set_postfix(metrics)
257
+ accumulated_loss += loss.item()
273
258
 
274
259
  if not is_accumulating:
275
- self._clip_gradients_if_needed()
260
+ self.clip_gradients_if_needed(self.model, self.optimizer)
276
261
 
277
262
  # run lr_scheduler at the end of the step if interval is set to "step"
278
263
  if (
@@ -285,105 +270,30 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
285
270
  self.optimizer.step()
286
271
  self.optimizer.zero_grad()
287
272
 
288
- # save the model at the end of the step if interval is set to "step" and frequency is met
289
- self._try_save_checkpoint(stage="end_of_step")
273
+ metrics = {
274
+ "train/loss": accumulated_loss,
275
+ "train/epoch_idx": self.epoch_idx,
276
+ "train/lr": self.optimizer.param_groups[0]["lr"],
277
+ }
278
+ fabric.log_dict(metrics, step=self.global_step_idx)
279
+ pbar.set_postfix(metrics)
290
280
 
291
- # break if max_steps_per_epoch is set, and exit epoch
292
- if (
293
- self.max_steps_per_epoch > 0
294
- and step_idx + 1 >= self.max_steps_per_epoch
295
- ):
296
- break
297
- # break if max_steps is set, and exit training
298
- if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
299
- self.is_training = False
300
- break
281
+ # save the model at the end of the step if interval is set to "step" and frequency is met
282
+ self.conditional_checkpoint_save(stage="end_of_step")
301
283
 
302
- self.global_step_idx += 1
284
+ # break if max_steps_per_epoch is set, and exit epoch
285
+ if (
286
+ self.max_steps_per_epoch > 0
287
+ and step_idx + 1 >= self.max_steps_per_epoch
288
+ ):
289
+ break
290
+ # break if max_steps is set, and exit training
291
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
292
+ self.is_training = False
293
+ break
303
294
 
304
- def train(self):
305
- fabric = self.fabric
306
- self.is_training = True
307
- self.global_step_idx = 0
308
- self.model.train()
309
- for epoch_idx in tqdm(
310
- range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
311
- "Training Epoch",
312
- dynamic_ncols=True,
313
- leave=False,
314
- disable=not fabric.is_global_zero,
315
- ):
316
- self.epoch_idx = epoch_idx
317
- self.train_epoch()
318
- # run lr_scheduler at the end of the epoch if interval is set to "epoch"
319
- if (
320
- self.lr_scheduler_interval == "epoch"
321
- and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
322
- ):
323
- self.lr_scheduler.step()
324
-
325
- # save the model at the end of the epoch if interval is set to "epoch" and frequency is met
326
- self._try_save_checkpoint(stage="end_of_epoch")
327
-
328
- if not self.is_training:
329
- break
330
-
331
- # save the model at the end of training
332
- self._try_save_checkpoint(stage="end_of_training")
333
-
334
- def _try_save_checkpoint(
335
- self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
336
- ):
337
- if stage == "end_of_step":
338
- if (
339
- self.checkpoint_save_interval == "step"
340
- and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
341
- ):
342
- self.save_checkpoint(
343
- os.path.join(
344
- self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
345
- )
346
- )
347
- elif stage == "end_of_epoch":
348
- if (
349
- self.checkpoint_save_interval == "epoch"
350
- and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
351
- ):
352
- self.save_checkpoint(
353
- os.path.join(
354
- self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
355
- )
356
- )
357
- elif stage == "end_of_training":
358
- # if the checkpoint has not been saved yet, save it
359
- if self.global_step_idx > self._latest_saved_checkpoint_global_step:
360
- self.save_checkpoint(
361
- os.path.join(
362
- self.log_dir,
363
- "checkpoints",
364
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
365
- )
366
- )
367
- try:
368
- os.symlink(
369
- os.path.join(
370
- self.log_dir,
371
- "checkpoints",
372
- "latest_model.ckpt",
373
- ),
374
- dst := os.path.join(
375
- self.log_dir,
376
- "checkpoints",
377
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
378
- ),
379
- target_is_directory=os.path.isdir(dst),
380
- )
381
- except Exception as e:
382
- log.error(f"Failed to create symlink: {e}")
383
- else:
384
- raise ValueError(
385
- f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
386
- )
295
+ self.global_step_idx += 1
296
+ accumulated_loss = 0
387
297
 
388
298
  def save_checkpoint(
389
299
  self,
@@ -418,7 +328,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
418
328
  if self.save_full_model
419
329
  else {"model": lambda k, p: k in trainable_param_names}
420
330
  )
421
-
331
+ os.makedirs(os.path.dirname(path), exist_ok=True)
422
332
  fabric.save(path, state=state, filter=filter)
423
333
  elif self.save_ckpt_type == "peft":
424
334
  self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
@@ -1,7 +1,7 @@
1
- from typing import Literal, Optional, Union
1
+ from typing import Dict, Literal, Optional, Union
2
2
 
3
3
  import torch
4
- from torch import Dict, nn
4
+ from torch import nn
5
5
  from tqdm.auto import tqdm
6
6
  from transformers import LlamaForCausalLM, LlamaModel
7
7