fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This implementation is largely based on the implementation from https://github.com/yule-BUAA/MergeLM/
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from typing import Dict, List
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor, nn
|
|
12
|
+
from tqdm.autonotebook import tqdm
|
|
13
|
+
|
|
14
|
+
from fusion_bench.method import BaseAlgorithm
|
|
15
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_param_names_to_merge(
|
|
21
|
+
input_param_names: List[str], exclude_param_names_regex: list
|
|
22
|
+
) -> List[str]:
|
|
23
|
+
"""
|
|
24
|
+
Get the names of parameters that need to be merged.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
input_param_names (List[str]): List of input parameter names.
|
|
28
|
+
exclude_param_names_regex (list): List of regular expressions for parameter names to be excluded.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
List[str]: List of parameter names to be merged.
|
|
32
|
+
"""
|
|
33
|
+
param_names_to_merge = []
|
|
34
|
+
for param_name in input_param_names:
|
|
35
|
+
exclude = any(
|
|
36
|
+
[
|
|
37
|
+
re.match(exclude_pattern, param_name)
|
|
38
|
+
for exclude_pattern in exclude_param_names_regex
|
|
39
|
+
]
|
|
40
|
+
)
|
|
41
|
+
if not exclude:
|
|
42
|
+
param_names_to_merge.append(param_name)
|
|
43
|
+
return param_names_to_merge
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_param_squared_gradients(
|
|
47
|
+
model: nn.Module, param_names_to_merge: List[str]
|
|
48
|
+
) -> Dict[str, Tensor]:
|
|
49
|
+
"""
|
|
50
|
+
Get the squared gradients of parameters.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model (nn.Module): The model.
|
|
54
|
+
param_names_to_merge (List[str]): List of parameter names to be merged.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Dict[str, Tensor]: Dictionary of parameter names and their squared gradients.
|
|
58
|
+
"""
|
|
59
|
+
param_squared_gradients = {
|
|
60
|
+
param_name: param_value.grad.detach() ** 2
|
|
61
|
+
for param_name, param_value in model.state_dict(keep_vars=True).items()
|
|
62
|
+
if param_name in param_names_to_merge
|
|
63
|
+
}
|
|
64
|
+
return param_squared_gradients
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_models_fisher_norm(
|
|
68
|
+
models_to_merge_param_dict: dict, models_to_merge_fisher_weights_list: list
|
|
69
|
+
) -> Tensor:
|
|
70
|
+
"""
|
|
71
|
+
Get normalization of Fisher weights of all the models that need to be merged.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
models_to_merge_param_dict (dict): Dictionary of list, where key is the parameter name,
|
|
75
|
+
value is a list of the corresponding parameters of all the models that need to be merged.
|
|
76
|
+
models_to_merge_fisher_weights_list (list): List of dictionaries with length len(models_to_merge),
|
|
77
|
+
each dictionary records the Fisher weights (matrix or vector) of parameters for each model that needs to be merged.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Tensor: L2 norm over all the parameters of models that need to be merged.
|
|
81
|
+
"""
|
|
82
|
+
# dict, key is parameter name, value is a Tensor with shape (num_models_to_merge, )
|
|
83
|
+
models_fisher_norm_dict = {}
|
|
84
|
+
# compute L2 norm over models for each parameter
|
|
85
|
+
for param_name, _ in models_to_merge_param_dict.items():
|
|
86
|
+
# Tensor, shape (num_models_to_merge, *fisher_weight_shape)
|
|
87
|
+
models_fisher = torch.stack(
|
|
88
|
+
[
|
|
89
|
+
model_to_merge_fisher_weights[param_name]
|
|
90
|
+
for model_to_merge_fisher_weights in models_to_merge_fisher_weights_list
|
|
91
|
+
],
|
|
92
|
+
dim=0,
|
|
93
|
+
)
|
|
94
|
+
dims = [dim_idx for dim_idx in range(1, models_fisher.dim())]
|
|
95
|
+
# Tensor, shape (num_models_to_merge, ), compute L2 norm for each parameter
|
|
96
|
+
models_fisher_norm = torch.linalg.vector_norm(models_fisher, dim=dims)
|
|
97
|
+
models_fisher_norm_dict[param_name] = models_fisher_norm
|
|
98
|
+
|
|
99
|
+
# Tensor, shape (num_models_to_merge, num_parameters)
|
|
100
|
+
models_fisher_norm = torch.stack(
|
|
101
|
+
[models_fisher_norm for models_fisher_norm in models_fisher_norm_dict.values()],
|
|
102
|
+
dim=1,
|
|
103
|
+
)
|
|
104
|
+
# Tensor, shape (num_models_to_merge, ), compute L2 norm over all the parameters
|
|
105
|
+
models_fisher_norm = torch.norm(models_fisher_norm, dim=1)
|
|
106
|
+
return models_fisher_norm
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def merging_with_fisher_weights(
|
|
110
|
+
models_to_merge_param_dict: Dict[str, List[Tensor]],
|
|
111
|
+
models_to_merge_fisher_weights_list: list,
|
|
112
|
+
fisher_scaling_coefficients: torch.Tensor,
|
|
113
|
+
normalize_fisher_weight: bool = True,
|
|
114
|
+
minimal_fisher_weight: float = 1e-6,
|
|
115
|
+
) -> Dict[str, Tensor]:
|
|
116
|
+
"""
|
|
117
|
+
Merge parameters of different models with computed Fisher weights.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
models_to_merge_param_dict (Dict[str, List[Tensor]]): Dictionary of list, where key is the parameter name,
|
|
121
|
+
value is a list of the corresponding parameters of all the models that need to be merged.
|
|
122
|
+
models_to_merge_fisher_weights_list (list): List of dictionaries with length len(models_to_merge),
|
|
123
|
+
each dictionary records the Fisher weights (matrix or vector) of parameters for each model that needs to be merged.
|
|
124
|
+
fisher_scaling_coefficients (torch.Tensor): Scaling coefficients to merge Fisher weights.
|
|
125
|
+
normalize_fisher_weight (bool): Whether to normalize Fisher weights (L2 norm) or not.
|
|
126
|
+
minimal_fisher_weight (float): The minimal value in Fisher weights, used for tackling the potential numerical issues.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Dict[str, Tensor]: Dictionary of merged parameters.
|
|
130
|
+
"""
|
|
131
|
+
# dict, dictionary of model parameters
|
|
132
|
+
merged_params = {}
|
|
133
|
+
|
|
134
|
+
if normalize_fisher_weight:
|
|
135
|
+
# Tensor, shape (num_models_to_merge, ), L2 norm over all the parameters of models that need to be merged
|
|
136
|
+
models_fisher_norm = get_models_fisher_norm(
|
|
137
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
138
|
+
models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
for param_name, param_value_list in models_to_merge_param_dict.items():
|
|
142
|
+
# shape (num_models_to_merge, *parameter_shape)
|
|
143
|
+
param_values = torch.stack(param_value_list, dim=0)
|
|
144
|
+
# Tensor, shape (num_models_to_merge, *fisher_weight_shape), use minimal_fisher_weight to solve the potential numerical issues
|
|
145
|
+
models_to_merge_fisher_weights = (
|
|
146
|
+
torch.stack(
|
|
147
|
+
[
|
|
148
|
+
model_to_merge_fisher_weights[param_name]
|
|
149
|
+
for model_to_merge_fisher_weights in models_to_merge_fisher_weights_list
|
|
150
|
+
],
|
|
151
|
+
dim=0,
|
|
152
|
+
)
|
|
153
|
+
+ minimal_fisher_weight
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Tensor, shape (num_models_to_merge, 1, 1, ...)
|
|
157
|
+
reshaped_scaling_coefficients = fisher_scaling_coefficients.reshape(
|
|
158
|
+
-1, *[1 for _ in range(param_values.dim() - 1)]
|
|
159
|
+
).to(param_values.device)
|
|
160
|
+
|
|
161
|
+
if normalize_fisher_weight:
|
|
162
|
+
# Tensor, shape (num_models_to_merge, )
|
|
163
|
+
_models_fisher_norm = 1.0 / (models_fisher_norm + minimal_fisher_weight)
|
|
164
|
+
normalized_models_fisher_norm = (
|
|
165
|
+
_models_fisher_norm / _models_fisher_norm.sum()
|
|
166
|
+
)
|
|
167
|
+
normalized_models_fisher_norm = normalized_models_fisher_norm.reshape(
|
|
168
|
+
-1, *[1 for _ in range(param_values.dim() - 1)]
|
|
169
|
+
)
|
|
170
|
+
reshaped_scaling_coefficients = (
|
|
171
|
+
reshaped_scaling_coefficients * normalized_models_fisher_norm
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# shape (*parameter_shape)
|
|
175
|
+
numerator = (
|
|
176
|
+
reshaped_scaling_coefficients
|
|
177
|
+
* models_to_merge_fisher_weights
|
|
178
|
+
* param_values
|
|
179
|
+
).sum(dim=0)
|
|
180
|
+
|
|
181
|
+
# shape (*parameter_shape)
|
|
182
|
+
denominator = (
|
|
183
|
+
reshaped_scaling_coefficients * models_to_merge_fisher_weights
|
|
184
|
+
).sum(dim=0)
|
|
185
|
+
|
|
186
|
+
merged_param = numerator / denominator
|
|
187
|
+
merged_params[param_name] = merged_param
|
|
188
|
+
return merged_params
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def fisher_merging(
|
|
192
|
+
models_to_merge: List[nn.Module],
|
|
193
|
+
trainers: list,
|
|
194
|
+
exclude_param_names_regex: list,
|
|
195
|
+
nums_fisher_examples: List[int],
|
|
196
|
+
fisher_scaling_coefficients: list = None,
|
|
197
|
+
normalize_fisher_weight: bool = True,
|
|
198
|
+
minimal_fisher_weight: float = 1e-6,
|
|
199
|
+
) -> Dict[str, Tensor]:
|
|
200
|
+
"""
|
|
201
|
+
Fisher merging method.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
models_to_merge (List[nn.Module]): List of individual models that need to be merged.
|
|
205
|
+
trainers (list): List of trainers of individual models.
|
|
206
|
+
exclude_param_names_regex (list): List of regular expressions for parameter names to be excluded.
|
|
207
|
+
nums_fisher_examples (List[int]): List of numbers of examples to compute Fisher weights.
|
|
208
|
+
fisher_scaling_coefficients (list, optional): Scaling coefficients to merge Fisher weights. Defaults to None.
|
|
209
|
+
normalize_fisher_weight (bool): Whether to normalize Fisher weights (L2 norm) or not. Defaults to True.
|
|
210
|
+
minimal_fisher_weight (float): The minimal value in Fisher weights, used for tackling the potential numerical issues. Defaults to 1e-6.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Dict[str, Tensor]: Dictionary of merged parameters.
|
|
214
|
+
"""
|
|
215
|
+
# dictionary of list, where key is the parameter name,
|
|
216
|
+
# value is a list of the corresponding parameters of all the models that need to be merged
|
|
217
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
218
|
+
|
|
219
|
+
# list of dictionaries with length len(models_to_merge),
|
|
220
|
+
# each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
|
|
221
|
+
models_to_merge_fisher_weights_list = []
|
|
222
|
+
|
|
223
|
+
assert (
|
|
224
|
+
len(models_to_merge) == len(trainers) == len(nums_fisher_examples)
|
|
225
|
+
), "sizes of lists are not identical!"
|
|
226
|
+
|
|
227
|
+
for model_idx, (model_to_merge, trainer, num_fisher_examples) in enumerate(
|
|
228
|
+
zip(models_to_merge, trainers, nums_fisher_examples)
|
|
229
|
+
):
|
|
230
|
+
param_dict = {
|
|
231
|
+
param_name: param_value
|
|
232
|
+
for param_name, param_value in model_to_merge.named_parameters()
|
|
233
|
+
}
|
|
234
|
+
# exclude parameter whose name matches element in exclude_param_names_regex
|
|
235
|
+
param_names_to_merge = get_param_names_to_merge(
|
|
236
|
+
input_param_names=list(param_dict.keys()),
|
|
237
|
+
exclude_param_names_regex=exclude_param_names_regex,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
for param_name in param_names_to_merge:
|
|
241
|
+
models_to_merge_param_dict[param_name].append(param_dict[param_name])
|
|
242
|
+
|
|
243
|
+
# list of dictionaries with length (num_fisher_examples // batch_size) or (num_fisher_examples // batch_size) + 1,
|
|
244
|
+
# each dictionary records the fisher weights of parameters for model_to_merge computed by examples in a batch
|
|
245
|
+
batches_fisher_weights_list = []
|
|
246
|
+
|
|
247
|
+
num_computed_examples = 0
|
|
248
|
+
train_dataloader = trainer.get_train_dataloader()
|
|
249
|
+
if num_fisher_examples % trainer._train_batch_size != 0:
|
|
250
|
+
print(
|
|
251
|
+
f"warning: the number of examples for computing fisher cannot be fully divided by the batch size for model {model_idx}, "
|
|
252
|
+
"which may lead to a slightly different number of the actually used examples."
|
|
253
|
+
)
|
|
254
|
+
for step, inputs in tqdm(
|
|
255
|
+
enumerate(train_dataloader),
|
|
256
|
+
desc=f"computing fisher weights for model {model_idx}",
|
|
257
|
+
):
|
|
258
|
+
if num_computed_examples >= num_fisher_examples:
|
|
259
|
+
break
|
|
260
|
+
inputs = trainer._prepare_inputs(inputs)
|
|
261
|
+
outputs = model_to_merge(**inputs)
|
|
262
|
+
# Tensor, shape (batch_size, num_label_classes)
|
|
263
|
+
logits = outputs.logits
|
|
264
|
+
# compute fisher weights for regression task
|
|
265
|
+
if logits.shape[-1] == 1:
|
|
266
|
+
# use the label information to compute loss and obtain gradients
|
|
267
|
+
mse_loss = outputs.loss
|
|
268
|
+
model_to_merge.zero_grad()
|
|
269
|
+
mse_loss.backward()
|
|
270
|
+
# dict, fisher weights of a batch
|
|
271
|
+
batch_fisher_weights = get_param_squared_gradients(
|
|
272
|
+
model=model_to_merge, param_names_to_merge=param_names_to_merge
|
|
273
|
+
)
|
|
274
|
+
# compute fisher weights for classifxication task
|
|
275
|
+
else:
|
|
276
|
+
# use detach() to detach from the computation graph
|
|
277
|
+
# Tensor, shape (batch_size, num_label_classes)
|
|
278
|
+
labels_probabilities = torch.softmax(logits, dim=-1).detach()
|
|
279
|
+
labels_log_probabilities = torch.log_softmax(logits, dim=-1)
|
|
280
|
+
# sqrt labels_probabilities, since torch.sqrt(labels_probabilities) would be squared in the following squared gradients
|
|
281
|
+
labels_expectations = (
|
|
282
|
+
torch.sqrt(labels_probabilities) * labels_log_probabilities
|
|
283
|
+
)
|
|
284
|
+
# sum over label classes and batch dimension
|
|
285
|
+
sum_labels_expectations = labels_expectations.sum(dim=-1).sum(dim=0)
|
|
286
|
+
model_to_merge.zero_grad()
|
|
287
|
+
sum_labels_expectations.backward()
|
|
288
|
+
# dict, fisher weights of a batch
|
|
289
|
+
batch_fisher_weights = get_param_squared_gradients(
|
|
290
|
+
model=model_to_merge, param_names_to_merge=param_names_to_merge
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
batches_fisher_weights_list.append(batch_fisher_weights)
|
|
294
|
+
num_computed_examples += trainer._train_batch_size
|
|
295
|
+
|
|
296
|
+
model_to_merge_fisher_weights = {}
|
|
297
|
+
for batch_fisher_weights in batches_fisher_weights_list:
|
|
298
|
+
for key in batch_fisher_weights:
|
|
299
|
+
if key not in model_to_merge_fisher_weights:
|
|
300
|
+
model_to_merge_fisher_weights[key] = batch_fisher_weights[key]
|
|
301
|
+
else:
|
|
302
|
+
model_to_merge_fisher_weights[key] += batch_fisher_weights[key]
|
|
303
|
+
|
|
304
|
+
# mean over batches
|
|
305
|
+
for key in model_to_merge_fisher_weights:
|
|
306
|
+
model_to_merge_fisher_weights[key] /= num_computed_examples
|
|
307
|
+
models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
|
|
308
|
+
|
|
309
|
+
# merging with fisher weights
|
|
310
|
+
# if fisher_scaling_coefficients is None, then set the fisher weights of different models to contribute equally
|
|
311
|
+
if fisher_scaling_coefficients is None:
|
|
312
|
+
fisher_scaling_coefficients = torch.ones(len(models_to_merge)) / len(
|
|
313
|
+
models_to_merge
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
assert isinstance(
|
|
317
|
+
fisher_scaling_coefficients, list
|
|
318
|
+
), "wrong type of fisher_scaling_coefficients, should be list!"
|
|
319
|
+
assert len(fisher_scaling_coefficients) == len(
|
|
320
|
+
models_to_merge
|
|
321
|
+
), "mismatched length of fisher_scaling_coefficients!"
|
|
322
|
+
fisher_scaling_coefficients = torch.Tensor(fisher_scaling_coefficients)
|
|
323
|
+
# merging with fisher weights
|
|
324
|
+
merged_params = merging_with_fisher_weights(
|
|
325
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
326
|
+
models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
|
|
327
|
+
fisher_scaling_coefficients=fisher_scaling_coefficients,
|
|
328
|
+
normalize_fisher_weight=normalize_fisher_weight,
|
|
329
|
+
minimal_fisher_weight=minimal_fisher_weight,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
return merged_params
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def filter_state_dict(
|
|
336
|
+
state_dict: Dict[str, Tensor],
|
|
337
|
+
param_names: List[str],
|
|
338
|
+
) -> Dict[str, Tensor]:
|
|
339
|
+
"""
|
|
340
|
+
Filter the state dict with the param names.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
state_dict (Dict[str, Tensor]): State dict of a model.
|
|
344
|
+
param_names (List[str]): List of parameter names to be filtered.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Dict[str, Tensor]: Filtered state dict.
|
|
348
|
+
"""
|
|
349
|
+
filtered_state_dict = {}
|
|
350
|
+
for key in param_names:
|
|
351
|
+
filtered_state_dict[key] = state_dict[key]
|
|
352
|
+
return filtered_state_dict
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class FisherMergingAlgorithm(BaseAlgorithm):
|
|
356
|
+
"""
|
|
357
|
+
Implements the Fisher Merging Algorithm.
|
|
358
|
+
|
|
359
|
+
This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights.
|
|
360
|
+
It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.
|
|
361
|
+
|
|
362
|
+
Methods:
|
|
363
|
+
run(modelpool: BaseModelPool) -> nn.Module:
|
|
364
|
+
Executes the Fisher merging process on the model pool and returns the merged model.
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
368
|
+
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
369
|
+
"normalize_fisher_weight": "normalize_fisher_weight",
|
|
370
|
+
"minimal_fisher_weight": "minimal_fisher_weight",
|
|
371
|
+
"num_fisher_examples": "num_fisher_examples",
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
def __init__(
|
|
375
|
+
self,
|
|
376
|
+
*,
|
|
377
|
+
exclude_param_names_regex: list,
|
|
378
|
+
normalize_fisher_weight: bool,
|
|
379
|
+
minimal_fisher_weight: float,
|
|
380
|
+
num_fisher_examples: int,
|
|
381
|
+
):
|
|
382
|
+
super().__init__()
|
|
383
|
+
self.exclude_param_names_regex = exclude_param_names_regex
|
|
384
|
+
self.normalize_fisher_weight = normalize_fisher_weight
|
|
385
|
+
self.minimal_fisher_weight = minimal_fisher_weight
|
|
386
|
+
self.num_fisher_examples = num_fisher_examples
|
|
387
|
+
|
|
388
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
389
|
+
"""
|
|
390
|
+
Run the Fisher Merging Algorithm.
|
|
391
|
+
|
|
392
|
+
This method constructs the wrapped model and performs test-time adaptation if necessary.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
nn.Module: The merged model after test-time adaptation.
|
|
399
|
+
"""
|
|
400
|
+
log.info("Running Fisher Merging Algorithm")
|
|
401
|
+
if isinstance(modelpool, (dict, list, tuple)):
|
|
402
|
+
modelpool = BaseModelPool(modelpool)
|
|
403
|
+
|
|
404
|
+
assert len(modelpool) > 0, "model pool is empty"
|
|
405
|
+
assert (
|
|
406
|
+
modelpool.has_pretrained
|
|
407
|
+
), "no pretrained model (base model) in the model pool"
|
|
408
|
+
|
|
409
|
+
self.modelpool = modelpool
|
|
410
|
+
self.on_fisher_merging_start()
|
|
411
|
+
|
|
412
|
+
# dictionary of list, where key is the parameter name,
|
|
413
|
+
# value is a list of the corresponding parameters of all the models that need to be merged
|
|
414
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
415
|
+
|
|
416
|
+
# list of dictionaries with length len(models_to_merge),
|
|
417
|
+
# each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
|
|
418
|
+
models_to_merge_fisher_weights_list = []
|
|
419
|
+
|
|
420
|
+
param_names_to_merge = None
|
|
421
|
+
|
|
422
|
+
for name, model in modelpool.named_models():
|
|
423
|
+
param_dict = model.state_dict()
|
|
424
|
+
if param_names_to_merge is None:
|
|
425
|
+
param_names_to_merge = get_param_names_to_merge(
|
|
426
|
+
input_param_names=list(param_dict.keys()),
|
|
427
|
+
exclude_param_names_regex=self.config.get(
|
|
428
|
+
"exclude_param_names_regex", []
|
|
429
|
+
),
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
for param_name in param_names_to_merge:
|
|
433
|
+
models_to_merge_param_dict[param_name].append(param_dict[param_name])
|
|
434
|
+
|
|
435
|
+
model_to_merge_fisher_weights = self.get_fisher_weights(
|
|
436
|
+
model_name=name,
|
|
437
|
+
model=model,
|
|
438
|
+
train_dataset=modelpool.load_train_dataset(name),
|
|
439
|
+
param_names_to_merge=param_names_to_merge,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
|
|
443
|
+
|
|
444
|
+
merged_params = merging_with_fisher_weights(
|
|
445
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
446
|
+
models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
|
|
447
|
+
fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
|
|
448
|
+
normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
|
|
449
|
+
minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
merged_model = modelpool.load_model("_pretrained_")
|
|
453
|
+
merged_model.load_state_dict(merged_params, strict=False)
|
|
454
|
+
return merged_model
|
|
455
|
+
|
|
456
|
+
def get_fisher_weights(
|
|
457
|
+
self,
|
|
458
|
+
model_name: str,
|
|
459
|
+
model: nn.Module,
|
|
460
|
+
train_dataset,
|
|
461
|
+
param_names_to_merge: List[str],
|
|
462
|
+
) -> Dict[str, Tensor]:
|
|
463
|
+
"""
|
|
464
|
+
Compute the Fisher weights for the given model and training dataset.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
model_name (str): The name of the model.
|
|
468
|
+
model (nn.Module): The model module.
|
|
469
|
+
train_dataset: The training dataset.
|
|
470
|
+
param_names_to_merge (List[str]): List of parameter names to merge.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
Dict[str, Tensor]: The computed Fisher weights for each parameter.
|
|
474
|
+
"""
|
|
475
|
+
# this function is used to compute fisher weights for a model
|
|
476
|
+
# it should be implemented in the subclass
|
|
477
|
+
raise NotImplementedError
|
|
478
|
+
|
|
479
|
+
def on_fisher_merging_start(self):
|
|
480
|
+
"""
|
|
481
|
+
Setup the zero-shot classification head before starting the Fisher merging process.
|
|
482
|
+
"""
|
|
483
|
+
# this function is used to initialize some variables before running fisher merging
|
|
484
|
+
pass
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from functools import cache
|
|
5
|
+
from typing import Dict, List, cast
|
|
6
|
+
|
|
7
|
+
import lightning as L
|
|
8
|
+
import torch
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
from torch.nn.modules import Module
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from tqdm.autonotebook import tqdm
|
|
14
|
+
from transformers import GPT2ForSequenceClassification, GPT2Model
|
|
15
|
+
from transformers.data import default_data_collator
|
|
16
|
+
from transformers.models.gpt2.modeling_gpt2 import Conv1D
|
|
17
|
+
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
19
|
+
from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
|
|
20
|
+
from fusion_bench.utils import timeit_context
|
|
21
|
+
|
|
22
|
+
from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FisherMergingAlgorithmForGPT2(
|
|
26
|
+
FisherMergingAlgorithm,
|
|
27
|
+
LightningFabricMixin,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Implements the Fisher Merging Algorithm for GPT-2 models on text classification tasks.
|
|
31
|
+
|
|
32
|
+
This class extends the FisherMergingAlgorithm to handle GPT-2 models specifically.
|
|
33
|
+
It supports caching, batch processing, and multi-worker data loading.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
classifiers (dict): A dictionary to store classifiers for each model.
|
|
37
|
+
modelpool (HuggingFaceGPT2ClassificationPool): The model pool containing the GPT-2 models.
|
|
38
|
+
cache_dir (str): Directory to cache data.
|
|
39
|
+
batch_size (int): Batch size for data loading.
|
|
40
|
+
num_workers (int): Number of workers for data loading.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
classifiers = {}
|
|
44
|
+
modelpool: GPT2ForSequenceClassificationPool = None
|
|
45
|
+
_config_mapping = FisherMergingAlgorithm._config_mapping | {
|
|
46
|
+
"cache_dir": "cache_dir",
|
|
47
|
+
"batch_size": "batch_size",
|
|
48
|
+
"num_workers": "num_workers",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
cache_dir: str,
|
|
54
|
+
batch_size: int,
|
|
55
|
+
num_workers: int,
|
|
56
|
+
**kwargs,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Initialize the FisherMergingAlgorithmForGPT2 with the given configuration.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
cache_dir (str): Directory to cache data.
|
|
63
|
+
batch_size (int): Batch size for data loading.
|
|
64
|
+
num_workers (int): Number of workers for data loading.
|
|
65
|
+
**kwargs: Additional keyword arguments.
|
|
66
|
+
"""
|
|
67
|
+
self.cache_dir = cache_dir
|
|
68
|
+
self.batch_size = batch_size
|
|
69
|
+
self.num_workers = num_workers
|
|
70
|
+
super().__init__(**kwargs)
|
|
71
|
+
|
|
72
|
+
def on_fisher_merging_start(self):
|
|
73
|
+
"""
|
|
74
|
+
Setup the classifiers for each model in the model pool before starting the Fisher merging process.
|
|
75
|
+
"""
|
|
76
|
+
for model_name in self.modelpool.model_names:
|
|
77
|
+
classifier = cast(
|
|
78
|
+
GPT2ForSequenceClassification,
|
|
79
|
+
self.modelpool.load_classifier(model_name),
|
|
80
|
+
).requires_grad_(False)
|
|
81
|
+
classifier.transformer = None
|
|
82
|
+
classifier = classifier.to(self.fabric.device)
|
|
83
|
+
self.classifiers[model_name] = classifier
|
|
84
|
+
|
|
85
|
+
def compute_logits(self, module: GPT2Model, batch, task: str) -> Tensor:
|
|
86
|
+
"""
|
|
87
|
+
Compute the logits for the given batch and task.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
module (GPT2Model): The GPT-2 model module.
|
|
91
|
+
batch (dict): The input batch.
|
|
92
|
+
task (str): The name of the task.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Tensor: The computed logits.
|
|
96
|
+
"""
|
|
97
|
+
self.classifiers[task].transformer = module
|
|
98
|
+
input_ids = batch["input_ids"]
|
|
99
|
+
attention_mask = batch["attention_mask"]
|
|
100
|
+
|
|
101
|
+
outputs = self.classifiers[task](input_ids, attention_mask=attention_mask)
|
|
102
|
+
logits = outputs.logits
|
|
103
|
+
assert logits.dim() == 2
|
|
104
|
+
return logits
|
|
105
|
+
|
|
106
|
+
def get_fisher_weights(
|
|
107
|
+
self,
|
|
108
|
+
model_name: str,
|
|
109
|
+
model: Module,
|
|
110
|
+
train_dataset,
|
|
111
|
+
param_names_to_merge: List[str],
|
|
112
|
+
) -> Dict[str, Tensor]:
|
|
113
|
+
"""
|
|
114
|
+
Compute the Fisher weights for the given model and training dataset.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model_name (str): The name of the model.
|
|
118
|
+
model (Module): The model module.
|
|
119
|
+
train_dataset: The training dataset.
|
|
120
|
+
param_names_to_merge (List[str]): List of parameter names to merge.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Dict[str, Tensor]: The computed Fisher weights for each parameter.
|
|
124
|
+
"""
|
|
125
|
+
# setup dataloader
|
|
126
|
+
train_dataloader = DataLoader(
|
|
127
|
+
train_dataset,
|
|
128
|
+
batch_size=self.config.batch_size,
|
|
129
|
+
shuffle=True,
|
|
130
|
+
collate_fn=default_data_collator,
|
|
131
|
+
num_workers=self.config.num_workers,
|
|
132
|
+
pin_memory=True,
|
|
133
|
+
)
|
|
134
|
+
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
135
|
+
model = self.fabric.setup(model)
|
|
136
|
+
num_fisher_examples = self.config.num_fisher_examples
|
|
137
|
+
if num_fisher_examples % train_dataloader.batch_size != 0:
|
|
138
|
+
print(
|
|
139
|
+
f"warning: the number of examples for computing fisher cannot be fully divided by the batch size for model, "
|
|
140
|
+
"which may lead to a slightly different number of the actually used examples."
|
|
141
|
+
)
|
|
142
|
+
num_computed_examples = 0
|
|
143
|
+
batches_fisher_weights_list = []
|
|
144
|
+
for step, batch in tqdm(
|
|
145
|
+
enumerate(train_dataloader),
|
|
146
|
+
desc=f"computing fisher weights",
|
|
147
|
+
total=num_fisher_examples // train_dataloader.batch_size,
|
|
148
|
+
):
|
|
149
|
+
if num_computed_examples >= num_fisher_examples:
|
|
150
|
+
break
|
|
151
|
+
logits = self.compute_logits(model, batch, model_name)
|
|
152
|
+
# Tensor, shape (batch_size, num_label_classes)
|
|
153
|
+
|
|
154
|
+
# compute fisher weights for classifxication task
|
|
155
|
+
# use detach() to detach from the computation graph
|
|
156
|
+
# Tensor, shape (batch_size, num_label_classes)
|
|
157
|
+
labels_probabilities = torch.softmax(logits, dim=-1).detach()
|
|
158
|
+
labels_log_probabilities = torch.log_softmax(logits, dim=-1)
|
|
159
|
+
# sqrt labels_probabilities, since torch.sqrt(labels_probabilities) would be squared in the following squared gradients
|
|
160
|
+
labels_expectations = (
|
|
161
|
+
torch.sqrt(labels_probabilities) * labels_log_probabilities
|
|
162
|
+
)
|
|
163
|
+
# sum over label classes and batch dimension
|
|
164
|
+
sum_labels_expectations = labels_expectations.sum(dim=-1).sum(dim=0)
|
|
165
|
+
model.zero_grad()
|
|
166
|
+
sum_labels_expectations.backward()
|
|
167
|
+
# dict, fisher weights of a batch
|
|
168
|
+
batch_fisher_weights = get_param_squared_gradients(
|
|
169
|
+
model=model, param_names_to_merge=param_names_to_merge
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# move fisher weights to cpu to save GPU memory
|
|
173
|
+
for key, weights in batch_fisher_weights.items():
|
|
174
|
+
batch_fisher_weights[key] = weights.detach().cpu()
|
|
175
|
+
|
|
176
|
+
batches_fisher_weights_list.append(batch_fisher_weights)
|
|
177
|
+
num_computed_examples += batch["input_ids"].size(0)
|
|
178
|
+
|
|
179
|
+
model_to_merge_fisher_weights = {}
|
|
180
|
+
for batch_fisher_weights in batches_fisher_weights_list:
|
|
181
|
+
for key in batch_fisher_weights:
|
|
182
|
+
if key not in model_to_merge_fisher_weights:
|
|
183
|
+
model_to_merge_fisher_weights[key] = batch_fisher_weights[key]
|
|
184
|
+
else:
|
|
185
|
+
model_to_merge_fisher_weights[key] += batch_fisher_weights[key]
|
|
186
|
+
|
|
187
|
+
# mean over batches
|
|
188
|
+
for key in model_to_merge_fisher_weights:
|
|
189
|
+
model_to_merge_fisher_weights[key] /= num_computed_examples
|
|
190
|
+
model_to_merge_fisher_weights[key] = (
|
|
191
|
+
model_to_merge_fisher_weights[key].detach().cpu()
|
|
192
|
+
)
|
|
193
|
+
return model_to_merge_fisher_weights
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
# flake8: noqa F401
|
|
2
|
+
from .expo import ExPOAlgorithm
|
|
3
|
+
from .linear_interpolation import LinearInterpolationAlgorithm
|
|
4
|
+
from .llama_expo import ExPOAlgorithmForLlama
|
|
5
|
+
from .simple_average_for_llama import SimpleAverageForLlama
|
|
6
|
+
from .task_arithmetic_for_llama import TaskArithmeticForLlama
|