fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +3 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +64 -11
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +12 -1
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +12 -5
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +14 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -58,11 +58,24 @@ class CLIPTemplateFactory:
|
|
|
58
58
|
"templates": "templates",
|
|
59
59
|
},
|
|
60
60
|
"nateraw/rendered-sst2": ".rendered_sst2",
|
|
61
|
+
"rendered-sst2": ".rendered_sst2",
|
|
61
62
|
"tanganke/stl10": ".stl10",
|
|
63
|
+
"stl10": ".stl10",
|
|
62
64
|
"dpdl-benchmark/oxford_flowers102": ".flower102",
|
|
65
|
+
"oxford_flowers102": ".flower102",
|
|
63
66
|
"timm/oxford-iiit-pet": ".oxford_iiit_pet",
|
|
67
|
+
"oxford-iiit-pet": ".oxford_iiit_pet",
|
|
64
68
|
"imagenet": ".imagenet",
|
|
65
69
|
"tiny-imagenet": ".tiny_imagenet",
|
|
70
|
+
"pcam": ".pcam",
|
|
71
|
+
"fer2013": ".fer2013",
|
|
72
|
+
"emnist_mnist": ".emnist_mnist",
|
|
73
|
+
"emnist_letters": ".emnist_letters",
|
|
74
|
+
"kmnist": ".kmnist",
|
|
75
|
+
"food101": ".food101",
|
|
76
|
+
"fashion_mnist": ".fashion_mnist",
|
|
77
|
+
"cub-200-2011": ".cub_200_2011",
|
|
78
|
+
"mango-leaf-disease": ".mango_leaf_disease",
|
|
66
79
|
}
|
|
67
80
|
|
|
68
81
|
@staticmethod
|
|
@@ -168,48 +181,3 @@ class CLIPTemplateFactory:
|
|
|
168
181
|
|
|
169
182
|
def get_classnames_and_templates(dataset_name: str):
|
|
170
183
|
return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def _load_hf_dataset(dataset_name: str):
|
|
174
|
-
"""
|
|
175
|
-
Load a dataset from the Hugging Face datasets library based on the specified dataset name.
|
|
176
|
-
|
|
177
|
-
This function handles specific preprocessing steps for certain datasets to ensure consistency in dataset format.
|
|
178
|
-
For example, it renames columns, removes unnecessary columns, and specifies subsets for certain datasets.
|
|
179
|
-
|
|
180
|
-
Expected dataset format:
|
|
181
|
-
- The dataset should have an "image" column containing the image data.
|
|
182
|
-
- The dataset should have a "label" column containing the class labels.
|
|
183
|
-
|
|
184
|
-
Args:
|
|
185
|
-
dataset_name (str): The name of the dataset to load. Can be one of "svhn", "cifar10", "cifar100", "timm/oxford-iiit-pet", or any other dataset name supported by the Hugging Face datasets library. By default, the datasets have two columns: "image" and "label".
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
A dataset object loaded from the Hugging Face datasets library, with any necessary preprocessing applied.
|
|
189
|
-
"""
|
|
190
|
-
if dataset_name == "svhn":
|
|
191
|
-
return load_dataset(dataset_name, "cropped_digits")
|
|
192
|
-
elif dataset_name == "cifar10":
|
|
193
|
-
dataset = load_dataset(dataset_name)
|
|
194
|
-
dataset = dataset.rename_columns({"img": "image"})
|
|
195
|
-
return dataset
|
|
196
|
-
elif dataset_name == "cifar100":
|
|
197
|
-
dataset = load_dataset(dataset_name)
|
|
198
|
-
dataset = dataset.remove_columns(["coarse_label"]).rename_columns(
|
|
199
|
-
{"img": "image", "fine_label": "label"}
|
|
200
|
-
)
|
|
201
|
-
return dataset
|
|
202
|
-
elif dataset_name == "timm/oxford-iiit-pet":
|
|
203
|
-
dataset = load_dataset(dataset_name)
|
|
204
|
-
dataset = dataset.remove_columns(["image_id", "label_cat_dog"])
|
|
205
|
-
return dataset
|
|
206
|
-
else:
|
|
207
|
-
return load_dataset(dataset_name)
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
def load_clip_dataset(dataset: str, processor):
|
|
211
|
-
hf_dataset = _load_hf_dataset(dataset)
|
|
212
|
-
return (
|
|
213
|
-
CLIPDataset(hf_dataset["train"], processor),
|
|
214
|
-
CLIPDataset(hf_dataset["test"], processor),
|
|
215
|
-
)
|
|
@@ -1,16 +1 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class CLIPDataset(torch.utils.data.Dataset):
|
|
5
|
-
def __init__(self, dataset, processor):
|
|
6
|
-
self.dataset = dataset
|
|
7
|
-
self.processor = processor
|
|
8
|
-
|
|
9
|
-
def __len__(self):
|
|
10
|
-
return len(self.dataset)
|
|
11
|
-
|
|
12
|
-
def __getitem__(self, idx):
|
|
13
|
-
item = self.dataset[idx]
|
|
14
|
-
image = item["image"]
|
|
15
|
-
inputs = self.processor(images=[image], return_tensors="pt")["pixel_values"][0]
|
|
16
|
-
return inputs, item["label"]
|
|
1
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
classname_mapping = {
|
|
2
|
+
"0": "Black_footed_Albatross",
|
|
3
|
+
"1": "Laysan_Albatross",
|
|
4
|
+
"2": "Sooty_Albatross",
|
|
5
|
+
"3": "Groove_billed_Ani",
|
|
6
|
+
"4": "Crested_Auklet",
|
|
7
|
+
"5": "Least_Auklet",
|
|
8
|
+
"6": "Parakeet_Auklet",
|
|
9
|
+
"7": "Rhinoceros_Auklet",
|
|
10
|
+
"8": "Brewer_Blackbird",
|
|
11
|
+
"9": "Red_winged_Blackbird",
|
|
12
|
+
"10": "Rusty_Blackbird",
|
|
13
|
+
"11": "Yellow_headed_Blackbird",
|
|
14
|
+
"12": "Bobolink",
|
|
15
|
+
"13": "Indigo_Bunting",
|
|
16
|
+
"14": "Lazuli_Bunting",
|
|
17
|
+
"15": "Painted_Bunting",
|
|
18
|
+
"16": "Cardinal",
|
|
19
|
+
"17": "Spotted_Catbird",
|
|
20
|
+
"18": "Gray_Catbird",
|
|
21
|
+
"19": "Yellow_breasted_Chat",
|
|
22
|
+
"20": "Eastern_Towhee",
|
|
23
|
+
"21": "Chuck_will_Widow",
|
|
24
|
+
"22": "Brandt_Cormorant",
|
|
25
|
+
"23": "Red_faced_Cormorant",
|
|
26
|
+
"24": "Pelagic_Cormorant",
|
|
27
|
+
"25": "Bronzed_Cowbird",
|
|
28
|
+
"26": "Shiny_Cowbird",
|
|
29
|
+
"27": "Brown_Creeper",
|
|
30
|
+
"28": "American_Crow",
|
|
31
|
+
"29": "Fish_Crow",
|
|
32
|
+
"30": "Black_billed_Cuckoo",
|
|
33
|
+
"31": "Mangrove_Cuckoo",
|
|
34
|
+
"32": "Yellow_billed_Cuckoo",
|
|
35
|
+
"33": "Gray_crowned_Rosy_Finch",
|
|
36
|
+
"34": "Purple_Finch",
|
|
37
|
+
"35": "Northern_Flicker",
|
|
38
|
+
"36": "Acadian_Flycatcher",
|
|
39
|
+
"37": "Great_Crested_Flycatcher",
|
|
40
|
+
"38": "Least_Flycatcher",
|
|
41
|
+
"39": "Olive_sided_Flycatcher",
|
|
42
|
+
"40": "Scissor_tailed_Flycatcher",
|
|
43
|
+
"41": "Vermilion_Flycatcher",
|
|
44
|
+
"42": "Yellow_bellied_Flycatcher",
|
|
45
|
+
"43": "Frigatebird",
|
|
46
|
+
"44": "Northern_Fulmar",
|
|
47
|
+
"45": "Gadwall",
|
|
48
|
+
"46": "American_Goldfinch",
|
|
49
|
+
"47": "European_Goldfinch",
|
|
50
|
+
"48": "Boat_tailed_Grackle",
|
|
51
|
+
"49": "Eared_Grebe",
|
|
52
|
+
"50": "Horned_Grebe",
|
|
53
|
+
"51": "Pied_billed_Grebe",
|
|
54
|
+
"52": "Western_Grebe",
|
|
55
|
+
"53": "Blue_Grosbeak",
|
|
56
|
+
"54": "Evening_Grosbeak",
|
|
57
|
+
"55": "Pine_Grosbeak",
|
|
58
|
+
"56": "Rose_breasted_Grosbeak",
|
|
59
|
+
"57": "Pigeon_Guillemot",
|
|
60
|
+
"58": "California_Gull",
|
|
61
|
+
"59": "Glaucous_winged_Gull",
|
|
62
|
+
"60": "Heermann_Gull",
|
|
63
|
+
"61": "Herring_Gull",
|
|
64
|
+
"62": "Ivory_Gull",
|
|
65
|
+
"63": "Ring_billed_Gull",
|
|
66
|
+
"64": "Slaty_backed_Gull",
|
|
67
|
+
"65": "Western_Gull",
|
|
68
|
+
"66": "Anna_Hummingbird",
|
|
69
|
+
"67": "Ruby_throated_Hummingbird",
|
|
70
|
+
"68": "Rufous_Hummingbird",
|
|
71
|
+
"69": "Green_Violetear",
|
|
72
|
+
"70": "Long_tailed_Jaeger",
|
|
73
|
+
"71": "Pomarine_Jaeger",
|
|
74
|
+
"72": "Blue_Jay",
|
|
75
|
+
"73": "Florida_Jay",
|
|
76
|
+
"74": "Green_Jay",
|
|
77
|
+
"75": "Dark_eyed_Junco",
|
|
78
|
+
"76": "Tropical_Kingbird",
|
|
79
|
+
"77": "Gray_Kingbird",
|
|
80
|
+
"78": "Belted_Kingfisher",
|
|
81
|
+
"79": "Green_Kingfisher",
|
|
82
|
+
"80": "Pied_Kingfisher",
|
|
83
|
+
"81": "Ringed_Kingfisher",
|
|
84
|
+
"82": "White_breasted_Kingfisher",
|
|
85
|
+
"83": "Red_legged_Kittiwake",
|
|
86
|
+
"84": "Horned_Lark",
|
|
87
|
+
"85": "Pacific_Loon",
|
|
88
|
+
"86": "Mallard",
|
|
89
|
+
"87": "Western_Meadowlark",
|
|
90
|
+
"88": "Hooded_Merganser",
|
|
91
|
+
"89": "Red_breasted_Merganser",
|
|
92
|
+
"90": "Mockingbird",
|
|
93
|
+
"91": "Nighthawk",
|
|
94
|
+
"92": "Clark_Nutcracker",
|
|
95
|
+
"93": "White_breasted_Nuthatch",
|
|
96
|
+
"94": "Baltimore_Oriole",
|
|
97
|
+
"95": "Hooded_Oriole",
|
|
98
|
+
"96": "Orchard_Oriole",
|
|
99
|
+
"97": "Scott_Oriole",
|
|
100
|
+
"98": "Ovenbird",
|
|
101
|
+
"99": "Brown_Pelican",
|
|
102
|
+
"100": "White_Pelican",
|
|
103
|
+
"101": "Western_Wood_Pewee",
|
|
104
|
+
"102": "Sayornis",
|
|
105
|
+
"103": "American_Pipit",
|
|
106
|
+
"104": "Whip_poor_Will",
|
|
107
|
+
"105": "Horned_Puffin",
|
|
108
|
+
"106": "Common_Raven",
|
|
109
|
+
"107": "White_necked_Raven",
|
|
110
|
+
"108": "American_Redstart",
|
|
111
|
+
"109": "Geococcyx",
|
|
112
|
+
"110": "Loggerhead_Shrike",
|
|
113
|
+
"111": "Great_Grey_Shrike",
|
|
114
|
+
"112": "Baird_Sparrow",
|
|
115
|
+
"113": "Black_throated_Sparrow",
|
|
116
|
+
"114": "Brewer_Sparrow",
|
|
117
|
+
"115": "Chipping_Sparrow",
|
|
118
|
+
"116": "Clay_colored_Sparrow",
|
|
119
|
+
"117": "House_Sparrow",
|
|
120
|
+
"118": "Field_Sparrow",
|
|
121
|
+
"119": "Fox_Sparrow",
|
|
122
|
+
"120": "Grasshopper_Sparrow",
|
|
123
|
+
"121": "Harris_Sparrow",
|
|
124
|
+
"122": "Henslow_Sparrow",
|
|
125
|
+
"123": "Le_Conte_Sparrow",
|
|
126
|
+
"124": "Lincoln_Sparrow",
|
|
127
|
+
"125": "Nelson_Sharp_tailed_Sparrow",
|
|
128
|
+
"126": "Savannah_Sparrow",
|
|
129
|
+
"127": "Seaside_Sparrow",
|
|
130
|
+
"128": "Song_Sparrow",
|
|
131
|
+
"129": "Tree_Sparrow",
|
|
132
|
+
"130": "Vesper_Sparrow",
|
|
133
|
+
"131": "White_crowned_Sparrow",
|
|
134
|
+
"132": "White_throated_Sparrow",
|
|
135
|
+
"133": "Cape_Glossy_Starling",
|
|
136
|
+
"134": "Bank_Swallow",
|
|
137
|
+
"135": "Barn_Swallow",
|
|
138
|
+
"136": "Cliff_Swallow",
|
|
139
|
+
"137": "Tree_Swallow",
|
|
140
|
+
"138": "Scarlet_Tanager",
|
|
141
|
+
"139": "Summer_Tanager",
|
|
142
|
+
"140": "Artic_Tern",
|
|
143
|
+
"141": "Black_Tern",
|
|
144
|
+
"142": "Caspian_Tern",
|
|
145
|
+
"143": "Common_Tern",
|
|
146
|
+
"144": "Elegant_Tern",
|
|
147
|
+
"145": "Forsters_Tern",
|
|
148
|
+
"146": "Least_Tern",
|
|
149
|
+
"147": "Green_tailed_Towhee",
|
|
150
|
+
"148": "Brown_Thrasher",
|
|
151
|
+
"149": "Sage_Thrasher",
|
|
152
|
+
"150": "Black_capped_Vireo",
|
|
153
|
+
"151": "Blue_headed_Vireo",
|
|
154
|
+
"152": "Philadelphia_Vireo",
|
|
155
|
+
"153": "Red_eyed_Vireo",
|
|
156
|
+
"154": "Warbling_Vireo",
|
|
157
|
+
"155": "White_eyed_Vireo",
|
|
158
|
+
"156": "Yellow_throated_Vireo",
|
|
159
|
+
"157": "Bay_breasted_Warbler",
|
|
160
|
+
"158": "Black_and_white_Warbler",
|
|
161
|
+
"159": "Black_throated_Blue_Warbler",
|
|
162
|
+
"160": "Blue_winged_Warbler",
|
|
163
|
+
"161": "Canada_Warbler",
|
|
164
|
+
"162": "Cape_May_Warbler",
|
|
165
|
+
"163": "Cerulean_Warbler",
|
|
166
|
+
"164": "Chestnut_sided_Warbler",
|
|
167
|
+
"165": "Golden_winged_Warbler",
|
|
168
|
+
"166": "Hooded_Warbler",
|
|
169
|
+
"167": "Kentucky_Warbler",
|
|
170
|
+
"168": "Magnolia_Warbler",
|
|
171
|
+
"169": "Mourning_Warbler",
|
|
172
|
+
"170": "Myrtle_Warbler",
|
|
173
|
+
"171": "Nashville_Warbler",
|
|
174
|
+
"172": "Orange_crowned_Warbler",
|
|
175
|
+
"173": "Palm_Warbler",
|
|
176
|
+
"174": "Pine_Warbler",
|
|
177
|
+
"175": "Prairie_Warbler",
|
|
178
|
+
"176": "Prothonotary_Warbler",
|
|
179
|
+
"177": "Swainson_Warbler",
|
|
180
|
+
"178": "Tennessee_Warbler",
|
|
181
|
+
"179": "Wilson_Warbler",
|
|
182
|
+
"180": "Worm_eating_Warbler",
|
|
183
|
+
"181": "Yellow_Warbler",
|
|
184
|
+
"182": "Northern_Waterthrush",
|
|
185
|
+
"183": "Louisiana_Waterthrush",
|
|
186
|
+
"184": "Bohemian_Waxwing",
|
|
187
|
+
"185": "Cedar_Waxwing",
|
|
188
|
+
"186": "American_Three_toed_Woodpecker",
|
|
189
|
+
"187": "Pileated_Woodpecker",
|
|
190
|
+
"188": "Red_bellied_Woodpecker",
|
|
191
|
+
"189": "Red_cockaded_Woodpecker",
|
|
192
|
+
"190": "Red_headed_Woodpecker",
|
|
193
|
+
"191": "Downy_Woodpecker",
|
|
194
|
+
"192": "Bewick_Wren",
|
|
195
|
+
"193": "Cactus_Wren",
|
|
196
|
+
"194": "Carolina_Wren",
|
|
197
|
+
"195": "House_Wren",
|
|
198
|
+
"196": "Marsh_Wren",
|
|
199
|
+
"197": "Rock_Wren",
|
|
200
|
+
"198": "Winter_Wren",
|
|
201
|
+
"199": "Common_Yellowthroat",
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
classnames = [classname_mapping[str(i)] for i in range(200)]
|
|
205
|
+
templates = [
|
|
206
|
+
lambda c: f"a photo of a {c}.",
|
|
207
|
+
lambda c: f"a photo of the {c}.",
|
|
208
|
+
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
classnames_mapping = {
|
|
2
|
+
"0": "A",
|
|
3
|
+
"1": "B",
|
|
4
|
+
"2": "C",
|
|
5
|
+
"3": "D",
|
|
6
|
+
"4": "E",
|
|
7
|
+
"5": "F",
|
|
8
|
+
"6": "G",
|
|
9
|
+
"7": "H",
|
|
10
|
+
"8": "I",
|
|
11
|
+
"9": "J",
|
|
12
|
+
"10": "K",
|
|
13
|
+
"11": "L",
|
|
14
|
+
"12": "M",
|
|
15
|
+
"13": "N",
|
|
16
|
+
"14": "O",
|
|
17
|
+
"15": "P",
|
|
18
|
+
"16": "Q",
|
|
19
|
+
"17": "R",
|
|
20
|
+
"18": "S",
|
|
21
|
+
"19": "T",
|
|
22
|
+
"20": "U",
|
|
23
|
+
"21": "V",
|
|
24
|
+
"22": "W",
|
|
25
|
+
"23": "X",
|
|
26
|
+
"24": "Y",
|
|
27
|
+
"25": "Z",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
classnames = [classnames_mapping[str(i)] for i in range(26)]
|
|
31
|
+
templates = [lambda c: f'a photo of the digit character: "{c}".']
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
classname_mapping = {
|
|
2
|
+
"0": "T - shirt / top",
|
|
3
|
+
"1": "Trouser",
|
|
4
|
+
"2": "Pullover",
|
|
5
|
+
"3": "Dress",
|
|
6
|
+
"4": "Coat",
|
|
7
|
+
"5": "Sandal",
|
|
8
|
+
"6": "Shirt",
|
|
9
|
+
"7": "Sneaker",
|
|
10
|
+
"8": "Bag",
|
|
11
|
+
"9": "Ankle boot",
|
|
12
|
+
}
|
|
13
|
+
classnames = [classname_mapping[str(i)] for i in range(10)]
|
|
14
|
+
|
|
15
|
+
templates = [
|
|
16
|
+
lambda c: f"a photo of a {c}.",
|
|
17
|
+
lambda c: f"a photo of the {c}.",
|
|
18
|
+
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
classnames = [
|
|
2
|
+
"angry",
|
|
3
|
+
"disgusted",
|
|
4
|
+
"fearful",
|
|
5
|
+
"happy",
|
|
6
|
+
"neutral",
|
|
7
|
+
"sad",
|
|
8
|
+
"surprised",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
templates = [
|
|
12
|
+
lambda c: f"a photo of a {c} looking face.",
|
|
13
|
+
lambda c: f"a photo of a face showing the emotion: {c}.",
|
|
14
|
+
lambda c: f"a photo of a face looking {c}.",
|
|
15
|
+
lambda c: f"a face that looks {c}.",
|
|
16
|
+
lambda c: f"they look {c}.",
|
|
17
|
+
lambda c: f"look at how {c} they are.",
|
|
18
|
+
]
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
classnames = [
|
|
2
|
+
"apple pie",
|
|
3
|
+
"baby back ribs",
|
|
4
|
+
"baklava",
|
|
5
|
+
"beef carpaccio",
|
|
6
|
+
"beef tartare",
|
|
7
|
+
"beet salad",
|
|
8
|
+
"beignets",
|
|
9
|
+
"bibimbap",
|
|
10
|
+
"bread pudding",
|
|
11
|
+
"breakfast burrito",
|
|
12
|
+
"bruschetta",
|
|
13
|
+
"caesar salad",
|
|
14
|
+
"cannoli",
|
|
15
|
+
"caprese salad",
|
|
16
|
+
"carrot cake",
|
|
17
|
+
"ceviche",
|
|
18
|
+
"cheese plate",
|
|
19
|
+
"cheesecake",
|
|
20
|
+
"chicken curry",
|
|
21
|
+
"chicken quesadilla",
|
|
22
|
+
"chicken wings",
|
|
23
|
+
"chocolate cake",
|
|
24
|
+
"chocolate mousse",
|
|
25
|
+
"churros",
|
|
26
|
+
"clam chowder",
|
|
27
|
+
"club sandwich",
|
|
28
|
+
"crab cakes",
|
|
29
|
+
"creme brulee",
|
|
30
|
+
"croque madame",
|
|
31
|
+
"cup cakes",
|
|
32
|
+
"deviled eggs",
|
|
33
|
+
"donuts",
|
|
34
|
+
"dumplings",
|
|
35
|
+
"edamame",
|
|
36
|
+
"eggs benedict",
|
|
37
|
+
"escargots",
|
|
38
|
+
"falafel",
|
|
39
|
+
"filet mignon",
|
|
40
|
+
"fish and chips",
|
|
41
|
+
"foie gras",
|
|
42
|
+
"french fries",
|
|
43
|
+
"french onion soup",
|
|
44
|
+
"french toast",
|
|
45
|
+
"fried calamari",
|
|
46
|
+
"fried rice",
|
|
47
|
+
"frozen yogurt",
|
|
48
|
+
"garlic bread",
|
|
49
|
+
"gnocchi",
|
|
50
|
+
"greek salad",
|
|
51
|
+
"grilled cheese sandwich",
|
|
52
|
+
"grilled salmon",
|
|
53
|
+
"guacamole",
|
|
54
|
+
"gyoza",
|
|
55
|
+
"hamburger",
|
|
56
|
+
"hot and sour soup",
|
|
57
|
+
"hot dog",
|
|
58
|
+
"huevos rancheros",
|
|
59
|
+
"hummus",
|
|
60
|
+
"ice cream",
|
|
61
|
+
"lasagna",
|
|
62
|
+
"lobster bisque",
|
|
63
|
+
"lobster roll sandwich",
|
|
64
|
+
"macaroni and cheese",
|
|
65
|
+
"macarons",
|
|
66
|
+
"miso soup",
|
|
67
|
+
"mussels",
|
|
68
|
+
"nachos",
|
|
69
|
+
"omelette",
|
|
70
|
+
"onion rings",
|
|
71
|
+
"oysters",
|
|
72
|
+
"pad thai",
|
|
73
|
+
"paella",
|
|
74
|
+
"pancakes",
|
|
75
|
+
"panna cotta",
|
|
76
|
+
"peking duck",
|
|
77
|
+
"pho",
|
|
78
|
+
"pizza",
|
|
79
|
+
"pork chop",
|
|
80
|
+
"poutine",
|
|
81
|
+
"prime rib",
|
|
82
|
+
"pulled pork sandwich",
|
|
83
|
+
"ramen",
|
|
84
|
+
"ravioli",
|
|
85
|
+
"red velvet cake",
|
|
86
|
+
"risotto",
|
|
87
|
+
"samosa",
|
|
88
|
+
"sashimi",
|
|
89
|
+
"scallops",
|
|
90
|
+
"seaweed salad",
|
|
91
|
+
"shrimp and grits",
|
|
92
|
+
"spaghetti bolognese",
|
|
93
|
+
"spaghetti carbonara",
|
|
94
|
+
"spring rolls",
|
|
95
|
+
"steak",
|
|
96
|
+
"strawberry shortcake",
|
|
97
|
+
"sushi",
|
|
98
|
+
"tacos",
|
|
99
|
+
"takoyaki",
|
|
100
|
+
"tiramisu",
|
|
101
|
+
"tuna tartare",
|
|
102
|
+
"waffles",
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
templates = [lambda c: f"a photo of {c}, a type of food."]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
classnames_mapping = {
|
|
2
|
+
"0": "お",
|
|
3
|
+
"1": "き",
|
|
4
|
+
"2": "す",
|
|
5
|
+
"3": "つ",
|
|
6
|
+
"4": "な",
|
|
7
|
+
"5": "は",
|
|
8
|
+
"6": "ま",
|
|
9
|
+
"7": "や",
|
|
10
|
+
"8": "れ",
|
|
11
|
+
"9": "を",
|
|
12
|
+
}
|
|
13
|
+
classnames = [classnames_mapping[str(c)] for c in range(10)]
|
|
14
|
+
|
|
15
|
+
templates = [
|
|
16
|
+
lambda c: f"a photo of the character {c}.",
|
|
17
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
classnames = [
|
|
2
|
+
"Anthracnose",
|
|
3
|
+
"Bacterial Canker",
|
|
4
|
+
"Cutting Weevil",
|
|
5
|
+
"Die Back",
|
|
6
|
+
"Gall Midge",
|
|
7
|
+
"Healthy",
|
|
8
|
+
"Powdery Mildew",
|
|
9
|
+
"Sooty Mould",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
templates = [
|
|
13
|
+
lambda c: f"a photo of a mango leaf with {c}.",
|
|
14
|
+
lambda c: f"a mango leaf showing symptoms of {c}.",
|
|
15
|
+
lambda c: f"a close-up photo of {c} on a mango leaf.",
|
|
16
|
+
lambda c: f"this mango leaf is affected by {c}.",
|
|
17
|
+
lambda c: f"a mango leaf disease identified as {c}.",
|
|
18
|
+
lambda c: f"a {c} infection on a mango leaf.",
|
|
19
|
+
]
|
|
@@ -4,3 +4,25 @@ import hydra.core.hydra_config
|
|
|
4
4
|
def get_hydra_output_dir():
|
|
5
5
|
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
|
|
6
6
|
return hydra_cfg.runtime.output_dir
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def config_priority_get(priority_config, general_config, key, default):
|
|
10
|
+
"""
|
|
11
|
+
Retrieve a configuration value with priority.
|
|
12
|
+
|
|
13
|
+
This function retrieves the value associated with `key` from `priority_config` if it exists.
|
|
14
|
+
If the key is not found in `priority_config`, it retrieves the value from `general_config`.
|
|
15
|
+
If the key is not found in either configuration, it returns the provided `default` value.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
priority_config (dict): The configuration dictionary with higher priority.
|
|
19
|
+
general_config (dict): The general configuration dictionary.
|
|
20
|
+
key (str): The key to look up in the configuration dictionaries.
|
|
21
|
+
default: The default value to return if the key is not found in either configuration.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
The value associated with `key` from `priority_config` or `general_config`, or the `default` value if the key is not found.
|
|
25
|
+
"""
|
|
26
|
+
if key in priority_config:
|
|
27
|
+
return priority_config[key]
|
|
28
|
+
return general_config.get(key, default)
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from collections import OrderedDict
|
|
3
|
-
from typing import List, Mapping, Union
|
|
3
|
+
from typing import List, Mapping, Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
@@ -43,7 +43,10 @@ def trainable_state_dict(
|
|
|
43
43
|
return state_dict
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def state_dict_to_vector(
|
|
46
|
+
def state_dict_to_vector(
|
|
47
|
+
state_dict: StateDictType,
|
|
48
|
+
remove_keys: Optional[List[str]] = None,
|
|
49
|
+
):
|
|
47
50
|
"""
|
|
48
51
|
Convert a state dictionary to a vector.
|
|
49
52
|
|
|
@@ -54,6 +57,7 @@ def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
|
54
57
|
Returns:
|
|
55
58
|
torch.Tensor: The converted vector.
|
|
56
59
|
"""
|
|
60
|
+
remove_keys = remove_keys if remove_keys is not None else []
|
|
57
61
|
shared_state_dict = copy.deepcopy(state_dict)
|
|
58
62
|
for key in remove_keys:
|
|
59
63
|
if key in shared_state_dict:
|
|
@@ -64,7 +68,11 @@ def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
|
64
68
|
)
|
|
65
69
|
|
|
66
70
|
|
|
67
|
-
def vector_to_state_dict(
|
|
71
|
+
def vector_to_state_dict(
|
|
72
|
+
vector: torch.Tensor,
|
|
73
|
+
state_dict: StateDictType,
|
|
74
|
+
remove_keys: Optional[List[str]] = None,
|
|
75
|
+
):
|
|
68
76
|
"""
|
|
69
77
|
Convert a vector to a state dictionary.
|
|
70
78
|
|
|
@@ -76,6 +84,7 @@ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
|
76
84
|
Returns:
|
|
77
85
|
dict: The converted state dictionary.
|
|
78
86
|
"""
|
|
87
|
+
remove_keys = remove_keys if remove_keys is not None else []
|
|
79
88
|
# create a reference dict to define the order of the vector
|
|
80
89
|
reference_dict = copy.deepcopy(state_dict)
|
|
81
90
|
for key in remove_keys:
|
|
File without changes
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
import seaborn as sns
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def visualize_model_inputs(input_ids, attention_mask, labels, tokenizer=None):
|
|
7
|
+
"""
|
|
8
|
+
Visualize model inputs: attention mask, labels and input_ids
|
|
9
|
+
|
|
10
|
+
Parameters:
|
|
11
|
+
-----------
|
|
12
|
+
attention_mask: numpy array or tensor
|
|
13
|
+
The attention mask array
|
|
14
|
+
labels: numpy array or tensor
|
|
15
|
+
The labels array
|
|
16
|
+
input_ids: numpy array or tensor
|
|
17
|
+
The input ids array
|
|
18
|
+
tokenizer: optional
|
|
19
|
+
The tokenizer object to decode input_ids
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Convert inputs to numpy if they're tensors
|
|
23
|
+
attention_mask = np.array(attention_mask)
|
|
24
|
+
labels = np.array(labels)
|
|
25
|
+
input_ids = np.array(input_ids)
|
|
26
|
+
|
|
27
|
+
# Create figure with 3 subplots
|
|
28
|
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 10))
|
|
29
|
+
|
|
30
|
+
# Plot attention mask
|
|
31
|
+
sns.heatmap(attention_mask.reshape(1, -1), ax=ax1, cmap="Blues", cbar=True)
|
|
32
|
+
ax1.set_title("**Attention Mask**")
|
|
33
|
+
ax1.set_ylabel("Sequence")
|
|
34
|
+
|
|
35
|
+
# Plot labels
|
|
36
|
+
sns.heatmap(labels.reshape(1, -1), ax=ax2, cmap="Reds", cbar=True)
|
|
37
|
+
ax2.set_title("**Labels**")
|
|
38
|
+
ax2.set_ylabel("Sequence")
|
|
39
|
+
|
|
40
|
+
# Plot input_ids
|
|
41
|
+
sns.heatmap(input_ids.reshape(1, -1), ax=ax3, cmap="Greens", cbar=True)
|
|
42
|
+
ax3.set_title("**Input IDs**")
|
|
43
|
+
ax3.set_ylabel("Sequence")
|
|
44
|
+
|
|
45
|
+
# If tokenizer is provided, add decoded tokens as x-axis labels
|
|
46
|
+
if tokenizer:
|
|
47
|
+
decoded_tokens = [tokenizer.decode(token_id) for token_id in input_ids]
|
|
48
|
+
ax3.set_xticks(np.arange(len(decoded_tokens)) + 0.5)
|
|
49
|
+
ax3.set_xticklabels(decoded_tokens, rotation=45, ha="right")
|
|
50
|
+
|
|
51
|
+
plt.tight_layout()
|
|
52
|
+
return fig
|