fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from transformers import CLIPModel, CLIPProcessor
|
|
10
|
+
|
|
11
|
+
from fusion_bench.dataset import CLIPDataset
|
|
12
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
13
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
14
|
+
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
15
|
+
from fusion_bench.utils import timeit_context
|
|
16
|
+
|
|
17
|
+
from .task_wise_adamerging import TaskWiseAdaMergingAlgorithm
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class InfiniteDataLoader:
|
|
23
|
+
"""
|
|
24
|
+
A wrapper class for DataLoader to create an infinite data loader.
|
|
25
|
+
This is useful in case we are only interested in the number of steps and not the number of epochs.
|
|
26
|
+
|
|
27
|
+
This class wraps a DataLoader and provides an iterator that resets
|
|
28
|
+
when the end of the dataset is reached, creating an infinite loop.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
data_loader (DataLoader): The DataLoader to wrap.
|
|
32
|
+
data_iter (iterator): An iterator over the DataLoader.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, data_loader):
|
|
36
|
+
self.data_loader = data_loader
|
|
37
|
+
self.data_iter = iter(data_loader)
|
|
38
|
+
|
|
39
|
+
def __iter__(self):
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def __next__(self):
|
|
43
|
+
try:
|
|
44
|
+
data = next(self.data_iter)
|
|
45
|
+
except StopIteration:
|
|
46
|
+
self.data_iter = iter(self.data_loader) # Reset the data loader
|
|
47
|
+
data = next(self.data_iter)
|
|
48
|
+
return data
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
52
|
+
"""
|
|
53
|
+
A class for task-wise adaptive merging of CLIP models.
|
|
54
|
+
|
|
55
|
+
This class extends the TaskWiseAdaMergingAlgorithm to provide specific
|
|
56
|
+
functionality for CLIP models, including loading datasets, constructing
|
|
57
|
+
zero-shot classification heads, and computing logits.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
modelpool (CLIPVisionModelPool): The model pool containing CLIP models.
|
|
61
|
+
_clip_processor (CLIPProcessor): The CLIP processor for preparing inputs.
|
|
62
|
+
zeroshot_weights (dict): A dictionary to store zero-shot weights for each task.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
modelpool: CLIPVisionModelPool = None
|
|
66
|
+
_clip_processor: CLIPProcessor = None
|
|
67
|
+
zeroshot_weights = {}
|
|
68
|
+
|
|
69
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
70
|
+
super().__init__(algorithm_config)
|
|
71
|
+
|
|
72
|
+
@functools.cache
|
|
73
|
+
def get_test_dataset(self, task: str):
|
|
74
|
+
"""
|
|
75
|
+
Load the test dataset for the task.
|
|
76
|
+
This method is cached, so the dataset is loaded only once.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
task (str): The name of the task.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
CLIPDataset: The test dataset for the task.
|
|
83
|
+
"""
|
|
84
|
+
log.info(f"Loading test dataset: {task}")
|
|
85
|
+
dataset = self.modelpool.load_test_dataset(task)
|
|
86
|
+
dataset = CLIPDataset(dataset, self._clip_processor)
|
|
87
|
+
return dataset
|
|
88
|
+
|
|
89
|
+
@functools.cache
|
|
90
|
+
def get_shuffled_test_loader_iter(self, task: str):
|
|
91
|
+
"""
|
|
92
|
+
Get an iterator over the shuffled test DataLoader for the task.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
task (str): The name of the task.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
iterator: An iterator over the shuffled test DataLoader.
|
|
99
|
+
"""
|
|
100
|
+
loader = DataLoader(
|
|
101
|
+
self.get_test_dataset(task),
|
|
102
|
+
batch_size=self.config.batch_size,
|
|
103
|
+
shuffle=True,
|
|
104
|
+
num_workers=self.config.num_workers,
|
|
105
|
+
pin_memory=True,
|
|
106
|
+
)
|
|
107
|
+
if self._fabric is not None:
|
|
108
|
+
loader = self._fabric.setup_dataloaders(loader)
|
|
109
|
+
return iter(InfiniteDataLoader(loader))
|
|
110
|
+
|
|
111
|
+
def on_test_time_adaptation_start(self):
|
|
112
|
+
"""
|
|
113
|
+
Prepare for test-time adaptation.
|
|
114
|
+
|
|
115
|
+
This method loads the CLIP processor and constructs the zero-shot
|
|
116
|
+
classification head for each task.
|
|
117
|
+
"""
|
|
118
|
+
clip_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
119
|
+
pretrained_path = (
|
|
120
|
+
clip_model_config.pretrained_model_name_or_path
|
|
121
|
+
if hasattr(clip_model_config, "pretrained_model_name_or_path")
|
|
122
|
+
else clip_model_config.path
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
with timeit_context("Loading CLIP processor and pretrained CLIP model."):
|
|
126
|
+
self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
|
|
127
|
+
clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)
|
|
128
|
+
|
|
129
|
+
clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
|
|
130
|
+
self.visual_projection = clip_model.visual_projection.requires_grad_(False)
|
|
131
|
+
self.logit_scale_exp = clip_model.logit_scale.exp()
|
|
132
|
+
if self._fabric is not None:
|
|
133
|
+
self.visual_projection = self._fabric.to_device(self.visual_projection)
|
|
134
|
+
self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)
|
|
135
|
+
|
|
136
|
+
for task in self.modelpool.model_names:
|
|
137
|
+
cache_file = os.path.join(
|
|
138
|
+
self.config.cache_dir,
|
|
139
|
+
f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
|
|
140
|
+
)
|
|
141
|
+
if os.path.exists(cache_file):
|
|
142
|
+
log.info(f"Loading cached zeroshot weights for task: {task}")
|
|
143
|
+
zeroshot_weights = torch.load(cache_file, map_location="cpu")
|
|
144
|
+
else:
|
|
145
|
+
log.info(f"Construct zero shot classification head for task: {task}")
|
|
146
|
+
classnames, templates = get_classnames_and_templates(task)
|
|
147
|
+
clip_classifier.set_classification_task(classnames, templates)
|
|
148
|
+
zeroshot_weights = clip_classifier.zeroshot_weights
|
|
149
|
+
log.info(f"save zeroshot weights to {cache_file}")
|
|
150
|
+
torch.save(zeroshot_weights, cache_file)
|
|
151
|
+
self.zeroshot_weights[task] = zeroshot_weights
|
|
152
|
+
if self._fabric is not None:
|
|
153
|
+
self.zeroshot_weights[task] = self._fabric.to_device(
|
|
154
|
+
self.zeroshot_weights[task]
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def compute_logits(self, module, batch, task: str) -> Tensor:
|
|
158
|
+
"""
|
|
159
|
+
Compute the logits for the given batch and task.
|
|
160
|
+
|
|
161
|
+
This method computes the image embeddings, normalizes them, and calculates
|
|
162
|
+
the cosine similarity with the text embeddings to produce classification logits.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
module (nn.Module): The model module.
|
|
166
|
+
batch (tuple): A batch of input data.
|
|
167
|
+
task (str): The name of the task.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Tensor: The classification logits for the batch.
|
|
171
|
+
"""
|
|
172
|
+
images, _ = batch
|
|
173
|
+
text_embeds = self.zeroshot_weights[task]
|
|
174
|
+
|
|
175
|
+
image_embeds = module(images)[1]
|
|
176
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
177
|
+
|
|
178
|
+
# normalize embeddings
|
|
179
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
180
|
+
|
|
181
|
+
# cosine similarity
|
|
182
|
+
logits_per_text = (
|
|
183
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
184
|
+
)
|
|
185
|
+
logits_per_image = logits_per_text.t()
|
|
186
|
+
|
|
187
|
+
return logits_per_image
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
|
|
6
|
+
"""
|
|
7
|
+
Compute the entropy loss of a set of logits.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
11
|
+
eps (float): A small value to avoid log(0). Default is 1e-8.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
Tensor: The entropy loss of the logits.
|
|
15
|
+
"""
|
|
16
|
+
# Ensure the logits tensor has 2 dimensions
|
|
17
|
+
assert (
|
|
18
|
+
logits.dim() == 2
|
|
19
|
+
), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
|
|
20
|
+
|
|
21
|
+
# Compute the softmax probabilities
|
|
22
|
+
probs = torch.softmax(logits, dim=-1)
|
|
23
|
+
|
|
24
|
+
# Compute the entropy loss
|
|
25
|
+
return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is an experimental implementation of the Layer-Wise AdaMerging Algorithm for Flan-T5 models.
|
|
3
|
+
The efficiency of the algorithm is not guaranteed, and it may not work as expected.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import functools
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from abc import abstractmethod
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
15
|
+
from omegaconf import DictConfig
|
|
16
|
+
from torch import Tensor, nn
|
|
17
|
+
from torch.utils.data import DataLoader
|
|
18
|
+
from tqdm.autonotebook import tqdm
|
|
19
|
+
from transformers import T5ForConditionalGeneration
|
|
20
|
+
from transformers.data import default_data_collator
|
|
21
|
+
|
|
22
|
+
from fusion_bench.method import BaseAlgorithm
|
|
23
|
+
from fusion_bench.method.simple_average import simple_average
|
|
24
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
25
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
26
|
+
from fusion_bench.modelpool import Seq2SeqLMPool
|
|
27
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
28
|
+
LayerWiseMergedModel,
|
|
29
|
+
get_layer_wise_weights,
|
|
30
|
+
)
|
|
31
|
+
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
32
|
+
from fusion_bench.utils.instantiate import instantiate
|
|
33
|
+
|
|
34
|
+
from .entropy_loss import entropy_loss
|
|
35
|
+
from .min_norm_solvers import MinNormSolver
|
|
36
|
+
from .utils import get_memory_usage
|
|
37
|
+
|
|
38
|
+
log = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FlanT5LayerWiseAdaMergingAlgorithm(
|
|
42
|
+
BaseAlgorithm,
|
|
43
|
+
LightningFabricMixin,
|
|
44
|
+
SimpleProfilerMixin,
|
|
45
|
+
):
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
optimizer: DictConfig,
|
|
50
|
+
dataloader_kwargs: DictConfig,
|
|
51
|
+
init_values: float,
|
|
52
|
+
max_steps: int,
|
|
53
|
+
merging_weights_load_path: Optional[Union[str, Path]] = None,
|
|
54
|
+
merging_weights_save_path: Optional[Union[str, Path]] = None,
|
|
55
|
+
clamp_weights: bool = False,
|
|
56
|
+
tie_weights: bool = True,
|
|
57
|
+
strict: bool = False,
|
|
58
|
+
cache_dir: str = "outputs/cache",
|
|
59
|
+
variant: Optional[str] = None,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
62
|
+
self._optimizer = optimizer
|
|
63
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
64
|
+
self.init_values = init_values
|
|
65
|
+
self.merging_weights_load_path = merging_weights_load_path
|
|
66
|
+
self.merging_weights_save_path = merging_weights_save_path
|
|
67
|
+
self.clamp_weights = clamp_weights
|
|
68
|
+
self.tie_weights = tie_weights
|
|
69
|
+
self.strict = strict
|
|
70
|
+
self.max_steps = max_steps
|
|
71
|
+
self.cache_dir = cache_dir
|
|
72
|
+
self.variant = variant
|
|
73
|
+
super().__init__(**kwargs)
|
|
74
|
+
|
|
75
|
+
@torch.no_grad()
|
|
76
|
+
def construct_layer_wise_merged_model(self, modelpool: Seq2SeqLMPool):
|
|
77
|
+
"""
|
|
78
|
+
Constructs a wrapped layer-wise merged model from model pool.
|
|
79
|
+
|
|
80
|
+
This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
|
|
81
|
+
The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
|
|
82
|
+
The merging weights can be initialized based on a provided configuration or loaded from a file.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
|
|
89
|
+
"""
|
|
90
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
91
|
+
finetuned_models = [
|
|
92
|
+
modelpool.load_model(name) for name in modelpool.model_names
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
|
|
96
|
+
if self.merging_weights_load_path is None:
|
|
97
|
+
layer_wise_weight = get_layer_wise_weights(
|
|
98
|
+
num_models=len(modelpool.model_names),
|
|
99
|
+
num_layers=len(
|
|
100
|
+
tuple(
|
|
101
|
+
filter(lambda p: p.requires_grad, pretrained_model.parameters())
|
|
102
|
+
)
|
|
103
|
+
),
|
|
104
|
+
init_values=self.init_values,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
if isinstance(self.merging_weights_load_path, str):
|
|
108
|
+
# load the merging weights from a file
|
|
109
|
+
layer_wise_weight = load_tensor_from_file(
|
|
110
|
+
self.merging_weights_load_path
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Unsupported weights format: {self.merging_weights_load_path}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
module = LayerWiseMergedModel(
|
|
118
|
+
layer_wise_weight=layer_wise_weight,
|
|
119
|
+
pretrained_model=pretrained_model,
|
|
120
|
+
finetuned_models=finetuned_models,
|
|
121
|
+
clamp_weights=self.clamp_weights,
|
|
122
|
+
tie_weights=self.tie_weights,
|
|
123
|
+
strict=self.strict,
|
|
124
|
+
)
|
|
125
|
+
print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
|
|
126
|
+
return module
|
|
127
|
+
|
|
128
|
+
@rank_zero_only
|
|
129
|
+
def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
|
|
130
|
+
"""
|
|
131
|
+
Save the merging weights to a file.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
file_path (str): The path to save the merging weights.
|
|
135
|
+
merging_weights (torch.Tensor): The merging weights to save.
|
|
136
|
+
"""
|
|
137
|
+
if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
|
|
138
|
+
if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
|
|
139
|
+
# if the file path is not absolute or relative to current working directory, save it in the log directory
|
|
140
|
+
save_path = os.path.join(self.log_dir, file_path)
|
|
141
|
+
else:
|
|
142
|
+
save_path = file_path
|
|
143
|
+
log.info(f"saving merging weights to {save_path}.")
|
|
144
|
+
if os.path.dirname(save_path):
|
|
145
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
146
|
+
torch.save(merging_weights.detach().cpu(), save_path)
|
|
147
|
+
|
|
148
|
+
def run(self, modelpool: Seq2SeqLMPool, **kwargs):
|
|
149
|
+
"""
|
|
150
|
+
Run the Layer-Wise AdaMerging Algorithm.
|
|
151
|
+
|
|
152
|
+
This method constructs the wrapped model and performs test-time adaptation if necessary.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
LayerWiseMergedModel: The merged model after test-time adaptation.
|
|
159
|
+
"""
|
|
160
|
+
log.info("Fusing models using layer-wise adaptive merging.")
|
|
161
|
+
self.modelpool = modelpool
|
|
162
|
+
|
|
163
|
+
with self.profile("construct the wrapped model"):
|
|
164
|
+
module = self.construct_layer_wise_merged_model(modelpool)
|
|
165
|
+
|
|
166
|
+
if self.merging_weights_load_path is not None:
|
|
167
|
+
# skip the test-time adaptation
|
|
168
|
+
return module.merge_and_unload()
|
|
169
|
+
else:
|
|
170
|
+
with self.profile("test-time adaptation"):
|
|
171
|
+
module = self.test_time_adaptation(module)
|
|
172
|
+
if self.merging_weights_save_path is not None:
|
|
173
|
+
self.save_merging_weights(
|
|
174
|
+
self.merging_weights_save_path, module.merge_weight
|
|
175
|
+
)
|
|
176
|
+
return module.merge_and_unload()
|
|
177
|
+
|
|
178
|
+
@functools.cache
|
|
179
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
180
|
+
"""
|
|
181
|
+
Loader of test dataset for test-time adaptation. labels are not needed.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
task (str): The name of the task.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
DataLoader: The data loader for the test dataset.
|
|
188
|
+
"""
|
|
189
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
190
|
+
dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))
|
|
191
|
+
|
|
192
|
+
dataset = self.modelpool.load_test_dataset(task)
|
|
193
|
+
loader = DataLoader(dataset, **dataloader_kwargs)
|
|
194
|
+
|
|
195
|
+
if self.fabric is not None:
|
|
196
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
197
|
+
return iter(InfiniteDataLoader(loader))
|
|
198
|
+
|
|
199
|
+
def compute_logits(
|
|
200
|
+
self,
|
|
201
|
+
module: Union[T5ForConditionalGeneration, LayerWiseMergedModel],
|
|
202
|
+
batch,
|
|
203
|
+
task: str,
|
|
204
|
+
) -> Tensor:
|
|
205
|
+
"""
|
|
206
|
+
Compute the logits for the given images and task.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
module: The model module.
|
|
210
|
+
images (Tensor): The input images.
|
|
211
|
+
task (str): The name of the task.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Tensor: The computed logits.
|
|
215
|
+
"""
|
|
216
|
+
input_ids: Tensor = batch["input_ids"]
|
|
217
|
+
attention_mask: Tensor = batch["attention_mask"]
|
|
218
|
+
|
|
219
|
+
# remove padding tokens from the input
|
|
220
|
+
while attention_mask[:, -1].eq(0).all():
|
|
221
|
+
input_ids = input_ids[:, :-1]
|
|
222
|
+
attention_mask = attention_mask[:, :-1]
|
|
223
|
+
|
|
224
|
+
outputs = module(
|
|
225
|
+
input_ids=input_ids,
|
|
226
|
+
attention_mask=attention_mask,
|
|
227
|
+
decoder_input_ids=torch.ones(
|
|
228
|
+
input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
|
|
229
|
+
),
|
|
230
|
+
)
|
|
231
|
+
logits = outputs.logits[:, 0, :]
|
|
232
|
+
return logits
|
|
233
|
+
|
|
234
|
+
def on_test_time_adaptation_start(self):
|
|
235
|
+
"""
|
|
236
|
+
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
|
|
237
|
+
"""
|
|
238
|
+
pass
|
|
239
|
+
|
|
240
|
+
def test_time_adaptation(self, module: LayerWiseMergedModel):
|
|
241
|
+
"""
|
|
242
|
+
Perform test-time adaptation on the merged model.
|
|
243
|
+
|
|
244
|
+
This method adapts the merging weights during test-time to improve performance.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
module (LayerWiseMergedModel): The merged model.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
LayerWiseMergedModel: The adapted merged model.
|
|
251
|
+
"""
|
|
252
|
+
self.on_test_time_adaptation_start()
|
|
253
|
+
|
|
254
|
+
# configure optimizer
|
|
255
|
+
optimizer = instantiate(self._optimizer, [module.merge_weight])
|
|
256
|
+
module, optimizer = self.fabric.setup(module, optimizer)
|
|
257
|
+
|
|
258
|
+
module.train()
|
|
259
|
+
module.merge_weights()
|
|
260
|
+
for step_idx in (
|
|
261
|
+
pbar := tqdm(
|
|
262
|
+
range(self.max_steps if not self.is_debug_mode else 1),
|
|
263
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
264
|
+
+ "AdaMerging Test-time adaptation",
|
|
265
|
+
dynamic_ncols=True,
|
|
266
|
+
)
|
|
267
|
+
):
|
|
268
|
+
if self.variant == "mgda":
|
|
269
|
+
total_loss = self._compute_gradients_using_mgda(module)
|
|
270
|
+
else:
|
|
271
|
+
total_loss = 0
|
|
272
|
+
for task in self.modelpool.model_names:
|
|
273
|
+
with self.profile("data loading"):
|
|
274
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
275
|
+
with self.profile("forward pass"):
|
|
276
|
+
logits = self.compute_logits(module, batch, task)
|
|
277
|
+
logits = logits.mean(dim=0, keepdim=True)
|
|
278
|
+
loss = entropy_loss(logits)
|
|
279
|
+
total_loss += loss
|
|
280
|
+
with self.profile("backward pass"):
|
|
281
|
+
self.fabric.backward(loss, retain_graph=True)
|
|
282
|
+
|
|
283
|
+
with self.profile("optimizer step"):
|
|
284
|
+
optimizer.step()
|
|
285
|
+
optimizer.zero_grad()
|
|
286
|
+
with self.profile("merging weights"):
|
|
287
|
+
module.merge_weights()
|
|
288
|
+
|
|
289
|
+
metrics = {
|
|
290
|
+
"train/loss": total_loss.item(),
|
|
291
|
+
"train/weight_max": module.merge_weight.max().item(),
|
|
292
|
+
"train/weight_min": module.merge_weight.min().item(),
|
|
293
|
+
"train/weight_mean": module.merge_weight.mean().item(),
|
|
294
|
+
}
|
|
295
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
296
|
+
pbar.set_postfix(metrics)
|
|
297
|
+
|
|
298
|
+
log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
|
|
299
|
+
self.print_profile_summary()
|
|
300
|
+
return module
|
|
301
|
+
|
|
302
|
+
def _compute_gradients_using_mgda(self, module: LayerWiseMergedModel):
|
|
303
|
+
all_grads = []
|
|
304
|
+
total_loss = 0
|
|
305
|
+
# default behavior for first-order optimizers
|
|
306
|
+
for task in self.modelpool.model_names:
|
|
307
|
+
with self.profile("data loading"):
|
|
308
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
309
|
+
with self.profile("forward pass"):
|
|
310
|
+
logits = self.compute_logits(module, batch, task)
|
|
311
|
+
logits = logits.mean(dim=0, keepdim=True)
|
|
312
|
+
loss = entropy_loss(logits)
|
|
313
|
+
total_loss += loss
|
|
314
|
+
with self.profile("backward pass"):
|
|
315
|
+
# self.fabric.backward(loss, retain_graph=True)
|
|
316
|
+
_grads = torch.autograd.grad(
|
|
317
|
+
loss,
|
|
318
|
+
[module.merge_weight],
|
|
319
|
+
create_graph=False,
|
|
320
|
+
retain_graph=True,
|
|
321
|
+
)
|
|
322
|
+
all_grads.append(_grads[0].flatten().detach())
|
|
323
|
+
sol, min_norm = MinNormSolver.find_min_norm_element(all_grads)
|
|
324
|
+
if not isinstance(sol, torch.Tensor):
|
|
325
|
+
sol = torch.from_numpy(sol)
|
|
326
|
+
sol = sol.to(
|
|
327
|
+
device=module.merge_weight.device,
|
|
328
|
+
dtype=module.merge_weight.dtype,
|
|
329
|
+
)
|
|
330
|
+
grad = torch.stack(all_grads) * sol.view(-1, 1)
|
|
331
|
+
module.merge_weight.grad = grad.sum(dim=0).view_as(module.merge_weight)
|
|
332
|
+
return total_loss
|