fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +3 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +64 -11
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +12 -1
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +12 -5
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +14 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
class NoSparseGradientError(Exception):
|
|
2
|
+
"""Raised when the gradient is sparse gradient.
|
|
3
|
+
|
|
4
|
+
:param optimizer_name: str. optimizer name.
|
|
5
|
+
:param note: str. special conditions to note (default '').
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self, optimizer_name: str, note: str = ""):
|
|
9
|
+
self.note: str = " " if not note else f" w/ {note} "
|
|
10
|
+
self.message: str = (
|
|
11
|
+
f"[-] {optimizer_name}{self.note}does not support sparse gradient."
|
|
12
|
+
)
|
|
13
|
+
super().__init__(self.message)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ZeroParameterSizeError(Exception):
|
|
17
|
+
"""Raised when the parameter size is 0."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.message: str = "[-] parameter size is 0"
|
|
21
|
+
super().__init__(self.message)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NoClosureError(Exception):
|
|
25
|
+
"""Raised when there's no closure function."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, optimizer_name: str, note: str = ""):
|
|
28
|
+
self.message: str = f"[-] {optimizer_name} requires closure.{note}"
|
|
29
|
+
super().__init__(self.message)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class NegativeLRError(Exception):
|
|
33
|
+
"""Raised when learning rate is negative."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, lr: float, lr_type: str = ""):
|
|
36
|
+
self.note: str = lr_type if lr_type else "learning rate"
|
|
37
|
+
self.message: str = f"[-] {self.note} must be positive. ({lr} > 0)"
|
|
38
|
+
super().__init__(self.message)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class NegativeStepError(Exception):
|
|
42
|
+
"""Raised when step is negative."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, num_steps: int, step_type: str = ""):
|
|
45
|
+
self.note: str = step_type if step_type else "step"
|
|
46
|
+
self.message: str = f"[-] {self.note} must be positive. ({num_steps} > 0)"
|
|
47
|
+
super().__init__(self.message)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .linear_warmup import *
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Modified from pytorch_optimizer: https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/lr_scheduler/linear_warmup.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from fusion_bench.optim.exception import NegativeLRError, NegativeStepError
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"BaseLinearWarmupScheduler",
|
|
16
|
+
"LinearWarmupScheduler",
|
|
17
|
+
"CosineDecayWithWarmup",
|
|
18
|
+
"PolySchedulerWithWarmup",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseLinearWarmupScheduler(ABC):
|
|
23
|
+
r"""BaseLinearWarmupScheduler class.
|
|
24
|
+
|
|
25
|
+
The LR Scheduler class based on this class has linear warmup strategy.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
optimizer (torch.optim.Optimizer): Optimizer. It will set learning rate to all trainable parameters in optimizer.
|
|
29
|
+
T_max (int): Total steps to train.
|
|
30
|
+
max_lr (float): Maximum learning rate.
|
|
31
|
+
min_lr (float): Minimum learning rate.
|
|
32
|
+
init_lr (float): Initial learning rate.
|
|
33
|
+
warmup_steps (int): Steps to warm-up.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
optimizer: torch.optim.Optimizer,
|
|
39
|
+
T_max: int,
|
|
40
|
+
max_lr: float,
|
|
41
|
+
min_lr: float = 0.0,
|
|
42
|
+
init_lr: float = 0.0,
|
|
43
|
+
warmup_steps: int = 0,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Initialize the BaseLinearWarmupScheduler.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
|
|
50
|
+
T_max (int): Total number of training steps.
|
|
51
|
+
max_lr (float): Maximum learning rate.
|
|
52
|
+
min_lr (float): Minimum learning rate.
|
|
53
|
+
init_lr (float): Initial learning rate.
|
|
54
|
+
warmup_steps (int): Number of steps for the warm-up phase.
|
|
55
|
+
"""
|
|
56
|
+
self.optimizer = optimizer
|
|
57
|
+
self.total_steps = T_max
|
|
58
|
+
self.max_lr = max_lr
|
|
59
|
+
self.min_lr = min_lr
|
|
60
|
+
self.init_lr = init_lr
|
|
61
|
+
self.warmup_steps = warmup_steps
|
|
62
|
+
|
|
63
|
+
self.step_t: int = 0
|
|
64
|
+
self.base_lrs: List[float] = []
|
|
65
|
+
|
|
66
|
+
# record current value in self._last_lr to match API from torch.optim.lr_scheduler
|
|
67
|
+
self.last_lr: List[float] = [init_lr]
|
|
68
|
+
|
|
69
|
+
self.validate_parameters()
|
|
70
|
+
|
|
71
|
+
self._init_lr()
|
|
72
|
+
|
|
73
|
+
def validate_parameters(self):
|
|
74
|
+
"""
|
|
75
|
+
Validate the parameters to ensure they are non-negative.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
NegativeLRError: If any of the learning rates are negative.
|
|
79
|
+
NegativeStepError: If any of the step values are negative.
|
|
80
|
+
"""
|
|
81
|
+
if self.min_lr < 0:
|
|
82
|
+
raise NegativeLRError(self.min_lr, "min_lr")
|
|
83
|
+
|
|
84
|
+
if self.max_lr < 0:
|
|
85
|
+
raise NegativeLRError(self.max_lr, "max_lr")
|
|
86
|
+
|
|
87
|
+
if self.init_lr < 0:
|
|
88
|
+
raise NegativeLRError(self.init_lr, "init_lr")
|
|
89
|
+
|
|
90
|
+
if self.total_steps < 0:
|
|
91
|
+
raise NegativeStepError(self.total_steps, "T_max")
|
|
92
|
+
|
|
93
|
+
if self.warmup_steps < 0:
|
|
94
|
+
raise NegativeStepError(self.warmup_steps, "warmup_steps")
|
|
95
|
+
|
|
96
|
+
def _init_lr(self):
|
|
97
|
+
"""
|
|
98
|
+
Initialize the learning rate for each parameter group in the optimizer.
|
|
99
|
+
"""
|
|
100
|
+
self.base_lrs = []
|
|
101
|
+
for param_group in self.optimizer.param_groups:
|
|
102
|
+
param_group["lr"] = self.min_lr
|
|
103
|
+
self.base_lrs.append(self.min_lr)
|
|
104
|
+
|
|
105
|
+
def step(self):
|
|
106
|
+
"""
|
|
107
|
+
Update the learning rate for the current step.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
float: The updated learning rate.
|
|
111
|
+
"""
|
|
112
|
+
if self.step_t < self.warmup_steps:
|
|
113
|
+
value = (
|
|
114
|
+
self.init_lr
|
|
115
|
+
+ (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps
|
|
116
|
+
)
|
|
117
|
+
elif self.step_t == self.warmup_steps:
|
|
118
|
+
value = self.max_lr
|
|
119
|
+
else:
|
|
120
|
+
value = self._step()
|
|
121
|
+
|
|
122
|
+
self.step_t += 1
|
|
123
|
+
|
|
124
|
+
if self.optimizer is not None:
|
|
125
|
+
for param_group in self.optimizer.param_groups:
|
|
126
|
+
param_group["lr"] = value
|
|
127
|
+
|
|
128
|
+
self.last_lr = [value]
|
|
129
|
+
|
|
130
|
+
return value
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def _step(self) -> float: # pragma: no cover
|
|
134
|
+
"""
|
|
135
|
+
Abstract method to calculate the learning rate for the current step.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
float: The calculated learning rate.
|
|
139
|
+
"""
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
def get_lr(self) -> float:
|
|
143
|
+
"""
|
|
144
|
+
Get the current learning rate.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
float: The current learning rate.
|
|
148
|
+
"""
|
|
149
|
+
return self.last_lr[0]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class LinearWarmupScheduler(BaseLinearWarmupScheduler):
|
|
153
|
+
r"""Linear LR Scheduler w/ linear warmup."""
|
|
154
|
+
|
|
155
|
+
def _step(self) -> float:
|
|
156
|
+
"""
|
|
157
|
+
Calculate the learning rate for the current step using a linear decay.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
float: The calculated learning rate.
|
|
161
|
+
"""
|
|
162
|
+
return self.max_lr + (self.min_lr - self.max_lr) * (
|
|
163
|
+
self.step_t - self.warmup_steps
|
|
164
|
+
) / (self.total_steps - self.warmup_steps)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class CosineDecayWithWarmup(BaseLinearWarmupScheduler):
|
|
168
|
+
r"""Cosine LR Scheduler w/ linear warmup."""
|
|
169
|
+
|
|
170
|
+
def _step(self) -> float:
|
|
171
|
+
"""
|
|
172
|
+
Calculate the learning rate for the current step using a cosine decay.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
float: The calculated learning rate.
|
|
176
|
+
"""
|
|
177
|
+
phase: float = (
|
|
178
|
+
(self.step_t - self.warmup_steps)
|
|
179
|
+
/ (self.total_steps - self.warmup_steps)
|
|
180
|
+
* math.pi
|
|
181
|
+
)
|
|
182
|
+
return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class PolySchedulerWithWarmup(BaseLinearWarmupScheduler):
|
|
186
|
+
r"""Poly LR Scheduler.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
poly_order (float): LR scheduler decreases with steps.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, optimizer, poly_order: float = 0.5, **kwargs):
|
|
193
|
+
"""
|
|
194
|
+
Initialize the PolySchedulerWithWarmup.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
|
|
198
|
+
poly_order (float): Order of the polynomial for the learning rate decay.
|
|
199
|
+
kwargs: Additional arguments for the base class.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
ValueError: If poly_order is not positive.
|
|
203
|
+
"""
|
|
204
|
+
self.poly_order = poly_order
|
|
205
|
+
|
|
206
|
+
if poly_order <= 0:
|
|
207
|
+
raise ValueError(f"[-] poly_order must be positive. {poly_order}")
|
|
208
|
+
|
|
209
|
+
super().__init__(optimizer, **kwargs)
|
|
210
|
+
|
|
211
|
+
def _step(self) -> float:
|
|
212
|
+
"""
|
|
213
|
+
Calculate the learning rate for the current step using a polynomial decay.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
float: The calculated learning rate.
|
|
217
|
+
"""
|
|
218
|
+
return (
|
|
219
|
+
self.min_lr
|
|
220
|
+
+ (self.max_lr - self.min_lr)
|
|
221
|
+
* (self.step_t - self.warmup_steps) ** self.poly_order
|
|
222
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .visualization import *
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides utilities for visualizing learning rate schedulers.
|
|
3
|
+
|
|
4
|
+
Functions:
|
|
5
|
+
simulate_scheduler(lr_scheduler, steps): Simulates the learning rate scheduler for a given number of steps.
|
|
6
|
+
plot_lr_schedulers(lr_schedulers, steps, titles): Plots the learning rates of one or more schedulers over a number of steps.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, List, Union
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
16
|
+
|
|
17
|
+
__all__ = ["simulate_scheduler", "plot_lr_schedulers"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def simulate_scheduler(lr_scheduler, steps: int):
|
|
21
|
+
"""
|
|
22
|
+
Simulates the learning rate scheduler for a given number of steps.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler object.
|
|
26
|
+
steps (int): The number of steps to simulate.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
List[float]: A list of learning rates for each step.
|
|
30
|
+
"""
|
|
31
|
+
lrs = []
|
|
32
|
+
for _ in range(steps):
|
|
33
|
+
lr = lr_scheduler.step()
|
|
34
|
+
lrs.append(lr)
|
|
35
|
+
return lrs
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def plot_lr_schedulers(
|
|
39
|
+
lr_schedulers: Union["LRScheduler", List["LRScheduler"]],
|
|
40
|
+
steps: int,
|
|
41
|
+
titles: Union[str, List[str]],
|
|
42
|
+
show_plot: bool = True,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Plots the learning rates of one or more schedulers over a number of steps.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
lr_schedulers (Union[LRScheduler, List[LRScheduler]]): One or more learning rate scheduler objects.
|
|
49
|
+
steps (int): The number of steps to simulate.
|
|
50
|
+
titles (Union[str, List[str]]): Titles for the plots.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
fig, axes: The matplotlib figure and axes objects.
|
|
54
|
+
"""
|
|
55
|
+
# Handle single scheduler
|
|
56
|
+
if isinstance(lr_schedulers, torch.optim.lr_scheduler.LRScheduler):
|
|
57
|
+
lr_schedulers = [lr_schedulers]
|
|
58
|
+
if isinstance(titles, str):
|
|
59
|
+
titles = [titles]
|
|
60
|
+
|
|
61
|
+
fig, axs = plt.subplots(len(lr_schedulers), 1, figsize=(5, 3 * len(lr_schedulers)))
|
|
62
|
+
if len(lr_schedulers) == 1:
|
|
63
|
+
axs = [axs]
|
|
64
|
+
|
|
65
|
+
for i, (scheduler, title) in enumerate(zip(lr_schedulers, titles)):
|
|
66
|
+
lrs = simulate_scheduler(scheduler, steps)
|
|
67
|
+
axs[i].plot(lrs, label=title)
|
|
68
|
+
axs[i].set_title(title)
|
|
69
|
+
axs[i].set_xlabel("Steps")
|
|
70
|
+
axs[i].set_ylabel("Learning Rate")
|
|
71
|
+
axs[i].legend()
|
|
72
|
+
axs[i].grid(True)
|
|
73
|
+
|
|
74
|
+
plt.tight_layout()
|
|
75
|
+
if show_plot:
|
|
76
|
+
plt.show()
|
|
77
|
+
return fig, axs
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Example usage
|
|
81
|
+
if __name__ == "__main__":
|
|
82
|
+
from fusion_bench.optim.lr_scheduler.linear_warmup import (
|
|
83
|
+
CosineDecayWithWarmup,
|
|
84
|
+
LinearWarmupScheduler,
|
|
85
|
+
PolySchedulerWithWarmup,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Dummy optimizer
|
|
89
|
+
optimizer = torch.optim.SGD(
|
|
90
|
+
[torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))], lr=0.1
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Define the schedulers
|
|
94
|
+
linear_scheduler = LinearWarmupScheduler(
|
|
95
|
+
optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
|
|
96
|
+
)
|
|
97
|
+
cosine_scheduler = CosineDecayWithWarmup(
|
|
98
|
+
optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
|
|
99
|
+
)
|
|
100
|
+
poly_scheduler = PolySchedulerWithWarmup(
|
|
101
|
+
optimizer,
|
|
102
|
+
t_max=100,
|
|
103
|
+
max_lr=0.1,
|
|
104
|
+
min_lr=0.01,
|
|
105
|
+
init_lr=0.0,
|
|
106
|
+
warmup_steps=40,
|
|
107
|
+
poly_order=2.0,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Plot the learning rates
|
|
111
|
+
plot_lr_schedulers(
|
|
112
|
+
[linear_scheduler, cosine_scheduler, poly_scheduler],
|
|
113
|
+
steps=100,
|
|
114
|
+
titles=[
|
|
115
|
+
"Linear Warmup",
|
|
116
|
+
"Cosine Decay with Warmup",
|
|
117
|
+
"Poly Scheduler with Warmup",
|
|
118
|
+
],
|
|
119
|
+
)
|
fusion_bench/optim/mezo.py
CHANGED
|
@@ -185,10 +185,13 @@ class FabricModelFusionProgram(
|
|
|
185
185
|
report = taskpool.evaluate(merged_model)
|
|
186
186
|
return report
|
|
187
187
|
elif isinstance(merged_model, Dict):
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
188
|
+
report = {}
|
|
189
|
+
for key, item in merged_model.items():
|
|
190
|
+
if isinstance(item, nn.Module):
|
|
191
|
+
report[key] = taskpool.evaluate(item)
|
|
192
|
+
else:
|
|
193
|
+
# metadata
|
|
194
|
+
report[key] = item
|
|
192
195
|
return report
|
|
193
196
|
elif isinstance(merged_model, Iterable):
|
|
194
197
|
return [
|
|
@@ -236,7 +239,11 @@ class FabricModelFusionProgram(
|
|
|
236
239
|
self.save_merged_model(merged_model)
|
|
237
240
|
if self.taskpool is not None:
|
|
238
241
|
report = self.evaluate_merged_model(self.taskpool, merged_model)
|
|
239
|
-
|
|
242
|
+
try:
|
|
243
|
+
print_json(report, print_type=False)
|
|
244
|
+
except Exception as e:
|
|
245
|
+
log.warning(f"Failed to pretty print the report: {e}")
|
|
246
|
+
print(report)
|
|
240
247
|
if self.report_save_path is not None:
|
|
241
248
|
# save report (Dict) to a file
|
|
242
249
|
# if the directory of `save_report` does not exists, create it
|
|
@@ -3,7 +3,17 @@ import json
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import ( # noqa: F401
|
|
7
|
+
TYPE_CHECKING,
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
Dict,
|
|
11
|
+
List,
|
|
12
|
+
Optional,
|
|
13
|
+
Tuple,
|
|
14
|
+
Union,
|
|
15
|
+
cast,
|
|
16
|
+
)
|
|
7
17
|
|
|
8
18
|
import torch
|
|
9
19
|
from omegaconf import DictConfig
|
|
@@ -25,6 +35,10 @@ from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
|
25
35
|
from fusion_bench.utils import instantiate
|
|
26
36
|
from fusion_bench.utils.parameters import count_parameters
|
|
27
37
|
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
40
|
+
|
|
41
|
+
# disable tokenizers parallelism by default to avoid deadlocks
|
|
28
42
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
29
43
|
|
|
30
44
|
log = logging.getLogger(__name__)
|
|
@@ -198,14 +212,16 @@ class CLIPVisionModelTaskPool(
|
|
|
198
212
|
classifier: HFCLIPClassifier,
|
|
199
213
|
test_loader: DataLoader,
|
|
200
214
|
num_classes: int,
|
|
215
|
+
task_name: str = None,
|
|
201
216
|
):
|
|
202
217
|
"""
|
|
203
|
-
Evaluate the classifier on the test dataset.
|
|
218
|
+
Evaluate the classifier on the test dataset (single-task evaluation).
|
|
204
219
|
|
|
205
220
|
Args:
|
|
206
221
|
classifier (HFCLIPClassifier): The classifier to evaluate.
|
|
207
222
|
test_loader (DataLoader): The data loader for the test dataset.
|
|
208
223
|
num_classes (int): The number of classes in the classification task.
|
|
224
|
+
task_name (str): The name of the task.
|
|
209
225
|
|
|
210
226
|
Returns:
|
|
211
227
|
Dict[str, float]: A dictionary containing the accuracy and loss of the classifier on the test dataset.
|
|
@@ -228,7 +244,12 @@ class CLIPVisionModelTaskPool(
|
|
|
228
244
|
)
|
|
229
245
|
):
|
|
230
246
|
inputs, targets = batch
|
|
231
|
-
outputs = classifier(
|
|
247
|
+
outputs = classifier(
|
|
248
|
+
inputs,
|
|
249
|
+
return_image_embeds=True,
|
|
250
|
+
return_dict=True,
|
|
251
|
+
task_name=task_name,
|
|
252
|
+
)
|
|
232
253
|
logits: Tensor = outputs["logits"]
|
|
233
254
|
|
|
234
255
|
loss = F.cross_entropy(logits, targets)
|
|
@@ -246,12 +267,18 @@ class CLIPVisionModelTaskPool(
|
|
|
246
267
|
results = {"accuracy": acc, "loss": loss}
|
|
247
268
|
return results
|
|
248
269
|
|
|
249
|
-
def evaluate(
|
|
270
|
+
def evaluate(
|
|
271
|
+
self,
|
|
272
|
+
model: Union[CLIPVisionModel, CLIPVisionTransformer],
|
|
273
|
+
name=None,
|
|
274
|
+
**kwargs,
|
|
275
|
+
):
|
|
250
276
|
"""
|
|
251
277
|
Evaluate the model on the image classification task.
|
|
252
278
|
|
|
253
279
|
Args:
|
|
254
280
|
model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
|
|
281
|
+
name (Optional[str]): The name of the model. This will be logged into the report if not None.
|
|
255
282
|
|
|
256
283
|
Returns:
|
|
257
284
|
Dict[str, Any]: A dictionary containing the evaluation results for each task.
|
|
@@ -261,8 +288,17 @@ class CLIPVisionModelTaskPool(
|
|
|
261
288
|
|
|
262
289
|
report = {}
|
|
263
290
|
# CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
|
|
264
|
-
|
|
265
|
-
|
|
291
|
+
if hasattr(model, "is_surgery_model") and model.is_surgery_model:
|
|
292
|
+
log.info("running evaluation on a surgery model.")
|
|
293
|
+
model: "SurgeryModelWrapper" = model
|
|
294
|
+
self.clip_model.vision_model = model
|
|
295
|
+
else:
|
|
296
|
+
# replace the vision encoder with the model
|
|
297
|
+
self.clip_model.vision_model = model
|
|
298
|
+
classifier = HFCLIPClassifier(
|
|
299
|
+
self.clip_model,
|
|
300
|
+
processor=self.processor,
|
|
301
|
+
)
|
|
266
302
|
classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
|
|
267
303
|
# collect basic model information
|
|
268
304
|
training_params, all_params = count_parameters(model)
|
|
@@ -285,6 +321,7 @@ class CLIPVisionModelTaskPool(
|
|
|
285
321
|
classifier,
|
|
286
322
|
test_dataloader,
|
|
287
323
|
num_classes=len(classnames),
|
|
324
|
+
task_name=task_name,
|
|
288
325
|
)
|
|
289
326
|
report[task_name] = result
|
|
290
327
|
self.on_task_evaluation_end()
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The dataset contains the following fields:
|
|
3
|
+
|
|
4
|
+
- chosen_input_ids: The input token ids for the winner.
|
|
5
|
+
- chosen_attention_mask: The attention mask for the winner.
|
|
6
|
+
- rejected_input_ids: The input token ids for the loser.
|
|
7
|
+
- rejected_attention_mask: The attention mask for the loser.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import functools
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
|
12
|
+
|
|
13
|
+
import lightning as L
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from omegaconf import DictConfig
|
|
17
|
+
from torch.utils.data import Subset
|
|
18
|
+
from tqdm.auto import tqdm
|
|
19
|
+
|
|
20
|
+
from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
|
|
21
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
22
|
+
from fusion_bench.taskpool import BaseTaskPool
|
|
23
|
+
from fusion_bench.utils import instantiate
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from transformers import LlamaForSequenceClassification
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def evaluate_batch(model: "LlamaForSequenceClassification", batch):
|
|
30
|
+
batch_size = batch["input_ids"].size(0)
|
|
31
|
+
assert batch_size % 2 == 0, "Batch size must be even."
|
|
32
|
+
|
|
33
|
+
outputs = model(
|
|
34
|
+
input_ids=batch["input_ids"],
|
|
35
|
+
attention_mask=batch["attention_mask"],
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
rewards = outputs[0]
|
|
39
|
+
chosen_reward = rewards[: batch_size // 2]
|
|
40
|
+
rejected_rewards = rewards[batch_size // 2 :]
|
|
41
|
+
|
|
42
|
+
loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()
|
|
43
|
+
correct = (chosen_reward > rejected_rewards).sum().item()
|
|
44
|
+
total = batch_size // 2
|
|
45
|
+
|
|
46
|
+
return {
|
|
47
|
+
"loss": loss.item(),
|
|
48
|
+
"correct": correct,
|
|
49
|
+
"total": total,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def evaluate_dataloader(model: "LlamaForSequenceClassification", dataloader):
|
|
54
|
+
"""
|
|
55
|
+
Compute the accuracy of the reward model on the given dataloader.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model: The reward model
|
|
59
|
+
dataloader: The dataloader for the dataset
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
float: The accuracy of the reward model on the dataset
|
|
63
|
+
"""
|
|
64
|
+
metrics = {
|
|
65
|
+
"loss": 0.0,
|
|
66
|
+
"correct": 0,
|
|
67
|
+
"total": 0,
|
|
68
|
+
}
|
|
69
|
+
with torch.no_grad():
|
|
70
|
+
for batch in (pbar := tqdm(dataloader)):
|
|
71
|
+
batch_result = evaluate_batch(model, batch)
|
|
72
|
+
new_total = metrics["total"] + batch_result["total"]
|
|
73
|
+
metrics["loss"] = (
|
|
74
|
+
metrics["loss"] * metrics["total"] / new_total
|
|
75
|
+
+ batch_result["loss"] * batch_result["total"] / new_total
|
|
76
|
+
)
|
|
77
|
+
metrics["correct"] += batch_result["correct"]
|
|
78
|
+
metrics["total"] += batch_result["total"]
|
|
79
|
+
pbar.set_postfix(metrics)
|
|
80
|
+
|
|
81
|
+
metrics["accuracy"] = metrics["correct"] / metrics["total"]
|
|
82
|
+
return metrics
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RewardModelEvaluationTaskPool(
|
|
86
|
+
BaseTaskPool,
|
|
87
|
+
LightningFabricMixin,
|
|
88
|
+
):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
test_datasets: List[DictConfig],
|
|
92
|
+
dataloader_kwargs: DictConfig,
|
|
93
|
+
tokenizer: Optional[DictConfig],
|
|
94
|
+
max_num_samples: int = -1,
|
|
95
|
+
seed: int = 0,
|
|
96
|
+
**kwargs,
|
|
97
|
+
):
|
|
98
|
+
self.seed = seed
|
|
99
|
+
L.seed_everything(seed)
|
|
100
|
+
self._test_datasets = test_datasets
|
|
101
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
102
|
+
self._tokenizer = tokenizer
|
|
103
|
+
self.max_num_samples = max_num_samples
|
|
104
|
+
super().__init__(**kwargs)
|
|
105
|
+
|
|
106
|
+
def setup(self):
|
|
107
|
+
if self._tokenizer is None:
|
|
108
|
+
# try to load the tokenizer from the model pool
|
|
109
|
+
tokenizer = self._program.modelpool.load_tokenizer()
|
|
110
|
+
else:
|
|
111
|
+
tokenizer = instantiate(self._tokenizer)
|
|
112
|
+
self.tokenizer = tokenizer
|
|
113
|
+
|
|
114
|
+
test_datasets = {
|
|
115
|
+
dataset_name: instantiate(self._test_datasets[dataset_name])
|
|
116
|
+
for dataset_name in self._test_datasets
|
|
117
|
+
}
|
|
118
|
+
if self.max_num_samples > 0:
|
|
119
|
+
test_datasets = {
|
|
120
|
+
dataset_name: Subset(
|
|
121
|
+
test_dataset,
|
|
122
|
+
np.random.permutation(len(test_dataset))[: self.max_num_samples],
|
|
123
|
+
)
|
|
124
|
+
for dataset_name, test_dataset in test_datasets.items()
|
|
125
|
+
}
|
|
126
|
+
test_dataloaders = {
|
|
127
|
+
dataset_name: torch.utils.data.DataLoader(
|
|
128
|
+
test_dataset,
|
|
129
|
+
collate_fn=functools.partial(
|
|
130
|
+
bradley_terry_rm_collate,
|
|
131
|
+
pad_token_id=tokenizer.pad_token_id,
|
|
132
|
+
),
|
|
133
|
+
**self.dataloader_kwargs,
|
|
134
|
+
)
|
|
135
|
+
for dataset_name, test_dataset in test_datasets.items()
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
self.test_dataloaders = {
|
|
139
|
+
dataset_name: self.fabric.setup_dataloaders(test_dataloader)
|
|
140
|
+
for dataset_name, test_dataloader in test_dataloaders.items()
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
@torch.no_grad()
|
|
144
|
+
def evaluate(self, model: "LlamaForSequenceClassification"):
|
|
145
|
+
self.setup()
|
|
146
|
+
|
|
147
|
+
model = self.fabric.setup_module(model)
|
|
148
|
+
if model.config.pad_token_id is None:
|
|
149
|
+
model.config.pad_token_id = self.tokenizer.pad_token_id
|
|
150
|
+
|
|
151
|
+
model.eval()
|
|
152
|
+
report = {}
|
|
153
|
+
for dataset_name, test_dataloader in self.test_dataloaders.items():
|
|
154
|
+
report[dataset_name] = evaluate_dataloader(model, test_dataloader)
|
|
155
|
+
|
|
156
|
+
print(report)
|
|
157
|
+
return report
|