fusion-bench 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/PKG-INFO +11 -6
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/README.md +10 -5
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/method/__init__.py +33 -0
- fusion_bench-0.2.2/fusion_bench/compat/method/base_algorithm.py +50 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/modelpool/__init__.py +36 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/modelpool/base_pool.py +95 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/modelpool/huggingface_clip_vision.py +53 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/taskpool/__init__.py +35 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/taskpool/base_pool.py +47 -3
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +1 -6
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/constants/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/dataset/clip_dataset.py +3 -1
- fusion_bench-0.2.2/fusion_bench/dataset/gsm8k.py +57 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/__init__.py +21 -3
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/ada_svd/clip_vision.py +1 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/adamerging/clip_task_wise_adamerging.py +58 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/adamerging/entropy_loss.py +6 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/adamerging/layer_wise_adamerging.py +59 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/base_algorithm.py +10 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/classification/clip_finetune.py +54 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +40 -0
- fusion_bench-0.2.2/fusion_bench/method/dare/__init__.py +2 -0
- fusion_bench-0.2.2/fusion_bench/method/dare/task_arithmetic.py +68 -0
- fusion_bench-0.2.2/fusion_bench/method/dare/utils.py +87 -0
- fusion_bench-0.2.2/fusion_bench/method/dawe/warppers/__init__.py +12 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/depth_upscaling/depth_upscaling.py +12 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +24 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/dummy.py +8 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/fisher_merging/clip_fisher_merging.py +46 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/fisher_merging/fisher_merging.py +105 -46
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +48 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/linear/__init__.py +1 -0
- fusion_bench-0.2.2/fusion_bench/method/linear/linear_interpolation.py +60 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/mixture_of_experts/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/mixture_of_experts/mixtral_merging.py +12 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +61 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/llama_magnitude_prune.py +44 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/llama_random_prune.py +49 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/llama_wanda_prune.py +78 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/magnitude_diff_pruning.py +49 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/prune_utils.py +50 -1
- fusion_bench-0.2.2/fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/wanda_utils/data.py +2 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/wanda_utils/eval.py +61 -6
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/wanda_utils/layerwrapper.py +25 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/wanda_utils/prune.py +67 -3
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/wanda_utils/prune_opt.py +78 -4
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/wanda_utils/sparsegpt.py +38 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pwe_moe/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pwe_moe/module.py +2 -2
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pwe_moe/phn/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pwe_moe/phn/solvers.py +3 -3
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/regmean/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/simple_average.py +3 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/slerp/slerp.py +24 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/smile_upscaling/singular_projection_merging.py +50 -7
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +42 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/smile_upscaling/smile_upscaling.py +119 -0
- fusion_bench-0.2.2/fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench-0.2.2/fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench-0.2.2/fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/sparselo/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/task_arithmetic/task_arithmetic.py +27 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/ties_merging/ties_merging.py +33 -3
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/ties_merging/ties_merging_utils.py +125 -1
- fusion_bench-0.2.2/fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench-0.2.2/fusion_bench/method/trust_region/clip_task_arithmetic.py +196 -0
- fusion_bench-0.2.2/fusion_bench/method/trust_region/utils.py +58 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/we_moe/clip_we_moe.py +70 -55
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/we_moe/we_moe.py +69 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/weighted_average/llama.py +11 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/text_to_image_generation/__init__.py +1 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/mixins/clip_classification.py +1 -4
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/causal_lm/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/nyuv2_modelpool.py +1 -2
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/seq2seq_lm/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/seq2seq_lm/modelpool.py +2 -4
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/hf_clip.py +11 -2
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/masks/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_losparse_llama/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/nyuv2/lightning_module.py +1 -7
- fusion_bench-0.2.2/fusion_bench/models/sparse_we_moe.py +429 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/we_moe.py +1 -6
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/wrappers/layer_wise_fusion.py +1 -4
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/programs/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/programs/fabric_fusion_program.py +3 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/scripts/clip/convert_checkpoint.py +0 -2
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/taskpool/__init__.py +2 -2
- fusion_bench-0.2.2/fusion_bench/taskpool/clip_vision/__init__.py +3 -0
- fusion_bench-0.2.2/fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench-0.2.2/fusion_bench/taskpool/clip_vision/taskpool.py +331 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/taskpool/dummy.py +0 -4
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/taskpool/gpt2_text_classification.py +1 -2
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/__init__.py +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +1 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +9 -5
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/__init__.py +5 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/auto.py +1 -3
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/data.py +27 -6
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/instantiate.py +1 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/json.py +4 -7
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/parameters.py +21 -3
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench.egg-info/PKG-INFO +11 -6
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench.egg-info/SOURCES.txt +37 -1
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/dtd.yaml +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/eurosat.yaml +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/mnist.yaml +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/resisc45.yaml +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/sun397.yaml +1 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/svhn.yaml +1 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench-0.2.2/fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench-0.2.2/fusion_bench_config/fabric/auto.yaml +10 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench-0.2.2/fusion_bench_config/method/dare/task_arithmetic.yaml +5 -0
- fusion_bench-0.2.2/fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench-0.2.2/fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench-0.2.2/fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench-0.2.2/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench-0.2.2/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +7 -0
- fusion_bench-0.2.2/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench-0.2.1/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml → fusion_bench-0.2.2/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +16 -6
- fusion_bench-0.2.2/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/pyproject.toml +1 -1
- fusion_bench-0.2.1/fusion_bench/compat/method/base_algorithm.py +0 -29
- fusion_bench-0.2.1/fusion_bench/method/dawe/warppers/__init__.py +0 -1
- fusion_bench-0.2.1/fusion_bench/method/pruning/wanda_utils/__init__.py +0 -3
- fusion_bench-0.2.1/fusion_bench/taskpool/clip_vision/__init__.py +0 -1
- fusion_bench-0.2.1/fusion_bench/taskpool/clip_vision/taskpool.py +0 -196
- fusion_bench-0.2.1/fusion_bench_config/fabric/auto.yaml +0 -2
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/LICENSE +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/taskpool/clip_image_classification.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/constants/paths.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/dataset/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/dataset/gpt2_glue.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/dataset/image_dataset.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/dataset/nyuv2.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/ada_svd/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/adamerging/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/adamerging/task_wise_adamerging.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/analysis/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/analysis/task_vector_cos_similarity.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/classification/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/concrete_subspace/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/dawe/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/dawe/dawe_for_clip.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/dawe/warppers/dawe_model.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/depth_upscaling/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/ensemble.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/fisher_merging/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/linear/simple_average_for_llama.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/linear/task_arithmetic_for_llama.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/model_recombination.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pruning/wanda_utils/ablate.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pwe_moe/clip_pwe_moe.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/pwe_moe/utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/regmean/clip_regmean.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/regmean/gpt2_regmean.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/regmean/regmean.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/slerp/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/slerp/slerp_utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/smile_upscaling/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/sparselo/sparselo.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/task_arithmetic/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/ties_merging/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/we_moe/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/weighted_average/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/method/weighted_average/weighted_average.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/nyuv2/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/nyuv2/depth.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/nyuv2/loss.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/nyuv2/noise.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/nyuv2/normal.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/nyuv2/segmentation.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/text_to_image_generation/compressibility.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/mixins/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/mixins/lightning_fabric.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/mixins/rich_live.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/mixins/serialization.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/mixins/simple_profiler.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/base_pool.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/causal_lm/causal_lm.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/clip_vision/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/clip_vision/modelpool.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/huggingface_automodel.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/modelpool/huggingface_gpt2_classification.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/linearized/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/linearized/linearized_model_utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/linearized/vision_model.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/masks/mask_model.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_losparse_llama/losparse_linear.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_losparse_llama/register.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_losparse_llama/utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/modeling_smile_mistral/register.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/nyuv2/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/nyuv2/aspp.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/nyuv2/resnet.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/nyuv2/resnet_dilated.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/parameter_dict.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/separate_io.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/smile_moe/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/smile_moe/linear.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/wrappers/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/wrappers/ensemble.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/models/wrappers/task_wise_fusion.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/optim/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/optim/mezo.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/programs/base_program.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/scripts/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/scripts/cli.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/scripts/clip/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/scripts/imgui.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/scripts/nyuv2_mtl_train.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/scripts/webui.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/taskpool/base_pool.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/taskpool/nyuv2_taskpool.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/base_task.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/classification.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/cifar10.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/cifar100.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/clip_dataset.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/dtd.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/eurosat.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/flower102.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/gtsrb.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/imagenet.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/mnist.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/rendered_sst2.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/resisc45.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/stanford_cars.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/stl10.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/sun397.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/svhn.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/clip_classification/tiny_imagenet.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/cache_utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/devices.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/dtype.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/functools.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/hydra_utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/lazy_imports.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/pylogger.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/rich_utils.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/state_dict_arithmetic.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/timer.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/utils/type.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench.egg-info/dependency_links.txt +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench.egg-info/entry_points.txt +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench.egg-info/requires.txt +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench.egg-info/top_level.txt +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/README.md +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/cifar10.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/cifar100.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/dtd.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/eurosat.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/mnist.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/resisc45.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/sun397.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/svhn.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/cifar10.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/cifar100.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/dtd.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/eurosat.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/mnist.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/resisc45.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/sun397.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/svhn.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/fabric_logger/tensorboard_logger.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/hydra/default.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/hydra/help/fusion_bench_help.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/hydra/job_logging/rich_logging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/llama_magnitude_pruning.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/llama_weighted_average.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/adamerging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/clip_finetune.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/depth_upscaling.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/dummy.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/ensemble/max_model_predictor.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/fisher_merging/fisher_merging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/linear/simple_average_for_llama.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/linear/weighted_average.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/magnitude_diff_pruning.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/mixtral_moe_merging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/mixtral_moe_upscaling.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/model_recombination.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/pruning/llama_random_pruning.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/regmean/clip_regmean.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/regmean/regmean.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/simple_average.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/slerp/slerp.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/task_arithmetic.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/task_vector_cos_similarity.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/method/ties_merging.yaml +0 -0
- {fusion_bench-0.2.1/fusion_bench_config/method → fusion_bench-0.2.2/fusion_bench_config/method/wemoe}/weight_ensembling_moe.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/model/flan-t5/generate_flan-t5.sh +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/automodelpool.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/nyuv2_modelpool.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/nyuv2_config.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/nyuv2_mtl_train.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/dummy.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench_config/taskpool/nyuv2_taskpool.yaml +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/setup.cfg +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/tests/test_depth_upscaling.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/tests/test_simple_average.py +0 -0
- {fusion_bench-0.2.1 → fusion_bench-0.2.2}/tests/test_weighed_ensemble.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -47,7 +47,7 @@ Requires-Dist: scipy
|
|
|
47
47
|
Requires-Dist: h5py
|
|
48
48
|
Requires-Dist: pytest
|
|
49
49
|
|
|
50
|
-
# FusionBench: A Comprehensive Benchmark of Deep Model Fusion
|
|
50
|
+
# FusionBench: A Comprehensive Benchmark/ToolKit of Deep Model Fusion
|
|
51
51
|
|
|
52
52
|
[](http://arxiv.org/abs/2406.03280)
|
|
53
53
|
[](https://github.com/tanganke/fusion_bench/blob/main/LICENSE)
|
|
@@ -57,8 +57,6 @@ Requires-Dist: pytest
|
|
|
57
57
|
[](https://github.com/psf/black)
|
|
58
58
|
[](https://github.com/google/yamlfmt)
|
|
59
59
|
|
|
60
|
-
> [!WARNING]
|
|
61
|
-
> This project is still in testing phase as the API may be subject to change. Please report any issues you encounter.
|
|
62
60
|
|
|
63
61
|
> [!TIP]
|
|
64
62
|
> Documentation is available at [tanganke.github.io/fusion_bench/](https://tanganke.github.io/fusion_bench/).
|
|
@@ -70,6 +68,12 @@ FusionBench is a benchmark suite designed to evaluate the performance of various
|
|
|
70
68
|
|
|
71
69
|
Projects based on FusionBench:
|
|
72
70
|
|
|
71
|
+
<details>
|
|
72
|
+
<summary>Jinluan Yang et al. Mitigating the Backdoor Effect for Multi-Task Model Merging via Safety-Aware Subspace. Oct, 2024. http://arxiv.org/abs/2410.13910</summary>
|
|
73
|
+
|
|
74
|
+
<img width="1018" alt="image" src="https://github.com/user-attachments/assets/679aaa7e-0506-4e09-a12a-345c12cf529f">
|
|
75
|
+
|
|
76
|
+
</details>
|
|
73
77
|
<details>
|
|
74
78
|
<summary>Anke Tang et al. SMILE: Zero-Shot Sparse Mixture of Low-Rank Experts Construction From Pre-Trained Foundation Models. Aug, 2024. http://arxiv.org/abs/2408.10174</summary>
|
|
75
79
|
|
|
@@ -123,8 +127,7 @@ Read the [CLI documentation](https://tanganke.github.io/fusion_bench/cli/fusion_
|
|
|
123
127
|
## Implement your own model fusion algorithm
|
|
124
128
|
|
|
125
129
|
```python
|
|
126
|
-
from fusion_bench
|
|
127
|
-
from fusion_bench.modelpool import BaseModelPool
|
|
130
|
+
from fusion_bench import BaseModelFusionAlgorithm, BaseModelPool
|
|
128
131
|
|
|
129
132
|
class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
130
133
|
"""
|
|
@@ -132,6 +135,8 @@ class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
|
132
135
|
"""
|
|
133
136
|
|
|
134
137
|
# _config_mapping maps the attribution to the corresponding key in the configuration file.
|
|
138
|
+
# this is optional and can be used to serialize the object to a configuration file.
|
|
139
|
+
# `self.config.hyperparam_1` will be mapped to the attribute `hyperparam_attr_1`.
|
|
135
140
|
_config_mapping = BaseModelFusionAlgorithm._config_mapping | {
|
|
136
141
|
"hyperparam_attr_1": "hyperparam_1",
|
|
137
142
|
"hyperparam_attr_2": "hyperparam_2",
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# FusionBench: A Comprehensive Benchmark of Deep Model Fusion
|
|
1
|
+
# FusionBench: A Comprehensive Benchmark/ToolKit of Deep Model Fusion
|
|
2
2
|
|
|
3
3
|
[](http://arxiv.org/abs/2406.03280)
|
|
4
4
|
[](https://github.com/tanganke/fusion_bench/blob/main/LICENSE)
|
|
@@ -8,8 +8,6 @@
|
|
|
8
8
|
[](https://github.com/psf/black)
|
|
9
9
|
[](https://github.com/google/yamlfmt)
|
|
10
10
|
|
|
11
|
-
> [!WARNING]
|
|
12
|
-
> This project is still in testing phase as the API may be subject to change. Please report any issues you encounter.
|
|
13
11
|
|
|
14
12
|
> [!TIP]
|
|
15
13
|
> Documentation is available at [tanganke.github.io/fusion_bench/](https://tanganke.github.io/fusion_bench/).
|
|
@@ -21,6 +19,12 @@ FusionBench is a benchmark suite designed to evaluate the performance of various
|
|
|
21
19
|
|
|
22
20
|
Projects based on FusionBench:
|
|
23
21
|
|
|
22
|
+
<details>
|
|
23
|
+
<summary>Jinluan Yang et al. Mitigating the Backdoor Effect for Multi-Task Model Merging via Safety-Aware Subspace. Oct, 2024. http://arxiv.org/abs/2410.13910</summary>
|
|
24
|
+
|
|
25
|
+
<img width="1018" alt="image" src="https://github.com/user-attachments/assets/679aaa7e-0506-4e09-a12a-345c12cf529f">
|
|
26
|
+
|
|
27
|
+
</details>
|
|
24
28
|
<details>
|
|
25
29
|
<summary>Anke Tang et al. SMILE: Zero-Shot Sparse Mixture of Low-Rank Experts Construction From Pre-Trained Foundation Models. Aug, 2024. http://arxiv.org/abs/2408.10174</summary>
|
|
26
30
|
|
|
@@ -74,8 +78,7 @@ Read the [CLI documentation](https://tanganke.github.io/fusion_bench/cli/fusion_
|
|
|
74
78
|
## Implement your own model fusion algorithm
|
|
75
79
|
|
|
76
80
|
```python
|
|
77
|
-
from fusion_bench
|
|
78
|
-
from fusion_bench.modelpool import BaseModelPool
|
|
81
|
+
from fusion_bench import BaseModelFusionAlgorithm, BaseModelPool
|
|
79
82
|
|
|
80
83
|
class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
81
84
|
"""
|
|
@@ -83,6 +86,8 @@ class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
|
83
86
|
"""
|
|
84
87
|
|
|
85
88
|
# _config_mapping maps the attribution to the corresponding key in the configuration file.
|
|
89
|
+
# this is optional and can be used to serialize the object to a configuration file.
|
|
90
|
+
# `self.config.hyperparam_1` will be mapped to the attribute `hyperparam_attr_1`.
|
|
86
91
|
_config_mapping = BaseModelFusionAlgorithm._config_mapping | {
|
|
87
92
|
"hyperparam_attr_1": "hyperparam_1",
|
|
88
93
|
"hyperparam_attr_2": "hyperparam_2",
|
|
@@ -4,6 +4,13 @@ from .base_algorithm import ModelFusionAlgorithm
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class AlgorithmFactory:
|
|
7
|
+
"""
|
|
8
|
+
Factory class to create and manage different model fusion algorithms.
|
|
9
|
+
|
|
10
|
+
This class provides methods to create algorithms based on a given configuration,
|
|
11
|
+
register new algorithms, and list available algorithms.
|
|
12
|
+
"""
|
|
13
|
+
|
|
7
14
|
_aglorithms = {
|
|
8
15
|
# single task learning (fine-tuning)
|
|
9
16
|
"clip_finetune": ".classification.clip_finetune.ImageClassificationFineTuningForCLIP",
|
|
@@ -32,6 +39,7 @@ class AlgorithmFactory:
|
|
|
32
39
|
"clip_weight_ensembling_moe": ".we_moe.clip_we_moe.CLIPWeightEnsemblingMoEAlgorithm",
|
|
33
40
|
"model_recombination": ".model_recombination.ModelRecombinationAlgorithm",
|
|
34
41
|
"smile_upscaling": ".smile_upscaling.smile_upscaling.SmileUpscalingAlgorithm",
|
|
42
|
+
"sparse_clip_weight_ensembling_moe": "fusion_bench.method.SparseCLIPWeightEnsemblingMoEAlgorithm",
|
|
35
43
|
"smile_mistral_upscaling": ".smile_upscaling.smile_mistral_upscaling.SmileMistralUpscalingAlgorithm",
|
|
36
44
|
# pruning methods
|
|
37
45
|
"magnitude_diff_pruning": ".pruning.MagnitudeDiffPruningAlgorithm",
|
|
@@ -41,6 +49,18 @@ class AlgorithmFactory:
|
|
|
41
49
|
|
|
42
50
|
@staticmethod
|
|
43
51
|
def create_algorithm(method_config: DictConfig) -> ModelFusionAlgorithm:
|
|
52
|
+
"""
|
|
53
|
+
Create an instance of a model fusion algorithm based on the provided configuration.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
method_config (DictConfig): The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
ModelFusionAlgorithm: An instance of the specified algorithm.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If 'name' attribute is not found in the configuration or does not match any known algorithm names.
|
|
63
|
+
"""
|
|
44
64
|
from fusion_bench.utils import import_object
|
|
45
65
|
|
|
46
66
|
algorithm_name = method_config.name
|
|
@@ -58,10 +78,23 @@ class AlgorithmFactory:
|
|
|
58
78
|
|
|
59
79
|
@staticmethod
|
|
60
80
|
def register_algorithm(name: str, algorithm_cls):
|
|
81
|
+
"""
|
|
82
|
+
Register a new algorithm with the factory.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
name (str): The name of the algorithm.
|
|
86
|
+
algorithm_cls: The class of the algorithm to register.
|
|
87
|
+
"""
|
|
61
88
|
AlgorithmFactory._aglorithms[name] = algorithm_cls
|
|
62
89
|
|
|
63
90
|
@classmethod
|
|
64
91
|
def available_algorithms(cls):
|
|
92
|
+
"""
|
|
93
|
+
Get a list of available algorithms.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
list: A list of available algorithm names.
|
|
97
|
+
"""
|
|
65
98
|
return list(cls._aglorithms.keys())
|
|
66
99
|
|
|
67
100
|
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from omegaconf import DictConfig
|
|
5
|
+
|
|
6
|
+
__all__ = ["ModelFusionAlgorithm"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelFusionAlgorithm(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for model fusion algorithms (for v0.1.x versions, deprecated).
|
|
12
|
+
For implementing new method, use `fusion_bench.method.BaseModelFusionAlgorithm` instead.
|
|
13
|
+
|
|
14
|
+
This class provides a template for implementing model fusion algorithms.
|
|
15
|
+
Subclasses must implement the `run` method to define the fusion logic.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
config (DictConfig): Configuration for the algorithm.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, algorithm_config: Optional[DictConfig] = None):
|
|
22
|
+
"""
|
|
23
|
+
Initialize the model fusion algorithm with the given configuration.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
algorithm_config (Optional[DictConfig]): Configuration for the algorithm. Defaults to an empty configuration if not provided.
|
|
27
|
+
Get access to the configuration using `self.config`.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
if algorithm_config is None:
|
|
31
|
+
algorithm_config = DictConfig({})
|
|
32
|
+
self.config = algorithm_config
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def run(self, modelpool):
|
|
36
|
+
"""
|
|
37
|
+
Fuse the models in the given model pool.
|
|
38
|
+
|
|
39
|
+
This method must be implemented by subclasses to define the fusion logic.
|
|
40
|
+
`modelpool` is an object responsible for managing the models to be fused and optional datasets to be used for fusion.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
modelpool: The pool of models to fuse.
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
>>> algorithm = SimpleAverageAlgorithm()
|
|
47
|
+
>>> modelpool = ModelPool()
|
|
48
|
+
>>> merged_model = algorithm.fuse(modelpool)
|
|
49
|
+
"""
|
|
50
|
+
pass
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# flake8: noqa F401
|
|
1
2
|
from omegaconf import DictConfig
|
|
2
3
|
|
|
3
4
|
from fusion_bench.modelpool.huggingface_gpt2_classification import (
|
|
@@ -11,6 +12,13 @@ from .huggingface_clip_vision import HuggingFaceClipVisionPool
|
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class ModelPoolFactory:
|
|
15
|
+
"""
|
|
16
|
+
Factory class to create and manage different model pools.
|
|
17
|
+
|
|
18
|
+
This class provides methods to create model pools based on a given configuration,
|
|
19
|
+
register new model pools, and list available model pools.
|
|
20
|
+
"""
|
|
21
|
+
|
|
14
22
|
_modelpool = {
|
|
15
23
|
"NYUv2ModelPool": ".nyuv2_modelpool.NYUv2ModelPool",
|
|
16
24
|
"huggingface_clip_vision": HuggingFaceClipVisionPool,
|
|
@@ -27,6 +35,21 @@ class ModelPoolFactory:
|
|
|
27
35
|
|
|
28
36
|
@staticmethod
|
|
29
37
|
def create_modelpool(modelpool_config: DictConfig) -> ModelPool:
|
|
38
|
+
"""
|
|
39
|
+
Create an instance of a model pool based on the provided configuration.
|
|
40
|
+
This is for v0.1.x versions, deprecated.
|
|
41
|
+
For implementing new model pool, use `fusion_bench.modelpool.BaseModelPool` instead.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
modelpool_config (DictConfig): The configuration for the model pool.
|
|
45
|
+
Must contain a 'type' attribute that specifies the type of the model pool.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
ModelPool: An instance of the specified model pool.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types.
|
|
52
|
+
"""
|
|
30
53
|
from fusion_bench.utils import import_object
|
|
31
54
|
|
|
32
55
|
modelpool_type = modelpool_config.get("type")
|
|
@@ -46,10 +69,23 @@ class ModelPoolFactory:
|
|
|
46
69
|
|
|
47
70
|
@staticmethod
|
|
48
71
|
def register_modelpool(name: str, modelpool_cls):
|
|
72
|
+
"""
|
|
73
|
+
Register a new model pool with the factory.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
name (str): The name of the model pool.
|
|
77
|
+
modelpool_cls: The class of the model pool to register.
|
|
78
|
+
"""
|
|
49
79
|
ModelPoolFactory._modelpool[name] = modelpool_cls
|
|
50
80
|
|
|
51
81
|
@classmethod
|
|
52
82
|
def available_modelpools(cls):
|
|
83
|
+
"""
|
|
84
|
+
Get a list of available model pools.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
list: A list of available model pool names.
|
|
88
|
+
"""
|
|
53
89
|
return list(cls._modelpool.keys())
|
|
54
90
|
|
|
55
91
|
|
|
@@ -18,11 +18,19 @@ log = logging.getLogger(__name__)
|
|
|
18
18
|
class ModelPool(ABC):
|
|
19
19
|
"""
|
|
20
20
|
This is the base class for all modelpools.
|
|
21
|
+
For verison v0.1.x, deprecated.
|
|
22
|
+
Please implemente new algorithms use `fusion_bench.modelpool.BaseModelPool`.
|
|
21
23
|
"""
|
|
22
24
|
|
|
23
25
|
_model_names = None
|
|
24
26
|
|
|
25
27
|
def __init__(self, modelpool_config: Optional[DictConfig] = None):
|
|
28
|
+
"""
|
|
29
|
+
Initialize the ModelPool with the given configuration.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
modelpool_config (Optional[DictConfig]): The configuration for the model pool.
|
|
33
|
+
"""
|
|
26
34
|
super().__init__()
|
|
27
35
|
self.config = modelpool_config
|
|
28
36
|
|
|
@@ -35,6 +43,12 @@ class ModelPool(ABC):
|
|
|
35
43
|
self._model_names = model_names
|
|
36
44
|
|
|
37
45
|
def __len__(self):
|
|
46
|
+
"""
|
|
47
|
+
Return the number of models in the model pool, exclude special models such as `_pretrained_`.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
int: The number of models in the model pool.
|
|
51
|
+
"""
|
|
38
52
|
return len(self.model_names)
|
|
39
53
|
|
|
40
54
|
@property
|
|
@@ -55,6 +69,9 @@ class ModelPool(ABC):
|
|
|
55
69
|
def has_pretrained(self):
|
|
56
70
|
"""
|
|
57
71
|
Check if the pretrained model is available in the model pool.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
bool: True if the pretrained model is available, False otherwise.
|
|
58
75
|
"""
|
|
59
76
|
for model_config in self.config["models"]:
|
|
60
77
|
if model_config.get("name", None) == "_pretrained_":
|
|
@@ -121,22 +138,46 @@ class ModelPool(ABC):
|
|
|
121
138
|
torch.save(model.state_dict(), path)
|
|
122
139
|
|
|
123
140
|
def models(self):
|
|
141
|
+
"""
|
|
142
|
+
Generator that yields models from the model pool.
|
|
143
|
+
|
|
144
|
+
Yields:
|
|
145
|
+
nn.Module: The next model in the model pool.
|
|
146
|
+
"""
|
|
124
147
|
for model_name in self.model_names:
|
|
125
148
|
yield self.load_model(model_name)
|
|
126
149
|
|
|
127
150
|
def named_models(self):
|
|
151
|
+
"""
|
|
152
|
+
Generator that yields model names and models from the model pool.
|
|
153
|
+
|
|
154
|
+
Yields:
|
|
155
|
+
tuple: A tuple containing the model name and the model.
|
|
156
|
+
"""
|
|
128
157
|
for model_name in self.model_names:
|
|
129
158
|
yield model_name, self.load_model(model_name)
|
|
130
159
|
|
|
131
160
|
def get_train_dataset(self, model_name: str):
|
|
132
161
|
"""
|
|
133
162
|
Get the training dataset for the model.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
model_name (str): The name of the model for which to get the training dataset.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Any: The training dataset for the model.
|
|
134
169
|
"""
|
|
135
170
|
raise NotImplementedError
|
|
136
171
|
|
|
137
172
|
def get_test_dataset(self, model_name: str):
|
|
138
173
|
"""
|
|
139
174
|
Get the testing dataset for the model.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
model_name (str): The name of the model for which to get the testing dataset.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Any: The testing dataset for the model.
|
|
140
181
|
"""
|
|
141
182
|
raise NotImplementedError
|
|
142
183
|
|
|
@@ -144,18 +185,27 @@ class ModelPool(ABC):
|
|
|
144
185
|
"""
|
|
145
186
|
Setup the taskpool before evaluation.
|
|
146
187
|
Such as setting the fabric, processor, tokenizer, etc.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
taskpool (Any): The taskpool to setup.
|
|
147
191
|
"""
|
|
148
192
|
pass
|
|
149
193
|
|
|
150
194
|
def to_modellist(self) -> List[nn.Module]:
|
|
151
195
|
"""
|
|
152
196
|
Convert the model pool to a list of models.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
list: A list of models.
|
|
153
200
|
"""
|
|
154
201
|
return [self.load_model(m) for m in self.model_names]
|
|
155
202
|
|
|
156
203
|
def to_modeldict(self) -> Dict[str, nn.Module]:
|
|
157
204
|
"""
|
|
158
205
|
Convert the model pool to a dictionary of models.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
dict: A dictionary of models.
|
|
159
209
|
"""
|
|
160
210
|
return {m: self.load_model(m) for m in self.model_names}
|
|
161
211
|
|
|
@@ -170,6 +220,13 @@ class ListModelPool(ModelPool):
|
|
|
170
220
|
models: List[nn.Module],
|
|
171
221
|
has_pretraned: bool = False,
|
|
172
222
|
):
|
|
223
|
+
"""
|
|
224
|
+
Initialize the ListModelPool with the given list of models.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
models (List[nn.Module]): The list of models.
|
|
228
|
+
has_pretraned (bool): Whether the first model in the list is pretrained.
|
|
229
|
+
"""
|
|
173
230
|
modelpool_config = {}
|
|
174
231
|
modelpool_config["models"] = []
|
|
175
232
|
model_dict = {}
|
|
@@ -188,6 +245,16 @@ class ListModelPool(ModelPool):
|
|
|
188
245
|
super().__init__(DictConfig(modelpool_config))
|
|
189
246
|
|
|
190
247
|
def load_model(self, model_config: str | DictConfig, copy=True) -> nn.Module:
|
|
248
|
+
"""
|
|
249
|
+
Load the model from the model pool.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
model_config (str | DictConfig): The model name or the configuration dictionary for the model to load.
|
|
253
|
+
copy (bool): Whether to return a copy of the model, defaults to `True`.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
nn.Module: The loaded model.
|
|
257
|
+
"""
|
|
191
258
|
if isinstance(model_config, str):
|
|
192
259
|
model_config = self.get_model_config(model_config)
|
|
193
260
|
model_name = model_config["name"]
|
|
@@ -203,6 +270,12 @@ class DictModelPool(ModelPool):
|
|
|
203
270
|
"""
|
|
204
271
|
|
|
205
272
|
def __init__(self, model_dict: Dict[str, nn.Module]):
|
|
273
|
+
"""
|
|
274
|
+
Initialize the DictModelPool with the given dictionary of models.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
model_dict (Dict[str, nn.Module]): The dictionary of models.
|
|
278
|
+
"""
|
|
206
279
|
modelpool_config = {}
|
|
207
280
|
modelpool_config["models"] = []
|
|
208
281
|
for model_name, model in model_dict.items():
|
|
@@ -211,6 +284,16 @@ class DictModelPool(ModelPool):
|
|
|
211
284
|
super().__init__(DictConfig(modelpool_config))
|
|
212
285
|
|
|
213
286
|
def load_model(self, model_config: str | DictConfig, copy=True) -> nn.Module:
|
|
287
|
+
"""
|
|
288
|
+
Load the model from the model pool.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
model_config (str | DictConfig): The configuration dictionary for the model to load.
|
|
292
|
+
copy (bool): Whether to return a copy of the model.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
nn.Module: The loaded model.
|
|
296
|
+
"""
|
|
214
297
|
if isinstance(model_config, str):
|
|
215
298
|
model_config = self.get_model_config(model_config)
|
|
216
299
|
model_name = model_config["name"]
|
|
@@ -221,6 +304,18 @@ class DictModelPool(ModelPool):
|
|
|
221
304
|
|
|
222
305
|
|
|
223
306
|
def to_modelpool(obj: List[nn.Module], **kwargs):
|
|
307
|
+
"""
|
|
308
|
+
Convert the given object to a model pool.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
obj (List[nn.Module]): The object to convert to a model pool.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
ModelPool: The converted model pool.
|
|
315
|
+
|
|
316
|
+
Raises:
|
|
317
|
+
ValueError: If the object cannot be converted to a model pool.
|
|
318
|
+
"""
|
|
224
319
|
if isinstance(obj, (ModelPool, BaseModelPool)):
|
|
225
320
|
return obj
|
|
226
321
|
elif isinstance(obj, (list, tuple)) and all(isinstance(m, nn.Module) for m in obj):
|
{fusion_bench-0.2.1 → fusion_bench-0.2.2}/fusion_bench/compat/modelpool/huggingface_clip_vision.py
RENAMED
|
@@ -29,6 +29,9 @@ class HuggingFaceClipVisionPool(ModelPool):
|
|
|
29
29
|
|
|
30
30
|
@property
|
|
31
31
|
def clip_processor(self):
|
|
32
|
+
"""
|
|
33
|
+
Returns the CLIP processor. If it's not already initialized, it initializes it using the path of the pretrained model.
|
|
34
|
+
"""
|
|
32
35
|
if self._clip_processor is None:
|
|
33
36
|
if "_pretrained_" in self._model_names:
|
|
34
37
|
self._clip_processor = CLIPProcessor.from_pretrained(
|
|
@@ -76,12 +79,33 @@ class HuggingFaceClipVisionPool(ModelPool):
|
|
|
76
79
|
model.save_pretrained(path)
|
|
77
80
|
|
|
78
81
|
def get_tta_dataset_config(self, dataset: str):
|
|
82
|
+
"""
|
|
83
|
+
Retrieve the configuration for a TTA (Test-Time Adaptation) dataset.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
dataset (str): The name of the dataset for which to retrieve the configuration.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
DictConfig: The configuration dictionary for the specified dataset.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ValueError: If the specified dataset is not found in the configuration.
|
|
93
|
+
"""
|
|
79
94
|
for dataset_config in self.config.tta_datasets:
|
|
80
95
|
if dataset_config.name == dataset:
|
|
81
96
|
return dataset_config
|
|
82
97
|
raise ValueError(f"Dataset {dataset} not found in config")
|
|
83
98
|
|
|
84
99
|
def prepare_dataset_config(self, dataset_config: DictConfig):
|
|
100
|
+
"""
|
|
101
|
+
Prepare the dataset configuration by setting the dataset type if it's not already set.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
dataset_config (DictConfig): The configuration dictionary for the dataset.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
DictConfig: The updated configuration dictionary for the dataset.
|
|
108
|
+
"""
|
|
85
109
|
if not hasattr(dataset_config, "type"):
|
|
86
110
|
with open_dict(dataset_config):
|
|
87
111
|
dataset_config["type"] = self.config.dataset_type
|
|
@@ -94,6 +118,13 @@ class HuggingFaceClipVisionPool(ModelPool):
|
|
|
94
118
|
"""
|
|
95
119
|
Load the test dataset for the task.
|
|
96
120
|
This method is cached, so the dataset is loaded only once.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
tta_dataset (str): The name of the TTA dataset to load.
|
|
124
|
+
clip_processor (Optional[CLIPProcessor]): The CLIP processor to use for preprocessing the dataset. If None, the default processor is used.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
CLIPDataset: The loaded and preprocessed TTA test dataset.
|
|
97
128
|
"""
|
|
98
129
|
if clip_processor is None:
|
|
99
130
|
# if clip_processor is not provided, try to load the clip_processor from pre-trained model
|
|
@@ -106,6 +137,18 @@ class HuggingFaceClipVisionPool(ModelPool):
|
|
|
106
137
|
return dataset
|
|
107
138
|
|
|
108
139
|
def get_train_dataset_config(self, model_name: str):
|
|
140
|
+
"""
|
|
141
|
+
Retrieve the configuration for a specific training dataset.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model_name (str): The name of the model for which to retrieve the training dataset configuration.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
DictConfig: The configuration dictionary for the specified training dataset.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
ValueError: If the specified training dataset is not found in the configuration.
|
|
151
|
+
"""
|
|
109
152
|
for dataset_config in self.config.train_datasets:
|
|
110
153
|
if dataset_config.name == model_name:
|
|
111
154
|
return dataset_config
|
|
@@ -114,6 +157,16 @@ class HuggingFaceClipVisionPool(ModelPool):
|
|
|
114
157
|
def get_train_dataset(
|
|
115
158
|
self, model_name: str, clip_processor: Optional[CLIPProcessor] = None
|
|
116
159
|
):
|
|
160
|
+
"""
|
|
161
|
+
Load the training dataset for the specified model.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
model_name (str): The name of the model for which to load the training dataset.
|
|
165
|
+
clip_processor (Optional[CLIPProcessor]): The CLIP processor to use for preprocessing the dataset. If None, the default processor is used.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
CLIPDataset: The loaded and preprocessed training dataset.
|
|
169
|
+
"""
|
|
117
170
|
if clip_processor is None:
|
|
118
171
|
# if clip_processor is not provided, try to load the clip_processor from pre-trained model
|
|
119
172
|
clip_processor = self.clip_processor
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# flake8: noqa F401
|
|
1
2
|
from omegaconf import DictConfig
|
|
2
3
|
|
|
3
4
|
from fusion_bench.taskpool.dummy import DummyTaskPool
|
|
@@ -6,6 +7,15 @@ from .base_pool import TaskPool
|
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class TaskPoolFactory:
|
|
10
|
+
"""
|
|
11
|
+
Factory class to create and manage different task pools.
|
|
12
|
+
This is for v0.1.x versions, deprecated.
|
|
13
|
+
For implementing new task pool, use `fusion_bench.taskpool.BaseTaskPool` instead.
|
|
14
|
+
|
|
15
|
+
This class provides methods to create task pools based on a given configuration,
|
|
16
|
+
register new task pools, and list available task pools.
|
|
17
|
+
"""
|
|
18
|
+
|
|
9
19
|
_taskpool_types = {
|
|
10
20
|
"dummy": DummyTaskPool,
|
|
11
21
|
"clip_vit_classification": ".clip_image_classification.CLIPImageClassificationTaskPool",
|
|
@@ -15,6 +25,18 @@ class TaskPoolFactory:
|
|
|
15
25
|
|
|
16
26
|
@staticmethod
|
|
17
27
|
def create_taskpool(taskpool_config: DictConfig):
|
|
28
|
+
"""
|
|
29
|
+
Create an instance of a task pool based on the provided configuration.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
taskpool_config (DictConfig): The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
TaskPool: An instance of the specified task pool.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If 'type' attribute is not found in the configuration or does not match any known task pool types.
|
|
39
|
+
"""
|
|
18
40
|
from fusion_bench.utils import import_object
|
|
19
41
|
|
|
20
42
|
taskpool_type = taskpool_config.get("type")
|
|
@@ -34,10 +56,23 @@ class TaskPoolFactory:
|
|
|
34
56
|
|
|
35
57
|
@staticmethod
|
|
36
58
|
def register_taskpool(name: str, taskpool_cls):
|
|
59
|
+
"""
|
|
60
|
+
Register a new task pool with the factory.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
name (str): The name of the task pool.
|
|
64
|
+
taskpool_cls: The class of the task pool to register.
|
|
65
|
+
"""
|
|
37
66
|
TaskPoolFactory._taskpool_types[name] = taskpool_cls
|
|
38
67
|
|
|
39
68
|
@classmethod
|
|
40
69
|
def available_taskpools(cls):
|
|
70
|
+
"""
|
|
71
|
+
Get a list of available task pools.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
list: A list of available task pool names.
|
|
75
|
+
"""
|
|
41
76
|
return list(cls._taskpool_types.keys())
|
|
42
77
|
|
|
43
78
|
|