fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import Dict, List, Mapping, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
10
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
12
|
+
state_dict_add,
|
|
13
|
+
state_dict_avg,
|
|
14
|
+
state_dict_div,
|
|
15
|
+
state_dict_mul,
|
|
16
|
+
)
|
|
17
|
+
from fusion_bench.utils.type import StateDictType
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def simple_average(
|
|
23
|
+
modules: List[Union[nn.Module, StateDictType]],
|
|
24
|
+
base_module: Optional[nn.Module] = None,
|
|
25
|
+
):
|
|
26
|
+
R"""
|
|
27
|
+
Averages the parameters of a list of PyTorch modules or state dictionaries.
|
|
28
|
+
|
|
29
|
+
This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
modules (List[Union[nn.Module, StateDictType]]): A list of PyTorch modules or state dictionaries.
|
|
33
|
+
base_module (Optional[nn.Module]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
module_or_state_dict (Union[nn.Module, StateDictType]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
37
|
+
|
|
38
|
+
Examples:
|
|
39
|
+
>>> import torch.nn as nn
|
|
40
|
+
>>> model1 = nn.Linear(10, 10)
|
|
41
|
+
>>> model2 = nn.Linear(10, 10)
|
|
42
|
+
>>> averaged_model = simple_averageing([model1, model2])
|
|
43
|
+
|
|
44
|
+
>>> state_dict1 = model1.state_dict()
|
|
45
|
+
>>> state_dict2 = model2.state_dict()
|
|
46
|
+
>>> averaged_state_dict = simple_averageing([state_dict1, state_dict2])
|
|
47
|
+
"""
|
|
48
|
+
if isinstance(modules[0], nn.Module):
|
|
49
|
+
if base_module is None:
|
|
50
|
+
new_module = deepcopy(modules[0])
|
|
51
|
+
else:
|
|
52
|
+
new_module = base_module
|
|
53
|
+
state_dict = state_dict_avg([module.state_dict() for module in modules])
|
|
54
|
+
new_module.load_state_dict(state_dict)
|
|
55
|
+
return new_module
|
|
56
|
+
elif isinstance(modules[0], Mapping):
|
|
57
|
+
return state_dict_avg(modules)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SimpleAverageAlgorithm(
|
|
61
|
+
BaseAlgorithm,
|
|
62
|
+
SimpleProfilerMixin,
|
|
63
|
+
):
|
|
64
|
+
@torch.no_grad()
|
|
65
|
+
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
66
|
+
"""
|
|
67
|
+
Fuse the models in the given model pool using simple averaging.
|
|
68
|
+
|
|
69
|
+
This method iterates over the names of the models in the model pool, loads each model, and appends it to a list.
|
|
70
|
+
It then returns the simple average of the models in the list.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
modelpool: The pool of models to fuse.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The fused model obtained by simple averaging.
|
|
77
|
+
"""
|
|
78
|
+
if isinstance(modelpool, dict):
|
|
79
|
+
modelpool = BaseModelPool(modelpool)
|
|
80
|
+
|
|
81
|
+
log.info(
|
|
82
|
+
f"Fusing models using simple average on {len(modelpool.model_names)} models."
|
|
83
|
+
f"models: {modelpool.model_names}"
|
|
84
|
+
)
|
|
85
|
+
sd: Optional[StateDictType] = None
|
|
86
|
+
forward_model = None
|
|
87
|
+
merged_model_names = []
|
|
88
|
+
|
|
89
|
+
for model_name in modelpool.model_names:
|
|
90
|
+
with self.profile("load model"):
|
|
91
|
+
model = modelpool.load_model(model_name)
|
|
92
|
+
merged_model_names.append(model_name)
|
|
93
|
+
print(f"load model of type: {type(model).__name__}")
|
|
94
|
+
with self.profile("merge weights"):
|
|
95
|
+
if sd is None:
|
|
96
|
+
# Initialize the state dictionary with the first model's state dictionary
|
|
97
|
+
sd = model.state_dict(keep_vars=True)
|
|
98
|
+
forward_model = model
|
|
99
|
+
else:
|
|
100
|
+
# Add the current model's state dictionary to the accumulated state dictionary
|
|
101
|
+
sd = state_dict_add(sd, model.state_dict(keep_vars=True))
|
|
102
|
+
with self.profile("merge weights"):
|
|
103
|
+
# Divide the accumulated state dictionary by the number of models to get the average
|
|
104
|
+
sd = state_dict_div(sd, len(modelpool.model_names))
|
|
105
|
+
|
|
106
|
+
forward_model.load_state_dict(sd)
|
|
107
|
+
# print profile report and log the merged models
|
|
108
|
+
self.print_profile_summary()
|
|
109
|
+
log.info(f"merged {len(merged_model_names)} models:")
|
|
110
|
+
for model_name in merged_model_names:
|
|
111
|
+
log.info(f" - {model_name}")
|
|
112
|
+
return forward_model
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
from fusion_bench.method import BaseAlgorithm
|
|
7
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
8
|
+
|
|
9
|
+
from .slerp_utils import slerp
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def slerp_on_state_dicts(
|
|
15
|
+
t,
|
|
16
|
+
primary_state_dict,
|
|
17
|
+
secondary_state_dict,
|
|
18
|
+
*,
|
|
19
|
+
DOT_THRESHOLD: float = 0.9995,
|
|
20
|
+
epsilon: float = 1e-8,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
t (float): The interpolation factor, typically between 0 and 1.
|
|
27
|
+
primary_state_dict (dict): The state dictionary of the primary model.
|
|
28
|
+
secondary_state_dict (dict): The state dictionary of the secondary model.
|
|
29
|
+
DOT_THRESHOLD (float, optional): Threshold for considering the vectors as collinear. Defaults to 0.9995.
|
|
30
|
+
epsilon (float, optional): Small value to avoid division by zero. Defaults to 1e-8.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
dict: The interpolated state dictionary.
|
|
34
|
+
"""
|
|
35
|
+
state_dict = {}
|
|
36
|
+
for key in secondary_state_dict:
|
|
37
|
+
v0 = primary_state_dict[key]
|
|
38
|
+
v1 = secondary_state_dict[key]
|
|
39
|
+
if v0.shape != v1.shape:
|
|
40
|
+
log.warning(
|
|
41
|
+
f"Skipping key {key} because the shapes of the tensors are different: {v0.shape} vs {v1.shape}. Base model parameters will be used."
|
|
42
|
+
)
|
|
43
|
+
state_dict[key] = v0
|
|
44
|
+
else:
|
|
45
|
+
state_dict[key] = slerp(t, v0, v1, DOT_THRESHOLD, epsilon)
|
|
46
|
+
return state_dict
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
50
|
+
"""
|
|
51
|
+
General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
55
|
+
"t": "t",
|
|
56
|
+
"DOT_THRESHOLD": "DOT_THRESHOLD",
|
|
57
|
+
"epsilon": "epsilon",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def __init__(self, t: float, DOT_THRESHOLD: float = 0.9995, epsilon: float = 1e-8):
|
|
61
|
+
"""
|
|
62
|
+
Initialize the SlerpMergeAlgorithm.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
t (float): The interpolation parameter. Must be in the range [0, 1].
|
|
66
|
+
DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
|
|
67
|
+
epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
|
|
68
|
+
"""
|
|
69
|
+
self.t = t
|
|
70
|
+
self.DOT_THRESHOLD = DOT_THRESHOLD
|
|
71
|
+
self.epsilon = epsilon
|
|
72
|
+
super().__init__()
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def run(self, modelpool: BaseModelPool):
|
|
76
|
+
"""
|
|
77
|
+
Run the SlerpMergeAlgorithm on the given model pool.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
modelpool (BaseModelPool): The pool of models to fuse.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
nn.Module: The fused model.
|
|
84
|
+
"""
|
|
85
|
+
assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
|
|
86
|
+
primary_model = modelpool.load_model(modelpool.all_model_names[0])
|
|
87
|
+
secondary_model = modelpool.load_model(modelpool.all_model_names[1])
|
|
88
|
+
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
primary_state_dict = primary_model.state_dict()
|
|
91
|
+
secondary_state_dict = secondary_model.state_dict()
|
|
92
|
+
state_dict = slerp_on_state_dicts(
|
|
93
|
+
self.t,
|
|
94
|
+
primary_state_dict,
|
|
95
|
+
secondary_state_dict,
|
|
96
|
+
DOT_THRESHOLD=self.DOT_THRESHOLD,
|
|
97
|
+
epsilon=self.epsilon,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
primary_model.load_state_dict(state_dict)
|
|
101
|
+
return primary_model
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Modification of: https://github.com/Digitous/LLM-SLERP-Merge/blob/main/slerpmergelm.py
|
|
2
|
+
# LLM HF SLERP Merge
|
|
3
|
+
|
|
4
|
+
# Retrofitted from dvschultz's script at https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
|
|
5
|
+
# to work with Huggingface Pretrained Language Models [by Chasm (AKA Digitous) and CalderaAI (on HuggingFace)].
|
|
6
|
+
# Original language model linear interpolation methods pioneered by Concedo AKA LostRuins on Github and HF.
|
|
7
|
+
|
|
8
|
+
# Idea for SLERP on LLMs sparked by discussion in Automatic1111 Stable Diffusion UI feature request for SLERP
|
|
9
|
+
# model merging for image diffusion domain models.
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import TypeVar
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
T = TypeVar("T", torch.Tensor, np.ndarray, float)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def lerp(t: float, v0: T, v1: T) -> T:
|
|
22
|
+
"""
|
|
23
|
+
Performs linear interpolation between two tensors v0 and v1.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
t (float): The interpolation factor, typically between 0 and 1.
|
|
27
|
+
v0 (T): The starting value.
|
|
28
|
+
v1 (T): The ending value.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
T: The interpolated value.
|
|
32
|
+
"""
|
|
33
|
+
return (1 - t) * v0 + t * v1
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def normalize(v: torch.Tensor, epsilon: float) -> torch.Tensor:
|
|
37
|
+
"""
|
|
38
|
+
Normalizes a tensor.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
v (torch.Tensor): The tensor to normalize.
|
|
42
|
+
epsilon (float, optional): A small value to avoid division by zero. Defaults to 1e-8.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
torch.Tensor: The normalized tensor.
|
|
46
|
+
"""
|
|
47
|
+
norm = torch.linalg.norm(v)
|
|
48
|
+
if norm > epsilon:
|
|
49
|
+
return v / norm
|
|
50
|
+
else:
|
|
51
|
+
log.debug(f"Warning: Norm of v is very small ({norm}). Skipping normalization.")
|
|
52
|
+
return v
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def slerp(
|
|
56
|
+
t: float,
|
|
57
|
+
v0: torch.Tensor,
|
|
58
|
+
v1: torch.Tensor,
|
|
59
|
+
DOT_THRESHOLD=0.9995,
|
|
60
|
+
epsilon=1e-8,
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Performs spherical linear interpolation (slerp) between two tensors v0 and v1.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
t (float): The interpolation factor, typically between 0 and 1.
|
|
67
|
+
v0 (torch.Tensor): The starting tensor.
|
|
68
|
+
v1 (torch.Tensor): The ending tensor.
|
|
69
|
+
DOT_THRESHOLD (float, optional): Threshold for considering the vectors as collinear. Defaults to 0.9995.
|
|
70
|
+
epsilon (float, optional): Small value to avoid division by zero. Defaults to 1e-8.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
torch.Tensor: The interpolated tensor.
|
|
74
|
+
"""
|
|
75
|
+
device = v0.device
|
|
76
|
+
# Convert tensors to a common format, at least float32
|
|
77
|
+
if v0.dtype != torch.float32 and v0.dtype != torch.float64:
|
|
78
|
+
v0 = v0.to(dtype=torch.float32, non_blocking=True)
|
|
79
|
+
v1 = v1.to(dtype=torch.float32, non_blocking=True)
|
|
80
|
+
|
|
81
|
+
# Copy the vectors to reuse them later
|
|
82
|
+
v0_copy = v0.clone()
|
|
83
|
+
v1_copy = v1.clone()
|
|
84
|
+
|
|
85
|
+
# Normalize the vectors to get the directions and angles
|
|
86
|
+
v0 = normalize(v0, epsilon)
|
|
87
|
+
v1 = normalize(v1, epsilon)
|
|
88
|
+
|
|
89
|
+
# Dot product with the normalized vectors (can't use np.dot in W)
|
|
90
|
+
dot = torch.sum(v0 * v1)
|
|
91
|
+
|
|
92
|
+
# If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp
|
|
93
|
+
if torch.abs(dot) > DOT_THRESHOLD:
|
|
94
|
+
res = lerp(t, v0_copy, v1_copy)
|
|
95
|
+
else:
|
|
96
|
+
# Calculate initial angle between v0 and v1
|
|
97
|
+
theta_0 = torch.arccos(dot)
|
|
98
|
+
sin_theta_0 = np.sin(theta_0)
|
|
99
|
+
# Angle at timestep t
|
|
100
|
+
theta_t = theta_0 * t
|
|
101
|
+
sin_theta_t = torch.sin(theta_t)
|
|
102
|
+
# Finish the slerp algorithm
|
|
103
|
+
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
|
104
|
+
s1 = sin_theta_t / sin_theta_0
|
|
105
|
+
res = s0 * v0_copy + s1 * v1_copy
|
|
106
|
+
|
|
107
|
+
return res.to(device, non_blocking=True)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import Dict, List, Tuple # noqa: F401
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
12
|
+
from fusion_bench.compat.modelpool import ModelPool, to_modelpool
|
|
13
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
14
|
+
from fusion_bench.models.utils import get_attr, set_attr
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def svd(w: Tensor, full_matrices: bool) -> Tuple[Tensor, Tensor, Tensor]:
|
|
20
|
+
"""
|
|
21
|
+
Perform Singular Value Decomposition (SVD) on the given tensor.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
w (Tensor): The input tensor to decompose.
|
|
25
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from the SVD.
|
|
29
|
+
"""
|
|
30
|
+
u, s, vh = torch.linalg.svd(
|
|
31
|
+
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
32
|
+
)
|
|
33
|
+
v = vh.T
|
|
34
|
+
return u, s, v
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _is_name_matched(name: str, extract_names: List[str]) -> bool:
|
|
38
|
+
"""
|
|
39
|
+
Check if the given name matches any of the provided regular expressions.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
name (str): The name to check.
|
|
43
|
+
extract_names (List[str]): A list of regular expressions to match against.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
bool: True if the name matches any of the regular expressions, False otherwise.
|
|
47
|
+
"""
|
|
48
|
+
for extract_name in extract_names:
|
|
49
|
+
# extract_name is a regular expression
|
|
50
|
+
if re.match(extract_name, name):
|
|
51
|
+
return True
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _total_parameters(state) -> int:
|
|
56
|
+
"""
|
|
57
|
+
Calculate the total number of parameters in the given state.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
state: The state to calculate the parameters for. Can be a Tensor or a dictionary of Tensors.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
int: The total number of parameters.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If the state is not a Tensor or a dictionary of Tensors.
|
|
67
|
+
"""
|
|
68
|
+
if isinstance(state, Tensor):
|
|
69
|
+
return state.numel()
|
|
70
|
+
elif isinstance(state, dict):
|
|
71
|
+
return sum(_total_parameters(v) for v in state.values())
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"Unsupported type: {type(state)}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class SingularProjectionMergingAlgorithm(ModelFusionAlgorithm, SimpleProfilerMixin):
|
|
77
|
+
"""
|
|
78
|
+
A model fusion algorithm that projects parameter differences into the SVD subspace of a pretrained model.
|
|
79
|
+
|
|
80
|
+
This algorithm is experimental and aims to investigate the location of task-specific knowledge.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
@torch.no_grad()
|
|
84
|
+
def run(self, modelpool: ModelPool) -> nn.Module:
|
|
85
|
+
"""
|
|
86
|
+
Run the singular projection merging algorithm on the given model pool.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
modelpool (ModelPool): The pool of models to merge.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
nn.Module: The merged model.
|
|
93
|
+
"""
|
|
94
|
+
modelpool = to_modelpool(modelpool)
|
|
95
|
+
|
|
96
|
+
if self.config.model_path is not None and os.path.exists(
|
|
97
|
+
self.config.model_path
|
|
98
|
+
):
|
|
99
|
+
log.info(f"loading merged model from {self.config.model_path}")
|
|
100
|
+
model = torch.load(self.config.model_path)
|
|
101
|
+
|
|
102
|
+
with self.profile("load pretrained model"):
|
|
103
|
+
pretrained_model = modelpool.load_model("_pretrained_").to(
|
|
104
|
+
self.config.device
|
|
105
|
+
)
|
|
106
|
+
with self.profile("load fine-tuned model"):
|
|
107
|
+
finetuned_models = modelpool.load_model(modelpool.model_names[0]).to(
|
|
108
|
+
self.config.device
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
with self.profile("merge model"):
|
|
112
|
+
model = self.merge(pretrained_model, finetuned_models)
|
|
113
|
+
|
|
114
|
+
if self.config.model_path is not None:
|
|
115
|
+
os.path.makedirs(os.path.dirname(self.config.model_path), exist_ok=True)
|
|
116
|
+
torch.save(model, self.config.model_path)
|
|
117
|
+
|
|
118
|
+
self.print_profile_summary()
|
|
119
|
+
return model
|
|
120
|
+
|
|
121
|
+
def merge(
|
|
122
|
+
self,
|
|
123
|
+
pretrained_model: nn.Module,
|
|
124
|
+
finetuned_model: nn.Module,
|
|
125
|
+
in_place: bool = True,
|
|
126
|
+
) -> nn.Module:
|
|
127
|
+
"""
|
|
128
|
+
Merges the pretrained model with the fine-tuned model by projecting parameter differences
|
|
129
|
+
into the SVD subspace of the pretrained model.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
pretrained_model (nn.Module): The pretrained model.
|
|
133
|
+
finetuned_model (nn.Module): The fine-tuned model.
|
|
134
|
+
in_place (bool): If True, modifies the fine-tuned model in place. Otherwise, creates a copy.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
nn.Module: The merged model.
|
|
138
|
+
"""
|
|
139
|
+
if in_place:
|
|
140
|
+
model = finetuned_model
|
|
141
|
+
else:
|
|
142
|
+
model = deepcopy(finetuned_model)
|
|
143
|
+
|
|
144
|
+
for name, module in tqdm(
|
|
145
|
+
tuple(model.named_modules()),
|
|
146
|
+
"Projection merging in SVD subspace of pretrained model",
|
|
147
|
+
):
|
|
148
|
+
if isinstance(module, nn.Linear):
|
|
149
|
+
name_list = name.split(".")
|
|
150
|
+
set_attr(
|
|
151
|
+
model,
|
|
152
|
+
name_list,
|
|
153
|
+
self.projection_merge_linear(
|
|
154
|
+
get_attr(pretrained_model, name_list),
|
|
155
|
+
get_attr(finetuned_model, name_list),
|
|
156
|
+
k=self.config.k,
|
|
157
|
+
),
|
|
158
|
+
)
|
|
159
|
+
return model
|
|
160
|
+
|
|
161
|
+
def projection_merge_linear(
|
|
162
|
+
self, pretrained_model: nn.Linear, finetuned_model: nn.Linear, k: int
|
|
163
|
+
) -> nn.Linear:
|
|
164
|
+
"""
|
|
165
|
+
Projects the parameter differences of linear layers into the SVD subspace of the pretrained model.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
pretrained_model (nn.Linear): The linear layer of the pretrained model.
|
|
169
|
+
finetuned_model (nn.Linear): The linear layer of the fine-tuned model.
|
|
170
|
+
k (int): The number of singular values to keep. If negative, it is determined based on the sum of singular values.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
nn.Linear: The merged linear layer with projected parameter differences.
|
|
174
|
+
"""
|
|
175
|
+
w = pretrained_model.weight
|
|
176
|
+
w_ft = finetuned_model.weight
|
|
177
|
+
|
|
178
|
+
u, s, v = svd(w, full_matrices=self.config.full_matrices)
|
|
179
|
+
if k < 0:
|
|
180
|
+
# find the position where the sum of singular values is larger than 50% of the total sum
|
|
181
|
+
cumsum = s.cumsum(0)
|
|
182
|
+
k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1
|
|
183
|
+
|
|
184
|
+
if self.config.rank == "low":
|
|
185
|
+
u = u[:, :k]
|
|
186
|
+
s = s[:k]
|
|
187
|
+
v = v[:, :k]
|
|
188
|
+
else:
|
|
189
|
+
u = u[:, k:]
|
|
190
|
+
s = s[k:]
|
|
191
|
+
v = v[:, k:]
|
|
192
|
+
|
|
193
|
+
w_diff = w_ft - w
|
|
194
|
+
w_diff_proj = u.T @ w_diff @ v
|
|
195
|
+
w.data = w + u @ w_diff_proj @ v.T
|
|
196
|
+
if pretrained_model.bias is not None:
|
|
197
|
+
pretrained_model.bias.data = finetuned_model.bias.data
|
|
198
|
+
return pretrained_model
|