fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from omegaconf import DictConfig, open_dict
|
|
6
|
+
from transformers import CLIPProcessor, CLIPVisionModel
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from fusion_bench.dataset import CLIPDataset, load_dataset_from_config
|
|
10
|
+
from fusion_bench.utils import timeit_context
|
|
11
|
+
|
|
12
|
+
from .base_pool import ModelPool
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HuggingFaceClipVisionPool(ModelPool):
|
|
18
|
+
"""
|
|
19
|
+
A model pool for managing Hugging Face's CLIP Vision models.
|
|
20
|
+
|
|
21
|
+
This class extends the base `ModelPool` class and overrides its methods to handle
|
|
22
|
+
the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, modelpool_config: DictConfig):
|
|
26
|
+
super().__init__(modelpool_config)
|
|
27
|
+
|
|
28
|
+
self._clip_processor = None
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def clip_processor(self):
|
|
32
|
+
"""
|
|
33
|
+
Returns the CLIP processor. If it's not already initialized, it initializes it using the path of the pretrained model.
|
|
34
|
+
"""
|
|
35
|
+
if self._clip_processor is None:
|
|
36
|
+
if "_pretrained_" in self._model_names:
|
|
37
|
+
self._clip_processor = CLIPProcessor.from_pretrained(
|
|
38
|
+
self.get_model_config("_pretrained_")["path"]
|
|
39
|
+
)
|
|
40
|
+
else:
|
|
41
|
+
log.warning(
|
|
42
|
+
"No pretrained model found in the model pool. Returning the first model."
|
|
43
|
+
)
|
|
44
|
+
self._clip_processor = CLIPProcessor.from_pretrained(
|
|
45
|
+
self.get_model_config(self.model_names[0])["path"]
|
|
46
|
+
)
|
|
47
|
+
return self._clip_processor
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def load_model(self, model_config: str | DictConfig) -> CLIPVisionModel:
|
|
51
|
+
"""
|
|
52
|
+
Load a CLIP Vision model from the given configuration.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
model_config (str | DictConfig): The configuration for the model to load.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
CLIPVisionModel: The loaded CLIP Vision model.
|
|
59
|
+
"""
|
|
60
|
+
if isinstance(model_config, str):
|
|
61
|
+
model_config = self.get_model_config(model_config)
|
|
62
|
+
|
|
63
|
+
with timeit_context(
|
|
64
|
+
f"Loading CLIP vision model: '{model_config.name}' from '{model_config.path}'."
|
|
65
|
+
):
|
|
66
|
+
vision_model = CLIPVisionModel.from_pretrained(model_config.path)
|
|
67
|
+
return vision_model
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def save_model(self, model: CLIPVisionModel, path: str):
|
|
71
|
+
"""
|
|
72
|
+
Save a CLIP Vision model to the given path.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
model (CLIPVisionModel): The model to save.
|
|
76
|
+
path (str): The path to save the model to.
|
|
77
|
+
"""
|
|
78
|
+
with timeit_context(f'Saving clip vision model to "{path}"'):
|
|
79
|
+
model.save_pretrained(path)
|
|
80
|
+
|
|
81
|
+
def get_tta_dataset_config(self, dataset: str):
|
|
82
|
+
"""
|
|
83
|
+
Retrieve the configuration for a TTA (Test-Time Adaptation) dataset.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
dataset (str): The name of the dataset for which to retrieve the configuration.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
DictConfig: The configuration dictionary for the specified dataset.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ValueError: If the specified dataset is not found in the configuration.
|
|
93
|
+
"""
|
|
94
|
+
for dataset_config in self.config.tta_datasets:
|
|
95
|
+
if dataset_config.name == dataset:
|
|
96
|
+
return dataset_config
|
|
97
|
+
raise ValueError(f"Dataset {dataset} not found in config")
|
|
98
|
+
|
|
99
|
+
def prepare_dataset_config(self, dataset_config: DictConfig):
|
|
100
|
+
"""
|
|
101
|
+
Prepare the dataset configuration by setting the dataset type if it's not already set.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
dataset_config (DictConfig): The configuration dictionary for the dataset.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
DictConfig: The updated configuration dictionary for the dataset.
|
|
108
|
+
"""
|
|
109
|
+
if not hasattr(dataset_config, "type"):
|
|
110
|
+
with open_dict(dataset_config):
|
|
111
|
+
dataset_config["type"] = self.config.dataset_type
|
|
112
|
+
return dataset_config
|
|
113
|
+
|
|
114
|
+
@functools.cache
|
|
115
|
+
def get_tta_test_dataset(
|
|
116
|
+
self, tta_dataset: str, clip_processor: Optional[CLIPProcessor] = None
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
Load the test dataset for the task.
|
|
120
|
+
This method is cached, so the dataset is loaded only once.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
tta_dataset (str): The name of the TTA dataset to load.
|
|
124
|
+
clip_processor (Optional[CLIPProcessor]): The CLIP processor to use for preprocessing the dataset. If None, the default processor is used.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
CLIPDataset: The loaded and preprocessed TTA test dataset.
|
|
128
|
+
"""
|
|
129
|
+
if clip_processor is None:
|
|
130
|
+
# if clip_processor is not provided, try to load the clip_processor from pre-trained model
|
|
131
|
+
clip_processor = self.clip_processor
|
|
132
|
+
dataset_config = self.get_tta_dataset_config(tta_dataset)["dataset"]
|
|
133
|
+
dataset_config = self.prepare_dataset_config(dataset_config)
|
|
134
|
+
with timeit_context(f"Loading test dataset: {dataset_config.name}"):
|
|
135
|
+
dataset = load_dataset_from_config(dataset_config)
|
|
136
|
+
dataset = CLIPDataset(dataset, self.clip_processor)
|
|
137
|
+
return dataset
|
|
138
|
+
|
|
139
|
+
def get_train_dataset_config(self, model_name: str):
|
|
140
|
+
"""
|
|
141
|
+
Retrieve the configuration for a specific training dataset.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model_name (str): The name of the model for which to retrieve the training dataset configuration.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
DictConfig: The configuration dictionary for the specified training dataset.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
ValueError: If the specified training dataset is not found in the configuration.
|
|
151
|
+
"""
|
|
152
|
+
for dataset_config in self.config.train_datasets:
|
|
153
|
+
if dataset_config.name == model_name:
|
|
154
|
+
return dataset_config
|
|
155
|
+
raise ValueError(f"Dataset {model_name} not found in config")
|
|
156
|
+
|
|
157
|
+
def get_train_dataset(
|
|
158
|
+
self, model_name: str, clip_processor: Optional[CLIPProcessor] = None
|
|
159
|
+
):
|
|
160
|
+
"""
|
|
161
|
+
Load the training dataset for the specified model.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
model_name (str): The name of the model for which to load the training dataset.
|
|
165
|
+
clip_processor (Optional[CLIPProcessor]): The CLIP processor to use for preprocessing the dataset. If None, the default processor is used.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
CLIPDataset: The loaded and preprocessed training dataset.
|
|
169
|
+
"""
|
|
170
|
+
if clip_processor is None:
|
|
171
|
+
# if clip_processor is not provided, try to load the clip_processor from pre-trained model
|
|
172
|
+
clip_processor = self.clip_processor
|
|
173
|
+
dataset_config = self.get_train_dataset_config(model_name)["dataset"]
|
|
174
|
+
dataset_config = self.prepare_dataset_config(dataset_config)
|
|
175
|
+
with timeit_context(f"Loading train dataset: {dataset_config.name}"):
|
|
176
|
+
dataset = load_dataset_from_config(dataset_config)
|
|
177
|
+
dataset = CLIPDataset(dataset, self.clip_processor)
|
|
178
|
+
return dataset
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# flake8: noqa F401
|
|
2
|
+
from omegaconf import DictConfig
|
|
3
|
+
|
|
4
|
+
from fusion_bench.taskpool.dummy import DummyTaskPool
|
|
5
|
+
|
|
6
|
+
from .base_pool import TaskPool
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TaskPoolFactory:
|
|
10
|
+
"""
|
|
11
|
+
Factory class to create and manage different task pools.
|
|
12
|
+
This is for v0.1.x versions, deprecated.
|
|
13
|
+
For implementing new task pool, use `fusion_bench.taskpool.BaseTaskPool` instead.
|
|
14
|
+
|
|
15
|
+
This class provides methods to create task pools based on a given configuration,
|
|
16
|
+
register new task pools, and list available task pools.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
_taskpool_types = {
|
|
20
|
+
"dummy": DummyTaskPool,
|
|
21
|
+
"clip_vit_classification": ".clip_image_classification.CLIPImageClassificationTaskPool",
|
|
22
|
+
"FlanT5GLUETextGenerationTaskPool": ".flan_t5_glue_text_generation.FlanT5GLUETextGenerationTaskPool",
|
|
23
|
+
"NYUv2TaskPool": "fusion_bench.taskpool.nyuv2_taskpool.NYUv2TaskPool",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def create_taskpool(taskpool_config: DictConfig):
|
|
28
|
+
"""
|
|
29
|
+
Create an instance of a task pool based on the provided configuration.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
taskpool_config (DictConfig): The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
TaskPool: An instance of the specified task pool.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If 'type' attribute is not found in the configuration or does not match any known task pool types.
|
|
39
|
+
"""
|
|
40
|
+
from fusion_bench.utils import import_object
|
|
41
|
+
|
|
42
|
+
taskpool_type = taskpool_config.get("type")
|
|
43
|
+
if taskpool_type is None:
|
|
44
|
+
raise ValueError("Task pool type not specified")
|
|
45
|
+
|
|
46
|
+
if taskpool_type not in TaskPoolFactory._taskpool_types:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Unknown task pool: {taskpool_type}, available task pools: {TaskPoolFactory._taskpool_types.keys()}. You can register a new task pool using `TaskPoolFactory.register_taskpool()` method."
|
|
49
|
+
)
|
|
50
|
+
taskpool_cls = TaskPoolFactory._taskpool_types[taskpool_type]
|
|
51
|
+
if isinstance(taskpool_cls, str):
|
|
52
|
+
if taskpool_cls.startswith("."):
|
|
53
|
+
taskpool_cls = f"fusion_bench.compat.taskpool.{taskpool_cls[1:]}"
|
|
54
|
+
taskpool_cls = import_object(taskpool_cls)
|
|
55
|
+
return taskpool_cls(taskpool_config)
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def register_taskpool(name: str, taskpool_cls):
|
|
59
|
+
"""
|
|
60
|
+
Register a new task pool with the factory.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
name (str): The name of the task pool.
|
|
64
|
+
taskpool_cls: The class of the task pool to register.
|
|
65
|
+
"""
|
|
66
|
+
TaskPoolFactory._taskpool_types[name] = taskpool_cls
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def available_taskpools(cls):
|
|
70
|
+
"""
|
|
71
|
+
Get a list of available task pools.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
list: A list of available task pool names.
|
|
75
|
+
"""
|
|
76
|
+
return list(cls._taskpool_types.keys())
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def load_taskpool_from_config(taskpool_config: DictConfig):
|
|
80
|
+
"""
|
|
81
|
+
Loads a task pool based on the provided configuration.
|
|
82
|
+
|
|
83
|
+
The function checks the 'type' attribute of the configuration and returns an instance of the corresponding task pool.
|
|
84
|
+
If the 'type' attribute is not found or does not match any known task pool types, a ValueError is raised.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
taskpool_config (DictConfig): The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
An instance of the specified task pool.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
ValueError: If 'type' attribute is not found in the configuration or does not match any known task pool types.
|
|
94
|
+
"""
|
|
95
|
+
return TaskPoolFactory.create_taskpool(taskpool_config)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from omegaconf import DictConfig
|
|
4
|
+
from tqdm.autonotebook import tqdm
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TaskPool:
|
|
8
|
+
"""
|
|
9
|
+
A class to manage a pool of tasks for evaluation.
|
|
10
|
+
This is the base class for version 0.1.x, deprecated.
|
|
11
|
+
Use `fusion_bench.taskpool.BaseTaskPool` instead.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
config (DictConfig): The configuration for the task pool.
|
|
15
|
+
_all_task_names (List[str]): A list of all task names in the task pool.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
_program = None
|
|
19
|
+
|
|
20
|
+
def __init__(self, taskpool_config: DictConfig):
|
|
21
|
+
"""
|
|
22
|
+
Initialize the TaskPool with the given configuration.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
taskpool_config (DictConfig): The configuration for the task pool.
|
|
26
|
+
"""
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.config = taskpool_config
|
|
29
|
+
|
|
30
|
+
# Check for duplicate task names
|
|
31
|
+
if self.config.get("tasks", None) is not None:
|
|
32
|
+
task_names = [task["name"] for task in self.config["tasks"]]
|
|
33
|
+
assert len(task_names) == len(
|
|
34
|
+
set(task_names)
|
|
35
|
+
), "Duplicate task names found in the task pool"
|
|
36
|
+
self._all_task_names = task_names
|
|
37
|
+
|
|
38
|
+
def evaluate(self, model):
|
|
39
|
+
"""
|
|
40
|
+
Evaluate the model on all tasks in the task pool, and return a report.
|
|
41
|
+
|
|
42
|
+
Take image classification as an example, the report will look like:
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
{
|
|
46
|
+
"mnist": {
|
|
47
|
+
"accuracy": 0.8,
|
|
48
|
+
"loss": 0.2,
|
|
49
|
+
},
|
|
50
|
+
<task_name>: {
|
|
51
|
+
<metric_name>: <metric_value>,
|
|
52
|
+
...
|
|
53
|
+
},
|
|
54
|
+
}
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model: The model to evaluate.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
report (dict): A dictionary containing the results of the evaluation for each task.
|
|
62
|
+
"""
|
|
63
|
+
report = {}
|
|
64
|
+
for task_name in tqdm(self.task_names, desc="Evaluating tasks"):
|
|
65
|
+
task = self.load_task(task_name)
|
|
66
|
+
result = task.evaluate(model)
|
|
67
|
+
report[task_name] = result
|
|
68
|
+
return report
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def task_names(self):
|
|
72
|
+
"""
|
|
73
|
+
Return a list of all task names in the task pool.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
List[str]: A list of all task names.
|
|
77
|
+
"""
|
|
78
|
+
return self._all_task_names
|
|
79
|
+
|
|
80
|
+
def get_task_config(self, task_name: str):
|
|
81
|
+
"""
|
|
82
|
+
Retrieve the configuration for a specific task from the task pool.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
task_name (str): The name of the task for which to retrieve the configuration.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
DictConfig: The configuration dictionary for the specified task.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If the specified task is not found in the task pool.
|
|
92
|
+
"""
|
|
93
|
+
for task in self.config["tasks"]:
|
|
94
|
+
if task["name"] == task_name:
|
|
95
|
+
return task
|
|
96
|
+
raise ValueError(f"Task {task_name} not found in the task pool")
|
|
97
|
+
|
|
98
|
+
def load_task(self, task_name_or_config: Union[str, DictConfig]):
|
|
99
|
+
"""
|
|
100
|
+
Load a task from the task pool.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
task_name_or_config (Union[str, DictConfig]): The name or configuration of the task to load.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Any: The loaded task.
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
NotImplementedError: If the method is not implemented in the subclass.
|
|
110
|
+
"""
|
|
111
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from typing import Callable, List, cast
|
|
8
|
+
|
|
9
|
+
import lightning as L
|
|
10
|
+
import torch
|
|
11
|
+
from omegaconf import DictConfig, open_dict
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from tqdm.autonotebook import tqdm
|
|
14
|
+
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
15
|
+
|
|
16
|
+
from fusion_bench.dataset import CLIPDataset, load_dataset_from_config
|
|
17
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
18
|
+
from fusion_bench.tasks.classification import ClassificationTask
|
|
19
|
+
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
20
|
+
from fusion_bench.utils.parameters import count_parameters
|
|
21
|
+
|
|
22
|
+
from .base_pool import TaskPool
|
|
23
|
+
|
|
24
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
25
|
+
|
|
26
|
+
log = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@functools.cache
|
|
30
|
+
def load_dataset_from_config_cached(dataset_config: DictConfig):
|
|
31
|
+
return load_dataset_from_config(dataset_config)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CLIPImageClassificationTask(ClassificationTask):
|
|
35
|
+
"""
|
|
36
|
+
This class is used to define the image classification task for CLIP models.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
_fabric: L.Fabric = None
|
|
40
|
+
_clip_processor: CLIPProcessor = None
|
|
41
|
+
#
|
|
42
|
+
_taskpool: "CLIPImageClassificationTaskPool" = None
|
|
43
|
+
|
|
44
|
+
classnames: List[str] = []
|
|
45
|
+
templates: List[Callable[[str], str]] = []
|
|
46
|
+
|
|
47
|
+
def __init__(self, task_config: DictConfig):
|
|
48
|
+
self.config = task_config
|
|
49
|
+
|
|
50
|
+
self.classnames, self.templates = get_classnames_and_templates(
|
|
51
|
+
self.config["dataset"].name
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
@cached_property
|
|
55
|
+
def test_dataset(self):
|
|
56
|
+
"""
|
|
57
|
+
Load the test dataset for the task.
|
|
58
|
+
This method is cached, so the dataset is loaded only once.
|
|
59
|
+
"""
|
|
60
|
+
dataset_config = self.config["dataset"]
|
|
61
|
+
dataset_config = self._taskpool.prepare_dataset_config(dataset_config)
|
|
62
|
+
log.info(f"Loading test dataset: {dataset_config.name}")
|
|
63
|
+
dataset = load_dataset_from_config_cached(dataset_config)
|
|
64
|
+
dataset = CLIPDataset(dataset, self._clip_processor)
|
|
65
|
+
return dataset
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def num_classes(self):
|
|
69
|
+
return len(self.classnames)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def test_loader(self):
|
|
73
|
+
loader = DataLoader(
|
|
74
|
+
self.test_dataset,
|
|
75
|
+
batch_size=self.config["batch_size"],
|
|
76
|
+
num_workers=self.config["num_workers"],
|
|
77
|
+
shuffle=False,
|
|
78
|
+
)
|
|
79
|
+
if self._fabric is not None:
|
|
80
|
+
loader = self._fabric.setup_dataloaders(loader)
|
|
81
|
+
return loader
|
|
82
|
+
|
|
83
|
+
def evaluate(self, clip_model: CLIPModel):
|
|
84
|
+
"""
|
|
85
|
+
Evaluate the model on the image classification task.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
clip_model (CLIPModel): The CLIP model to evaluate.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
dict: A dictionary containing the evaluation results.
|
|
92
|
+
"""
|
|
93
|
+
classifier = HFCLIPClassifier(
|
|
94
|
+
clip_model=clip_model, processor=self._clip_processor
|
|
95
|
+
)
|
|
96
|
+
classifier.set_classification_task(self.classnames, self.templates)
|
|
97
|
+
if self._fabric is not None:
|
|
98
|
+
classifier = self._fabric.setup_module(deepcopy(classifier))
|
|
99
|
+
results = super().evaluate(classifier)
|
|
100
|
+
log.info(f"Results for task {self.config.name}: {results}")
|
|
101
|
+
return results
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class CLIPImageClassificationTaskPool(TaskPool):
|
|
105
|
+
_fabric: L.Fabric = None
|
|
106
|
+
|
|
107
|
+
# CLIP forward model and processor
|
|
108
|
+
_clip_model: CLIPModel = None
|
|
109
|
+
_clip_processor: CLIPProcessor = None
|
|
110
|
+
|
|
111
|
+
def __init__(self, taskpool_config: DictConfig):
|
|
112
|
+
super().__init__(taskpool_config)
|
|
113
|
+
|
|
114
|
+
def prepare_dataset_config(self, dataset_config: DictConfig):
|
|
115
|
+
if not hasattr(dataset_config, "type"):
|
|
116
|
+
with open_dict(dataset_config):
|
|
117
|
+
dataset_config["type"] = self.config.dataset_type
|
|
118
|
+
return dataset_config
|
|
119
|
+
|
|
120
|
+
def prepare_task_config(self, task_config: DictConfig):
|
|
121
|
+
# set default values for keys that are not present in per task configuration
|
|
122
|
+
for key in ["num_workers", "batch_size", "fast_dev_run"]:
|
|
123
|
+
if not hasattr(task_config, key):
|
|
124
|
+
with open_dict(task_config):
|
|
125
|
+
task_config[key] = self.config[key]
|
|
126
|
+
return task_config
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def clip_model(self):
|
|
130
|
+
if self._clip_model is None:
|
|
131
|
+
self._clip_model = CLIPModel.from_pretrained(self.config["clip_model"])
|
|
132
|
+
return self._clip_model
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def clip_processor(self):
|
|
136
|
+
if self._clip_processor is None:
|
|
137
|
+
self._clip_processor = CLIPProcessor.from_pretrained(
|
|
138
|
+
self.config["clip_model"]
|
|
139
|
+
)
|
|
140
|
+
return self._clip_processor
|
|
141
|
+
|
|
142
|
+
def load_task(self, task_name_or_config: str | DictConfig):
|
|
143
|
+
if isinstance(task_name_or_config, str):
|
|
144
|
+
task_config = self.get_task_config(task_name_or_config)
|
|
145
|
+
else:
|
|
146
|
+
task_config = task_name_or_config
|
|
147
|
+
task_config = self.prepare_task_config(task_config)
|
|
148
|
+
|
|
149
|
+
# load the task from the configuration
|
|
150
|
+
task = CLIPImageClassificationTask(task_config)
|
|
151
|
+
task._fabric = self._fabric
|
|
152
|
+
task._taskpool = self
|
|
153
|
+
task._clip_processor = self.clip_processor
|
|
154
|
+
|
|
155
|
+
return task
|
|
156
|
+
|
|
157
|
+
def evaluate(self, model: CLIPVisionModel, name=None):
|
|
158
|
+
"""
|
|
159
|
+
Evaluate the model on the image classification task.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
model (CLIPVisionModel): The vision model to evaluate.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
dict: A dictionary containing the evaluation results for each task.
|
|
166
|
+
"""
|
|
167
|
+
# if the fabric is not set, and we have a GPU, create a fabric instance
|
|
168
|
+
if self._fabric is None and torch.cuda.is_available():
|
|
169
|
+
self._fabric = L.Fabric(devices=1)
|
|
170
|
+
self._fabric.launch()
|
|
171
|
+
|
|
172
|
+
# CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
|
|
173
|
+
self.clip_model.vision_model = model
|
|
174
|
+
report = {}
|
|
175
|
+
training_params, all_params = count_parameters(model)
|
|
176
|
+
report["model_info"] = {
|
|
177
|
+
"trainable_params": training_params,
|
|
178
|
+
"all_params": all_params,
|
|
179
|
+
"trainable_percentage": training_params / all_params,
|
|
180
|
+
}
|
|
181
|
+
if name is not None:
|
|
182
|
+
report["model_info"]["name"] = name
|
|
183
|
+
for task_name in tqdm(self.task_names, desc="Evaluating tasks"):
|
|
184
|
+
task = self.load_task(task_name)
|
|
185
|
+
result = task.evaluate(self.clip_model)
|
|
186
|
+
report[task_name] = result
|
|
187
|
+
|
|
188
|
+
# calculate the average accuracy and loss
|
|
189
|
+
if "average" not in report:
|
|
190
|
+
report["average"] = {}
|
|
191
|
+
accuracies = [
|
|
192
|
+
value["accuracy"]
|
|
193
|
+
for key, value in report.items()
|
|
194
|
+
if "accuracy" in value
|
|
195
|
+
]
|
|
196
|
+
if len(accuracies) > 0:
|
|
197
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
198
|
+
report["average"]["accuracy"] = average_accuracy
|
|
199
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
200
|
+
if len(losses) > 0:
|
|
201
|
+
average_loss = sum(losses) / len(losses)
|
|
202
|
+
report["average"]["loss"] = average_loss
|
|
203
|
+
|
|
204
|
+
log.info(f"Results for taskpool {self.config.name}: {report}")
|
|
205
|
+
if self._fabric.is_global_zero and len(self._fabric._loggers) > 0:
|
|
206
|
+
with open(
|
|
207
|
+
os.path.join(self._fabric.logger.log_dir, "report.json"), "w"
|
|
208
|
+
) as fp:
|
|
209
|
+
json.dump(report, fp)
|
|
210
|
+
return report
|