fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import Dict, List, Literal, Optional, Tuple, cast # noqa: F401
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.utils.hooks
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
from transformers import LlamaForCausalLM
|
|
10
|
+
|
|
11
|
+
from fusion_bench.method import BaseAlgorithm
|
|
12
|
+
from fusion_bench.method.pruning.wanda_utils.data import get_loaders
|
|
13
|
+
from fusion_bench.method.pruning.wanda_utils.prune import prepare_calibration_input
|
|
14
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
15
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
16
|
+
from fusion_bench.utils import timeit_context
|
|
17
|
+
from fusion_bench.utils.cache_utils import cache_to_disk
|
|
18
|
+
|
|
19
|
+
from .prune_utils import (
|
|
20
|
+
PruningType,
|
|
21
|
+
compute_sparsity,
|
|
22
|
+
find_linear_layers,
|
|
23
|
+
semistructured_magnitude_prune_,
|
|
24
|
+
unstructured_magnitude_prune_,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
log = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BaseLoSparseHookFn:
|
|
31
|
+
"""
|
|
32
|
+
Base class for low-sparsity hook functions.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, linear):
|
|
36
|
+
self.linear = linear
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def compute(self) -> Tensor:
|
|
40
|
+
"""
|
|
41
|
+
Compute the importance scores.
|
|
42
|
+
"""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def __call__(self, linear, inp: Tuple[Tensor], out: Tensor):
|
|
47
|
+
"""
|
|
48
|
+
Hook function to be called during the forward pass.
|
|
49
|
+
"""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class WandaHookFn(BaseLoSparseHookFn):
|
|
54
|
+
R"""
|
|
55
|
+
Here in this class, the `scalar_row` is the mean of the squared sum of the input to the linear layer along a specific input dimension.
|
|
56
|
+
|
|
57
|
+
$$\frac{\sum_{i=1}^{N L} X_{ij}^2}{N L}$$
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, linear: nn.Linear):
|
|
61
|
+
super().__init__(linear)
|
|
62
|
+
|
|
63
|
+
self.scalar_row = torch.zeros(
|
|
64
|
+
(linear.weight.size(1),), device=linear.weight.device
|
|
65
|
+
)
|
|
66
|
+
self.nsamples = 0
|
|
67
|
+
|
|
68
|
+
def compute(self):
|
|
69
|
+
return torch.abs(self.linear.weight) * torch.sqrt(
|
|
70
|
+
self.scalar_row.reshape(1, -1)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def __call__(self, linear: nn.Linear, inps: Tuple[Tensor], out: Tensor):
|
|
74
|
+
assert len(inps) == 1
|
|
75
|
+
inp = inps[0]
|
|
76
|
+
if len(inp.shape) == 2:
|
|
77
|
+
inp = inp.unsqueeze(0)
|
|
78
|
+
|
|
79
|
+
batch_size = inp.shape[0]
|
|
80
|
+
if len(inp.shape) == 3:
|
|
81
|
+
inp = inp.reshape((-1, inp.shape[-1]))
|
|
82
|
+
# (NxL, C) -> (C, NxL)
|
|
83
|
+
inp = inp.t()
|
|
84
|
+
|
|
85
|
+
self.scalar_row *= self.nsamples / (self.nsamples + batch_size)
|
|
86
|
+
self.nsamples += batch_size
|
|
87
|
+
|
|
88
|
+
inp = inp.type(torch.float32)
|
|
89
|
+
self.scalar_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class WandaPruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
|
|
93
|
+
"""
|
|
94
|
+
Class for Wanda pruning for Llama models.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
98
|
+
"nsamples": "nsamples",
|
|
99
|
+
"seed": "seed",
|
|
100
|
+
"use_variant": "use_variant",
|
|
101
|
+
"prune_type": "prune_type",
|
|
102
|
+
"device": "device",
|
|
103
|
+
"dtype": "dtype",
|
|
104
|
+
"sparsity_ratio": "sparsity_ratio",
|
|
105
|
+
"n": "n",
|
|
106
|
+
"m": "m",
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
*,
|
|
112
|
+
nsamples: int,
|
|
113
|
+
seed: int,
|
|
114
|
+
use_variant: bool,
|
|
115
|
+
prune_type: PruningType,
|
|
116
|
+
device: str,
|
|
117
|
+
dtype: str,
|
|
118
|
+
sparsity_ratio: float,
|
|
119
|
+
n: int,
|
|
120
|
+
m: int,
|
|
121
|
+
model_save_path: Optional[str] = None,
|
|
122
|
+
**kwargs,
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
Initialize the WandaPruningForLlama class.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
nsamples (int): Number of samples for calibration.
|
|
129
|
+
seed (int): Random seed.
|
|
130
|
+
use_variant (bool): Whether to use a variant of the pruning method.
|
|
131
|
+
prune_type (PruningType): Type of pruning to perform.
|
|
132
|
+
device (str): Device to use for computation.
|
|
133
|
+
dtype (str): Data type to use for computation.
|
|
134
|
+
sparsity_ratio (float): Sparsity ratio for pruning.
|
|
135
|
+
n (int): Number of elements to keep in semi-structured pruning.
|
|
136
|
+
m (int): Number of elements in a group for semi-structured pruning.
|
|
137
|
+
model_save_path (Optional[str]): Path to save the pruned model.
|
|
138
|
+
**kwargs: Additional arguments.
|
|
139
|
+
"""
|
|
140
|
+
super().__init__(**kwargs)
|
|
141
|
+
self.nsamples = nsamples
|
|
142
|
+
self.seed = seed
|
|
143
|
+
self.use_variant = use_variant
|
|
144
|
+
self.prune_type = prune_type
|
|
145
|
+
self.device = device
|
|
146
|
+
self.dtype = dtype
|
|
147
|
+
self.sparsity_ratio = sparsity_ratio
|
|
148
|
+
self.n = n
|
|
149
|
+
self.m = m
|
|
150
|
+
self.model_save_path = model_save_path
|
|
151
|
+
|
|
152
|
+
def run(self, modelpool: CausalLMPool):
|
|
153
|
+
"""
|
|
154
|
+
Run the pruning algorithm on the model pool.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
modelpool (CausalLMPool): Pool of causal language models.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
LlamaForCausalLM: Pruned model.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
# load pre-trained model or the first model in the pool
|
|
164
|
+
with self.profile("load_model"):
|
|
165
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
166
|
+
model.seqlen = model.config.max_position_embeddings
|
|
167
|
+
tokenizer = modelpool.load_tokenizer(use_fast=False)
|
|
168
|
+
|
|
169
|
+
if not isinstance(model, (LlamaForCausalLM,)):
|
|
170
|
+
log.warning(f"Model type {type(model)} may not supported.")
|
|
171
|
+
|
|
172
|
+
inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
|
|
173
|
+
model, tokenizer
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.prune_using_calibration_data_(
|
|
177
|
+
model,
|
|
178
|
+
inps=inps,
|
|
179
|
+
outs=outs,
|
|
180
|
+
attention_mask=attention_mask,
|
|
181
|
+
position_ids=position_ids,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if self.model_save_path is not None:
|
|
185
|
+
with timeit_context(f"Saving pruned model to {self.model_save_path}"):
|
|
186
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
187
|
+
model.save_pretrained(self.model_save_path)
|
|
188
|
+
return model
|
|
189
|
+
|
|
190
|
+
def _prepare_calibration_data(self, model, tokenizer):
|
|
191
|
+
"""
|
|
192
|
+
Prepare calibration data for pruning.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
model (LlamaForCausalLM): Model to be pruned.
|
|
196
|
+
tokenizer: Tokenizer for the model.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
|
|
200
|
+
"""
|
|
201
|
+
with timeit_context("loading calibration data"):
|
|
202
|
+
dataloader, _ = get_loaders(
|
|
203
|
+
"c4",
|
|
204
|
+
nsamples=self.nsamples,
|
|
205
|
+
seed=self.seed,
|
|
206
|
+
seqlen=model.seqlen,
|
|
207
|
+
tokenizer=tokenizer,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
with torch.no_grad():
|
|
211
|
+
# collect input to the first layer
|
|
212
|
+
inps, outs, attention_mask, position_ids = prepare_calibration_input(
|
|
213
|
+
model, dataloader, self.device
|
|
214
|
+
)
|
|
215
|
+
return inps, outs, attention_mask, position_ids
|
|
216
|
+
|
|
217
|
+
def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):
|
|
218
|
+
"""
|
|
219
|
+
Prepare calibration data for pruning with caching.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
model (LlamaForCausalLM): Model to be pruned.
|
|
223
|
+
tokenizer: Tokenizer for the model.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
@cache_to_disk(
|
|
230
|
+
f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
|
|
231
|
+
)
|
|
232
|
+
def _prepare_calibration_data(model, tokenizer):
|
|
233
|
+
return self._prepare_calibration_data(model, tokenizer)
|
|
234
|
+
|
|
235
|
+
return _prepare_calibration_data(model, tokenizer)
|
|
236
|
+
|
|
237
|
+
def prune_using_calibration_data_(
|
|
238
|
+
self,
|
|
239
|
+
model: LlamaForCausalLM,
|
|
240
|
+
*,
|
|
241
|
+
inps,
|
|
242
|
+
outs,
|
|
243
|
+
attention_mask,
|
|
244
|
+
position_ids,
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Prune the model using calibration data.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
model (LlamaForCausalLM): Model to be pruned.
|
|
251
|
+
inps: Calibration inputs.
|
|
252
|
+
outs: Calibration outputs.
|
|
253
|
+
attention_mask: Attention mask for calibration data.
|
|
254
|
+
position_ids: Position IDs for calibration data.
|
|
255
|
+
"""
|
|
256
|
+
layers = model.model.layers
|
|
257
|
+
for layer_idx, layer in tqdm(
|
|
258
|
+
enumerate(layers),
|
|
259
|
+
"Pruning Layers",
|
|
260
|
+
total=len(layers),
|
|
261
|
+
dynamic_ncols=True,
|
|
262
|
+
):
|
|
263
|
+
if (
|
|
264
|
+
hasattr(model, "hf_device_map")
|
|
265
|
+
and f"model.layers.{layer_idx}" in model.hf_device_map
|
|
266
|
+
):
|
|
267
|
+
# handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
|
|
268
|
+
dev = model.hf_device_map[f"model.layers.{layer_idx}"]
|
|
269
|
+
inps, outs, attention_mask, position_ids = (
|
|
270
|
+
inps.to(dev),
|
|
271
|
+
outs.to(dev),
|
|
272
|
+
attention_mask.to(dev) if attention_mask is not None else None,
|
|
273
|
+
position_ids.to(dev) if position_ids is not None else None,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# collect the importance scores
|
|
277
|
+
linear_layers = cast(
|
|
278
|
+
Dict[str, nn.Linear],
|
|
279
|
+
find_linear_layers(layer, layers=[nn.Linear]),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# register hooks to collect the importance scores
|
|
283
|
+
def get_hook_fn(linear: nn.Linear):
|
|
284
|
+
hook_fn = WandaHookFn(linear)
|
|
285
|
+
return hook_fn
|
|
286
|
+
|
|
287
|
+
hooks = {}
|
|
288
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
289
|
+
for name, linear in linear_layers.items():
|
|
290
|
+
hook_fn = get_hook_fn(linear)
|
|
291
|
+
hooks[name] = hook_fn
|
|
292
|
+
handles.append(linear.register_forward_hook(hook_fn))
|
|
293
|
+
|
|
294
|
+
with torch.no_grad():
|
|
295
|
+
for j in range(self.nsamples):
|
|
296
|
+
outs[j] = layer(
|
|
297
|
+
inps[j].unsqueeze(0),
|
|
298
|
+
attention_mask=attention_mask,
|
|
299
|
+
position_ids=position_ids,
|
|
300
|
+
)[0]
|
|
301
|
+
|
|
302
|
+
# compute the importance scores and remove the hooks
|
|
303
|
+
metrics = {}
|
|
304
|
+
for name, hook in hooks.items():
|
|
305
|
+
metrics[name] = hook.compute()
|
|
306
|
+
for h in handles:
|
|
307
|
+
h.remove()
|
|
308
|
+
|
|
309
|
+
# prune the weights based on the importance scores
|
|
310
|
+
if self.prune_type == PruningType.UNSTRUCTURED:
|
|
311
|
+
for name, linear in linear_layers.items():
|
|
312
|
+
log.info(f"Pruning {name}")
|
|
313
|
+
unstructured_magnitude_prune_(
|
|
314
|
+
linear.weight.data,
|
|
315
|
+
metrics[name],
|
|
316
|
+
sparsity_ratio=self.sparsity_ratio,
|
|
317
|
+
)
|
|
318
|
+
self.check_sparsity(linear.weight)
|
|
319
|
+
elif self.prune_type == PruningType.SEMISTRUCTURED:
|
|
320
|
+
for name, linear in linear_layers.items():
|
|
321
|
+
log.info(f"Pruning {name}")
|
|
322
|
+
semistructured_magnitude_prune_(
|
|
323
|
+
linear.weight.data,
|
|
324
|
+
metrics[name],
|
|
325
|
+
n=self.n,
|
|
326
|
+
m=self.m,
|
|
327
|
+
)
|
|
328
|
+
self.check_sparsity(linear.weight)
|
|
329
|
+
else:
|
|
330
|
+
raise ValueError(f"Invalid pruning type: {self.prune_type}")
|
|
331
|
+
|
|
332
|
+
# compute the input to the next layer
|
|
333
|
+
with torch.no_grad():
|
|
334
|
+
for j in range(self.nsamples):
|
|
335
|
+
outs[j] = layer(
|
|
336
|
+
inps[j].unsqueeze(0),
|
|
337
|
+
attention_mask=attention_mask,
|
|
338
|
+
position_ids=position_ids,
|
|
339
|
+
)[0]
|
|
340
|
+
inps, outs = outs, inps
|
|
341
|
+
|
|
342
|
+
@torch.no_grad()
|
|
343
|
+
def check_sparsity(self, weight: Tensor, tol: float = 0.01):
|
|
344
|
+
"""
|
|
345
|
+
Check the sparsity of the weight tensor.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
weight (Tensor): Weight tensor.
|
|
349
|
+
tol (float): Tolerance for sparsity check.
|
|
350
|
+
|
|
351
|
+
Raises:
|
|
352
|
+
ValueError: If the pruning type is invalid.
|
|
353
|
+
"""
|
|
354
|
+
if self.prune_type == PruningType.UNSTRUCTURED:
|
|
355
|
+
assert (compute_sparsity(weight) - self.sparsity_ratio).abs() < tol
|
|
356
|
+
elif self.prune_type == PruningType.SEMISTRUCTURED:
|
|
357
|
+
assert (compute_sparsity(weight) - self.n / self.m).abs() < tol
|
|
358
|
+
else:
|
|
359
|
+
raise ValueError(f"Invalid pruning type: {self.prune_type}")
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import Dict, List, Literal, Optional, Union # noqa: F401
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench.method import BaseAlgorithm
|
|
12
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
13
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
14
|
+
|
|
15
|
+
from .prune_utils import unstructured_magnitude_prune_
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _is_name_matched(name: str, extract_names: List[str]):
|
|
21
|
+
"""
|
|
22
|
+
Check if the parameter name matches any of the provided regular expressions.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
name (str): The name of the parameter.
|
|
26
|
+
extract_names (List[str]): List of regular expressions to match the parameter names.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
bool: True if the name matches any of the regular expressions, False otherwise.
|
|
30
|
+
"""
|
|
31
|
+
for extract_name in extract_names:
|
|
32
|
+
# extract_name is a regular expression
|
|
33
|
+
if re.match(extract_name, name):
|
|
34
|
+
return True
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MagnitudeDiffPruningAlgorithm(
|
|
39
|
+
BaseAlgorithm,
|
|
40
|
+
SimpleProfilerMixin,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Implements magnitude-based pruning on the difference between pretrained and fine-tuned model parameters.
|
|
44
|
+
|
|
45
|
+
This class supports pruning the difference between the pretrained and fine-tuned model parameters
|
|
46
|
+
based on their magnitude. It allows specifying the ratio of weights to prune and the names of
|
|
47
|
+
parameters to extract for pruning.
|
|
48
|
+
|
|
49
|
+
Methods:
|
|
50
|
+
run(modelpool: BaseModelPool) -> nn.Module:
|
|
51
|
+
Executes the pruning process on the model pool and returns the pruned model.
|
|
52
|
+
magnitude_prune(pretrained_model: nn.Module, finetuned_model: nn.Module, in_place: bool = True) -> nn.Module:
|
|
53
|
+
Prunes the difference between the pretrained and fine-tuned model parameters.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
57
|
+
"prune_ratio": "prune_ratio",
|
|
58
|
+
"extract_names": "extract_names",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
prune_ratio: float,
|
|
64
|
+
rescale: Optional[Union[bool, float]] = None,
|
|
65
|
+
extract_names: List[str] = None,
|
|
66
|
+
prune_type: Literal["minor", "major"] = "minor",
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Initialize the MagnitudeDiffPruningAlgorithm with the given configuration.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
prune_ratio (float): The ratio of weights to prune.
|
|
74
|
+
extract_names (List[str], optional): List of regular expressions to match the parameter names for pruning. Defaults to None.
|
|
75
|
+
**kwargs: Additional keyword arguments.
|
|
76
|
+
"""
|
|
77
|
+
self.prune_ratio = prune_ratio
|
|
78
|
+
self.rescale = rescale
|
|
79
|
+
self.extract_names = extract_names
|
|
80
|
+
self.prune_type = prune_type
|
|
81
|
+
super().__init__(**kwargs)
|
|
82
|
+
|
|
83
|
+
@torch.no_grad()
|
|
84
|
+
def run(self, modelpool: BaseModelPool):
|
|
85
|
+
"""
|
|
86
|
+
Execute the pruning process on the model pool.
|
|
87
|
+
|
|
88
|
+
This method loads the pretrained and fine-tuned models from the model pool,
|
|
89
|
+
prunes the difference between their parameters, and returns the pruned model.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
modelpool (BaseModelPool): The model pool containing the models to prune.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
nn.Module: The pruned model.
|
|
96
|
+
"""
|
|
97
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
98
|
+
modelpool = BaseModelPool(modelpool)
|
|
99
|
+
|
|
100
|
+
assert (
|
|
101
|
+
len(modelpool.model_names) == 1
|
|
102
|
+
), "Only one fine-tuned model is allowed in the model pool."
|
|
103
|
+
with self.profile("load pretrained model"):
|
|
104
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
105
|
+
with self.profile("load fine-tuned model"):
|
|
106
|
+
finetuned_model = modelpool.load_model(modelpool.model_names[0])
|
|
107
|
+
|
|
108
|
+
with self.profile("prune model"):
|
|
109
|
+
model = self.magnitude_prune(pretrained_model, finetuned_model)
|
|
110
|
+
|
|
111
|
+
self.print_profile_summary()
|
|
112
|
+
return model
|
|
113
|
+
|
|
114
|
+
@torch.no_grad()
|
|
115
|
+
def magnitude_prune(
|
|
116
|
+
self,
|
|
117
|
+
pretrained_model: nn.Module,
|
|
118
|
+
finetuned_model: nn.Module,
|
|
119
|
+
in_place: bool = True,
|
|
120
|
+
):
|
|
121
|
+
"""
|
|
122
|
+
Prune the difference between the pretrained and fine-tuned model parameters.
|
|
123
|
+
|
|
124
|
+
This method calculates the difference between the pretrained and fine-tuned model parameters,
|
|
125
|
+
prunes the difference based on their magnitude, and updates the pretrained model parameters
|
|
126
|
+
with the pruned difference.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
pretrained_model (nn.Module): The pretrained model.
|
|
130
|
+
finetuned_model (nn.Module): The fine-tuned model.
|
|
131
|
+
in_place (bool, optional): Whether to perform the pruning in place. Defaults to True.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
nn.Module: The pruned model.
|
|
135
|
+
"""
|
|
136
|
+
if in_place:
|
|
137
|
+
model = pretrained_model
|
|
138
|
+
else:
|
|
139
|
+
model = deepcopy(pretrained_model)
|
|
140
|
+
|
|
141
|
+
if self.extract_names is not None:
|
|
142
|
+
extract_names: List[str] = (
|
|
143
|
+
self.extract_names
|
|
144
|
+
) # regular expressions for the names of the parameters
|
|
145
|
+
else:
|
|
146
|
+
# extract the weight matrix of each linear layer
|
|
147
|
+
extract_names = []
|
|
148
|
+
for name, module in model.named_modules():
|
|
149
|
+
if isinstance(module, nn.Linear):
|
|
150
|
+
extract_names.append(f"{name}.weight")
|
|
151
|
+
|
|
152
|
+
ft_state_dict = finetuned_model.state_dict()
|
|
153
|
+
for name, param in tqdm(
|
|
154
|
+
model.named_parameters(),
|
|
155
|
+
"Magnitude Pruning On Parameter Difference",
|
|
156
|
+
total=len(tuple(model.named_parameters())),
|
|
157
|
+
):
|
|
158
|
+
if not param.requires_grad:
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
# Prune the diff parameter if its name matches
|
|
162
|
+
if _is_name_matched(name, extract_names):
|
|
163
|
+
w_diff = ft_state_dict[name] - param
|
|
164
|
+
w_diff = unstructured_magnitude_prune_(
|
|
165
|
+
w_diff,
|
|
166
|
+
(
|
|
167
|
+
torch.abs
|
|
168
|
+
if self.prune_type == "minor"
|
|
169
|
+
else lambda x: -torch.abs(x)
|
|
170
|
+
),
|
|
171
|
+
sparsity_ratio=self.prune_ratio,
|
|
172
|
+
)
|
|
173
|
+
if self.rescale is not None:
|
|
174
|
+
rescale = (
|
|
175
|
+
1 / self.prune_ratio if self.rescale == True else self.rescale
|
|
176
|
+
)
|
|
177
|
+
w_diff = w_diff * rescale
|
|
178
|
+
param.data = param + w_diff
|
|
179
|
+
|
|
180
|
+
return model
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Callable, Dict, Union # noqa: F401
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
# strEnum only available for python >= 3.11
|
|
8
|
+
# for older version, load from fusion_bench.utils.strenum
|
|
9
|
+
from enum import StrEnum
|
|
10
|
+
except ImportError:
|
|
11
|
+
from fusion_bench.utils.strenum import StrEnum
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PruningType(StrEnum):
|
|
15
|
+
"""
|
|
16
|
+
Enum class for different types of pruning.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
UNSTRUCTURED = "unstructured"
|
|
20
|
+
SEMISTRUCTURED = "semistructured" # N:M structured
|
|
21
|
+
STRUCTURED = "structured"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def find_linear_layers(module: nn.Module, layers=[nn.Linear], prefix=""):
|
|
25
|
+
"""
|
|
26
|
+
Recursively find the layers of a certain type in a module.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
module (nn.Module): PyTorch module.
|
|
30
|
+
layers (list): List of layer types to find.
|
|
31
|
+
prefix (str): A prefix to add to the layer names.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
dict: Dictionary of layers of the given type(s) within the module.
|
|
35
|
+
"""
|
|
36
|
+
res = {}
|
|
37
|
+
for name, submodule in module.named_modules(prefix=prefix):
|
|
38
|
+
if isinstance(submodule, tuple(layers)):
|
|
39
|
+
res[name] = submodule
|
|
40
|
+
return res
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def unstructured_magnitude_prune_(
|
|
44
|
+
weight: torch.Tensor,
|
|
45
|
+
metric_function_or_scores: Union[
|
|
46
|
+
Callable[[torch.Tensor], torch.Tensor], torch.Tensor
|
|
47
|
+
],
|
|
48
|
+
sparsity_ratio: float,
|
|
49
|
+
dtype: torch.dtype = None,
|
|
50
|
+
device: torch.device = None,
|
|
51
|
+
return_pruned_weight: bool = False,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Perform unstructured magnitude pruning on the given weight tensor.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
weight (torch.Tensor): The weight tensor to prune.
|
|
58
|
+
metric_function_or_scores (Union[Callable[[torch.Tensor], torch.Tensor], torch.Tensor]):
|
|
59
|
+
A function to compute the metric for pruning or a precomputed metric tensor.
|
|
60
|
+
sparsity_ratio (float): The ratio of weights to prune.
|
|
61
|
+
dtype (torch.dtype, optional): The data type to use for computations. Defaults to None.
|
|
62
|
+
device (torch.device, optional): The device to use for computations. Defaults to None.
|
|
63
|
+
return_pruned_weight (bool, optional): Whether to return the pruned weight tensor. Defaults to False.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
torch.Tensor: The pruned weight tensor.
|
|
67
|
+
torch.Tensor (optional): The pruned weight tensor if return_pruned_weight is True.
|
|
68
|
+
"""
|
|
69
|
+
original_device = weight.device
|
|
70
|
+
if callable(metric_function_or_scores):
|
|
71
|
+
W_metric = metric_function_or_scores(weight.to(dtype=dtype, device=device))
|
|
72
|
+
elif isinstance(metric_function_or_scores, torch.Tensor):
|
|
73
|
+
W_metric = metric_function_or_scores.to(dtype=dtype, device=device)
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"metric_function_or_scores should be either a callable or a tensor"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Create a mask for the weights to prune
|
|
80
|
+
W_mask = torch.zeros_like(W_metric) == 1
|
|
81
|
+
sort_res = torch.sort(W_metric, dim=-1, stable=True)
|
|
82
|
+
indices = sort_res[1][:, : int(W_metric.shape[1] * sparsity_ratio)]
|
|
83
|
+
W_mask.scatter_(1, indices, True)
|
|
84
|
+
W_mask = W_mask.to(device=original_device)
|
|
85
|
+
|
|
86
|
+
if not return_pruned_weight:
|
|
87
|
+
weight.masked_fill_(W_mask, 0)
|
|
88
|
+
return weight
|
|
89
|
+
else:
|
|
90
|
+
pruned_weight = weight.clone()
|
|
91
|
+
weight.masked_fill_(W_mask, 0)
|
|
92
|
+
pruned_weight.masked_fill_(~W_mask, 0)
|
|
93
|
+
return weight, pruned_weight
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def semistructured_magnitude_prune_(
|
|
97
|
+
weight: torch.Tensor,
|
|
98
|
+
metric_function_or_scores: Union[
|
|
99
|
+
Callable[[torch.Tensor], torch.Tensor], torch.Tensor
|
|
100
|
+
],
|
|
101
|
+
n: int,
|
|
102
|
+
m: int,
|
|
103
|
+
dtype: torch.dtype = None,
|
|
104
|
+
device: torch.device = None,
|
|
105
|
+
return_pruned_weight: bool = False,
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Perform semi-structured (N:M structured) magnitude pruning on the given weight tensor.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
weight (torch.Tensor): The weight tensor to prune.
|
|
112
|
+
metric_function_or_scores (Union[Callable[[torch.Tensor], torch.Tensor], torch.Tensor]):
|
|
113
|
+
A function to compute the metric for pruning or a precomputed metric tensor.
|
|
114
|
+
n (int): The number of weights to keep in each group.
|
|
115
|
+
m (int): The size of each group.
|
|
116
|
+
dtype (torch.dtype, optional): The data type to use for computations. Defaults to None.
|
|
117
|
+
device (torch.device, optional): The device to use for computations. Defaults to None.
|
|
118
|
+
return_pruned_weight (bool, optional): Whether to return the pruned weight tensor. Defaults to False.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
torch.Tensor: The pruned weight tensor.
|
|
122
|
+
torch.Tensor (optional): The pruned weight tensor if return_pruned_weight is True.
|
|
123
|
+
"""
|
|
124
|
+
original_device = weight.device
|
|
125
|
+
if callable(metric_function_or_scores):
|
|
126
|
+
W_metric = metric_function_or_scores(weight.to(dtype=dtype, device=device))
|
|
127
|
+
elif isinstance(metric_function_or_scores, torch.Tensor):
|
|
128
|
+
W_metric = metric_function_or_scores.to(dtype=dtype, device=device)
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"metric_function_or_scores should be either a callable or a tensor"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Create a mask for the weights to prune
|
|
135
|
+
W_mask = torch.zeros_like(W_metric, dtype=torch.bool)
|
|
136
|
+
for col_idx in range(0, W_metric.shape[1], m):
|
|
137
|
+
tmp = W_metric[:, col_idx : (col_idx + m)].float() # noqa: E203
|
|
138
|
+
W_mask.scatter_(
|
|
139
|
+
1,
|
|
140
|
+
col_idx + torch.topk(tmp, n, dim=1, largest=False)[1],
|
|
141
|
+
True,
|
|
142
|
+
)
|
|
143
|
+
W_mask = W_mask.to(device=original_device)
|
|
144
|
+
|
|
145
|
+
if not return_pruned_weight:
|
|
146
|
+
weight.masked_fill_(W_mask, 0)
|
|
147
|
+
return weight
|
|
148
|
+
else:
|
|
149
|
+
pruned_weight = weight.clone()
|
|
150
|
+
weight.masked_fill_(W_mask, 0)
|
|
151
|
+
pruned_weight.masked_fill_(~W_mask, 0)
|
|
152
|
+
return weight, pruned_weight
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def compute_sparsity(weight: torch.Tensor):
|
|
156
|
+
"""
|
|
157
|
+
Compute the sparsity of the given weight tensor.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
weight (torch.Tensor): The weight tensor.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
float: The sparsity of the weight tensor.
|
|
164
|
+
"""
|
|
165
|
+
return (weight == 0).sum() / weight.numel()
|