fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, List, cast # noqa: F401
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils.data
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from torch.nn.modules import Module
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from tqdm.autonotebook import tqdm
|
|
11
|
+
|
|
12
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
13
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
14
|
+
|
|
15
|
+
from .regmean import RegMeanAlgorithm
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RegMeanAlgorithmForCLIP(
|
|
21
|
+
RegMeanAlgorithm,
|
|
22
|
+
CLIPClassificationMixin,
|
|
23
|
+
):
|
|
24
|
+
_config_mapping = {
|
|
25
|
+
"_dataloader_kwargs": "dataloader_kwargs",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
29
|
+
super().__init__(**kwargs)
|
|
30
|
+
self._dataloader_kwargs = dataloader_kwargs
|
|
31
|
+
|
|
32
|
+
def on_regmean_start(self):
|
|
33
|
+
self.setup_zero_shot_classification_head()
|
|
34
|
+
|
|
35
|
+
def compute_logits(self, module, batch, task: str) -> Tensor:
|
|
36
|
+
images, _ = batch
|
|
37
|
+
text_embeds = self.zeroshot_weights[task]
|
|
38
|
+
|
|
39
|
+
image_embeds = module(images)[1]
|
|
40
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
41
|
+
|
|
42
|
+
# normalize embeddings
|
|
43
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
44
|
+
|
|
45
|
+
# cosine similarity
|
|
46
|
+
logits_per_text = (
|
|
47
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
48
|
+
)
|
|
49
|
+
logits_per_image = logits_per_text.t()
|
|
50
|
+
|
|
51
|
+
return logits_per_image
|
|
52
|
+
|
|
53
|
+
def get_regmean_weights(
|
|
54
|
+
self,
|
|
55
|
+
model_name: str,
|
|
56
|
+
model: Module,
|
|
57
|
+
train_dataset: torch.utils.data.Dataset,
|
|
58
|
+
linear_modules_to_merge: Dict[str, Module],
|
|
59
|
+
):
|
|
60
|
+
# setup dataloader
|
|
61
|
+
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
62
|
+
train_dataloader = DataLoader(
|
|
63
|
+
train_dataset, shuffle=True, **self._dataloader_kwargs
|
|
64
|
+
)
|
|
65
|
+
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
66
|
+
model = self.fabric.setup(model)
|
|
67
|
+
|
|
68
|
+
def compute_regmean_weights(module_name: str):
|
|
69
|
+
"""
|
|
70
|
+
compute the regmean weights, a hook function to deal with each module's input
|
|
71
|
+
:param module_name: str, module name
|
|
72
|
+
:return:
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def hook(module: nn.Module, input: tuple, output: torch.Tensor):
|
|
76
|
+
# Tensor, shape (batch_size, sequence_length, hidden_dim)
|
|
77
|
+
x = cast(Tensor, input[0]).detach()
|
|
78
|
+
batch_num_actual_examples = x.shape[0]
|
|
79
|
+
# Tensor, shape (batch_size * sequence_length, hidden_dim)
|
|
80
|
+
x = x.reshape(-1, x.shape[-1])
|
|
81
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
82
|
+
xtx = torch.matmul(x.transpose(0, 1), x)
|
|
83
|
+
# store the averaged weights in regmean_weights
|
|
84
|
+
if module_name not in regmean_weights.keys():
|
|
85
|
+
regmean_weights[module_name] = xtx / x.shape[0]
|
|
86
|
+
num_computed_examples[module_name] = x.shape[0]
|
|
87
|
+
num_actual_examples[module_name] = batch_num_actual_examples
|
|
88
|
+
else:
|
|
89
|
+
regmean_weights[module_name] = (
|
|
90
|
+
regmean_weights[module_name]
|
|
91
|
+
* num_computed_examples[module_name]
|
|
92
|
+
+ xtx
|
|
93
|
+
) / (num_computed_examples[module_name] + x.shape[0])
|
|
94
|
+
num_computed_examples[module_name] += x.shape[0]
|
|
95
|
+
num_actual_examples[module_name] += batch_num_actual_examples
|
|
96
|
+
|
|
97
|
+
return hook
|
|
98
|
+
|
|
99
|
+
handles = []
|
|
100
|
+
# dictionary, regmean matrices for each linear module inputs
|
|
101
|
+
regmean_weights = {}
|
|
102
|
+
# dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
|
|
103
|
+
num_computed_examples = {}
|
|
104
|
+
# dictionary, number of actual examples used for computing regmean matrices
|
|
105
|
+
num_actual_examples = {}
|
|
106
|
+
|
|
107
|
+
for module_name, linear_module_to_merge in linear_modules_to_merge.items():
|
|
108
|
+
# register a hook in the forward process
|
|
109
|
+
handle = linear_module_to_merge.register_forward_hook(
|
|
110
|
+
compute_regmean_weights(module_name=module_name)
|
|
111
|
+
)
|
|
112
|
+
handles.append(handle)
|
|
113
|
+
for step, batch in tqdm(
|
|
114
|
+
enumerate(train_dataloader),
|
|
115
|
+
desc=f"computing regmean weights for model {model_name}",
|
|
116
|
+
):
|
|
117
|
+
if (
|
|
118
|
+
len(num_actual_examples) > 0
|
|
119
|
+
and list(num_actual_examples.values())[0] >= self.num_regmean_examples
|
|
120
|
+
):
|
|
121
|
+
break
|
|
122
|
+
logits = self.compute_logits(model, batch, model_name) # noqa: F841
|
|
123
|
+
|
|
124
|
+
# remove the added hook
|
|
125
|
+
for handle in handles:
|
|
126
|
+
handle.remove()
|
|
127
|
+
|
|
128
|
+
for module_name in regmean_weights.keys():
|
|
129
|
+
regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()
|
|
130
|
+
|
|
131
|
+
return regmean_weights
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from functools import cache
|
|
5
|
+
from typing import Dict, List, cast
|
|
6
|
+
|
|
7
|
+
import lightning as L
|
|
8
|
+
import torch
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
from torch.nn.modules import Module
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from tqdm.autonotebook import tqdm
|
|
14
|
+
from transformers import GPT2ForSequenceClassification, GPT2Model
|
|
15
|
+
from transformers.data import default_data_collator
|
|
16
|
+
from transformers.models.gpt2.modeling_gpt2 import Conv1D
|
|
17
|
+
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
19
|
+
from fusion_bench.utils import timeit_context
|
|
20
|
+
|
|
21
|
+
from .regmean import RegMeanAlgorithm
|
|
22
|
+
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RegMeanAlgorithmForGPT2(
|
|
27
|
+
RegMeanAlgorithm,
|
|
28
|
+
LightningFabricMixin,
|
|
29
|
+
):
|
|
30
|
+
_include_module_type = [Conv1D]
|
|
31
|
+
classifiers = {}
|
|
32
|
+
_config_mapping = RegMeanAlgorithm._config_mapping | {
|
|
33
|
+
"cache_dir": "cache_dir",
|
|
34
|
+
"batch_size": "batch_size",
|
|
35
|
+
"num_workers": "num_workers",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def __init__(self, cache_dir: str, batch_size: int, num_workers: int, **kwargs):
|
|
39
|
+
self.cache_dir = cache_dir
|
|
40
|
+
self.batch_size = batch_size
|
|
41
|
+
self.num_workers = num_workers
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
|
|
44
|
+
def on_regmean_start(self):
|
|
45
|
+
for model_name in self.modelpool.model_names:
|
|
46
|
+
classifier = cast(
|
|
47
|
+
GPT2ForSequenceClassification,
|
|
48
|
+
self.modelpool.load_classifier(model_name),
|
|
49
|
+
).requires_grad_(False)
|
|
50
|
+
classifier.transformer = None
|
|
51
|
+
classifier = classifier.to(self.fabric.device)
|
|
52
|
+
self.classifiers[model_name] = classifier
|
|
53
|
+
|
|
54
|
+
def compute_logits(self, module: GPT2Model, batch, task: str) -> Tensor:
|
|
55
|
+
self.classifiers[task].transformer = module
|
|
56
|
+
input_ids = batch["input_ids"]
|
|
57
|
+
attention_mask = batch["attention_mask"]
|
|
58
|
+
|
|
59
|
+
outputs = self.classifiers[task](input_ids, attention_mask=attention_mask)
|
|
60
|
+
logits = outputs.logits
|
|
61
|
+
assert logits.dim() == 2
|
|
62
|
+
return logits
|
|
63
|
+
|
|
64
|
+
def get_regmean_weights(
|
|
65
|
+
self,
|
|
66
|
+
model_name: str,
|
|
67
|
+
model: Module,
|
|
68
|
+
train_dataset,
|
|
69
|
+
linear_modules_to_merge: Dict[str, Module],
|
|
70
|
+
):
|
|
71
|
+
# setup dataloader
|
|
72
|
+
train_dataloader = DataLoader(
|
|
73
|
+
train_dataset,
|
|
74
|
+
batch_size=self.config.batch_size,
|
|
75
|
+
shuffle=True,
|
|
76
|
+
num_workers=self.config.num_workers,
|
|
77
|
+
collate_fn=default_data_collator,
|
|
78
|
+
pin_memory=True,
|
|
79
|
+
)
|
|
80
|
+
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
81
|
+
model = self.fabric.setup(model)
|
|
82
|
+
|
|
83
|
+
def compute_regmean_weights(module_name: str):
|
|
84
|
+
"""
|
|
85
|
+
compute the regmean weights, a hook function to deal with each module's input
|
|
86
|
+
:param module_name: str, module name
|
|
87
|
+
:return:
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def hook(module: nn.Module, input: tuple, output: torch.Tensor):
|
|
91
|
+
# Tensor, shape (batch_size, sequence_length, hidden_dim)
|
|
92
|
+
x = cast(Tensor, input[0]).detach()
|
|
93
|
+
batch_num_actual_examples = x.shape[0]
|
|
94
|
+
# Tensor, shape (batch_size * sequence_length, hidden_dim)
|
|
95
|
+
x = x.reshape(-1, x.shape[-1])
|
|
96
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
97
|
+
xtx = torch.matmul(x.transpose(0, 1), x)
|
|
98
|
+
# store the averaged weights in regmean_weights
|
|
99
|
+
if module_name not in regmean_weights.keys():
|
|
100
|
+
regmean_weights[module_name] = xtx / x.shape[0]
|
|
101
|
+
num_computed_examples[module_name] = x.shape[0]
|
|
102
|
+
num_actual_examples[module_name] = batch_num_actual_examples
|
|
103
|
+
else:
|
|
104
|
+
regmean_weights[module_name] = (
|
|
105
|
+
regmean_weights[module_name]
|
|
106
|
+
* num_computed_examples[module_name]
|
|
107
|
+
+ xtx
|
|
108
|
+
) / (num_computed_examples[module_name] + x.shape[0])
|
|
109
|
+
num_computed_examples[module_name] += x.shape[0]
|
|
110
|
+
num_actual_examples[module_name] += batch_num_actual_examples
|
|
111
|
+
|
|
112
|
+
return hook
|
|
113
|
+
|
|
114
|
+
handles = []
|
|
115
|
+
# dictionary, regmean matrices for each linear module inputs
|
|
116
|
+
regmean_weights = {}
|
|
117
|
+
# dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
|
|
118
|
+
num_computed_examples = {}
|
|
119
|
+
# dictionary, number of actual examples used for computing regmean matrices
|
|
120
|
+
num_actual_examples = {}
|
|
121
|
+
|
|
122
|
+
for module_name, linear_module_to_merge in linear_modules_to_merge.items():
|
|
123
|
+
# register a hook in the forward process
|
|
124
|
+
handle = linear_module_to_merge.register_forward_hook(
|
|
125
|
+
compute_regmean_weights(module_name=module_name)
|
|
126
|
+
)
|
|
127
|
+
handles.append(handle)
|
|
128
|
+
for step, batch in tqdm(
|
|
129
|
+
enumerate(train_dataloader),
|
|
130
|
+
desc=f"computing regmean weights for model {model_name}",
|
|
131
|
+
):
|
|
132
|
+
if (
|
|
133
|
+
len(num_actual_examples) > 0
|
|
134
|
+
and list(num_actual_examples.values())[0]
|
|
135
|
+
>= self.config.num_regmean_examples
|
|
136
|
+
):
|
|
137
|
+
break
|
|
138
|
+
logits = self.compute_logits(model, batch, model_name)
|
|
139
|
+
|
|
140
|
+
# remove the added hook
|
|
141
|
+
for handle in handles:
|
|
142
|
+
handle.remove()
|
|
143
|
+
|
|
144
|
+
for module_name in regmean_weights.keys():
|
|
145
|
+
regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()
|
|
146
|
+
|
|
147
|
+
return regmean_weights
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file contains the implementation of the regmean method for model merging.
|
|
3
|
+
modified from https://github.com/yule-BUAA/MergeLM/blob/6d49ad96fd69c92013654b837041b868aa806564/model_merging_methods/merging_methods.py
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import re
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from typing import Dict, List, cast
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch import Tensor, nn
|
|
13
|
+
from tqdm.autonotebook import tqdm
|
|
14
|
+
|
|
15
|
+
from fusion_bench.method import BaseAlgorithm
|
|
16
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_param_names_to_merge(
|
|
22
|
+
input_param_names: List[str], exclude_param_names_regex: list
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
get the names of parameters that need to be merged
|
|
26
|
+
:param input_param_names: list, names of input parameters
|
|
27
|
+
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
28
|
+
:return:
|
|
29
|
+
"""
|
|
30
|
+
param_names_to_merge = []
|
|
31
|
+
for param_name in input_param_names:
|
|
32
|
+
exclude = any(
|
|
33
|
+
[
|
|
34
|
+
re.match(exclude_pattern, param_name)
|
|
35
|
+
for exclude_pattern in exclude_param_names_regex
|
|
36
|
+
]
|
|
37
|
+
)
|
|
38
|
+
if not exclude:
|
|
39
|
+
param_names_to_merge.append(param_name)
|
|
40
|
+
return param_names_to_merge
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_modules_to_merge(model: nn.Module, include_module_types: list):
|
|
44
|
+
"""
|
|
45
|
+
get the model modules that need to be merged, whose type is in include_module_types
|
|
46
|
+
:param model: nn.Module, input model
|
|
47
|
+
:param include_module_types: list, module types that want to include
|
|
48
|
+
:return:
|
|
49
|
+
"""
|
|
50
|
+
modules_to_merge: Dict[str, nn.Module] = {}
|
|
51
|
+
for module_name, module in model.named_modules():
|
|
52
|
+
is_valid_type = not include_module_types or any(
|
|
53
|
+
[
|
|
54
|
+
isinstance(module, include_module_type)
|
|
55
|
+
for include_module_type in include_module_types
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
if is_valid_type:
|
|
59
|
+
modules_to_merge[module_name] = module
|
|
60
|
+
return modules_to_merge
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def reduce_non_diagonal_elements(
|
|
64
|
+
regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
reduce the non-diagonal elements in regmean_weights
|
|
68
|
+
:param regmean_weights: Tensor, shape (hidden_dim, hidden_dim), input regmean weights
|
|
69
|
+
:param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
70
|
+
:return:
|
|
71
|
+
"""
|
|
72
|
+
# diagonal matrix with (1 - reduce_non_diagonal_ratio) as elements
|
|
73
|
+
diag_weights = torch.diag(
|
|
74
|
+
torch.ones(regmean_weights.shape[0]) - reduce_non_diagonal_ratio
|
|
75
|
+
).to(regmean_weights.device)
|
|
76
|
+
# matrix with reduce_non_diagonal_ratio as elements
|
|
77
|
+
non_diag_weights = torch.zeros_like(diag_weights).fill_(reduce_non_diagonal_ratio)
|
|
78
|
+
# diagonal elements are unchanged, while non-diagonal elements are multiplied by reduce_non_diagonal_ratio
|
|
79
|
+
return regmean_weights * (diag_weights + non_diag_weights)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def merging_with_regmean_weights(
|
|
83
|
+
models_to_merge_param_dict: dict,
|
|
84
|
+
models_to_merge_regmean_weights_list: list,
|
|
85
|
+
reduce_non_diagonal_ratio: float = 1.0,
|
|
86
|
+
weight_transpose: bool = True,
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
merge parameters of different models with computed regmean weights
|
|
90
|
+
:param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
|
|
91
|
+
value is a list of the corresponding parameters of all the models that need to be merged
|
|
92
|
+
:param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
|
|
93
|
+
each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
|
|
94
|
+
:param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
95
|
+
:return:
|
|
96
|
+
"""
|
|
97
|
+
# dict, dictionary of model parameters
|
|
98
|
+
merged_params = {}
|
|
99
|
+
|
|
100
|
+
for param_name, param_value_list in models_to_merge_param_dict.items():
|
|
101
|
+
merged_by_regmean = False
|
|
102
|
+
# only perform regmean merging on the "weight" parameter of Linear module
|
|
103
|
+
if param_name.endswith(".weight"):
|
|
104
|
+
module_name = param_name[: -len(".weight")]
|
|
105
|
+
if module_name in models_to_merge_regmean_weights_list[0].keys():
|
|
106
|
+
# two lists with length num_models_to_merge
|
|
107
|
+
param_multiplied_results, module_regmean_weights_list = [], []
|
|
108
|
+
for model_idx, model_to_merge_regmean_weights in enumerate(
|
|
109
|
+
models_to_merge_regmean_weights_list
|
|
110
|
+
):
|
|
111
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
112
|
+
module_regmean_weights = model_to_merge_regmean_weights[module_name]
|
|
113
|
+
|
|
114
|
+
# reduce non-diagonal elements
|
|
115
|
+
module_regmean_weights = reduce_non_diagonal_elements(
|
|
116
|
+
regmean_weights=module_regmean_weights,
|
|
117
|
+
reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
|
|
118
|
+
)
|
|
119
|
+
module_regmean_weights_list.append(module_regmean_weights)
|
|
120
|
+
|
|
121
|
+
model_to_merge_param = param_value_list[model_idx]
|
|
122
|
+
# since the weight shape of Linear module is (output_size, input_size), we need to transpose it
|
|
123
|
+
param_multiplied_results.append(
|
|
124
|
+
torch.matmul(
|
|
125
|
+
module_regmean_weights,
|
|
126
|
+
(
|
|
127
|
+
model_to_merge_param.transpose(0, 1)
|
|
128
|
+
if weight_transpose
|
|
129
|
+
else model_to_merge_param
|
|
130
|
+
),
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# sum up module_regmean_weights and param_multiplied_results over all individual models
|
|
135
|
+
sum_module_regmean_weights = sum(module_regmean_weights_list)
|
|
136
|
+
sum_param_multiplied_results = sum(param_multiplied_results)
|
|
137
|
+
|
|
138
|
+
# get the inverse matrix
|
|
139
|
+
inv_sum_module_regmean_weights = torch.inverse(
|
|
140
|
+
sum_module_regmean_weights
|
|
141
|
+
)
|
|
142
|
+
# merge parameters with regmean
|
|
143
|
+
merged_param = torch.matmul(
|
|
144
|
+
inv_sum_module_regmean_weights, sum_param_multiplied_results
|
|
145
|
+
)
|
|
146
|
+
# transpose to the original shape of "weight" in Linear module
|
|
147
|
+
merged_params[param_name] = (
|
|
148
|
+
merged_param.transpose(0, 1) if weight_transpose else merged_param
|
|
149
|
+
)
|
|
150
|
+
merged_by_regmean = True
|
|
151
|
+
# use average merging for parameters whose names are not end with ".weight" or not in Linear module
|
|
152
|
+
if not merged_by_regmean:
|
|
153
|
+
merged_params[param_name] = torch.stack(param_value_list, dim=0).mean(dim=0)
|
|
154
|
+
|
|
155
|
+
return merged_params
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def regmean_merging(
|
|
159
|
+
models_to_merge: list,
|
|
160
|
+
trainers: list,
|
|
161
|
+
exclude_param_names_regex: list,
|
|
162
|
+
nums_regmean_examples: list,
|
|
163
|
+
reduce_non_diagonal_ratio: float = 1.0,
|
|
164
|
+
):
|
|
165
|
+
"""
|
|
166
|
+
regmean merging method
|
|
167
|
+
:param models_to_merge: list, individual models that need to be merged
|
|
168
|
+
:param trainers: list, trainers of individual models
|
|
169
|
+
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
170
|
+
:param nums_regmean_examples: list, numbers of examples to compute regmean weights
|
|
171
|
+
:param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
172
|
+
:return:
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def compute_regmean_weights(module_name: str):
|
|
176
|
+
"""
|
|
177
|
+
compute the regmean weights, a hook function to deal with each module's input
|
|
178
|
+
:param module_name: str, module name
|
|
179
|
+
:return:
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def hook(module: nn.Module, input: tuple, output: torch.Tensor):
|
|
183
|
+
# Tensor, shape (batch_size, sequence_length, hidden_dim)
|
|
184
|
+
x = input[0].detach()
|
|
185
|
+
batch_num_actual_examples = x.shape[0]
|
|
186
|
+
# Tensor, shape (batch_size * sequence_length, hidden_dim)
|
|
187
|
+
x = x.reshape(-1, x.shape[-1])
|
|
188
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
189
|
+
xtx = torch.matmul(x.transpose(0, 1), x)
|
|
190
|
+
# store the averaged weights in regmean_weights
|
|
191
|
+
if module_name not in regmean_weights.keys():
|
|
192
|
+
regmean_weights[module_name] = xtx / x.shape[0]
|
|
193
|
+
num_computed_examples[module_name] = x.shape[0]
|
|
194
|
+
num_actual_examples[module_name] = batch_num_actual_examples
|
|
195
|
+
else:
|
|
196
|
+
regmean_weights[module_name] = (
|
|
197
|
+
regmean_weights[module_name] * num_computed_examples[module_name]
|
|
198
|
+
+ xtx
|
|
199
|
+
) / (num_computed_examples[module_name] + x.shape[0])
|
|
200
|
+
num_computed_examples[module_name] += x.shape[0]
|
|
201
|
+
num_actual_examples[module_name] += batch_num_actual_examples
|
|
202
|
+
|
|
203
|
+
return hook
|
|
204
|
+
|
|
205
|
+
# dictionary of list, where key is the parameter name,
|
|
206
|
+
# value is a list of the corresponding parameters of all the models that need to be merged
|
|
207
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
208
|
+
|
|
209
|
+
# list of dictionaries with length len(models_to_merge),
|
|
210
|
+
# each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
|
|
211
|
+
models_to_merge_regmean_weights_list = []
|
|
212
|
+
|
|
213
|
+
# iterate each individual model that needs to be merged
|
|
214
|
+
with torch.no_grad():
|
|
215
|
+
for model_idx, (model_to_merge, trainer, num_regmean_examples) in enumerate(
|
|
216
|
+
zip(models_to_merge, trainers, nums_regmean_examples)
|
|
217
|
+
):
|
|
218
|
+
param_dict = {
|
|
219
|
+
param_name: param_value
|
|
220
|
+
for param_name, param_value in model_to_merge.named_parameters()
|
|
221
|
+
}
|
|
222
|
+
# exclude parameter whose name matches element in exclude_param_names_regex
|
|
223
|
+
param_names_to_merge = get_param_names_to_merge(
|
|
224
|
+
input_param_names=list(param_dict.keys()),
|
|
225
|
+
exclude_param_names_regex=exclude_param_names_regex,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
for param_name in param_names_to_merge:
|
|
229
|
+
models_to_merge_param_dict[param_name].append(param_dict[param_name])
|
|
230
|
+
|
|
231
|
+
linear_modules_to_merge = get_modules_to_merge(
|
|
232
|
+
model=model_to_merge, include_module_types=[nn.Linear]
|
|
233
|
+
)
|
|
234
|
+
handles = []
|
|
235
|
+
# dictionary, regmean matrices for each linear module inputs
|
|
236
|
+
regmean_weights = {}
|
|
237
|
+
# dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
|
|
238
|
+
num_computed_examples = {}
|
|
239
|
+
# dictionary, number of actual examples used for computing regmean matrices
|
|
240
|
+
num_actual_examples = {}
|
|
241
|
+
|
|
242
|
+
for module_name, linear_module_to_merge in linear_modules_to_merge.items():
|
|
243
|
+
# register a hook in the forward process
|
|
244
|
+
handle = linear_module_to_merge.register_forward_hook(
|
|
245
|
+
compute_regmean_weights(module_name=module_name)
|
|
246
|
+
)
|
|
247
|
+
handles.append(handle)
|
|
248
|
+
|
|
249
|
+
train_dataloader = trainer.get_train_dataloader()
|
|
250
|
+
if num_regmean_examples % trainer._train_batch_size != 0:
|
|
251
|
+
print(
|
|
252
|
+
f"warning: the number of examples for computing regmean cannot be fully divided by the batch size for model {model_idx}, "
|
|
253
|
+
"which may lead to a slightly different number of the actually used examples."
|
|
254
|
+
)
|
|
255
|
+
for step, inputs in tqdm(
|
|
256
|
+
enumerate(train_dataloader),
|
|
257
|
+
desc=f"computing regmean weights for model {model_idx}",
|
|
258
|
+
):
|
|
259
|
+
if (
|
|
260
|
+
len(num_actual_examples) > 0
|
|
261
|
+
and list(num_actual_examples.values())[0] >= num_regmean_examples
|
|
262
|
+
):
|
|
263
|
+
break
|
|
264
|
+
inputs = trainer._prepare_inputs(inputs)
|
|
265
|
+
outputs = model_to_merge(**inputs)
|
|
266
|
+
|
|
267
|
+
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
268
|
+
|
|
269
|
+
# remove the added hook
|
|
270
|
+
for handle in handles:
|
|
271
|
+
handle.remove()
|
|
272
|
+
# merging with regmean weights
|
|
273
|
+
merged_params = merging_with_regmean_weights(
|
|
274
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
275
|
+
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
276
|
+
reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return merged_params
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class RegMeanAlgorithm(BaseAlgorithm):
|
|
283
|
+
_include_module_type = [nn.Linear]
|
|
284
|
+
_config_mapping = {
|
|
285
|
+
"num_regmean_examples": "num_regmean_examples",
|
|
286
|
+
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
287
|
+
"reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
|
|
288
|
+
"weight_transpose": "weight_transpose",
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
*,
|
|
294
|
+
num_regmean_examples: int,
|
|
295
|
+
exclude_param_names_regex: list,
|
|
296
|
+
reduce_non_diagonal_ratio: float,
|
|
297
|
+
weight_transpose: bool,
|
|
298
|
+
**kwargs,
|
|
299
|
+
):
|
|
300
|
+
self.num_regmean_examples = num_regmean_examples
|
|
301
|
+
self.exclude_param_names_regex = exclude_param_names_regex
|
|
302
|
+
self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
|
|
303
|
+
self.weight_transpose = weight_transpose
|
|
304
|
+
super().__init__(**kwargs)
|
|
305
|
+
|
|
306
|
+
def run(self, modelpool: BaseModelPool, **kwargs):
|
|
307
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
308
|
+
modelpool = BaseModelPool(modelpool)
|
|
309
|
+
self.modelpool = modelpool
|
|
310
|
+
self.on_regmean_start()
|
|
311
|
+
|
|
312
|
+
# dictionary of list, where key is the parameter name,
|
|
313
|
+
# value is a list of the corresponding parameters of all the models that need to be merged
|
|
314
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
315
|
+
|
|
316
|
+
# list of dictionaries with length len(models_to_merge),
|
|
317
|
+
# each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
|
|
318
|
+
models_to_merge_regmean_weights_list = []
|
|
319
|
+
|
|
320
|
+
param_names_to_merge = None
|
|
321
|
+
|
|
322
|
+
with torch.no_grad():
|
|
323
|
+
for name, model in modelpool.named_models():
|
|
324
|
+
param_dict = model.state_dict()
|
|
325
|
+
|
|
326
|
+
# exclude parameter whose name matches element in exclude_param_names_regex
|
|
327
|
+
if param_names_to_merge is None:
|
|
328
|
+
param_names_to_merge = get_param_names_to_merge(
|
|
329
|
+
input_param_names=list(param_dict.keys()),
|
|
330
|
+
exclude_param_names_regex=self.config.get(
|
|
331
|
+
"exclude_param_names_regex", []
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
for param_name in param_names_to_merge:
|
|
336
|
+
models_to_merge_param_dict[param_name].append(
|
|
337
|
+
param_dict[param_name]
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
linear_modules_to_merge = get_modules_to_merge(
|
|
341
|
+
model=model, include_module_types=self._include_module_type
|
|
342
|
+
)
|
|
343
|
+
assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
|
|
344
|
+
|
|
345
|
+
regmean_weights = self.get_regmean_weights(
|
|
346
|
+
name,
|
|
347
|
+
model,
|
|
348
|
+
train_dataset=modelpool.load_train_dataset(name),
|
|
349
|
+
linear_modules_to_merge=linear_modules_to_merge,
|
|
350
|
+
)
|
|
351
|
+
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
352
|
+
|
|
353
|
+
# merging with regmean weights
|
|
354
|
+
merged_params = merging_with_regmean_weights(
|
|
355
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
356
|
+
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
357
|
+
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
358
|
+
weight_transpose=self.config.get("weight_transpose", True),
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
merged_model = modelpool.load_model("_pretrained_")
|
|
362
|
+
merged_model.load_state_dict(merged_params, strict=False)
|
|
363
|
+
return merged_model
|
|
364
|
+
|
|
365
|
+
def on_regmean_start(self):
|
|
366
|
+
pass
|
|
367
|
+
|
|
368
|
+
def get_regmean_weights(
|
|
369
|
+
self,
|
|
370
|
+
model_name: str,
|
|
371
|
+
model: nn.Module,
|
|
372
|
+
train_dataset,
|
|
373
|
+
linear_modules_to_merge: Dict[str, nn.Module],
|
|
374
|
+
):
|
|
375
|
+
raise NotImplementedError
|