fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Overview of Ties-Merging:
|
|
3
|
+
|
|
4
|
+
1. Trim: For each task t, we trim the redundant parameters from the task vector $\tau_t$ to create $\hat{\tau}_t$ by keeping the top-k% values according to their magnitude and trimming the bottom $(100 - k)\%$ of the redundant parameters by resetting them to 0. This can be decomposed further as $\hat{\tau}_t = \hat{\gamma}_t \odot \hat{\mu}_t$.
|
|
5
|
+
|
|
6
|
+
2. Elect: Next, we create an aggregate elected sign vector $\gamma_m$ for the merged model that resolves the disagreements in the sign for each parameter p across different models. To create the elected sign vector, we choose the sign with the highest total magnitude across all relevant models. For each parameter $p \in \{1, 2, \ldots, d\}$, we separate the values $\{\hat{\tau}_t^p\}_{t=1}^n$ based on their sign $(+1$ or $-1)$ and take their sum to calculate the total mass (i.e., total magnitude) in the positive and the negative direction. We then assign $\gamma_m^p$ as the sign with greater total movement. This can be efficiently computed using $\gamma_m^p = \text{sgn}(\sum_{t=1}^n \hat{\tau}_t^p)$.
|
|
7
|
+
|
|
8
|
+
3. Disjoint Merge: Then, for each parameter p, we compute a disjoint mean by only keeping the parameter values from the models whose signs are the same as the aggregated elected sign and calculate their mean. Formally, let $A_p = \{t \in [n] \mid \hat{\gamma}_t^p = \gamma_m^p\}$, then $\tau_m^p = \frac{1}{|A_p|}\sum_{t\in A_p} \hat{\tau}_t^p$. Note that the disjoint mean always ignores the zero values.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Dict, List, Literal, Mapping, Union # noqa: F401
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from torch import Tensor, nn
|
|
16
|
+
|
|
17
|
+
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
|
+
from fusion_bench.method import BaseAlgorithm
|
|
19
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
20
|
+
from fusion_bench.utils.type import StateDictType
|
|
21
|
+
|
|
22
|
+
from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_state_dict
|
|
23
|
+
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TiesMergingAlgorithm(BaseAlgorithm):
|
|
28
|
+
"""
|
|
29
|
+
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
scaling_factor (float): The scaling factor to apply to the merged task vector.
|
|
33
|
+
threshold (float): The threshold for resetting values in the task vector.
|
|
34
|
+
remove_keys (List[str]): List of keys to remove from the state dictionary.
|
|
35
|
+
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
39
|
+
"scaling_factor": "scaling_factor",
|
|
40
|
+
"threshold": "threshold",
|
|
41
|
+
"remove_keys": "remove_keys",
|
|
42
|
+
"merge_func": "merge_func",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
scaling_factor: float,
|
|
48
|
+
threshold: float,
|
|
49
|
+
remove_keys: List[str],
|
|
50
|
+
merge_func: Literal["sum", "mean", "max"],
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the TiesMergingAlgorithm with the given parameters.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
scaling_factor (float): The scaling factor to apply to the merged task vector.
|
|
58
|
+
threshold (float): The threshold for resetting values in the task vector.
|
|
59
|
+
remove_keys (List[str]): List of keys to remove from the state dictionary.
|
|
60
|
+
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
61
|
+
**kwargs: Additional keyword arguments for the base class.
|
|
62
|
+
"""
|
|
63
|
+
self.scaling_factor = scaling_factor
|
|
64
|
+
self.threshold = threshold
|
|
65
|
+
self.remove_keys = remove_keys
|
|
66
|
+
self.merge_func = merge_func
|
|
67
|
+
super().__init__(**kwargs)
|
|
68
|
+
|
|
69
|
+
@torch.no_grad()
|
|
70
|
+
def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
|
|
71
|
+
"""
|
|
72
|
+
Run the TIES merging algorithm to fuse models in the model pool.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
nn.Module: The fused model.
|
|
79
|
+
"""
|
|
80
|
+
log.info("Fusing models using ties merging.")
|
|
81
|
+
modelpool = to_modelpool(modelpool)
|
|
82
|
+
remove_keys = self.config.get("remove_keys", [])
|
|
83
|
+
merge_func = self.config.get("merge_func", "sum")
|
|
84
|
+
scaling_factor = self.scaling_factor
|
|
85
|
+
threshold = self.threshold
|
|
86
|
+
|
|
87
|
+
# Load the pretrained model
|
|
88
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
89
|
+
|
|
90
|
+
# Load the state dicts of the models
|
|
91
|
+
ft_checks: List[StateDictType] = [
|
|
92
|
+
modelpool.load_model(model_name).state_dict(keep_vars=True)
|
|
93
|
+
for model_name in modelpool.model_names
|
|
94
|
+
]
|
|
95
|
+
ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
|
|
96
|
+
|
|
97
|
+
# Compute the task vectors
|
|
98
|
+
flat_ft: Tensor = torch.vstack(
|
|
99
|
+
[state_dict_to_vector(check, remove_keys) for check in ft_checks]
|
|
100
|
+
)
|
|
101
|
+
flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
|
|
102
|
+
tv_flat_checks = flat_ft - flat_ptm
|
|
103
|
+
|
|
104
|
+
# Perform TIES Merging
|
|
105
|
+
merged_tv = ties_merging(
|
|
106
|
+
tv_flat_checks,
|
|
107
|
+
reset_thresh=threshold,
|
|
108
|
+
merge_func=merge_func,
|
|
109
|
+
)
|
|
110
|
+
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
111
|
+
merged_state_dict = vector_to_state_dict(
|
|
112
|
+
merged_check, ptm_check, remove_keys=remove_keys
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Load the merged state dict into the pretrained model
|
|
116
|
+
pretrained_model.load_state_dict(merged_state_dict)
|
|
117
|
+
return pretrained_model
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is modified based on https://github.com/EnnengYang/AdaMerging/blob/main/src/ties_merging_utils.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
|
|
12
|
+
from fusion_bench.utils.type import StateDictType
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Model conversion utils
|
|
16
|
+
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
17
|
+
"""
|
|
18
|
+
Convert a state dictionary to a vector, removing specified keys.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
state_dict (dict): The state dictionary to convert.
|
|
22
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tensor: A vector representation of the state dictionary.
|
|
26
|
+
"""
|
|
27
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
28
|
+
for key in remove_keys:
|
|
29
|
+
if key in shared_state_dict:
|
|
30
|
+
del shared_state_dict[key]
|
|
31
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
32
|
+
return nn.utils.parameters_to_vector(
|
|
33
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
38
|
+
"""
|
|
39
|
+
Convert a vector back to a state dictionary, removing specified keys.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
vector (Tensor): The vector to convert.
|
|
43
|
+
state_dict (dict): The reference state dictionary.
|
|
44
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
dict: A state dictionary representation of the vector.
|
|
48
|
+
"""
|
|
49
|
+
# create a reference dict to define the order of the vector
|
|
50
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
51
|
+
for key in remove_keys:
|
|
52
|
+
if key in reference_dict:
|
|
53
|
+
del reference_dict[key]
|
|
54
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
55
|
+
|
|
56
|
+
# create a shared state dict using the reference dict
|
|
57
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
58
|
+
|
|
59
|
+
# add back the encoder and decoder embedding weights.
|
|
60
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
61
|
+
for key in remove_keys:
|
|
62
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
63
|
+
"transformer.shared.weight"
|
|
64
|
+
]
|
|
65
|
+
return sorted_reference_dict
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def add_ptm_to_tv(tv_dict, ptm_dict):
|
|
69
|
+
"""
|
|
70
|
+
Add the values of one state dictionary to another.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
tv_dict (dict): The target state dictionary.
|
|
74
|
+
ptm_dict (dict): The state dictionary to add.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
dict: The resulting state dictionary after addition.
|
|
78
|
+
"""
|
|
79
|
+
assert set(tv_dict.keys()) == set(
|
|
80
|
+
ptm_dict.keys()
|
|
81
|
+
), "Differing parameter names in models."
|
|
82
|
+
final_dict = copy.deepcopy(tv_dict)
|
|
83
|
+
for k, v in ptm_dict.items():
|
|
84
|
+
final_dict[k] = tv_dict[k] + v
|
|
85
|
+
return final_dict
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Check if the parameter names match across multiple checkpoints.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
checkpoints (list): List of state dictionaries to check.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ValueError: If the parameter names do not match.
|
|
97
|
+
"""
|
|
98
|
+
parameter_names = set(checkpoints[0].keys())
|
|
99
|
+
|
|
100
|
+
if len(checkpoints) >= 2:
|
|
101
|
+
# raise ValueError("Number of models is less than 2.")
|
|
102
|
+
for checkpoint in checkpoints[1:]:
|
|
103
|
+
current_parameterNames = set(checkpoint.keys())
|
|
104
|
+
if current_parameterNames != parameter_names:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
"Differing parameter names in models. "
|
|
107
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def check_state_dicts_equal(
|
|
112
|
+
state_dict1: StateDictType, state_dict2: StateDictType
|
|
113
|
+
) -> bool:
|
|
114
|
+
"""
|
|
115
|
+
Check if two state dictionaries are equal.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
state_dict1 (dict): The first state dictionary.
|
|
119
|
+
state_dict2 (dict): The second state dictionary.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
bool: True if the state dictionaries are equal, False otherwise.
|
|
123
|
+
"""
|
|
124
|
+
if set(state_dict1.keys()) != set(state_dict2.keys()):
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
for key in state_dict1.keys():
|
|
128
|
+
if not torch.equal(state_dict1[key], state_dict2[key]):
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# TIES MERGING UTILS
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def topk_values_mask(M, K=0.7, return_mask=False):
|
|
138
|
+
"""
|
|
139
|
+
Mask the top K values in a tensor.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
M (Tensor): The input tensor.
|
|
143
|
+
K (float): The proportion of top values to keep.
|
|
144
|
+
return_mask (bool): Whether to return the mask tensor.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
tuple: The masked tensor, the mean of the mask, and optionally the mask tensor.
|
|
148
|
+
"""
|
|
149
|
+
if K > 1:
|
|
150
|
+
K /= 100
|
|
151
|
+
|
|
152
|
+
original_shape = M.shape
|
|
153
|
+
if M.dim() == 1:
|
|
154
|
+
M = M.unsqueeze(0)
|
|
155
|
+
|
|
156
|
+
n, d = M.shape
|
|
157
|
+
k = int(d * K)
|
|
158
|
+
k = d - k # Keep top k elements instead of bottom k elements
|
|
159
|
+
|
|
160
|
+
# Find the k-th smallest element by magnitude for each row
|
|
161
|
+
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
|
|
162
|
+
# Create a mask tensor with True for the top k elements in each row
|
|
163
|
+
mask = M.abs() >= kth_values
|
|
164
|
+
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
|
|
165
|
+
|
|
166
|
+
if return_mask:
|
|
167
|
+
return M * final_mask, final_mask.float().mean(dim=1), final_mask
|
|
168
|
+
return M * final_mask, final_mask.float().mean(dim=1)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def resolve_zero_signs(sign_to_mult, method="majority"):
|
|
172
|
+
"""
|
|
173
|
+
Resolve zero signs in a tensor by majority or minority rule.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
sign_to_mult (Tensor): The tensor with signs to resolve.
|
|
177
|
+
method (str): The method to use for resolving zero signs ("majority" or "minority").
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Tensor: The tensor with resolved signs.
|
|
181
|
+
"""
|
|
182
|
+
majority_sign = torch.sign(sign_to_mult.sum())
|
|
183
|
+
|
|
184
|
+
if method == "majority":
|
|
185
|
+
sign_to_mult[sign_to_mult == 0] = majority_sign
|
|
186
|
+
elif method == "minority":
|
|
187
|
+
sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
|
|
188
|
+
return sign_to_mult
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def resolve_sign(v: Tensor):
|
|
192
|
+
"""
|
|
193
|
+
Resolve the sign of a tensor by majority rule.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
v (Tensor): The input tensor.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Tensor: The tensor with resolved signs.
|
|
200
|
+
"""
|
|
201
|
+
sign_to_mult = torch.sign(v.sum(dim=0))
|
|
202
|
+
sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
|
|
203
|
+
return sign_to_mult
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def disjoint_merge(v: Tensor, merge_func: str, sign_to_mult):
|
|
207
|
+
"""
|
|
208
|
+
Perform disjoint merging of a tensor using a specified merge function.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
v (Tensor): The input tensor.
|
|
212
|
+
merge_func (str): The merge function to use ("mean", "sum", or "max").
|
|
213
|
+
sign_to_mult (Tensor): The tensor with signs to use for merging.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Tensor: The merged tensor.
|
|
217
|
+
"""
|
|
218
|
+
merge_func = merge_func.split("-")[-1]
|
|
219
|
+
|
|
220
|
+
# If sign is provided then we select the corresponding entries and aggregate.
|
|
221
|
+
if sign_to_mult is not None:
|
|
222
|
+
rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
|
|
223
|
+
selected_entries = v * rows_to_keep
|
|
224
|
+
# Else we select all non-zero entries and aggregate.
|
|
225
|
+
else:
|
|
226
|
+
rows_to_keep = v != 0
|
|
227
|
+
selected_entries = v * rows_to_keep
|
|
228
|
+
|
|
229
|
+
if merge_func == "mean":
|
|
230
|
+
non_zero_counts = (selected_entries != 0).sum(dim=0).float()
|
|
231
|
+
disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
|
|
232
|
+
non_zero_counts, min=1
|
|
233
|
+
)
|
|
234
|
+
elif merge_func == "sum":
|
|
235
|
+
disjoint_aggs = torch.sum(selected_entries, dim=0)
|
|
236
|
+
elif merge_func == "max":
|
|
237
|
+
disjoint_aggs = selected_entries.abs().max(dim=0)[0]
|
|
238
|
+
disjoint_aggs *= sign_to_mult
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"Merge method {merge_func} is not defined.")
|
|
241
|
+
|
|
242
|
+
return disjoint_aggs
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def ties_merging(
|
|
246
|
+
flat_task_checks,
|
|
247
|
+
reset_thresh=None,
|
|
248
|
+
merge_func="",
|
|
249
|
+
):
|
|
250
|
+
"""
|
|
251
|
+
Perform TIES merging on a tensor.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
flat_task_checks (Tensor): The input tensor.
|
|
255
|
+
reset_thresh (float): The threshold for resetting values.
|
|
256
|
+
merge_func (str): The merge function to use.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Tensor: The merged tensor.
|
|
260
|
+
"""
|
|
261
|
+
all_checks = flat_task_checks.clone()
|
|
262
|
+
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
|
|
263
|
+
print("RESOLVING SIGN")
|
|
264
|
+
final_signs = resolve_sign(updated_checks)
|
|
265
|
+
assert final_signs is not None
|
|
266
|
+
|
|
267
|
+
print(f"Disjoint AGGREGATION: {merge_func}")
|
|
268
|
+
merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
|
|
269
|
+
|
|
270
|
+
return merged_tv
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def disjoint_merge_split(v: Tensor, merge_func: str, sign_to_mult):
|
|
274
|
+
"""
|
|
275
|
+
Perform disjoint merging of a tensor using a specified merge function and return selected entries.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
v (Tensor): The input tensor.
|
|
279
|
+
merge_func (str): The merge function to use ("sum").
|
|
280
|
+
sign_to_mult (Tensor): The tensor with signs to use for merging.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
tuple: The selected entries and the merged tensor.
|
|
284
|
+
"""
|
|
285
|
+
merge_func = merge_func.split("-")[-1]
|
|
286
|
+
|
|
287
|
+
# If sign is provided then we select the corresponding entries and aggregate.
|
|
288
|
+
if sign_to_mult is not None:
|
|
289
|
+
rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
|
|
290
|
+
selected_entries = v * rows_to_keep
|
|
291
|
+
# Else we select all non-zero entries and aggregate.
|
|
292
|
+
else:
|
|
293
|
+
rows_to_keep = v != 0
|
|
294
|
+
selected_entries = v * rows_to_keep
|
|
295
|
+
|
|
296
|
+
if merge_func == "sum":
|
|
297
|
+
disjoint_aggs = torch.sum(selected_entries, dim=0)
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Merge method {merge_func} is not defined.")
|
|
300
|
+
|
|
301
|
+
return selected_entries, disjoint_aggs
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def ties_merging_split(
|
|
305
|
+
flat_task_checks,
|
|
306
|
+
reset_thresh=None,
|
|
307
|
+
merge_func: str = "",
|
|
308
|
+
):
|
|
309
|
+
"""
|
|
310
|
+
Perform TIES merging on a tensor and return selected entries.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
flat_task_checks (Tensor): The input tensor.
|
|
314
|
+
reset_thresh (float): The threshold for resetting values.
|
|
315
|
+
merge_func (str): The merge function to use.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
tuple: The selected entries and the merged tensor.
|
|
319
|
+
"""
|
|
320
|
+
all_checks = flat_task_checks.clone()
|
|
321
|
+
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
|
|
322
|
+
print("RESOLVING SIGN")
|
|
323
|
+
final_signs = resolve_sign(updated_checks)
|
|
324
|
+
assert final_signs is not None
|
|
325
|
+
|
|
326
|
+
print(f"Disjoint AGGREGATION: {merge_func}")
|
|
327
|
+
selected_entries, merged_tv = disjoint_merge_split(
|
|
328
|
+
updated_checks, merge_func, final_signs
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
return selected_entries, merged_tv
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implementation of Task Arithmetic in Trust Region: A Training-Free Model Merging Approach to Navigate Knowledge Conflicts
|
|
3
|
+
https://openreview.net/forum?id=q3ztjJRQuJ
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
from typing import Dict, Iterable, List, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
from torch.utils.data import DataLoader
|
|
15
|
+
from tqdm.auto import tqdm
|
|
16
|
+
from typing_extensions import override
|
|
17
|
+
|
|
18
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
19
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
20
|
+
from fusion_bench.mixins import CLIPClassificationMixin, SimpleProfilerMixin
|
|
21
|
+
from fusion_bench.utils import first
|
|
22
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
23
|
+
from fusion_bench.utils.type import StateDictType
|
|
24
|
+
|
|
25
|
+
from .utils import state_dict_to_vector, vector_to_state_dict
|
|
26
|
+
|
|
27
|
+
log = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def trainable_state_dict(module: nn.Module) -> StateDictType:
|
|
31
|
+
"""
|
|
32
|
+
Returns the state dictionary of the module containing only the trainable parameters.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
module (nn.Module): The neural network module.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Dict[str, Tensor]: A dictionary containing the names and values of the trainable parameters.
|
|
39
|
+
"""
|
|
40
|
+
return {
|
|
41
|
+
name: param for name, param in module.named_parameters() if param.requires_grad
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TaskArithmeticWithTrustRegionForCLIP(
|
|
46
|
+
BaseAlgorithm,
|
|
47
|
+
SimpleProfilerMixin,
|
|
48
|
+
CLIPClassificationMixin,
|
|
49
|
+
):
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
scaling_factor: Union[float, List[float]],
|
|
53
|
+
threshold_quantile: float,
|
|
54
|
+
max_samples: int,
|
|
55
|
+
batch_size: int,
|
|
56
|
+
zero_shot: bool,
|
|
57
|
+
**kwargs,
|
|
58
|
+
):
|
|
59
|
+
self.scaling_factor = scaling_factor
|
|
60
|
+
self.threshold_quantile = threshold_quantile
|
|
61
|
+
self.max_samples = max_samples
|
|
62
|
+
self.batch_size = batch_size
|
|
63
|
+
self.zero_shot = zero_shot
|
|
64
|
+
super().__init__(**kwargs)
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def run(self, modelpool: BaseModelPool):
|
|
68
|
+
self.modelpool = modelpool
|
|
69
|
+
|
|
70
|
+
# compute the task vectors
|
|
71
|
+
pretrained_model, task_vectors = self.compute_vanilla_task_vectors()
|
|
72
|
+
task_vectors = {
|
|
73
|
+
name: state_dict_to_vector(task_vector)
|
|
74
|
+
for name, task_vector in task_vectors.items()
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
if not self.zero_shot:
|
|
78
|
+
all_avg_abs_grads = self.compute_avg_abs_grads(pretrained_model)
|
|
79
|
+
all_avg_abs_grads = {
|
|
80
|
+
n: state_dict_to_vector(grad) for n, grad in all_avg_abs_grads.items()
|
|
81
|
+
}
|
|
82
|
+
else:
|
|
83
|
+
# the task vector is used to estimate the gradient
|
|
84
|
+
all_avg_abs_grads = {name: tv.abs() for name, tv in task_vectors.items()}
|
|
85
|
+
|
|
86
|
+
# compute the trust region
|
|
87
|
+
Omega = torch.zeros_like(first(all_avg_abs_grads.values()))
|
|
88
|
+
|
|
89
|
+
for i in all_avg_abs_grads:
|
|
90
|
+
for j in all_avg_abs_grads:
|
|
91
|
+
if i != j:
|
|
92
|
+
vector1 = all_avg_abs_grads[i]
|
|
93
|
+
vector2 = torch.abs(task_vectors[j])
|
|
94
|
+
Omega += vector1 * vector2
|
|
95
|
+
|
|
96
|
+
values, indices = Omega.sort(descending=False)
|
|
97
|
+
threshold = values[
|
|
98
|
+
max(0, min(int(Omega.numel() * self.threshold_quantile), Omega.numel() - 1))
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
mask = (Omega < threshold).bool()
|
|
102
|
+
|
|
103
|
+
# compute the task vectors
|
|
104
|
+
for task in task_vectors:
|
|
105
|
+
task_vectors[task] = task_vectors[task] * mask
|
|
106
|
+
|
|
107
|
+
task_vector_sum = sum(task_vectors.values())
|
|
108
|
+
task_vector_sum = vector_to_state_dict(
|
|
109
|
+
task_vector_sum, trainable_state_dict(pretrained_model)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if isinstance(self.scaling_factor, (int, float)):
|
|
113
|
+
model = pretrained_model
|
|
114
|
+
for name, param in model.named_parameters():
|
|
115
|
+
param.data += task_vector_sum[name] * self.scaling_factor
|
|
116
|
+
return model
|
|
117
|
+
elif isinstance(self.scaling_factor, Iterable):
|
|
118
|
+
models = {}
|
|
119
|
+
for scaling_factor in self.scaling_factor:
|
|
120
|
+
model = deepcopy(pretrained_model)
|
|
121
|
+
for name, param in pretrained_model.named_parameters():
|
|
122
|
+
param.data += task_vector_sum[name] * scaling_factor
|
|
123
|
+
models[scaling_factor] = model
|
|
124
|
+
return models
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Incorrect type of `scaling_factor`: {type(self.scaling_factor)}. "
|
|
128
|
+
"It should be a single real number or a list of real numbers."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def compute_avg_abs_grads(self, pretrained_model):
|
|
132
|
+
modelpool = self.modelpool
|
|
133
|
+
|
|
134
|
+
self.setup_zero_shot_classification_head()
|
|
135
|
+
|
|
136
|
+
pretrained_model = (
|
|
137
|
+
deepcopy(pretrained_model)
|
|
138
|
+
if pretrained_model is not None
|
|
139
|
+
else modelpool.load_pretrained_model()
|
|
140
|
+
)
|
|
141
|
+
pretrained_model = self.fabric.setup_module(pretrained_model)
|
|
142
|
+
pretrained_model.train()
|
|
143
|
+
|
|
144
|
+
all_avg_abs_grads: Dict[str, StateDictType] = {}
|
|
145
|
+
for train_dataset_name in (
|
|
146
|
+
pbar := tqdm(
|
|
147
|
+
modelpool.train_dataset_names, desc="Train datasets", dynamic_ncols=True
|
|
148
|
+
)
|
|
149
|
+
):
|
|
150
|
+
pbar.set_description(f"Train dataset: {train_dataset_name}")
|
|
151
|
+
dataset = modelpool.load_train_dataset(train_dataset_name)
|
|
152
|
+
dataset = CLIPDataset(dataset, self.clip_processor)
|
|
153
|
+
dataloader = DataLoader(dataset, shuffle=True, batch_size=self.batch_size)
|
|
154
|
+
dataloader = self.fabric.setup_dataloaders(dataloader)
|
|
155
|
+
|
|
156
|
+
grad: StateDictType = defaultdict(float)
|
|
157
|
+
num_samples = 0
|
|
158
|
+
for batch in dataloader:
|
|
159
|
+
images, labels = batch
|
|
160
|
+
batch_size = images.size(0)
|
|
161
|
+
|
|
162
|
+
if num_samples + batch_size > self.max_samples:
|
|
163
|
+
batch_size = self.max_samples - num_samples
|
|
164
|
+
images = images[:batch_size]
|
|
165
|
+
labels = labels[:batch_size]
|
|
166
|
+
|
|
167
|
+
logits = self.compute_logits(
|
|
168
|
+
pretrained_model, images, task=train_dataset_name
|
|
169
|
+
)
|
|
170
|
+
for i in range(batch_size):
|
|
171
|
+
pretrained_model.zero_grad()
|
|
172
|
+
loss = F.cross_entropy(logits[i], labels[i])
|
|
173
|
+
self.fabric.backward(
|
|
174
|
+
loss, retain_graph=True if i != batch_size - 1 else False
|
|
175
|
+
)
|
|
176
|
+
for name, param in pretrained_model.module.named_parameters():
|
|
177
|
+
if param.requires_grad:
|
|
178
|
+
grad[name] += torch.abs(param.grad).detach()
|
|
179
|
+
|
|
180
|
+
num_samples += batch_size
|
|
181
|
+
if num_samples >= self.max_samples:
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
for name in grad:
|
|
185
|
+
grad[name] = (grad[name] / num_samples).cpu()
|
|
186
|
+
|
|
187
|
+
all_avg_abs_grads[name] = grad
|
|
188
|
+
return all_avg_abs_grads
|
|
189
|
+
|
|
190
|
+
@torch.no_grad()
|
|
191
|
+
def compute_vanilla_task_vectors(self):
|
|
192
|
+
modelpool = self.modelpool
|
|
193
|
+
|
|
194
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
195
|
+
pretrained_sd = trainable_state_dict(pretrained_model)
|
|
196
|
+
finetuned_sds = {
|
|
197
|
+
name: trainable_state_dict(model)
|
|
198
|
+
for name, model in modelpool.named_models()
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
task_vectors = {
|
|
202
|
+
name: state_dict_sub(finetuned, pretrained_sd)
|
|
203
|
+
for name, finetuned in finetuned_sds.items()
|
|
204
|
+
}
|
|
205
|
+
return pretrained_model, task_vectors
|