fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import pickle
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Union
|
|
7
|
+
|
|
8
|
+
__all__ = ["cache_to_disk"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def cache_to_disk(file_path: Union[str, Path]) -> Callable:
|
|
15
|
+
"""
|
|
16
|
+
A decorator to cache the result of a function to a file. If the file exists,
|
|
17
|
+
the result is loaded from the file. Otherwise, the function is executed and
|
|
18
|
+
the result is saved to the file.
|
|
19
|
+
|
|
20
|
+
## Example usage
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
@cache_to_disk("path_to_file.pkl")
|
|
24
|
+
def some_function(*args: Any, **kwargs: Any) -> Any:
|
|
25
|
+
# Function implementation
|
|
26
|
+
return "some result"
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
file_path (str): The path to the file where the result should be cached.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Callable: The decorated function.
|
|
34
|
+
"""
|
|
35
|
+
if isinstance(file_path, str):
|
|
36
|
+
file_path = Path(file_path)
|
|
37
|
+
assert isinstance(file_path, Path)
|
|
38
|
+
|
|
39
|
+
def decorator(func: Callable) -> Callable:
|
|
40
|
+
@wraps(func)
|
|
41
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
42
|
+
if os.path.exists(file_path):
|
|
43
|
+
log.info(
|
|
44
|
+
f"Loading cached result of {func.__name__} from {file_path}",
|
|
45
|
+
stacklevel=2,
|
|
46
|
+
)
|
|
47
|
+
with open(file_path, "rb") as f:
|
|
48
|
+
return pickle.load(f)
|
|
49
|
+
else:
|
|
50
|
+
result = func(*args, **kwargs)
|
|
51
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
with open(file_path, "wb") as f:
|
|
53
|
+
pickle.dump(result, f)
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
return wrapper
|
|
57
|
+
|
|
58
|
+
return decorator
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import pickle
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.utils.data
|
|
8
|
+
from torch.utils.data import DataLoader, Dataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InfiniteDataLoader:
|
|
12
|
+
def __init__(self, data_loader: DataLoader):
|
|
13
|
+
self.data_loader = data_loader
|
|
14
|
+
self.data_iter = iter(data_loader)
|
|
15
|
+
|
|
16
|
+
def __iter__(self):
|
|
17
|
+
return self
|
|
18
|
+
|
|
19
|
+
def __next__(self):
|
|
20
|
+
try:
|
|
21
|
+
data = next(self.data_iter)
|
|
22
|
+
except StopIteration:
|
|
23
|
+
self.data_iter = iter(self.data_loader) # Reset the data loader
|
|
24
|
+
data = next(self.data_iter)
|
|
25
|
+
return data
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_tensor_from_file(file_path: Union[str, Path], device=None) -> torch.Tensor:
|
|
29
|
+
"""
|
|
30
|
+
Loads a tensor from a file, which can be either a .pt, .pth or .np file.
|
|
31
|
+
If the file is not one of these formats, it will try to load it as a pickle file.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
file_path (str): The path to the file to load.
|
|
35
|
+
device: The device to move the tensor to. By default the tensor is loaded on the CPU.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
torch.Tensor: The tensor loaded from the file.
|
|
39
|
+
"""
|
|
40
|
+
if file_path.endswith(".np"):
|
|
41
|
+
tensor = torch.from_numpy(np.load(file_path)).detach_()
|
|
42
|
+
if file_path.endswith((".pt", ".pth")):
|
|
43
|
+
tensor = torch.load(file_path, map_location="cpu").detach_()
|
|
44
|
+
else:
|
|
45
|
+
try:
|
|
46
|
+
tensor = pickle.load(open(file_path, "rb"))
|
|
47
|
+
except Exception:
|
|
48
|
+
raise ValueError(f"Unsupported file format: {file_path}")
|
|
49
|
+
|
|
50
|
+
# Move tensor to device
|
|
51
|
+
assert isinstance(tensor, torch.Tensor), f"Expected tensor, got {type(tensor)}"
|
|
52
|
+
if device is not None:
|
|
53
|
+
tensor = tensor.to(device=device)
|
|
54
|
+
return tensor
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def train_validation_split(
|
|
58
|
+
dataset: Dataset,
|
|
59
|
+
validation_fraction: Optional[float] = 0.1,
|
|
60
|
+
validation_size: Optional[int] = None,
|
|
61
|
+
random_seed: Optional[int] = None,
|
|
62
|
+
return_split: Literal["all", "train", "val"] = "both",
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Split a dataset into a training and validation set.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataset (Dataset): The dataset to split.
|
|
69
|
+
validation_fraction (Optional[float]): The fraction of the dataset to use for validation.
|
|
70
|
+
validation_size (Optional[int]): The number of samples to use for validation. `validation_fraction` must be set to `None` if this is provided.
|
|
71
|
+
random_seed (Optional[int]): The random seed to use for reproducibility.
|
|
72
|
+
return_split (Literal["all", "train", "val"]): The split to return.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Tuple[Dataset, Dataset]: The training and validation datasets.
|
|
76
|
+
"""
|
|
77
|
+
# Check the input arguments
|
|
78
|
+
assert (
|
|
79
|
+
validation_fraction is None or validation_size is None
|
|
80
|
+
), "Only one of validation_fraction and validation_size can be provided"
|
|
81
|
+
assert (
|
|
82
|
+
validation_fraction is not None or validation_size is not None
|
|
83
|
+
), "Either validation_fraction or validation_size must be provided"
|
|
84
|
+
|
|
85
|
+
# Compute the number of samples for training and validation
|
|
86
|
+
num_samples = len(dataset)
|
|
87
|
+
if validation_size is not None:
|
|
88
|
+
assert (
|
|
89
|
+
0 < validation_fraction < 1
|
|
90
|
+
), "Validation fraction must be between 0 and 1"
|
|
91
|
+
num_validation_samples = int(num_samples * validation_fraction)
|
|
92
|
+
num_training_samples = num_samples - num_validation_samples
|
|
93
|
+
else:
|
|
94
|
+
assert (
|
|
95
|
+
validation_size < num_samples
|
|
96
|
+
), "Validation size must be less than num_samples"
|
|
97
|
+
num_validation_samples = validation_size
|
|
98
|
+
num_training_samples = num_samples - num_validation_samples
|
|
99
|
+
|
|
100
|
+
# Split the dataset
|
|
101
|
+
generator = (
|
|
102
|
+
torch.Generator().manual_seed(random_seed) if random_seed is not None else None
|
|
103
|
+
)
|
|
104
|
+
training_dataset, validation_dataset = torch.utils.data.random_split(
|
|
105
|
+
dataset, [num_training_samples, num_validation_samples], generator=generator
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# return the split as requested
|
|
109
|
+
if return_split == "all":
|
|
110
|
+
return training_dataset, validation_dataset
|
|
111
|
+
elif return_split == "train":
|
|
112
|
+
return training_dataset
|
|
113
|
+
elif return_split == "val":
|
|
114
|
+
return validation_dataset
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(f"Invalid return_split: {return_split}")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def train_validation_test_split(
|
|
120
|
+
dataset: Dataset,
|
|
121
|
+
validation_fraction: float,
|
|
122
|
+
test_fraction: float,
|
|
123
|
+
random_seed: Optional[int] = None,
|
|
124
|
+
return_spilt: Literal["all", "train", "val", "test"] = "all",
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Split a dataset into a training, validation and test set.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
dataset (Dataset): The dataset to split.
|
|
131
|
+
validation_fraction (float): The fraction of the dataset to use for validation.
|
|
132
|
+
test_fraction (float): The fraction of the dataset to use for test.
|
|
133
|
+
random_seed (Optional[int]): The random seed to use for reproducibility.
|
|
134
|
+
return_spilt (Literal["all", "train", "val", "test"]): The split to return.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Tuple[Dataset, Dataset, Dataset]: The training, validation and test datasets.
|
|
138
|
+
"""
|
|
139
|
+
num_samples = len(dataset)
|
|
140
|
+
assert 0 < validation_fraction < 1, "Validation fraction must be between 0 and 1"
|
|
141
|
+
assert 0 < test_fraction < 1, "Test fraction must be between 0 and 1"
|
|
142
|
+
generaotr = (
|
|
143
|
+
torch.Generator().manual_seed(random_seed) if random_seed is not None else None
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
num_validation_samples = int(num_samples * validation_fraction)
|
|
147
|
+
num_test_samples = int(num_samples * test_fraction)
|
|
148
|
+
num_training_samples = num_samples - num_validation_samples - num_test_samples
|
|
149
|
+
training_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(
|
|
150
|
+
dataset,
|
|
151
|
+
[num_training_samples, num_validation_samples, num_test_samples],
|
|
152
|
+
generator=generaotr,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# return the split as requested
|
|
156
|
+
if return_spilt == "all":
|
|
157
|
+
return training_dataset, validation_dataset, test_dataset
|
|
158
|
+
elif return_spilt == "train":
|
|
159
|
+
return training_dataset
|
|
160
|
+
elif return_spilt == "val":
|
|
161
|
+
return validation_dataset
|
|
162
|
+
elif return_spilt == "test":
|
|
163
|
+
return test_dataset
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(f"Invalid return_split: {return_spilt}")
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import os
|
|
3
|
+
from typing import List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformers.utils import (
|
|
7
|
+
is_torch_bf16_gpu_available,
|
|
8
|
+
is_torch_cuda_available,
|
|
9
|
+
is_torch_mps_available,
|
|
10
|
+
is_torch_npu_available,
|
|
11
|
+
is_torch_xpu_available,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"cuda_empty_cache",
|
|
16
|
+
"to_device",
|
|
17
|
+
"num_devices",
|
|
18
|
+
"get_device",
|
|
19
|
+
"get_current_device",
|
|
20
|
+
"get_device_memory_info",
|
|
21
|
+
"get_device_capabilities",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def cuda_empty_cache():
|
|
26
|
+
gc.collect()
|
|
27
|
+
torch.cuda.empty_cache()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def to_device(obj, device: Optional[torch.device], **kwargs):
|
|
31
|
+
"""
|
|
32
|
+
Move a given object to the specified device.
|
|
33
|
+
|
|
34
|
+
This function recursively moves tensors, modules, lists, tuples, and dictionaries to the specified device.
|
|
35
|
+
For unsupported types, the object is returned as is.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
obj: The object to be moved to the device. This can be a torch.Tensor, torch.nn.Module, list, tuple, or dict.
|
|
39
|
+
device (torch.device): The target device to move the object to. This can be `None`.
|
|
40
|
+
**kwargs: Additional keyword arguments to be passed to the `to` method of torch.Tensor or torch.nn.Module. For example, `non_blocking=True`, `dtype=torch.float16`.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The object moved to the specified device. The type of the returned object matches the type of the input object.
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
>>> tensor = torch.tensor([1, 2, 3])
|
|
47
|
+
>>> to_device(tensor, torch.device('cuda'))
|
|
48
|
+
tensor([1, 2, 3], device='cuda:0')
|
|
49
|
+
|
|
50
|
+
>>> model = torch.nn.Linear(2, 2)
|
|
51
|
+
>>> to_device(model, torch.device('cuda'))
|
|
52
|
+
Linear(..., device='cuda:0')
|
|
53
|
+
|
|
54
|
+
>>> data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
|
55
|
+
>>> to_device(data, torch.device('cuda'))
|
|
56
|
+
[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')]
|
|
57
|
+
"""
|
|
58
|
+
if isinstance(obj, (torch.Tensor, torch.nn.Module)):
|
|
59
|
+
return obj.to(device, **kwargs)
|
|
60
|
+
elif isinstance(obj, list):
|
|
61
|
+
return [to_device(o, device) for o in obj]
|
|
62
|
+
elif isinstance(obj, tuple):
|
|
63
|
+
return tuple(to_device(o, device) for o in obj)
|
|
64
|
+
elif isinstance(obj, dict):
|
|
65
|
+
for key in obj:
|
|
66
|
+
obj[key] = to_device(obj[key], device)
|
|
67
|
+
return obj
|
|
68
|
+
else:
|
|
69
|
+
# the default behavior is to return the object as is
|
|
70
|
+
return obj
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def num_devices(devices: Union[int, List[int], str]) -> int:
|
|
74
|
+
"""
|
|
75
|
+
Return the number of devices.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
The number of devices.
|
|
82
|
+
"""
|
|
83
|
+
if isinstance(devices, int):
|
|
84
|
+
return devices
|
|
85
|
+
elif isinstance(devices, str):
|
|
86
|
+
return len(devices.split(","))
|
|
87
|
+
elif isinstance(devices, list):
|
|
88
|
+
return len(devices)
|
|
89
|
+
else:
|
|
90
|
+
raise TypeError(
|
|
91
|
+
f"devices must be a single int or a list of ints, but got {type(devices)}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_device(obj) -> torch.device:
|
|
96
|
+
"""
|
|
97
|
+
Get the device of a given object.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
obj: The object whose device is to be determined.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
torch.device: The device of the given object.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If the object type is not supported.
|
|
107
|
+
"""
|
|
108
|
+
if isinstance(obj, torch.Tensor):
|
|
109
|
+
return obj.device
|
|
110
|
+
elif isinstance(obj, torch.nn.Module):
|
|
111
|
+
if hasattr(obj, "device"):
|
|
112
|
+
return obj.device
|
|
113
|
+
else:
|
|
114
|
+
return next(iter(obj.parameters())).device
|
|
115
|
+
elif isinstance(obj, torch.device):
|
|
116
|
+
return obj
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(f"Unsupported object type: {type(obj)}")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def get_current_device() -> torch.device:
|
|
122
|
+
R"""
|
|
123
|
+
Gets the current available device for PyTorch operations.
|
|
124
|
+
This is used for distributed training.
|
|
125
|
+
|
|
126
|
+
This function checks the availability of various types of devices in the following order:
|
|
127
|
+
1. XPU (Intel's AI accelerator)
|
|
128
|
+
2. NPU (Neural Processing Unit)
|
|
129
|
+
3. MPS (Metal Performance Shaders, for Apple devices)
|
|
130
|
+
4. CUDA (NVIDIA's GPU)
|
|
131
|
+
5. CPU (Central Processing Unit, used as a fallback)
|
|
132
|
+
|
|
133
|
+
The function returns the first available device found in the above order. If none of the specialized devices
|
|
134
|
+
are available, it defaults to the CPU.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
torch.device: The current available device for PyTorch operations.
|
|
138
|
+
|
|
139
|
+
Environment Variables:
|
|
140
|
+
LOCAL_RANK: This environment variable is used to specify the device index for multi-device setups.
|
|
141
|
+
If not set, it defaults to "0".
|
|
142
|
+
|
|
143
|
+
Example:
|
|
144
|
+
>>> device = get_current_device()
|
|
145
|
+
>>> print(device)
|
|
146
|
+
xpu:0 # or npu:0, mps:0, cuda:0, cpu depending on availability
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
if is_torch_xpu_available():
|
|
150
|
+
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
|
151
|
+
elif is_torch_npu_available():
|
|
152
|
+
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
|
153
|
+
elif is_torch_mps_available():
|
|
154
|
+
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
|
155
|
+
elif is_torch_cuda_available():
|
|
156
|
+
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
|
157
|
+
else:
|
|
158
|
+
device = "cpu"
|
|
159
|
+
|
|
160
|
+
return torch.device(device)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def get_device_memory_info(device: torch.device, reset_stats: bool = True) -> dict:
|
|
164
|
+
"""
|
|
165
|
+
Get memory information for a given device.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
device (torch.device): The device for which to get memory information.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
dict: A dictionary containing memory information for the given device.
|
|
172
|
+
"""
|
|
173
|
+
if device.type == "cuda":
|
|
174
|
+
total_memory = torch.cuda.get_device_properties(device).total_memory
|
|
175
|
+
reserved_memory = torch.cuda.memory_reserved(device)
|
|
176
|
+
allocated_memory = torch.cuda.memory_allocated(device)
|
|
177
|
+
peak_memory_active = torch.cuda.memory_stats(device).get(
|
|
178
|
+
"active_bytes.all.peak", 0
|
|
179
|
+
)
|
|
180
|
+
peak_mem_alloc = torch.cuda.max_memory_allocated(device)
|
|
181
|
+
peak_mem_reserved = torch.cuda.max_memory_reserved(device)
|
|
182
|
+
|
|
183
|
+
if reset_stats:
|
|
184
|
+
torch.cuda.reset_peak_memory_stats(device)
|
|
185
|
+
|
|
186
|
+
return {
|
|
187
|
+
"total_memory": total_memory,
|
|
188
|
+
"reserved_memory": reserved_memory,
|
|
189
|
+
"allocated_memory": allocated_memory,
|
|
190
|
+
"peak_memory_active": peak_memory_active,
|
|
191
|
+
"peak_memory_allocated": peak_mem_alloc,
|
|
192
|
+
"peak_memory_reserved": peak_mem_reserved,
|
|
193
|
+
}
|
|
194
|
+
else:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"Memory information not available for device type: {device.type}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def get_device_capabilities(device: torch.device) -> dict:
|
|
201
|
+
"""
|
|
202
|
+
Get capabilities information for a given device.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
device (torch.device): The device for which to get capabilities information.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
dict: A dictionary containing capabilities information for the given device.
|
|
209
|
+
"""
|
|
210
|
+
if device.type == "cuda":
|
|
211
|
+
return {
|
|
212
|
+
"name": torch.cuda.get_device_name(device),
|
|
213
|
+
"capability": torch.cuda.get_device_capability(device),
|
|
214
|
+
"total_memory": torch.cuda.get_device_properties(device).total_memory,
|
|
215
|
+
"multi_processor_count": torch.cuda.get_device_properties(
|
|
216
|
+
device
|
|
217
|
+
).multi_processor_count,
|
|
218
|
+
}
|
|
219
|
+
else:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"Capabilities information not available for device type: {device.type}"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def cleanup_cuda():
|
|
226
|
+
"""
|
|
227
|
+
Call gc collect, empty CUDA cache, and reset peak memory stats.
|
|
228
|
+
"""
|
|
229
|
+
gc.collect()
|
|
230
|
+
torch.cuda.empty_cache()
|
|
231
|
+
torch.cuda.reset_peak_memory_stats()
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import Iterable, List, Tuple, Union
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def dict_get(d: dict, keys: Iterable[str], default=None):
|
|
6
|
+
return [d.get(k, default) for k in keys]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def dict_map(f, d: dict, *, max_level: int = -1, skip_levels=0, inplace=False):
|
|
10
|
+
"""Apply function f to each element in dictionary d and return a new dictionary.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
f (callable): function to apply
|
|
14
|
+
d (dict): input dictionary
|
|
15
|
+
max_level (int, optional): maximum depth to apply function, -1 means unlimited. Defaults to -1.
|
|
16
|
+
skip_levels (int, optional): number of levels to skip. Defaults to 0.
|
|
17
|
+
inplace (bool, optional): whether to modify input dictionary in place. Defaults to False.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
dict: transformed dictionary
|
|
21
|
+
"""
|
|
22
|
+
if not isinstance(d, dict):
|
|
23
|
+
raise TypeError("dict_map: d must be a dict")
|
|
24
|
+
|
|
25
|
+
if inplace:
|
|
26
|
+
ans = d
|
|
27
|
+
else:
|
|
28
|
+
ans = deepcopy(d)
|
|
29
|
+
|
|
30
|
+
def dict_map_impl(from_dict, to_dict, level):
|
|
31
|
+
if level == max_level:
|
|
32
|
+
return
|
|
33
|
+
for k in from_dict.keys():
|
|
34
|
+
if isinstance(from_dict[k], dict):
|
|
35
|
+
dict_map_impl(from_dict[k], to_dict[k], level + 1)
|
|
36
|
+
else:
|
|
37
|
+
if level < skip_levels:
|
|
38
|
+
continue
|
|
39
|
+
else:
|
|
40
|
+
to_dict[k] = f(from_dict[k])
|
|
41
|
+
|
|
42
|
+
dict_map_impl(d, ans, 0)
|
|
43
|
+
return ans
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import Dict, Generator, Iterable, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from transformers.utils import (
|
|
6
|
+
is_torch_bf16_gpu_available,
|
|
7
|
+
is_torch_cuda_available,
|
|
8
|
+
is_torch_mps_available,
|
|
9
|
+
is_torch_npu_available,
|
|
10
|
+
is_torch_xpu_available,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
|
|
14
|
+
"fp16": torch.float16,
|
|
15
|
+
"float16": torch.float16,
|
|
16
|
+
"bf16": torch.bfloat16,
|
|
17
|
+
"bfloat16": torch.bfloat16,
|
|
18
|
+
"float": torch.float32,
|
|
19
|
+
"fp32": torch.float32,
|
|
20
|
+
"float32": torch.float32,
|
|
21
|
+
"double": torch.float64,
|
|
22
|
+
"fp64": torch.float64,
|
|
23
|
+
"float64": torch.float64,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def parse_dtype(dtype: Optional[str]):
|
|
28
|
+
"""
|
|
29
|
+
Parses a string representation of a data type and returns the corresponding torch.dtype.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
dtype (Optional[str]): The string representation of the data type.
|
|
33
|
+
Can be one of "float32", "float", "float64", "double",
|
|
34
|
+
"float16", "half", "bfloat16", or "bf16".
|
|
35
|
+
If None, returns None.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
torch.dtype: The corresponding torch.dtype if the input is a valid string representation.
|
|
39
|
+
If the input is already a torch.dtype, it is returned as is.
|
|
40
|
+
If the input is None, returns None.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If the input string does not correspond to a supported data type.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(dtype, torch.dtype):
|
|
46
|
+
return dtype
|
|
47
|
+
|
|
48
|
+
if dtype is None:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
dtype = dtype.strip('"')
|
|
52
|
+
if dtype not in PRECISION_STR_TO_DTYPE:
|
|
53
|
+
raise ValueError(f"Unsupported dtype: {type(dtype)}")
|
|
54
|
+
|
|
55
|
+
dtype = PRECISION_STR_TO_DTYPE[dtype]
|
|
56
|
+
return dtype
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_dtype(obj) -> torch.dtype:
|
|
60
|
+
"""
|
|
61
|
+
Get the data type (dtype) of a given object.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.dtype: The data type of the given object.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
ValueError: If the object type is not supported.
|
|
68
|
+
"""
|
|
69
|
+
if isinstance(obj, torch.Tensor):
|
|
70
|
+
return obj.dtype
|
|
71
|
+
elif isinstance(obj, torch.nn.Module):
|
|
72
|
+
if hasattr(obj, "dtype"):
|
|
73
|
+
return obj.dtype
|
|
74
|
+
else:
|
|
75
|
+
return next(iter(obj.parameters())).dtype
|
|
76
|
+
elif isinstance(obj, (torch.device, str)):
|
|
77
|
+
return parse_dtype(obj)
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(f"Unsupported object type: {type(obj)}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@contextlib.contextmanager
|
|
83
|
+
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
|
|
84
|
+
"""
|
|
85
|
+
Context manager to set torch's default dtype.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
dtype (torch.dtype): The desired default dtype inside the context manager.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
ContextManager: context manager for setting default dtype.
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
>>> with set_default_dtype(torch.bfloat16):
|
|
95
|
+
>>> x = torch.tensor([1, 2, 3])
|
|
96
|
+
>>> x.dtype
|
|
97
|
+
torch.bfloat16
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
old_dtype = torch.get_default_dtype()
|
|
102
|
+
torch.set_default_dtype(dtype)
|
|
103
|
+
try:
|
|
104
|
+
yield
|
|
105
|
+
finally:
|
|
106
|
+
torch.set_default_dtype(old_dtype)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
|
110
|
+
r"""
|
|
111
|
+
Infers the optimal dtype according to the model_dtype and device compatibility.
|
|
112
|
+
"""
|
|
113
|
+
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
|
114
|
+
try:
|
|
115
|
+
_is_bf16_available = is_torch_bf16_gpu_available() or (
|
|
116
|
+
is_torch_npu_available() and torch.npu.is_bf16_supported()
|
|
117
|
+
)
|
|
118
|
+
except Exception:
|
|
119
|
+
_is_bf16_available = False
|
|
120
|
+
|
|
121
|
+
if _is_bf16_available and model_dtype == torch.bfloat16:
|
|
122
|
+
return torch.bfloat16
|
|
123
|
+
elif _is_fp16_available:
|
|
124
|
+
return torch.float16
|
|
125
|
+
else:
|
|
126
|
+
return torch.float32
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def validate_expected_param_dtype(
|
|
130
|
+
named_params: Iterable[Tuple[str, torch.nn.Parameter]], dtype: torch.dtype
|
|
131
|
+
) -> None:
|
|
132
|
+
"""
|
|
133
|
+
Validates that all input parameters have the expected dtype.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
named_params (Iterable[Tuple[str, torch.nn.Parameter]]): Iterable of named parameters.
|
|
137
|
+
dtype (torch.dtype): Expected dtype.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If any parameter has a different dtype than `dtype`.
|
|
141
|
+
"""
|
|
142
|
+
for name, param in named_params:
|
|
143
|
+
if param.dtype != dtype:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
|
|
146
|
+
)
|