fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
# Model conversion utils
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
10
|
+
"""
|
|
11
|
+
Convert a state dictionary to a vector.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
state_dict (dict): The state dictionary to convert.
|
|
15
|
+
remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
torch.Tensor: The converted vector.
|
|
19
|
+
"""
|
|
20
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
21
|
+
for key in remove_keys:
|
|
22
|
+
if key in shared_state_dict:
|
|
23
|
+
del shared_state_dict[key]
|
|
24
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
25
|
+
return nn.utils.parameters_to_vector(
|
|
26
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
31
|
+
"""
|
|
32
|
+
Convert a vector to a state dictionary.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
vector (torch.Tensor): The vector to convert.
|
|
36
|
+
state_dict (dict): The reference state dictionary to define the order of the vector.
|
|
37
|
+
remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
dict: The converted state dictionary.
|
|
41
|
+
"""
|
|
42
|
+
# create a reference dict to define the order of the vector
|
|
43
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
44
|
+
for key in remove_keys:
|
|
45
|
+
if key in reference_dict:
|
|
46
|
+
del reference_dict[key]
|
|
47
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
48
|
+
|
|
49
|
+
# create a shared state dict using the reference dict
|
|
50
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
51
|
+
|
|
52
|
+
# add back the encoder and decoder embedding weights.
|
|
53
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
54
|
+
for key in remove_keys:
|
|
55
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
56
|
+
"transformer.shared.weight"
|
|
57
|
+
]
|
|
58
|
+
return sorted_reference_dict
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from transformers import CLIPModel, CLIPProcessor
|
|
10
|
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
|
11
|
+
|
|
12
|
+
from fusion_bench.dataset import CLIPDataset
|
|
13
|
+
from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
|
|
14
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
15
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
16
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
17
|
+
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
18
|
+
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
19
|
+
from fusion_bench.utils import timeit_context
|
|
20
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
21
|
+
|
|
22
|
+
from .we_moe import WeightEnsemblingMoEAlgorithm
|
|
23
|
+
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CLIPWeightEnsemblingMoEAlgorithm(
|
|
28
|
+
WeightEnsemblingMoEAlgorithm,
|
|
29
|
+
CLIPClassificationMixin,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
CLIPWeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
|
|
33
|
+
for CLIP models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
modelpool: CLIPVisionModelPool = None
|
|
40
|
+
|
|
41
|
+
def load_checkpoint(self, model, checkpoint):
|
|
42
|
+
"""
|
|
43
|
+
Load the checkpoint file.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model: The model to load the checkpoint into.
|
|
47
|
+
checkpoint: The path to the checkpoint file.
|
|
48
|
+
"""
|
|
49
|
+
state = {"model": model}
|
|
50
|
+
self._fabric.load(checkpoint, state)
|
|
51
|
+
|
|
52
|
+
def save_checkpoint(self, model, checkpoint):
|
|
53
|
+
"""
|
|
54
|
+
Save the checkpoint file.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model: The model to save the checkpoint from.
|
|
58
|
+
checkpoint: The path to the checkpoint file.
|
|
59
|
+
"""
|
|
60
|
+
self._fabric.save(checkpoint, {"model": model})
|
|
61
|
+
|
|
62
|
+
def construct_moe_model(self) -> WeightEnsemblingMoE:
|
|
63
|
+
"""
|
|
64
|
+
Construct the Mixture of Experts (MoE) model using the models in the model pool.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
WeightEnsemblingMoE: The constructed MoE model.
|
|
68
|
+
"""
|
|
69
|
+
base_model = self.modelpool.load_model("_pretrained_")
|
|
70
|
+
expert_models = [
|
|
71
|
+
self.modelpool.load_model(m) for m in self.modelpool.model_names
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
# Merge the models using task arithmetic
|
|
75
|
+
moe_model = task_arithmetic_merge(
|
|
76
|
+
# This function modifies the model in place, so we need to pass a deepcopy
|
|
77
|
+
deepcopy(base_model),
|
|
78
|
+
expert_models,
|
|
79
|
+
scaling_factor=self.config.init_lambda,
|
|
80
|
+
).requires_grad_(False)
|
|
81
|
+
|
|
82
|
+
# Up-scale MLP modules
|
|
83
|
+
base_encoder: CLIPEncoder = base_model.vision_model.encoder
|
|
84
|
+
moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
|
|
85
|
+
expert_encoders = [m.vision_model.encoder for m in expert_models]
|
|
86
|
+
|
|
87
|
+
num_layers = len(base_encoder.layers)
|
|
88
|
+
for layer_idx in range(num_layers):
|
|
89
|
+
base_mlp = base_encoder.layers[layer_idx].mlp
|
|
90
|
+
expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]
|
|
91
|
+
|
|
92
|
+
moe_encoder.layers[layer_idx].mlp = WeightEnsemblingMoE(
|
|
93
|
+
hidden_size=base_encoder.config.hidden_size,
|
|
94
|
+
base_model=base_mlp,
|
|
95
|
+
expert_models=expert_mlps,
|
|
96
|
+
init_lambda=self.config.init_lambda,
|
|
97
|
+
batch_first=True, # For open_clip models this is False
|
|
98
|
+
router_hidden_layers=self.config.router_hidden_layers,
|
|
99
|
+
batch_reduce=self.config.batch_reduce,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return moe_model
|
|
103
|
+
|
|
104
|
+
@functools.cache
|
|
105
|
+
def get_shuffled_test_loader_iter(self, tta_dataset: str):
|
|
106
|
+
"""
|
|
107
|
+
Get an iterator for the shuffled test data loader.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
tta_dataset (str): The name of the test-time adaptation dataset.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Iterator: An iterator for the shuffled test data loader.
|
|
114
|
+
"""
|
|
115
|
+
dataset = self.modelpool.load_test_dataset(tta_dataset)
|
|
116
|
+
dataset = CLIPDataset(dataset, processor=self.clip_processor)
|
|
117
|
+
log.info("get_shuffled_test_loader_iter")
|
|
118
|
+
loader = DataLoader(
|
|
119
|
+
dataset,
|
|
120
|
+
batch_size=self.config.batch_size,
|
|
121
|
+
shuffle=True,
|
|
122
|
+
num_workers=self.config.num_workers,
|
|
123
|
+
pin_memory=True,
|
|
124
|
+
)
|
|
125
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
126
|
+
return iter(InfiniteDataLoader(loader))
|
|
127
|
+
|
|
128
|
+
def on_test_time_adaptation_start(self):
|
|
129
|
+
"""
|
|
130
|
+
Load the CLIP processor and construct the zero-shot classification head for each task.
|
|
131
|
+
"""
|
|
132
|
+
self.setup_zero_shot_classification_head()
|
|
133
|
+
|
|
134
|
+
def compute_logits(self, module, batch, task) -> Tensor:
|
|
135
|
+
"""
|
|
136
|
+
Compute the logits for the given batch and task.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
module: The model module.
|
|
140
|
+
batch: The input batch.
|
|
141
|
+
task: The task name.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Tensor: The computed logits.
|
|
145
|
+
"""
|
|
146
|
+
images, _ = batch
|
|
147
|
+
text_embeds = self.zeroshot_weights[task]
|
|
148
|
+
|
|
149
|
+
image_embeds = module(images)[1]
|
|
150
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
151
|
+
|
|
152
|
+
# Normalize embeddings
|
|
153
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
154
|
+
|
|
155
|
+
# Cosine similarity
|
|
156
|
+
logits_per_text = (
|
|
157
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
158
|
+
)
|
|
159
|
+
logits_per_image = logits_per_text.t()
|
|
160
|
+
|
|
161
|
+
return logits_per_image
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import cast # noqa: F401
|
|
4
|
+
|
|
5
|
+
import lightning as L
|
|
6
|
+
import lightning.fabric.wrappers
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch.profilers import SimpleProfiler
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
from tqdm.autonotebook import tqdm
|
|
13
|
+
|
|
14
|
+
from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
|
|
15
|
+
from fusion_bench.compat.modelpool import ModelPool
|
|
16
|
+
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
17
|
+
from fusion_bench.utils import timeit_context
|
|
18
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def entropy_loss(logits: Tensor) -> Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Compute the entropy loss of a set of logits.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tensor: The entropy loss of the logits.
|
|
32
|
+
"""
|
|
33
|
+
probs = torch.softmax(logits, dim=-1)
|
|
34
|
+
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
38
|
+
"""
|
|
39
|
+
Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
|
|
40
|
+
|
|
41
|
+
This class provides methods for constructing the MoE model, performing test-time adaptation,
|
|
42
|
+
and running the fusion process.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
_fabric (L.Fabric): The fabric for distributed training.
|
|
46
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
47
|
+
profiler (SimpleProfiler): The profiler for measuring performance.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
_fabric: L.Fabric = None
|
|
51
|
+
modelpool: ModelPool = None
|
|
52
|
+
|
|
53
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
54
|
+
"""
|
|
55
|
+
Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(algorithm_config)
|
|
61
|
+
|
|
62
|
+
if self._fabric is None and torch.cuda.is_available():
|
|
63
|
+
self._fabric = L.Fabric(
|
|
64
|
+
devices=self.config.get("devices", 1),
|
|
65
|
+
)
|
|
66
|
+
self._fabric.launch()
|
|
67
|
+
else:
|
|
68
|
+
assert "No CUDA device available."
|
|
69
|
+
self.profiler = SimpleProfiler(
|
|
70
|
+
self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def load_checkpoint(self, model, checkpoint):
|
|
75
|
+
"""
|
|
76
|
+
Load the checkpoint file.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
model: The model to load the checkpoint into.
|
|
80
|
+
checkpoint: The checkpoint file to load.
|
|
81
|
+
"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def save_checkpoint(self, model, checkpoint):
|
|
86
|
+
"""
|
|
87
|
+
Save the checkpoint file.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model: The model to save the checkpoint from.
|
|
91
|
+
checkpoint: The checkpoint file to save.
|
|
92
|
+
"""
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def construct_moe_model(self) -> WeightEnsemblingMoE:
|
|
97
|
+
"""
|
|
98
|
+
Construct the Mixture of Experts model using the models in the model pool.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
WeightEnsemblingMoE: The constructed MoE model.
|
|
102
|
+
"""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
def on_test_time_adaptation_start(self):
|
|
106
|
+
"""
|
|
107
|
+
Hook method called at the start of test-time adaptation.
|
|
108
|
+
"""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
113
|
+
"""
|
|
114
|
+
Get an iterator for the shuffled test data loader for a specific task.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
task (str): The task for which to get the test data loader.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
DataLoader: The shuffled test data loader iterator.
|
|
121
|
+
"""
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def compute_logits(self, module, batch, task) -> Tensor:
|
|
126
|
+
"""
|
|
127
|
+
Compute the logits for a given batch and task.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
module: The model module to use for computing logits.
|
|
131
|
+
batch: The batch of data.
|
|
132
|
+
task: The task for which to compute logits.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Tensor: The computed logits.
|
|
136
|
+
"""
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def test_time_adaptation(self, module: WeightEnsemblingMoE):
|
|
140
|
+
"""
|
|
141
|
+
Perform test-time adaptation for the given module.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
module (WeightEnsemblingMoE): The MoE module to adapt.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
WeightEnsemblingMoE: The adapted MoE module.
|
|
148
|
+
"""
|
|
149
|
+
self.on_test_time_adaptation_start()
|
|
150
|
+
|
|
151
|
+
# configure optimizer
|
|
152
|
+
if self.config.optimizer == "adam":
|
|
153
|
+
optimizer = torch.optim.Adam(
|
|
154
|
+
[p for p in module.parameters() if p.requires_grad], lr=self.config.lr
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
158
|
+
|
|
159
|
+
if self._fabric is not None:
|
|
160
|
+
module, optimizer = self._fabric.setup(module, optimizer)
|
|
161
|
+
|
|
162
|
+
module.train()
|
|
163
|
+
|
|
164
|
+
if self.config.get("fast_dev_run", False):
|
|
165
|
+
log.info("Running fast_dev_run, only one step")
|
|
166
|
+
pbar = tqdm(
|
|
167
|
+
range(1),
|
|
168
|
+
"Test-time adaptation",
|
|
169
|
+
dynamic_ncols=True,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
pbar = tqdm(
|
|
173
|
+
range(self.config.max_steps),
|
|
174
|
+
"Test-time adaptation",
|
|
175
|
+
dynamic_ncols=True,
|
|
176
|
+
)
|
|
177
|
+
for step_idx in pbar:
|
|
178
|
+
if self.config.use_grad_accumulate:
|
|
179
|
+
for task in self.modelpool.model_names:
|
|
180
|
+
with self.profiler.profile("data time"):
|
|
181
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
182
|
+
with self.profiler.profile("forward pass"):
|
|
183
|
+
logits = self.compute_logits(module, batch, task)
|
|
184
|
+
assert (
|
|
185
|
+
logits.dim() == 2
|
|
186
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
187
|
+
loss = entropy_loss(logits)
|
|
188
|
+
# .backward() accumulates when .zero_grad() wasn't called
|
|
189
|
+
# this can save memory
|
|
190
|
+
with self.profiler.profile("backward pass"):
|
|
191
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
192
|
+
else:
|
|
193
|
+
loss = 0
|
|
194
|
+
for task in self.modelpool.model_names:
|
|
195
|
+
with self.profiler.profile("data time"):
|
|
196
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
197
|
+
with self.profiler.profile("forward pass"):
|
|
198
|
+
logits = self.compute_logits(module, batch, task)
|
|
199
|
+
assert (
|
|
200
|
+
logits.dim() == 2
|
|
201
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
202
|
+
loss = loss + entropy_loss(logits)
|
|
203
|
+
with self.profiler.profile("backward pass"):
|
|
204
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
205
|
+
|
|
206
|
+
with self.profiler.profile("optimizer step"):
|
|
207
|
+
optimizer.step()
|
|
208
|
+
optimizer.zero_grad()
|
|
209
|
+
|
|
210
|
+
return module
|
|
211
|
+
|
|
212
|
+
def run(self, modelpool: ModelPool):
|
|
213
|
+
"""
|
|
214
|
+
Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
WeightEnsemblingMoE: The fused MoE model.
|
|
221
|
+
"""
|
|
222
|
+
log.info("Fusing models using WeightEnsembling Mixture of Experts modules.")
|
|
223
|
+
self.modelpool = modelpool
|
|
224
|
+
|
|
225
|
+
with timeit_context("upscaling models to a weight-ensembling MoE model"):
|
|
226
|
+
moe_model = self.construct_moe_model()
|
|
227
|
+
print_parameters(moe_model)
|
|
228
|
+
|
|
229
|
+
if self.config.get("checkpoint", False):
|
|
230
|
+
log.info(
|
|
231
|
+
f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
|
|
232
|
+
)
|
|
233
|
+
self.load_checkpoint(moe_model, self.config.checkpoint)
|
|
234
|
+
else:
|
|
235
|
+
with self.profiler.profile("test-time adaptation"):
|
|
236
|
+
moe_model = self.test_time_adaptation(moe_model)
|
|
237
|
+
if self.config.get("save_checkpoint", False):
|
|
238
|
+
log.info(f"save checkpoint to {self.config.save_checkpoint}")
|
|
239
|
+
self.save_checkpoint(moe_model, self.config.save_checkpoint)
|
|
240
|
+
|
|
241
|
+
if lightning.fabric.wrappers.is_wrapped(moe_model):
|
|
242
|
+
moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
|
|
243
|
+
|
|
244
|
+
# enable sample-wise adaptation
|
|
245
|
+
moe_model.batch_reduce = False
|
|
246
|
+
print(self.profiler.summary())
|
|
247
|
+
return moe_model
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Mapping, Union # noqa: F401
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from fusion_bench.method import BaseAlgorithm
|
|
9
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
10
|
+
from fusion_bench.utils import timeit_context
|
|
11
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
12
|
+
from fusion_bench.utils.type import StateDictType
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WeightedAverageForLLama(BaseAlgorithm):
|
|
18
|
+
"""
|
|
19
|
+
A class to perform weighted averaging of LlaMa/Mistral models.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
23
|
+
"normalize": "normalize",
|
|
24
|
+
"weights": "weights",
|
|
25
|
+
"backbone_only": "backbone_only",
|
|
26
|
+
"merged_model_save_path": "merged_model_save_path",
|
|
27
|
+
"save_tokenizer": "save_tokenizer",
|
|
28
|
+
"push_to_hub": "push_to_hub",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
normalize: bool,
|
|
34
|
+
weights: List[float],
|
|
35
|
+
backbone_only: bool,
|
|
36
|
+
merged_model_save_path: str,
|
|
37
|
+
save_tokenizer: bool,
|
|
38
|
+
push_to_hub: bool,
|
|
39
|
+
**kwargs,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the WeightedAverageForLLama class with the given parameters.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
normalize (bool): Whether to normalize the weights.
|
|
46
|
+
weights (List[float]): The weights for averaging the models.
|
|
47
|
+
backbone_only (bool): Whether to use only the backbone of the models.
|
|
48
|
+
merged_model_save_path (str): The path to save the merged model.
|
|
49
|
+
save_tokenizer (bool): Whether to save the tokenizer.
|
|
50
|
+
push_to_hub (bool): Whether to push the model to the hub.
|
|
51
|
+
"""
|
|
52
|
+
self.normalize = normalize
|
|
53
|
+
self.weights = weights
|
|
54
|
+
self.backbone_only = backbone_only
|
|
55
|
+
self.merged_model_save_path = merged_model_save_path
|
|
56
|
+
self.save_tokenizer = save_tokenizer
|
|
57
|
+
self.push_to_hub = push_to_hub
|
|
58
|
+
super().__init__(**kwargs)
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
@torch.no_grad()
|
|
62
|
+
def run(self, modelpool: CausalLMPool):
|
|
63
|
+
"""
|
|
64
|
+
Executes the weighted averaging of models in the provided model pool.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
modelpool (LLamaForCausalLMPoolThe): pool of models to be averaged.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
base_model: The base model after merging the state dictionaries of the models in the pool.
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ValueError: If the number of weights does not match the number of models in the pool.
|
|
74
|
+
"""
|
|
75
|
+
if modelpool.has_pretrained:
|
|
76
|
+
base_model = modelpool.load_model("_pretrained_")
|
|
77
|
+
else:
|
|
78
|
+
base_model = modelpool.load_model(modelpool.model_names[0])
|
|
79
|
+
|
|
80
|
+
weights = self.weights
|
|
81
|
+
if len(weights) != len(modelpool.model_names):
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"Number of weights must match the number of models.,"
|
|
84
|
+
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
|
|
85
|
+
f"weights: {weights}, models: {modelpool.model_names}"
|
|
86
|
+
)
|
|
87
|
+
if self.normalize:
|
|
88
|
+
weights = np.asarray(weights)
|
|
89
|
+
weights = weights / np.sum(weights)
|
|
90
|
+
|
|
91
|
+
merged_state_dict: StateDictType = None
|
|
92
|
+
for model_name, weight in zip(modelpool.model_names, weights):
|
|
93
|
+
model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
|
|
94
|
+
sd = state_dict_mul(model.state_dict(), weight)
|
|
95
|
+
if merged_state_dict is None:
|
|
96
|
+
merged_state_dict = sd
|
|
97
|
+
else:
|
|
98
|
+
merged_state_dict = state_dict_add(merged_state_dict, sd)
|
|
99
|
+
|
|
100
|
+
base_model.load_state_dict(
|
|
101
|
+
merged_state_dict, strict=False if self.backbone_only else True
|
|
102
|
+
)
|
|
103
|
+
if self.merged_model_save_path is not None:
|
|
104
|
+
with timeit_context(
|
|
105
|
+
f"Saving the merged model to {self.merged_model_save_path}"
|
|
106
|
+
):
|
|
107
|
+
modelpool.save_model(
|
|
108
|
+
base_model,
|
|
109
|
+
path=self.merged_model_save_path,
|
|
110
|
+
save_tokenizer=self.save_tokenizer,
|
|
111
|
+
push_to_hub=self.push_to_hub,
|
|
112
|
+
)
|
|
113
|
+
return base_model
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Examples:
|
|
3
|
+
|
|
4
|
+
The following command merges eight clip-ViT models using a weighted average approach.
|
|
5
|
+
Because `method.normalize` is set to true, the weights are normalized to sum to 1, thus equivalent to simple average.
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
fusion_bench \
|
|
9
|
+
method=linear/weighted_average \
|
|
10
|
+
method.normalize=true \
|
|
11
|
+
method.weights=[0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3] \
|
|
12
|
+
modelpool=... \
|
|
13
|
+
taskpool=...
|
|
14
|
+
```
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import List, Mapping, Optional, Union # noqa: F401
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
from typing_extensions import override
|
|
23
|
+
|
|
24
|
+
from fusion_bench.method import BaseAlgorithm
|
|
25
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
26
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
27
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
28
|
+
from fusion_bench.utils.type import StateDictType
|
|
29
|
+
|
|
30
|
+
log = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class WeightedAverageAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
34
|
+
|
|
35
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
36
|
+
"normalize": "normalize",
|
|
37
|
+
"weights": "weights",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
normalize: bool,
|
|
43
|
+
weights: List[float],
|
|
44
|
+
verbose: bool = True,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
self.normalize = normalize
|
|
48
|
+
self.weights = weights
|
|
49
|
+
self.verbose = verbose
|
|
50
|
+
log.disabled = not self.verbose
|
|
51
|
+
super().__init__(**kwargs)
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
@torch.no_grad()
|
|
55
|
+
def run(self, modelpool: BaseModelPool):
|
|
56
|
+
"""
|
|
57
|
+
Fuses the models in the model pool using a weighted average approach.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
61
|
+
|
|
62
|
+
Raises
|
|
63
|
+
ValueError: If the number of weights does not match the number of models in the model pool.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
forward_model (torch.nn.Module): The resulting model after fusion.
|
|
67
|
+
"""
|
|
68
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
69
|
+
modelpool = BaseModelPool(modelpool)
|
|
70
|
+
|
|
71
|
+
log.info("Fusing models using weighted average.")
|
|
72
|
+
weights = np.asarray(self.weights)
|
|
73
|
+
if len(weights) != len(modelpool.model_names):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"Number of weights must match the number of models.,"
|
|
76
|
+
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
|
|
77
|
+
f"weights: {weights}, models: {modelpool.model_names}"
|
|
78
|
+
)
|
|
79
|
+
if self.normalize:
|
|
80
|
+
weights = weights / np.sum(weights)
|
|
81
|
+
if self.verbose:
|
|
82
|
+
print(f"weights: {weights}, normalized: {self.normalize}")
|
|
83
|
+
|
|
84
|
+
sd: Optional[StateDictType] = None
|
|
85
|
+
forward_model = None
|
|
86
|
+
|
|
87
|
+
for model_name, weight in zip(modelpool.model_names, weights):
|
|
88
|
+
with self.profile("load_model"):
|
|
89
|
+
model = modelpool.load_model(model_name)
|
|
90
|
+
with self.profile("merge weights"):
|
|
91
|
+
if sd is None:
|
|
92
|
+
sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
|
|
93
|
+
forward_model = model
|
|
94
|
+
else:
|
|
95
|
+
sd = state_dict_add(
|
|
96
|
+
sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
forward_model.load_state_dict(sd)
|
|
100
|
+
if self.verbose:
|
|
101
|
+
self.print_profile_summary()
|
|
102
|
+
return forward_model
|
|
File without changes
|