fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +3 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +64 -11
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +12 -1
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +12 -5
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +14 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import itertools
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -18,7 +19,7 @@ from typing_extensions import TYPE_CHECKING, override
|
|
|
18
19
|
|
|
19
20
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
20
21
|
from fusion_bench.dataset.llama.collate import padded_collate_sft
|
|
21
|
-
from fusion_bench.mixins import
|
|
22
|
+
from fusion_bench.mixins import FabricTrainingMixin
|
|
22
23
|
from fusion_bench.modelpool import CausalLMPool
|
|
23
24
|
from fusion_bench.utils import instantiate
|
|
24
25
|
from fusion_bench.utils.dtype import get_dtype
|
|
@@ -34,7 +35,7 @@ if TYPE_CHECKING:
|
|
|
34
35
|
log = logging.getLogger(__name__)
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
class FullFinetuneSFT(BaseAlgorithm,
|
|
38
|
+
class FullFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
|
|
38
39
|
|
|
39
40
|
model: Union[nn.Module, "_FabricModule", "LlamaForCausalLM"]
|
|
40
41
|
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
|
|
@@ -59,7 +60,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
59
60
|
gradient_clip_algorithm: Literal["value", "norm"] = "norm",
|
|
60
61
|
save_optimizer_state: bool = False,
|
|
61
62
|
save_full_model: bool = False,
|
|
63
|
+
save_ckpt_type: Literal["lightning", "hf"] = "lightning",
|
|
62
64
|
ckpt_path: Optional[str] = None,
|
|
65
|
+
max_length: int = 6144,
|
|
66
|
+
fix_token_embedding: bool = True,
|
|
63
67
|
**kwargs,
|
|
64
68
|
):
|
|
65
69
|
"""
|
|
@@ -81,7 +85,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
81
85
|
gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
|
|
82
86
|
save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
|
|
83
87
|
save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
|
|
88
|
+
save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
|
|
84
89
|
ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
|
|
90
|
+
max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
|
|
91
|
+
fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
|
|
85
92
|
"""
|
|
86
93
|
self._optimizer = optimizer
|
|
87
94
|
self._lr_scheduler = lr_scheduler
|
|
@@ -98,18 +105,28 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
98
105
|
self.gradient_clip_algorithm = gradient_clip_algorithm
|
|
99
106
|
self.save_optimizer_state = save_optimizer_state
|
|
100
107
|
self.save_full_model = save_full_model
|
|
108
|
+
self.save_ckpt_type = save_ckpt_type
|
|
101
109
|
self.ckpt_path = ckpt_path
|
|
110
|
+
self.max_length = max_length
|
|
111
|
+
self.fix_token_embedding = fix_token_embedding
|
|
102
112
|
super().__init__(**kwargs)
|
|
103
113
|
|
|
104
114
|
def run(self, modelpool: CausalLMPool):
|
|
105
115
|
self.modelpool = modelpool
|
|
106
116
|
self.setup()
|
|
107
|
-
self.train()
|
|
117
|
+
self.train(self.model, self.optimizer, self.lr_scheduler)
|
|
108
118
|
return self.model
|
|
109
119
|
|
|
110
120
|
def setup_model(self):
|
|
121
|
+
self.tokenizer = self.modelpool.load_tokenizer()
|
|
122
|
+
if self.tokenizer.pad_token_id is None:
|
|
123
|
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
124
|
+
|
|
111
125
|
model = self.modelpool.load_pretrained_model()
|
|
112
|
-
self.model = model
|
|
126
|
+
self.model: "LlamaForCausalLM" = model
|
|
127
|
+
|
|
128
|
+
if self.fix_token_embedding:
|
|
129
|
+
self.model.model.embed_tokens.requires_grad_(False)
|
|
113
130
|
|
|
114
131
|
if self.fabric.strategy == "fsdp" or isinstance(
|
|
115
132
|
self.fabric.strategy, FSDPStrategy
|
|
@@ -125,17 +142,7 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
125
142
|
|
|
126
143
|
def configure_optimizer(self):
|
|
127
144
|
# compute expected total steps
|
|
128
|
-
self.
|
|
129
|
-
if self.max_steps > 0:
|
|
130
|
-
self.expected_total_steps.append(self.max_steps)
|
|
131
|
-
if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
|
|
132
|
-
self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
|
|
133
|
-
if self.max_epochs > 0:
|
|
134
|
-
self.expected_total_steps.append(
|
|
135
|
-
len(self.train_dataloader) * self.max_epochs
|
|
136
|
-
)
|
|
137
|
-
self.expected_total_steps = min(self.expected_total_steps)
|
|
138
|
-
log.info(f"Expected total steps: {self.expected_total_steps}")
|
|
145
|
+
self.compute_expected_total_steps(self.train_dataloader)
|
|
139
146
|
|
|
140
147
|
optimizer = instantiate(self._optimizer, self.model.parameters())
|
|
141
148
|
if self._lr_scheduler is not None:
|
|
@@ -174,7 +181,9 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
174
181
|
train_dataset,
|
|
175
182
|
**self.dataloader_kwargs,
|
|
176
183
|
shuffle=True,
|
|
177
|
-
collate_fn=
|
|
184
|
+
collate_fn=functools.partial(
|
|
185
|
+
padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
|
|
186
|
+
),
|
|
178
187
|
)
|
|
179
188
|
self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
|
|
180
189
|
|
|
@@ -190,25 +199,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
190
199
|
self.model, self.optimizer = fabric.setup(self.model, optimizer)
|
|
191
200
|
self.lr_scheduler = lr_scheduler
|
|
192
201
|
|
|
193
|
-
|
|
202
|
+
@override
|
|
203
|
+
def train_epoch(self, *args, **kwargs):
|
|
194
204
|
fabric = self.fabric
|
|
195
205
|
|
|
196
|
-
|
|
197
|
-
if self.gradient_clip_algorithm == "value":
|
|
198
|
-
fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
|
|
199
|
-
elif self.gradient_clip_algorithm == "norm":
|
|
200
|
-
fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
|
|
201
|
-
else:
|
|
202
|
-
raise ValueError(
|
|
203
|
-
f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
def train_epoch(self):
|
|
207
|
-
fabric = self.fabric
|
|
206
|
+
accumulated_loss = 0
|
|
208
207
|
for step_idx, batch in enumerate(
|
|
209
208
|
pbar := tqdm(
|
|
210
209
|
self.train_dataloader,
|
|
211
|
-
desc="Training
|
|
210
|
+
desc="Training Batches",
|
|
212
211
|
dynamic_ncols=True,
|
|
213
212
|
leave=False,
|
|
214
213
|
disable=not fabric.is_global_zero,
|
|
@@ -216,6 +215,14 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
216
215
|
):
|
|
217
216
|
is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
|
|
218
217
|
|
|
218
|
+
if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
|
|
219
|
+
log.warning(
|
|
220
|
+
f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
|
|
221
|
+
)
|
|
222
|
+
batch["input_ids"] = batch["input_ids"][:, : self.max_length]
|
|
223
|
+
batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
|
|
224
|
+
batch["labels"] = batch["labels"][:, : self.max_length]
|
|
225
|
+
|
|
219
226
|
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
220
227
|
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
221
228
|
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
@@ -225,20 +232,13 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
225
232
|
labels=batch["labels"],
|
|
226
233
|
use_cache=self.use_cache,
|
|
227
234
|
)
|
|
228
|
-
loss = output["loss"]
|
|
235
|
+
loss = output["loss"] / self.accumulate_grad_batches
|
|
229
236
|
|
|
230
237
|
fabric.backward(loss)
|
|
231
|
-
|
|
232
|
-
metrics = {
|
|
233
|
-
"train/loss": loss.item(),
|
|
234
|
-
"train/epoch_idx": self.epoch_idx,
|
|
235
|
-
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
236
|
-
}
|
|
237
|
-
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
238
|
-
pbar.set_postfix(metrics)
|
|
238
|
+
accumulated_loss += loss.item()
|
|
239
239
|
|
|
240
240
|
if not is_accumulating:
|
|
241
|
-
self.
|
|
241
|
+
self.clip_gradients_if_needed(self.model, self.optimizer)
|
|
242
242
|
|
|
243
243
|
# run lr_scheduler at the end of the step if interval is set to "step"
|
|
244
244
|
if (
|
|
@@ -251,105 +251,30 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
251
251
|
self.optimizer.step()
|
|
252
252
|
self.optimizer.zero_grad()
|
|
253
253
|
|
|
254
|
-
|
|
255
|
-
|
|
254
|
+
metrics = {
|
|
255
|
+
"train/loss": accumulated_loss,
|
|
256
|
+
"train/epoch_idx": self.epoch_idx,
|
|
257
|
+
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
258
|
+
}
|
|
259
|
+
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
260
|
+
pbar.set_postfix(metrics)
|
|
256
261
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
self.max_steps_per_epoch > 0
|
|
260
|
-
and step_idx + 1 >= self.max_steps_per_epoch
|
|
261
|
-
):
|
|
262
|
-
break
|
|
263
|
-
# break if max_steps is set, and exit training
|
|
264
|
-
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
265
|
-
self.is_training = False
|
|
266
|
-
break
|
|
262
|
+
# save the model at the end of the step if interval is set to "step" and frequency is met
|
|
263
|
+
self.conditional_checkpoint_save(stage="end_of_step")
|
|
267
264
|
|
|
268
|
-
|
|
265
|
+
# break if max_steps_per_epoch is set, and exit epoch
|
|
266
|
+
if (
|
|
267
|
+
self.max_steps_per_epoch > 0
|
|
268
|
+
and step_idx + 1 >= self.max_steps_per_epoch
|
|
269
|
+
):
|
|
270
|
+
break
|
|
271
|
+
# break if max_steps is set, and exit training
|
|
272
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
273
|
+
self.is_training = False
|
|
274
|
+
break
|
|
269
275
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
self.is_training = True
|
|
273
|
-
self.global_step_idx = 0
|
|
274
|
-
self.model.train()
|
|
275
|
-
for epoch_idx in tqdm(
|
|
276
|
-
range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
|
|
277
|
-
"Training Epoch",
|
|
278
|
-
dynamic_ncols=True,
|
|
279
|
-
leave=False,
|
|
280
|
-
disable=not fabric.is_global_zero,
|
|
281
|
-
):
|
|
282
|
-
self.epoch_idx = epoch_idx
|
|
283
|
-
self.train_epoch()
|
|
284
|
-
# run lr_scheduler at the end of the epoch if interval is set to "epoch"
|
|
285
|
-
if (
|
|
286
|
-
self.lr_scheduler_interval == "epoch"
|
|
287
|
-
and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
|
|
288
|
-
):
|
|
289
|
-
self.lr_scheduler.step()
|
|
290
|
-
|
|
291
|
-
# save the model at the end of the epoch if interval is set to "epoch" and frequency is met
|
|
292
|
-
self._try_save_checkpoint(stage="end_of_epoch")
|
|
293
|
-
|
|
294
|
-
if not self.is_training:
|
|
295
|
-
break
|
|
296
|
-
|
|
297
|
-
# save the model at the end of training
|
|
298
|
-
self._try_save_checkpoint(stage="end_of_training")
|
|
299
|
-
|
|
300
|
-
def _try_save_checkpoint(
|
|
301
|
-
self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
|
|
302
|
-
):
|
|
303
|
-
if stage == "end_of_step":
|
|
304
|
-
if (
|
|
305
|
-
self.checkpoint_save_interval == "step"
|
|
306
|
-
and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
|
|
307
|
-
):
|
|
308
|
-
self.save_checkpoint(
|
|
309
|
-
os.path.join(
|
|
310
|
-
self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
|
|
311
|
-
)
|
|
312
|
-
)
|
|
313
|
-
elif stage == "end_of_epoch":
|
|
314
|
-
if (
|
|
315
|
-
self.checkpoint_save_interval == "epoch"
|
|
316
|
-
and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
|
|
317
|
-
):
|
|
318
|
-
self.save_checkpoint(
|
|
319
|
-
os.path.join(
|
|
320
|
-
self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
|
|
321
|
-
)
|
|
322
|
-
)
|
|
323
|
-
elif stage == "end_of_training":
|
|
324
|
-
# if the checkpoint has not been saved yet, save it
|
|
325
|
-
if self.global_step_idx > self._latest_saved_checkpoint_global_step:
|
|
326
|
-
self.save_checkpoint(
|
|
327
|
-
os.path.join(
|
|
328
|
-
self.log_dir,
|
|
329
|
-
"checkpoints",
|
|
330
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
331
|
-
)
|
|
332
|
-
)
|
|
333
|
-
try:
|
|
334
|
-
os.symlink(
|
|
335
|
-
os.path.join(
|
|
336
|
-
self.log_dir,
|
|
337
|
-
"checkpoints",
|
|
338
|
-
"latest_model.ckpt",
|
|
339
|
-
),
|
|
340
|
-
dst := os.path.join(
|
|
341
|
-
self.log_dir,
|
|
342
|
-
"checkpoints",
|
|
343
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
344
|
-
),
|
|
345
|
-
target_is_directory=os.path.isdir(dst),
|
|
346
|
-
)
|
|
347
|
-
except Exception as e:
|
|
348
|
-
log.error(f"Failed to create symlink: {e}")
|
|
349
|
-
else:
|
|
350
|
-
raise ValueError(
|
|
351
|
-
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
352
|
-
)
|
|
276
|
+
self.global_step_idx += 1
|
|
277
|
+
accumulated_loss = 0
|
|
353
278
|
|
|
354
279
|
def save_checkpoint(
|
|
355
280
|
self,
|
|
@@ -361,31 +286,36 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
361
286
|
return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
|
|
362
287
|
|
|
363
288
|
fabric = self.fabric
|
|
364
|
-
state = {"model": self.model}
|
|
365
289
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
290
|
+
if self.save_ckpt_type == "lightning":
|
|
291
|
+
state = {"model": self.model}
|
|
292
|
+
|
|
293
|
+
# save the optimizer and lr_scheduler state if needed
|
|
294
|
+
if self.save_optimizer_state and save_optimizer_state is not False:
|
|
295
|
+
state.update(
|
|
296
|
+
{
|
|
297
|
+
"optimizer": self.optimizer,
|
|
298
|
+
"lr_scheduler": self.lr_scheduler,
|
|
299
|
+
"global_step_idx": self.global_step_idx,
|
|
300
|
+
"epoch_idx": self.epoch_idx,
|
|
301
|
+
}
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
trainable_param_names = set(
|
|
305
|
+
name
|
|
306
|
+
for name, param in self.model.state_dict(keep_vars=True).items()
|
|
307
|
+
if param.requires_grad
|
|
308
|
+
)
|
|
309
|
+
filter = (
|
|
310
|
+
None
|
|
311
|
+
if self.save_full_model
|
|
312
|
+
else {"model": lambda k, p: k in trainable_param_names}
|
|
375
313
|
)
|
|
376
314
|
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
if param.requires_grad
|
|
381
|
-
)
|
|
382
|
-
filter = (
|
|
383
|
-
None
|
|
384
|
-
if self.save_full_model
|
|
385
|
-
else {"model": lambda k, p: k in trainable_param_names}
|
|
386
|
-
)
|
|
315
|
+
fabric.save(path, state=state, filter=filter)
|
|
316
|
+
else:
|
|
317
|
+
self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
|
|
387
318
|
|
|
388
|
-
fabric.save(path, state=state, filter=filter)
|
|
389
319
|
self._latest_saved_checkpoint_global_step = self.global_step_idx
|
|
390
320
|
|
|
391
321
|
def load_checkpoint(self, path: Union[str, Path]):
|
|
@@ -425,9 +355,9 @@ if __name__ == "__main__":
|
|
|
425
355
|
import argparse
|
|
426
356
|
|
|
427
357
|
parser = argparse.ArgumentParser()
|
|
428
|
-
parser.add_argument("--
|
|
429
|
-
parser.add_argument("--
|
|
430
|
-
parser.add_argument("--
|
|
358
|
+
parser.add_argument("--base-model-path", type=str)
|
|
359
|
+
parser.add_argument("--ckpt-path", type=str)
|
|
360
|
+
parser.add_argument("--output-path", type=str)
|
|
431
361
|
|
|
432
362
|
args = parser.parse_args()
|
|
433
363
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import itertools
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -11,7 +12,7 @@ import torch
|
|
|
11
12
|
from lightning.fabric.strategies.fsdp import FSDPStrategy
|
|
12
13
|
from lightning.fabric.utilities import rank_zero_only
|
|
13
14
|
from omegaconf import DictConfig, OmegaConf
|
|
14
|
-
from peft import PeftModel, get_peft_config, get_peft_model
|
|
15
|
+
from peft import LoraConfig, PeftModel, get_peft_config, get_peft_model
|
|
15
16
|
from torch import nn
|
|
16
17
|
from torch.utils.data import ConcatDataset, DataLoader
|
|
17
18
|
from tqdm.auto import tqdm
|
|
@@ -19,7 +20,7 @@ from typing_extensions import TYPE_CHECKING, override
|
|
|
19
20
|
|
|
20
21
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
21
22
|
from fusion_bench.dataset.llama.collate import padded_collate_sft
|
|
22
|
-
from fusion_bench.mixins import
|
|
23
|
+
from fusion_bench.mixins import FabricTrainingMixin
|
|
23
24
|
from fusion_bench.modelpool import CausalLMPool
|
|
24
25
|
from fusion_bench.utils import instantiate
|
|
25
26
|
from fusion_bench.utils.dtype import get_dtype
|
|
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
|
|
|
35
36
|
log = logging.getLogger(__name__)
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
class PeftFinetuneSFT(BaseAlgorithm,
|
|
39
|
+
class PeftFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
|
|
39
40
|
|
|
40
41
|
model: Union[
|
|
41
42
|
nn.Module, "_FabricModule", "LlamaForCausalLM", PeftModel, peft.LoraModel
|
|
@@ -67,7 +68,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
67
68
|
save_full_model: bool = False,
|
|
68
69
|
save_ckpt_type: Literal["lightning", "peft"] = "peft",
|
|
69
70
|
ckpt_path: Optional[str] = None,
|
|
70
|
-
max_length: int =
|
|
71
|
+
max_length: int = 6144,
|
|
71
72
|
**kwargs,
|
|
72
73
|
):
|
|
73
74
|
"""
|
|
@@ -121,17 +122,23 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
121
122
|
def run(self, modelpool: CausalLMPool):
|
|
122
123
|
self.modelpool = modelpool
|
|
123
124
|
self.setup()
|
|
124
|
-
self.train()
|
|
125
|
+
self.train(self.model, self.optimizer, self.lr_scheduler)
|
|
125
126
|
|
|
126
127
|
if self.merge_and_unload:
|
|
127
128
|
self.model = self.model.merge_and_unload()
|
|
128
129
|
return self.model
|
|
129
130
|
|
|
130
131
|
def setup_model(self):
|
|
132
|
+
# https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py
|
|
133
|
+
self.tokenizer = self.modelpool.load_tokenizer()
|
|
134
|
+
if self.tokenizer.pad_token_id is None:
|
|
135
|
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
136
|
+
|
|
131
137
|
model = self.modelpool.load_pretrained_model()
|
|
132
138
|
|
|
133
139
|
# get the PEFT model
|
|
134
140
|
peft_config = instantiate(self._peft_config, _convert_="all")
|
|
141
|
+
peft_config.save_pretrained(os.path.join(self.log_dir, "peft_config"))
|
|
135
142
|
peft_model = get_peft_model(model, peft_config, self.adapter_name)
|
|
136
143
|
peft_model.print_trainable_parameters()
|
|
137
144
|
|
|
@@ -149,20 +156,11 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
149
156
|
self.use_cache = True
|
|
150
157
|
|
|
151
158
|
self.model_dtype = get_dtype(self.model)
|
|
159
|
+
self.model = self.model.to(dtype=self.model_dtype)
|
|
152
160
|
|
|
153
161
|
def configure_optimizer(self):
|
|
154
162
|
# compute expected total steps
|
|
155
|
-
self.
|
|
156
|
-
if self.max_steps > 0:
|
|
157
|
-
self.expected_total_steps.append(self.max_steps)
|
|
158
|
-
if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
|
|
159
|
-
self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
|
|
160
|
-
if self.max_epochs > 0:
|
|
161
|
-
self.expected_total_steps.append(
|
|
162
|
-
len(self.train_dataloader) * self.max_epochs
|
|
163
|
-
)
|
|
164
|
-
self.expected_total_steps = min(self.expected_total_steps)
|
|
165
|
-
log.info(f"Expected total steps: {self.expected_total_steps}")
|
|
163
|
+
self.compute_expected_total_steps(self.train_dataloader)
|
|
166
164
|
|
|
167
165
|
optimizer = instantiate(self._optimizer, self.model.parameters())
|
|
168
166
|
if self._lr_scheduler is not None:
|
|
@@ -201,7 +199,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
201
199
|
train_dataset,
|
|
202
200
|
**self.dataloader_kwargs,
|
|
203
201
|
shuffle=True,
|
|
204
|
-
collate_fn=
|
|
202
|
+
collate_fn=functools.partial(
|
|
203
|
+
padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
|
|
204
|
+
),
|
|
205
205
|
)
|
|
206
206
|
self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
|
|
207
207
|
|
|
@@ -214,28 +214,19 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
214
214
|
optimizer = self.configure_optimizer()
|
|
215
215
|
optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]
|
|
216
216
|
|
|
217
|
-
self.model
|
|
217
|
+
self.model = self.fabric.setup_module(self.model)
|
|
218
|
+
self.optimizer = self.fabric.setup_optimizers(optimizer)
|
|
218
219
|
self.lr_scheduler = lr_scheduler
|
|
219
220
|
|
|
220
|
-
|
|
221
|
+
@override
|
|
222
|
+
def train_epoch(self, *args, **kwargs):
|
|
221
223
|
fabric = self.fabric
|
|
222
224
|
|
|
223
|
-
|
|
224
|
-
if self.gradient_clip_algorithm == "value":
|
|
225
|
-
fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
|
|
226
|
-
elif self.gradient_clip_algorithm == "norm":
|
|
227
|
-
fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
|
|
228
|
-
else:
|
|
229
|
-
raise ValueError(
|
|
230
|
-
f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
def train_epoch(self):
|
|
234
|
-
fabric = self.fabric
|
|
225
|
+
accumulated_loss = 0
|
|
235
226
|
for step_idx, batch in enumerate(
|
|
236
227
|
pbar := tqdm(
|
|
237
228
|
self.train_dataloader,
|
|
238
|
-
desc="Training
|
|
229
|
+
desc="Training Batches",
|
|
239
230
|
dynamic_ncols=True,
|
|
240
231
|
leave=False,
|
|
241
232
|
disable=not fabric.is_global_zero,
|
|
@@ -250,6 +241,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
250
241
|
batch["input_ids"] = batch["input_ids"][:, : self.max_length]
|
|
251
242
|
batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
|
|
252
243
|
batch["labels"] = batch["labels"][:, : self.max_length]
|
|
244
|
+
|
|
253
245
|
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
254
246
|
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
255
247
|
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
@@ -259,20 +251,13 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
259
251
|
labels=batch["labels"],
|
|
260
252
|
use_cache=self.use_cache,
|
|
261
253
|
)
|
|
262
|
-
loss = output["loss"]
|
|
254
|
+
loss = output["loss"] / self.accumulate_grad_batches
|
|
263
255
|
|
|
264
256
|
fabric.backward(loss)
|
|
265
|
-
|
|
266
|
-
metrics = {
|
|
267
|
-
"train/loss": loss.item(),
|
|
268
|
-
"train/epoch_idx": self.epoch_idx,
|
|
269
|
-
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
270
|
-
}
|
|
271
|
-
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
272
|
-
pbar.set_postfix(metrics)
|
|
257
|
+
accumulated_loss += loss.item()
|
|
273
258
|
|
|
274
259
|
if not is_accumulating:
|
|
275
|
-
self.
|
|
260
|
+
self.clip_gradients_if_needed(self.model, self.optimizer)
|
|
276
261
|
|
|
277
262
|
# run lr_scheduler at the end of the step if interval is set to "step"
|
|
278
263
|
if (
|
|
@@ -285,105 +270,30 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
285
270
|
self.optimizer.step()
|
|
286
271
|
self.optimizer.zero_grad()
|
|
287
272
|
|
|
288
|
-
|
|
289
|
-
|
|
273
|
+
metrics = {
|
|
274
|
+
"train/loss": accumulated_loss,
|
|
275
|
+
"train/epoch_idx": self.epoch_idx,
|
|
276
|
+
"train/lr": self.optimizer.param_groups[0]["lr"],
|
|
277
|
+
}
|
|
278
|
+
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
279
|
+
pbar.set_postfix(metrics)
|
|
290
280
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
self.max_steps_per_epoch > 0
|
|
294
|
-
and step_idx + 1 >= self.max_steps_per_epoch
|
|
295
|
-
):
|
|
296
|
-
break
|
|
297
|
-
# break if max_steps is set, and exit training
|
|
298
|
-
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
299
|
-
self.is_training = False
|
|
300
|
-
break
|
|
281
|
+
# save the model at the end of the step if interval is set to "step" and frequency is met
|
|
282
|
+
self.conditional_checkpoint_save(stage="end_of_step")
|
|
301
283
|
|
|
302
|
-
|
|
284
|
+
# break if max_steps_per_epoch is set, and exit epoch
|
|
285
|
+
if (
|
|
286
|
+
self.max_steps_per_epoch > 0
|
|
287
|
+
and step_idx + 1 >= self.max_steps_per_epoch
|
|
288
|
+
):
|
|
289
|
+
break
|
|
290
|
+
# break if max_steps is set, and exit training
|
|
291
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
292
|
+
self.is_training = False
|
|
293
|
+
break
|
|
303
294
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
self.is_training = True
|
|
307
|
-
self.global_step_idx = 0
|
|
308
|
-
self.model.train()
|
|
309
|
-
for epoch_idx in tqdm(
|
|
310
|
-
range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
|
|
311
|
-
"Training Epoch",
|
|
312
|
-
dynamic_ncols=True,
|
|
313
|
-
leave=False,
|
|
314
|
-
disable=not fabric.is_global_zero,
|
|
315
|
-
):
|
|
316
|
-
self.epoch_idx = epoch_idx
|
|
317
|
-
self.train_epoch()
|
|
318
|
-
# run lr_scheduler at the end of the epoch if interval is set to "epoch"
|
|
319
|
-
if (
|
|
320
|
-
self.lr_scheduler_interval == "epoch"
|
|
321
|
-
and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
|
|
322
|
-
):
|
|
323
|
-
self.lr_scheduler.step()
|
|
324
|
-
|
|
325
|
-
# save the model at the end of the epoch if interval is set to "epoch" and frequency is met
|
|
326
|
-
self._try_save_checkpoint(stage="end_of_epoch")
|
|
327
|
-
|
|
328
|
-
if not self.is_training:
|
|
329
|
-
break
|
|
330
|
-
|
|
331
|
-
# save the model at the end of training
|
|
332
|
-
self._try_save_checkpoint(stage="end_of_training")
|
|
333
|
-
|
|
334
|
-
def _try_save_checkpoint(
|
|
335
|
-
self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
|
|
336
|
-
):
|
|
337
|
-
if stage == "end_of_step":
|
|
338
|
-
if (
|
|
339
|
-
self.checkpoint_save_interval == "step"
|
|
340
|
-
and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
|
|
341
|
-
):
|
|
342
|
-
self.save_checkpoint(
|
|
343
|
-
os.path.join(
|
|
344
|
-
self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
|
|
345
|
-
)
|
|
346
|
-
)
|
|
347
|
-
elif stage == "end_of_epoch":
|
|
348
|
-
if (
|
|
349
|
-
self.checkpoint_save_interval == "epoch"
|
|
350
|
-
and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
|
|
351
|
-
):
|
|
352
|
-
self.save_checkpoint(
|
|
353
|
-
os.path.join(
|
|
354
|
-
self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
|
|
355
|
-
)
|
|
356
|
-
)
|
|
357
|
-
elif stage == "end_of_training":
|
|
358
|
-
# if the checkpoint has not been saved yet, save it
|
|
359
|
-
if self.global_step_idx > self._latest_saved_checkpoint_global_step:
|
|
360
|
-
self.save_checkpoint(
|
|
361
|
-
os.path.join(
|
|
362
|
-
self.log_dir,
|
|
363
|
-
"checkpoints",
|
|
364
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
365
|
-
)
|
|
366
|
-
)
|
|
367
|
-
try:
|
|
368
|
-
os.symlink(
|
|
369
|
-
os.path.join(
|
|
370
|
-
self.log_dir,
|
|
371
|
-
"checkpoints",
|
|
372
|
-
"latest_model.ckpt",
|
|
373
|
-
),
|
|
374
|
-
dst := os.path.join(
|
|
375
|
-
self.log_dir,
|
|
376
|
-
"checkpoints",
|
|
377
|
-
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
378
|
-
),
|
|
379
|
-
target_is_directory=os.path.isdir(dst),
|
|
380
|
-
)
|
|
381
|
-
except Exception as e:
|
|
382
|
-
log.error(f"Failed to create symlink: {e}")
|
|
383
|
-
else:
|
|
384
|
-
raise ValueError(
|
|
385
|
-
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
386
|
-
)
|
|
295
|
+
self.global_step_idx += 1
|
|
296
|
+
accumulated_loss = 0
|
|
387
297
|
|
|
388
298
|
def save_checkpoint(
|
|
389
299
|
self,
|
|
@@ -418,7 +328,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
|
|
|
418
328
|
if self.save_full_model
|
|
419
329
|
else {"model": lambda k, p: k in trainable_param_names}
|
|
420
330
|
)
|
|
421
|
-
|
|
331
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
422
332
|
fabric.save(path, state=state, filter=filter)
|
|
423
333
|
elif self.save_ckpt_type == "peft":
|
|
424
334
|
self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
|