fusion-bench 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +20 -0
- fusion_bench/__main__.py +4 -0
- fusion_bench/compat/__init__.py +0 -0
- fusion_bench/compat/method/__init__.py +109 -0
- fusion_bench/compat/method/base_algorithm.py +58 -0
- fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py +34 -0
- fusion_bench/compat/modelpool/__init__.py +116 -0
- fusion_bench/compat/modelpool/base_pool.py +328 -0
- fusion_bench/compat/modelpool/huggingface_clip_vision.py +178 -0
- fusion_bench/compat/taskpool/__init__.py +95 -0
- fusion_bench/compat/taskpool/base_pool.py +111 -0
- fusion_bench/compat/taskpool/clip_image_classification.py +210 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +175 -0
- fusion_bench/constants/__init__.py +2 -0
- fusion_bench/constants/paths.py +18 -0
- fusion_bench/dataset/__init__.py +29 -0
- fusion_bench/dataset/arc_agi/__init__.py +6 -0
- fusion_bench/dataset/arc_agi/arc.py +308 -0
- fusion_bench/dataset/arc_agi/arc_agi.py +365 -0
- fusion_bench/dataset/arc_agi/augmenters.py +1036 -0
- fusion_bench/dataset/arc_agi/messagers.py +1355 -0
- fusion_bench/dataset/arc_agi/np_cache.py +168 -0
- fusion_bench/dataset/arc_agi/preprocess.py +298 -0
- fusion_bench/dataset/arc_agi/representers.py +1019 -0
- fusion_bench/dataset/clip_dataset.py +71 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/gpt2_glue.py +300 -0
- fusion_bench/dataset/gsm8k.py +60 -0
- fusion_bench/dataset/image_dataset.py +55 -0
- fusion_bench/dataset/imdb.py +11 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +232 -0
- fusion_bench/dataset/llama/collate.py +120 -0
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/openai.py +160 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/sharegpt.py +141 -0
- fusion_bench/dataset/llama/squad.py +125 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/dataset/llama/wikitext.py +89 -0
- fusion_bench/dataset/nyuv2.py +119 -0
- fusion_bench/method/__init__.py +177 -0
- fusion_bench/method/ada_svd/__init__.py +2 -0
- fusion_bench/method/ada_svd/clip_vision.py +319 -0
- fusion_bench/method/adamerging/__init__.py +6 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +187 -0
- fusion_bench/method/adamerging/entropy_loss.py +25 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +332 -0
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +351 -0
- fusion_bench/method/adamerging/layer_wise_adamerging.py +252 -0
- fusion_bench/method/adamerging/llama_adamerging.py +335 -0
- fusion_bench/method/adamerging/min_norm_solvers.py +227 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +174 -0
- fusion_bench/method/adamerging/utils.py +15 -0
- fusion_bench/method/analysis/__init__.py +2 -0
- fusion_bench/method/analysis/task_vector_cos_similarity.py +172 -0
- fusion_bench/method/analysis/task_vector_violin_plot.py +205 -0
- fusion_bench/method/base_algorithm.py +44 -0
- fusion_bench/method/classification/__init__.py +3 -0
- fusion_bench/method/classification/clip_finetune.py +444 -0
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/concrete_subspace/__init__.py +6 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +595 -0
- fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +263 -0
- fusion_bench/method/dare/__init__.py +4 -0
- fusion_bench/method/dare/simple_average.py +31 -0
- fusion_bench/method/dare/task_arithmetic.py +82 -0
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/dare/utils.py +87 -0
- fusion_bench/method/dawe/__init__.py +2 -0
- fusion_bench/method/dawe/dawe_for_clip.py +274 -0
- fusion_bench/method/dawe/warppers/__init__.py +13 -0
- fusion_bench/method/dawe/warppers/dawe_model.py +256 -0
- fusion_bench/method/depth_upscaling/__init__.py +3 -0
- fusion_bench/method/depth_upscaling/depth_upscaling.py +89 -0
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +57 -0
- fusion_bench/method/dummy.py +35 -0
- fusion_bench/method/ensemble.py +98 -0
- fusion_bench/method/fisher_merging/__init__.py +4 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +191 -0
- fusion_bench/method/fisher_merging/fisher_merging.py +484 -0
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +193 -0
- fusion_bench/method/linear/__init__.py +6 -0
- fusion_bench/method/linear/expo.py +118 -0
- fusion_bench/method/linear/linear_interpolation.py +60 -0
- fusion_bench/method/linear/llama_expo.py +229 -0
- fusion_bench/method/linear/simple_average_for_llama.py +54 -0
- fusion_bench/method/linear/task_arithmetic_for_llama.py +57 -0
- fusion_bench/method/lm_finetune/__init__.py +3 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/causal_lm_pretrain.py +7 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +375 -0
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +370 -0
- fusion_bench/method/mixture_of_experts/__init__.py +7 -0
- fusion_bench/method/mixture_of_experts/mixtral_merging.py +112 -0
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +329 -0
- fusion_bench/method/model_recombination.py +121 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/pruning/__init__.py +5 -0
- fusion_bench/method/pruning/llama_magnitude_prune.py +202 -0
- fusion_bench/method/pruning/llama_random_prune.py +143 -0
- fusion_bench/method/pruning/llama_wanda_prune.py +359 -0
- fusion_bench/method/pruning/magnitude_diff_pruning.py +180 -0
- fusion_bench/method/pruning/prune_utils.py +165 -0
- fusion_bench/method/pruning/wanda_utils/__init__.py +7 -0
- fusion_bench/method/pruning/wanda_utils/ablate.py +188 -0
- fusion_bench/method/pruning/wanda_utils/data.py +135 -0
- fusion_bench/method/pruning/wanda_utils/eval.py +245 -0
- fusion_bench/method/pruning/wanda_utils/layerwrapper.py +61 -0
- fusion_bench/method/pruning/wanda_utils/prune.py +581 -0
- fusion_bench/method/pruning/wanda_utils/prune_opt.py +539 -0
- fusion_bench/method/pruning/wanda_utils/sparsegpt.py +165 -0
- fusion_bench/method/pwe_moe/__init__.py +5 -0
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +315 -0
- fusion_bench/method/pwe_moe/module.py +316 -0
- fusion_bench/method/pwe_moe/phn/__init__.py +2 -0
- fusion_bench/method/pwe_moe/phn/solvers.py +195 -0
- fusion_bench/method/pwe_moe/utils.py +43 -0
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/regmean/__init__.py +4 -0
- fusion_bench/method/regmean/clip_regmean.py +131 -0
- fusion_bench/method/regmean/gpt2_regmean.py +147 -0
- fusion_bench/method/regmean/regmean.py +375 -0
- fusion_bench/method/simple_average.py +112 -0
- fusion_bench/method/slerp/__init__.py +2 -0
- fusion_bench/method/slerp/slerp.py +101 -0
- fusion_bench/method/slerp/slerp_utils.py +107 -0
- fusion_bench/method/smile_upscaling/__init__.py +3 -0
- fusion_bench/method/smile_upscaling/singular_projection_merging.py +198 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +331 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +573 -0
- fusion_bench/method/sparse_we_moe/__init__.py +2 -0
- fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py +248 -0
- fusion_bench/method/sparse_we_moe/sparse_we_moe.py +301 -0
- fusion_bench/method/sparselo/__init__.py +2 -0
- fusion_bench/method/sparselo/sparselo.py +955 -0
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_arithmetic/__init__.py +2 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +151 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +640 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/__init__.py +2 -0
- fusion_bench/method/ties_merging/ties_merging.py +117 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +331 -0
- fusion_bench/method/trust_region/__init__.py +2 -0
- fusion_bench/method/trust_region/clip_task_arithmetic.py +205 -0
- fusion_bench/method/trust_region/utils.py +58 -0
- fusion_bench/method/we_moe/__init__.py +2 -0
- fusion_bench/method/we_moe/clip_we_moe.py +161 -0
- fusion_bench/method/we_moe/we_moe.py +247 -0
- fusion_bench/method/weighted_average/__init__.py +3 -0
- fusion_bench/method/weighted_average/llama.py +113 -0
- fusion_bench/method/weighted_average/weighted_average.py +102 -0
- fusion_bench/metrics/__init__.py +0 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/metrics/nyuv2/__init__.py +11 -0
- fusion_bench/metrics/nyuv2/depth.py +45 -0
- fusion_bench/metrics/nyuv2/loss.py +31 -0
- fusion_bench/metrics/nyuv2/noise.py +16 -0
- fusion_bench/metrics/nyuv2/normal.py +48 -0
- fusion_bench/metrics/nyuv2/segmentation.py +43 -0
- fusion_bench/metrics/text_to_image_generation/__init__.py +9 -0
- fusion_bench/metrics/text_to_image_generation/aesthetic_scorer.py +123 -0
- fusion_bench/metrics/text_to_image_generation/compressibility.py +49 -0
- fusion_bench/metrics/text_to_image_generation/pickscore_scorer.py +95 -0
- fusion_bench/mixins/__init__.py +28 -0
- fusion_bench/mixins/clip_classification.py +252 -0
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +174 -0
- fusion_bench/mixins/optim/__init__.py +0 -0
- fusion_bench/mixins/optim/adamw_with_warmup.py +42 -0
- fusion_bench/mixins/rich_live.py +21 -0
- fusion_bench/mixins/serialization.py +132 -0
- fusion_bench/mixins/simple_profiler.py +79 -0
- fusion_bench/modelpool/PeftModelForSeq2SeqLM.py +49 -0
- fusion_bench/modelpool/__init__.py +42 -0
- fusion_bench/modelpool/base_pool.py +268 -0
- fusion_bench/modelpool/causal_lm/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +139 -0
- fusion_bench/modelpool/clip_vision/__init__.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +145 -0
- fusion_bench/modelpool/huggingface_automodel.py +20 -0
- fusion_bench/modelpool/huggingface_gpt2_classification.py +63 -0
- fusion_bench/modelpool/nyuv2_modelpool.py +40 -0
- fusion_bench/modelpool/seq2seq_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +65 -0
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/__init__.py +3 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +199 -0
- fusion_bench/models/linearized/__init__.py +0 -0
- fusion_bench/models/linearized/linearized_model_utils.py +91 -0
- fusion_bench/models/linearized/vision_model.py +122 -0
- fusion_bench/models/llama/__init__.py +16 -0
- fusion_bench/models/llama/model_utils/__init__.py +0 -0
- fusion_bench/models/llama/model_utils/embedding.py +87 -0
- fusion_bench/models/llama/model_utils/liger_kernel.py +86 -0
- fusion_bench/models/llama/model_utils/misc.py +112 -0
- fusion_bench/models/llama/model_utils/mod.py +52 -0
- fusion_bench/models/llama/model_utils/visual.py +241 -0
- fusion_bench/models/llama/patcher.py +78 -0
- fusion_bench/models/llama/tokenizer_loader.py +153 -0
- fusion_bench/models/masks/__init__.py +2 -0
- fusion_bench/models/masks/mask_model.py +160 -0
- fusion_bench/models/modeling_losparse_llama/__init__.py +4 -0
- fusion_bench/models/modeling_losparse_llama/configuration_losparse_llama.py +205 -0
- fusion_bench/models/modeling_losparse_llama/losparse_linear.py +67 -0
- fusion_bench/models/modeling_losparse_llama/modeling_losparse_llama.py +1825 -0
- fusion_bench/models/modeling_losparse_llama/register.py +8 -0
- fusion_bench/models/modeling_losparse_llama/utils.py +60 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +48 -0
- fusion_bench/models/modeling_smile_mistral/configuration_smile_mistral.py +21 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +1034 -0
- fusion_bench/models/modeling_smile_mistral/register.py +8 -0
- fusion_bench/models/nyuv2/__init__.py +0 -0
- fusion_bench/models/nyuv2/aspp.py +82 -0
- fusion_bench/models/nyuv2/lightning_module.py +176 -0
- fusion_bench/models/nyuv2/resnet.py +405 -0
- fusion_bench/models/nyuv2/resnet_dilated.py +99 -0
- fusion_bench/models/parameter_dict.py +75 -0
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/separate_io.py +105 -0
- fusion_bench/models/smile_moe/__init__.py +0 -0
- fusion_bench/models/smile_moe/linear.py +256 -0
- fusion_bench/models/sparse_we_moe.py +459 -0
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +80 -0
- fusion_bench/models/we_moe.py +247 -0
- fusion_bench/models/wrappers/__init__.py +0 -0
- fusion_bench/models/wrappers/ensemble.py +183 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +336 -0
- fusion_bench/models/wrappers/task_wise_fusion.py +249 -0
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +118 -0
- fusion_bench/programs/__init__.py +20 -0
- fusion_bench/programs/base_program.py +9 -0
- fusion_bench/programs/fabric_fusion_program.py +299 -0
- fusion_bench/scripts/__init__.py +0 -0
- fusion_bench/scripts/cli.py +43 -0
- fusion_bench/scripts/clip/__init__.py +0 -0
- fusion_bench/scripts/clip/convert_checkpoint.py +39 -0
- fusion_bench/scripts/imgui.py +218 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +137 -0
- fusion_bench/scripts/webui.py +405 -0
- fusion_bench/taskpool/__init__.py +39 -0
- fusion_bench/taskpool/base_pool.py +35 -0
- fusion_bench/taskpool/clip_vision/__init__.py +4 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +120 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +392 -0
- fusion_bench/taskpool/dummy.py +58 -0
- fusion_bench/taskpool/gpt2_text_classification.py +149 -0
- fusion_bench/taskpool/llama/__init__.py +1 -0
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/llama/test_generation.py +185 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +65 -0
- fusion_bench/tasks/__init__.py +2 -0
- fusion_bench/tasks/base_task.py +18 -0
- fusion_bench/tasks/classification.py +75 -0
- fusion_bench/tasks/clip_classification/__init__.py +183 -0
- fusion_bench/tasks/clip_classification/cifar10.py +33 -0
- fusion_bench/tasks/clip_classification/cifar100.py +146 -0
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -0
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/dtd.py +60 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/eurosat.py +18 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/flower102.py +106 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/gtsrb.py +51 -0
- fusion_bench/tasks/clip_classification/imagenet.py +2103 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mnist.py +5 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/oxford_iiit_pet.py +41 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/tasks/clip_classification/rendered_sst2.py +3 -0
- fusion_bench/tasks/clip_classification/resisc45.py +68 -0
- fusion_bench/tasks/clip_classification/stanford_cars.py +209 -0
- fusion_bench/tasks/clip_classification/stl10.py +17 -0
- fusion_bench/tasks/clip_classification/sun397.py +404 -0
- fusion_bench/tasks/clip_classification/svhn.py +5 -0
- fusion_bench/tasks/clip_classification/tiny_imagenet.py +208 -0
- fusion_bench/tasks/flan_t5_text_generation/__init__.py +0 -0
- fusion_bench/tasks/flan_t5_text_generation/datasets_preprocess.py +71 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_evaluation.py +132 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +64 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_preprocessors.py +379 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py +52 -0
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/auto.py +31 -0
- fusion_bench/utils/cache_utils.py +58 -0
- fusion_bench/utils/data.py +165 -0
- fusion_bench/utils/devices.py +231 -0
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/dtype.py +146 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/functools.py +37 -0
- fusion_bench/utils/hydra_utils.py +28 -0
- fusion_bench/utils/instantiate.py +450 -0
- fusion_bench/utils/json.py +93 -0
- fusion_bench/utils/lazy_imports.py +74 -0
- fusion_bench/utils/misc.py +18 -0
- fusion_bench/utils/packages.py +84 -0
- fusion_bench/utils/parameters.py +323 -0
- fusion_bench/utils/path.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/pylogger.py +55 -0
- fusion_bench/utils/rich_utils.py +201 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/state_dict_arithmetic.py +297 -0
- fusion_bench/utils/strenum/__init__.py +326 -0
- fusion_bench/utils/strenum/_name_mangler.py +127 -0
- fusion_bench/utils/strenum/_version.py +556 -0
- fusion_bench/utils/tensorboard.py +51 -0
- fusion_bench/utils/timer.py +49 -0
- fusion_bench/utils/type.py +34 -0
- fusion_bench-0.2.9.dist-info/LICENSE +21 -0
- fusion_bench-0.2.9.dist-info/METADATA +258 -0
- fusion_bench-0.2.9.dist-info/RECORD +727 -0
- fusion_bench-0.2.9.dist-info/WHEEL +5 -0
- fusion_bench-0.2.9.dist-info/entry_points.txt +3 -0
- fusion_bench-0.2.9.dist-info/top_level.txt +1 -0
- fusion_bench_config/README.md +12 -0
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +23 -0
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/test/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/test/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/dtd.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/eurosat.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/gtsrb.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/resisc45.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stanford-cars.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/sun397.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/svhn.yaml +6 -0
- fusion_bench_config/dataset/image_classification/train/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/image_classification/train/tiny-imagenet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/val/dtd.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/eurosat.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/gtsrb.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/mnist.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/resisc45.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/stanford-cars.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/sun397.yaml +10 -0
- fusion_bench_config/dataset/image_classification/val/svhn.yaml +12 -0
- fusion_bench_config/dataset/image_classification/val/the_eight_tasks.yaml +9 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/dataset/question_answering/search_qa.yaml +6 -0
- fusion_bench_config/dataset/question_answering/test/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/train/MetaMathQA.yaml +4 -0
- fusion_bench_config/dataset/question_answering/train/search_qa.yaml +7 -0
- fusion_bench_config/dataset/question_answering/val/search_qa.yaml +7 -0
- fusion_bench_config/dataset/summarization/test/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/train/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/val/xsum.yaml +4 -0
- fusion_bench_config/dataset/summarization/xsum.yaml +3 -0
- fusion_bench_config/dataset/text_generation/test/gsm-hard.yaml +4 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/test/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/dataset/text_generation/train/CodeAlpaca-20k.yaml +4 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k.yaml +5 -0
- fusion_bench_config/dataset/text_generation/train/gsm8k_question_label.yaml +3 -0
- fusion_bench_config/fabric/auto.yaml +16 -0
- fusion_bench_config/fabric/llama_ddp.yaml +18 -0
- fusion_bench_config/fabric/llama_fsdp.yaml +16 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +11 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_fsdp.yaml +8 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +20 -0
- fusion_bench_config/hydra/default.yaml +8 -0
- fusion_bench_config/hydra/help/fusion_bench_help.yaml +47 -0
- fusion_bench_config/hydra/job_logging/rich_logging.yaml +20 -0
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/llama_magnitude_pruning.yaml +16 -0
- fusion_bench_config/llama_model_fusion.yaml +17 -0
- fusion_bench_config/method/ada_svd/clip_vision.yaml +9 -0
- fusion_bench_config/method/adamerging/clip.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +23 -0
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +23 -0
- fusion_bench_config/method/adamerging/llama_sft.yaml +33 -0
- fusion_bench_config/method/adamerging.yaml +23 -0
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +6 -0
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +6 -0
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +26 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml +25 -0
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +27 -0
- fusion_bench_config/method/dare/simple_average.yaml +5 -0
- fusion_bench_config/method/dare/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +32 -0
- fusion_bench_config/method/depth_upscaling.yaml +5 -0
- fusion_bench_config/method/dummy.yaml +1 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +2 -0
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +6 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +13 -0
- fusion_bench_config/method/fisher_merging/fisher_merging.yaml +9 -0
- fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml +12 -0
- fusion_bench_config/method/linear/expo.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +3 -0
- fusion_bench_config/method/linear/llama_expo.yaml +19 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +19 -0
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +4 -0
- fusion_bench_config/method/linear/weighted_average.yaml +6 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +12 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +47 -0
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +63 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +4 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +7 -0
- fusion_bench_config/method/model_recombination.yaml +4 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/pruning/llama_magnitude_pruning.yaml +14 -0
- fusion_bench_config/method/pruning/llama_random_pruning.yaml +9 -0
- fusion_bench_config/method/pruning/llama_wanda_pruning.yaml +16 -0
- fusion_bench_config/method/pruning/magnitude_diff_pruning.yaml +5 -0
- fusion_bench_config/method/pwe_moe_ls_for_clip.yaml +22 -0
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +11 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +12 -0
- fusion_bench_config/method/regmean/regmean.yaml +4 -0
- fusion_bench_config/method/simple_average.yaml +1 -0
- fusion_bench_config/method/slerp/slerp.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +8 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +10 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +14 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +19 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_arithmetic.yaml +2 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/method/ties_merging.yaml +8 -0
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +7 -0
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +39 -0
- fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +11 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eight_tasks.yaml +10 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -0
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/model/clip-vit/generate_vit_model_config.sh +23 -0
- fusion_bench_config/model/flan-t5/flan-t5-base.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-base_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large.yaml +3 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-cola_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-mrpc_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qnli_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-qqp_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-rte_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-sst2_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/flan-t5-large_glue-stsb_lora-16.yaml +4 -0
- fusion_bench_config/model/flan-t5/generate_flan-t5.sh +38 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +12 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +53 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +19 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +14 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +3 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml +24 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +13 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +5 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +17 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml +8 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +13 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +41 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +68 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +7 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +45 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/modelpool/automodelpool.yaml +12 -0
- fusion_bench_config/modelpool/gpt-2_glue.yaml +64 -0
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +14 -0
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +6 -0
- fusion_bench_config/modelpool/nyuv2_modelpool.yaml +26 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v1.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v2.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v3.yaml +9 -0
- fusion_bench_config/modelpool/smile_mistral_exp_v4.yaml +13 -0
- fusion_bench_config/nyuv2_config.yaml +17 -0
- fusion_bench_config/nyuv2_mtl_train.yaml +32 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml +11 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +31 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_L14.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_val.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_with_control_task.yaml +12 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_clean.yaml +24 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +22 -0
- fusion_bench_config/taskpool/dummy.yaml +2 -0
- fusion_bench_config/taskpool/flan-t5_glue_text_generation.yaml +44 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +39 -0
- fusion_bench_config/taskpool/nyuv2_taskpool.yaml +9 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from omegaconf import DictConfig, flag_override
|
|
6
|
+
from peft import PeftModel
|
|
7
|
+
from transformers import AutoModelForSeq2SeqLM
|
|
8
|
+
|
|
9
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
10
|
+
from fusion_bench.utils import parse_dtype
|
|
11
|
+
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_lora_model(
|
|
16
|
+
base_model_path: str,
|
|
17
|
+
peft_model_path: str,
|
|
18
|
+
is_trainable: bool = True,
|
|
19
|
+
merge_and_unload: bool = True,
|
|
20
|
+
):
|
|
21
|
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_path)
|
|
22
|
+
model = PeftModel.from_pretrained(
|
|
23
|
+
base_model,
|
|
24
|
+
peft_model_path,
|
|
25
|
+
is_trainable=is_trainable,
|
|
26
|
+
)
|
|
27
|
+
if merge_and_unload:
|
|
28
|
+
model = model.merge_and_unload()
|
|
29
|
+
return model
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Seq2SeqLMPool(BaseModelPool):
|
|
33
|
+
_config_mapping = BaseModelPool._config_mapping | {
|
|
34
|
+
"_tokenizer": "tokenizer",
|
|
35
|
+
"_model_kwargs": "model_kwargs",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
models: DictConfig,
|
|
41
|
+
*,
|
|
42
|
+
tokenizer: Optional[DictConfig],
|
|
43
|
+
model_kwargs: Optional[DictConfig] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
):
|
|
46
|
+
super().__init__(models, **kwargs)
|
|
47
|
+
self._tokenizer = tokenizer
|
|
48
|
+
self._model_kwargs = model_kwargs
|
|
49
|
+
if self._model_kwargs is None:
|
|
50
|
+
self._model_kwargs = DictConfig({})
|
|
51
|
+
with flag_override(self._model_kwargs, "allow_objects", True):
|
|
52
|
+
if hasattr(self._model_kwargs, "torch_dtype"):
|
|
53
|
+
self._model_kwargs.torch_dtype = parse_dtype(
|
|
54
|
+
self._model_kwargs.torch_dtype
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def load_model(self, model_name_or_config: str | DictConfig, *args, **kwargs):
|
|
58
|
+
model_kwargs = deepcopy(self._model_kwargs)
|
|
59
|
+
model_kwargs.update(kwargs)
|
|
60
|
+
return super().load_model(model_name_or_config, *args, **model_kwargs)
|
|
61
|
+
|
|
62
|
+
def load_tokenizer(self, *args, **kwargs):
|
|
63
|
+
assert self._tokenizer is not None, "Tokenizer is not defined in the config"
|
|
64
|
+
tokenizer = isinstance(self._tokenizer, *args, **kwargs)
|
|
65
|
+
return tokenizer
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from transformers import AutoModelForSequenceClassification
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def create_reward_model_from_pretrained(pretrained_model_name_or_path: str, **kwargs):
|
|
5
|
+
"""
|
|
6
|
+
Create a reward model for reward modeling (RLHF).
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
pretrained_model_name_or_path (str): The name or path of the pretrained model.
|
|
10
|
+
**kwargs: Additional keyword arguments passed to the model class.
|
|
11
|
+
"""
|
|
12
|
+
model = AutoModelForSequenceClassification.from_pretrained(
|
|
13
|
+
pretrained_model_name_or_path, num_labels=1, **kwargs
|
|
14
|
+
)
|
|
15
|
+
return model
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union, cast # noqa: F401
|
|
5
|
+
|
|
6
|
+
from omegaconf import DictConfig, flag_override
|
|
7
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.utils import instantiate
|
|
12
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from transformers import LlamaForSequenceClassification
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SeqenceClassificationModelPool(BaseModelPool):
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
models,
|
|
25
|
+
*,
|
|
26
|
+
tokenizer: Optional[DictConfig],
|
|
27
|
+
model_kwargs: Optional[DictConfig] = None,
|
|
28
|
+
**kwargs,
|
|
29
|
+
):
|
|
30
|
+
super().__init__(models, **kwargs)
|
|
31
|
+
# process `model_kwargs`
|
|
32
|
+
self._tokenizer = tokenizer
|
|
33
|
+
self._model_kwargs = model_kwargs
|
|
34
|
+
if self._model_kwargs is None:
|
|
35
|
+
self._model_kwargs = DictConfig({})
|
|
36
|
+
with flag_override(self._model_kwargs, "allow_objects", True):
|
|
37
|
+
if hasattr(self._model_kwargs, "torch_dtype"):
|
|
38
|
+
self._model_kwargs.torch_dtype = parse_dtype(
|
|
39
|
+
self._model_kwargs.torch_dtype
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def load_model(
|
|
44
|
+
self,
|
|
45
|
+
model_name_or_config: str | DictConfig,
|
|
46
|
+
*args,
|
|
47
|
+
**kwargs,
|
|
48
|
+
) -> Union[PreTrainedModel, "LlamaForSequenceClassification"]:
|
|
49
|
+
model_kwargs = deepcopy(self._model_kwargs)
|
|
50
|
+
model_kwargs.update(kwargs)
|
|
51
|
+
if isinstance(model_name_or_config, str):
|
|
52
|
+
log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
53
|
+
return super().load_model(model_name_or_config, *args, **model_kwargs)
|
|
54
|
+
|
|
55
|
+
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
|
|
56
|
+
assert self._tokenizer is not None, "Tokenizer is not defined in the config"
|
|
57
|
+
log.info("Loading tokenizer.", stacklevel=2)
|
|
58
|
+
tokenizer = instantiate(self._tokenizer, *args, **kwargs)
|
|
59
|
+
return tokenizer
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def save_model(
|
|
63
|
+
self,
|
|
64
|
+
model: PreTrainedModel,
|
|
65
|
+
path: str,
|
|
66
|
+
push_to_hub: bool = False,
|
|
67
|
+
model_dtype: Optional[str] = None,
|
|
68
|
+
save_tokenizer: bool = False,
|
|
69
|
+
tokenizer_kwargs=None,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Save the model to the specified path.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
model (PreTrainedModel): The model to be saved.
|
|
77
|
+
path (str): The path where the model will be saved.
|
|
78
|
+
push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
|
|
79
|
+
save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
|
|
80
|
+
**kwargs: Additional keyword arguments passed to the `save_pretrained` method.
|
|
81
|
+
"""
|
|
82
|
+
path = os.path.expanduser(path)
|
|
83
|
+
if save_tokenizer:
|
|
84
|
+
if tokenizer_kwargs is None:
|
|
85
|
+
tokenizer_kwargs = {}
|
|
86
|
+
# load the tokenizer
|
|
87
|
+
tokenizer = self.load_tokenizer(**tokenizer_kwargs)
|
|
88
|
+
tokenizer.save_pretrained(
|
|
89
|
+
path,
|
|
90
|
+
push_to_hub=push_to_hub,
|
|
91
|
+
)
|
|
92
|
+
if model_dtype is not None:
|
|
93
|
+
model.to(dtype=parse_dtype(model_dtype))
|
|
94
|
+
model.save_pretrained(
|
|
95
|
+
path,
|
|
96
|
+
push_to_hub=push_to_hub,
|
|
97
|
+
**kwargs,
|
|
98
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .load_tokenizer import chat_template_mapping, load_tokenizer_with_chat_template
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
CHAT_TEMPLATE = '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\'+ message[\'content\'] | trim + \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n'
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer
|
|
4
|
+
|
|
5
|
+
from .llama_3_Instruct import CHAT_TEMPLATE as LLAMA_3_INSTRUCT_CHAT_TEMPLATE
|
|
6
|
+
|
|
7
|
+
chat_template_mapping = {"llama_3_instruct": LLAMA_3_INSTRUCT_CHAT_TEMPLATE}
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_tokenizer_with_chat_template(
|
|
13
|
+
pretrained_model_name_or_path: str,
|
|
14
|
+
model_family: str,
|
|
15
|
+
overwrite_chat_template: bool = True,
|
|
16
|
+
**kwargs,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Load the tokenizer for Llama 3 model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
pretrained_model_name_or_path (str): The name or path of the pretrained model.
|
|
23
|
+
model_family (str): The model family.
|
|
24
|
+
**kwargs: Additional keyword arguments passed to the tokenizer class.
|
|
25
|
+
"""
|
|
26
|
+
assert (
|
|
27
|
+
model_family in chat_template_mapping
|
|
28
|
+
), f"Model family {model_family} not found. Available model families: {chat_template_mapping.keys()}"
|
|
29
|
+
|
|
30
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
31
|
+
pretrained_model_name_or_path,
|
|
32
|
+
**kwargs,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if tokenizer.chat_template is None:
|
|
36
|
+
tokenizer.chat_template = chat_template_mapping[model_family]
|
|
37
|
+
else:
|
|
38
|
+
if overwrite_chat_template:
|
|
39
|
+
log.warning("Overwriting the chat template with the default chat template.")
|
|
40
|
+
tokenizer.chat_template = chat_template_mapping[model_family]
|
|
41
|
+
else:
|
|
42
|
+
log.warning("Chat template already exists. Skipping overwriting.")
|
|
43
|
+
return tokenizer
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING, Callable, Iterable, List # noqa: F401
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
from transformers import CLIPModel, CLIPProcessor
|
|
7
|
+
from transformers.models.clip.modeling_clip import BaseModelOutputWithPooling
|
|
8
|
+
|
|
9
|
+
from fusion_bench.utils.devices import get_device
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
default_templates = [
|
|
17
|
+
lambda c: f"a photo of a {c}",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HFCLIPClassifier(nn.Module):
|
|
22
|
+
"""
|
|
23
|
+
A classifier based on the CLIP (Contrastive Language-Image Pre-training) model.
|
|
24
|
+
|
|
25
|
+
This class wraps a CLIP model and provides functionality for image classification
|
|
26
|
+
using zero-shot learning. It allows setting a classification task with custom
|
|
27
|
+
class names and text templates.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
clip_model (CLIPModel): The underlying CLIP model.
|
|
31
|
+
processor (CLIPProcessor): The CLIP processor for preparing inputs.
|
|
32
|
+
zeroshot_weights (Tensor): Computed text embeddings for zero-shot classification.
|
|
33
|
+
classnames (List[str]): List of class names for the current classification task.
|
|
34
|
+
templates (List[Callable[[str], str]]): List of template functions for generating text prompts.
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
clip_model: CLIPModel,
|
|
41
|
+
processor: CLIPProcessor,
|
|
42
|
+
extra_module=None,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Initialize the HFCLIPClassifier.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
clip_model (CLIPModel): The CLIP model to use for classification.
|
|
49
|
+
processor (CLIPProcessor): The CLIP processor for preparing inputs.
|
|
50
|
+
"""
|
|
51
|
+
super().__init__()
|
|
52
|
+
# we only fine-tune the vision model
|
|
53
|
+
clip_model.visual_projection.requires_grad_(False)
|
|
54
|
+
clip_model.text_model.requires_grad_(False)
|
|
55
|
+
clip_model.text_projection.requires_grad_(False)
|
|
56
|
+
clip_model.logit_scale.requires_grad_(False)
|
|
57
|
+
|
|
58
|
+
self.clip_model = clip_model
|
|
59
|
+
self.processor = processor
|
|
60
|
+
self.register_buffer(
|
|
61
|
+
"zeroshot_weights",
|
|
62
|
+
None,
|
|
63
|
+
persistent=False,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.extra_module = extra_module
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def text_model(self):
|
|
70
|
+
"""Get the text model component of CLIP."""
|
|
71
|
+
return self.clip_model.text_model
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def vision_model(self):
|
|
75
|
+
"""Get the vision model component of CLIP."""
|
|
76
|
+
return self.clip_model.vision_model
|
|
77
|
+
|
|
78
|
+
def set_classification_task(
|
|
79
|
+
self,
|
|
80
|
+
classnames: List[str],
|
|
81
|
+
templates: List[Callable[[str], str]] = default_templates,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Set up the zero-shot classification task.
|
|
85
|
+
|
|
86
|
+
This method computes text embeddings for the given class names using the
|
|
87
|
+
provided templates. These embeddings are then used for classification.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
classnames (List[str]): List of class names for the classification task.
|
|
91
|
+
templates (List[Callable[[str], str]], optional): List of template functions
|
|
92
|
+
for generating text prompts. Defaults to `default_templates`, i.e.
|
|
93
|
+
["a photo of a {classname}"].
|
|
94
|
+
"""
|
|
95
|
+
processor = self.processor
|
|
96
|
+
|
|
97
|
+
self.classnames = classnames
|
|
98
|
+
self.templates = templates
|
|
99
|
+
|
|
100
|
+
with torch.no_grad():
|
|
101
|
+
zeroshot_weights = []
|
|
102
|
+
for classname in classnames:
|
|
103
|
+
text = [template(classname) for template in templates]
|
|
104
|
+
inputs = processor(text=text, return_tensors="pt", padding=True)
|
|
105
|
+
inputs = {
|
|
106
|
+
k: v.to(get_device(self.text_model)) for k, v in inputs.items()
|
|
107
|
+
}
|
|
108
|
+
embeddings = self.text_model(**inputs)[1]
|
|
109
|
+
embeddings = self.clip_model.text_projection(embeddings)
|
|
110
|
+
|
|
111
|
+
# normalize embeddings
|
|
112
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
|
|
113
|
+
|
|
114
|
+
embeddings = embeddings.mean(dim=0)
|
|
115
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
|
|
116
|
+
|
|
117
|
+
zeroshot_weights.append(embeddings)
|
|
118
|
+
|
|
119
|
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=0)
|
|
120
|
+
|
|
121
|
+
self.zeroshot_weights = zeroshot_weights
|
|
122
|
+
|
|
123
|
+
def forward(
|
|
124
|
+
self,
|
|
125
|
+
images: Tensor,
|
|
126
|
+
return_image_embeds=False,
|
|
127
|
+
return_dict=False,
|
|
128
|
+
task_name=None,
|
|
129
|
+
):
|
|
130
|
+
"""
|
|
131
|
+
Perform forward pass for zero-shot image classification.
|
|
132
|
+
|
|
133
|
+
This method computes image embeddings for the input images and calculates
|
|
134
|
+
the similarity with the pre-computed text embeddings to produce classification logits.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
images (Tensor): Input images to classify.
|
|
138
|
+
return_image_embeds (bool): Whether to return the image embeddings.
|
|
139
|
+
return_dict (bool): Whether to return a dictionary with logits and image embeddings.
|
|
140
|
+
task_name (Optional[str]): The name of the task.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Tensor: Classification logits for each input image.
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
ValueError: If the classification task hasn't been set using set_classification_task.
|
|
147
|
+
"""
|
|
148
|
+
if self.zeroshot_weights is None:
|
|
149
|
+
raise ValueError("Must set classification task before forward pass")
|
|
150
|
+
text_embeds = self.zeroshot_weights
|
|
151
|
+
|
|
152
|
+
image_embeds = self.get_image_features(images)
|
|
153
|
+
# normalize embeddings
|
|
154
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
155
|
+
|
|
156
|
+
if (
|
|
157
|
+
hasattr(self.vision_model, "is_surgery_model")
|
|
158
|
+
and self.vision_model.is_surgery_model
|
|
159
|
+
):
|
|
160
|
+
# Dealing with the surgery model, for more details, please refer to:
|
|
161
|
+
# (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging
|
|
162
|
+
# https://arxiv.org/abs/2402.02705
|
|
163
|
+
self.vision_model: "SurgeryModelWrapper" = self.vision_model
|
|
164
|
+
image_embeds, _, _ = self.vision_model.compute_surgery_features(
|
|
165
|
+
image_embeds, dataset_name=task_name
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# cosine similarity
|
|
169
|
+
logit_scale = self.clip_model.logit_scale.exp()
|
|
170
|
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
|
171
|
+
logits_per_image = logits_per_text.t()
|
|
172
|
+
|
|
173
|
+
if return_dict:
|
|
174
|
+
ret = {"logits": logits_per_image}
|
|
175
|
+
if return_image_embeds:
|
|
176
|
+
ret.update({"image_embeds": image_embeds})
|
|
177
|
+
return ret
|
|
178
|
+
else:
|
|
179
|
+
if return_image_embeds:
|
|
180
|
+
return logits_per_image, image_embeds
|
|
181
|
+
else:
|
|
182
|
+
return logits_per_image
|
|
183
|
+
|
|
184
|
+
def get_image_features(self, images: Tensor) -> Tensor:
|
|
185
|
+
"""
|
|
186
|
+
Compute the image embeddings.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
|
190
|
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
image_embeds = self.vision_model(images)
|
|
194
|
+
if isinstance(image_embeds, Tensor):
|
|
195
|
+
pass
|
|
196
|
+
elif isinstance(image_embeds, BaseModelOutputWithPooling):
|
|
197
|
+
image_embeds = image_embeds[1]
|
|
198
|
+
image_embeds = self.clip_model.visual_projection(image_embeds)
|
|
199
|
+
return image_embeds
|
|
File without changes
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from torch.func import functional_call, jvp
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def dict_params_to_tuple(dict_params: dict):
|
|
13
|
+
return tuple(v for k, v in dict_params.items())
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LinearizedModelWraper(nn.Module):
|
|
17
|
+
def __init__(self, model: nn.Module, init_model: Optional[nn.Module] = None):
|
|
18
|
+
"""
|
|
19
|
+
Initializes a linearized model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model (nn.Module): The underlying PyTorch model to be linearized.
|
|
23
|
+
init_model (nn.Module): The initial PyTorch model used to compute the linearization parameters (default: None).
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.model = model
|
|
27
|
+
if init_model is None:
|
|
28
|
+
init_model = model
|
|
29
|
+
assert not hasattr(self, "params0")
|
|
30
|
+
params0 = deepcopy([(k, v.detach()) for k, v in init_model.named_parameters()])
|
|
31
|
+
self.params0_keys = [k for k, v in params0]
|
|
32
|
+
self.params0_values = nn.ParameterList([v for k, v in params0])
|
|
33
|
+
for p in self.params0_values:
|
|
34
|
+
p.requires_grad_(False)
|
|
35
|
+
|
|
36
|
+
def tuple_params_to_dict(self, tuple_params):
|
|
37
|
+
"""
|
|
38
|
+
Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tuple_params (Tuple[Tensor, ...]): A tuple of parameters.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dict[str, Tensor]: A dictionary with keys corresponding to the parameter names and values corresponding to the
|
|
45
|
+
parameter values.
|
|
46
|
+
"""
|
|
47
|
+
assert len(tuple_params) == len(self.params0_keys)
|
|
48
|
+
state_dict = {}
|
|
49
|
+
for k, p in zip(self.params0_keys, tuple_params):
|
|
50
|
+
state_dict[k] = p
|
|
51
|
+
return state_dict
|
|
52
|
+
|
|
53
|
+
def forward(self, *args, **kwargs):
|
|
54
|
+
"""
|
|
55
|
+
Computes the linearized model output using a first-order Taylor decomposition.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
*args: Positional arguments to be passed to the model.
|
|
59
|
+
**kwargs: Keyword arguments to be passed to the model.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
torch.Tensor: The output of the linearized model, computed using a first-order Taylor decomposition.
|
|
63
|
+
"""
|
|
64
|
+
params0 = tuple(self.params0_values)
|
|
65
|
+
params = dict_params_to_tuple(OrderedDict(self.model.named_parameters()))
|
|
66
|
+
dparams = tuple(p - p0 for p, p0 in zip(params, params0))
|
|
67
|
+
out, dp = jvp(
|
|
68
|
+
lambda *param: functional_call(
|
|
69
|
+
self.model, self.tuple_params_to_dict(param), args, kwargs
|
|
70
|
+
),
|
|
71
|
+
params0,
|
|
72
|
+
dparams,
|
|
73
|
+
)
|
|
74
|
+
return out + dp
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def unload_linearized_modules_(module: nn.Module):
|
|
78
|
+
"""
|
|
79
|
+
Unloads the linearized module and returns the original module.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
module (nn.Module): The linearized module to be unloaded.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
nn.Module: The original module.
|
|
86
|
+
"""
|
|
87
|
+
for name, model in module.named_children():
|
|
88
|
+
if isinstance(model, LinearizedModelWraper):
|
|
89
|
+
setattr(module, name, model.model)
|
|
90
|
+
else:
|
|
91
|
+
LinearizedModelWraper.unload_linearized_modules_(model)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Tuple, Union
|
|
3
|
+
|
|
4
|
+
from huggingface_hub import hf_hub_download
|
|
5
|
+
from peft import LoraConfig, PeftModel, get_peft_model
|
|
6
|
+
from peft.tuners.lora import LoraLayer
|
|
7
|
+
from safetensors.torch import load_file
|
|
8
|
+
from transformers import CLIPVisionModel
|
|
9
|
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
10
|
+
|
|
11
|
+
from .linearized_model_utils import LinearizedModelWraper
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_file_path(peft_name, filename):
|
|
15
|
+
if os.path.isdir(peft_name):
|
|
16
|
+
# If peft_name is a local directory path
|
|
17
|
+
return os.path.join(peft_name, filename)
|
|
18
|
+
else:
|
|
19
|
+
# If peft_name is a Hugging Face model name
|
|
20
|
+
return hf_hub_download(peft_name, filename)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _get_submodules(model, key) -> Tuple:
|
|
24
|
+
"""
|
|
25
|
+
Retrieves the parent module, target module, and target module name for a given key in a PyTorch model.
|
|
26
|
+
"""
|
|
27
|
+
parent = model.get_submodule(".".join(key.split(".")[:-1]))
|
|
28
|
+
target_name = key.split(".")[-1]
|
|
29
|
+
target = model.get_submodule(key)
|
|
30
|
+
return parent, target, target_name
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def linearize_lora_model_(model):
|
|
34
|
+
"""
|
|
35
|
+
Linearizes the LoraLayer modules in a PyTorch model according to the PETA paper.
|
|
36
|
+
"""
|
|
37
|
+
for key, module in model.named_modules():
|
|
38
|
+
# if isinstance(module, LoraLayer) and isinstance(module, nn.Linear):
|
|
39
|
+
if isinstance(module, LoraLayer):
|
|
40
|
+
# print("L-LoRA MODULE : ", module)
|
|
41
|
+
parent, target, target_name = _get_submodules(model, key)
|
|
42
|
+
setattr(parent, target_name, LinearizedModelWraper(target))
|
|
43
|
+
# print("Linearized Lora Layer")
|
|
44
|
+
return model
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def load_fft_vision_model_hf(
|
|
48
|
+
model_name: str, return_vison_model=True
|
|
49
|
+
) -> Union[CLIPVisionTransformer, CLIPVisionModel]:
|
|
50
|
+
"""
|
|
51
|
+
Load a CLIP vision model from Hugging Face.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_name (str): The name of the CLIP vision model to load from Hugging Face.
|
|
55
|
+
return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Union[CLIPVisionTransformer, CLIPVisionModel]: The vision model.
|
|
59
|
+
"""
|
|
60
|
+
model = CLIPVisionModel.from_pretrained(model_name)
|
|
61
|
+
|
|
62
|
+
if return_vison_model:
|
|
63
|
+
return CLIPVisionModel.from_pretrained(model_name).vision_model
|
|
64
|
+
else:
|
|
65
|
+
return model
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def load_lora_vision_model_hf(
|
|
69
|
+
base_model_name: str,
|
|
70
|
+
peft_name: str,
|
|
71
|
+
merge_and_unload: bool = False,
|
|
72
|
+
return_vison_model=True,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
|
|
76
|
+
|
|
77
|
+
This function loads a vision model and applies a LoRA adaptation to it. The model can be optionally merged and unloaded.
|
|
78
|
+
|
|
79
|
+
Parameters:
|
|
80
|
+
base_model_name (str): The name of the base vision model to load from Hugging Face.
|
|
81
|
+
peft_name (str): The name of the LoRA adaptation to apply to the base model.
|
|
82
|
+
merge_and_unload (bool, optional): If True, the LoRA adaptation is merged into the base model and the LoRA layers are removed. Defaults to False.
|
|
83
|
+
return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
PeftModel: The adapted vision model, optionally merged and unloaded.
|
|
87
|
+
"""
|
|
88
|
+
model = CLIPVisionModel.from_pretrained(base_model_name)
|
|
89
|
+
|
|
90
|
+
# Load the Peft model
|
|
91
|
+
# note that we apply lora on type `CLIPVisionTransformer` instead of `CLIPVisionModel`
|
|
92
|
+
vision_model = model.vision_model
|
|
93
|
+
peft_model = PeftModel.from_pretrained(vision_model, peft_name, is_trainable=True)
|
|
94
|
+
if merge_and_unload:
|
|
95
|
+
vision_model = peft_model.merge_and_unload()
|
|
96
|
+
else:
|
|
97
|
+
vision_model = peft_model
|
|
98
|
+
|
|
99
|
+
# Return the vision model
|
|
100
|
+
if return_vison_model:
|
|
101
|
+
return vision_model
|
|
102
|
+
else:
|
|
103
|
+
model.vision_model = vision_model
|
|
104
|
+
return model
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_l_lora_vision_model_hf(base_model_name: str, peft_name: str):
|
|
108
|
+
"""
|
|
109
|
+
Load a linearized L-LoRA model from a base model and a Peft model (HuggingFace).
|
|
110
|
+
"""
|
|
111
|
+
base_model = CLIPVisionModel.from_pretrained(base_model_name).vision_model
|
|
112
|
+
peft_config = LoraConfig.from_pretrained(peft_name)
|
|
113
|
+
peft_config.inference_mode = False # This is important, make the model trainable
|
|
114
|
+
model = get_peft_model(base_model, peft_config)
|
|
115
|
+
linearize_lora_model_(model)
|
|
116
|
+
for filename in ["linearized_adapter_model.safetensors"]:
|
|
117
|
+
path = get_file_path(peft_name, filename)
|
|
118
|
+
state_dict = load_file(path)
|
|
119
|
+
for name, param in state_dict.items():
|
|
120
|
+
model.get_parameter(name).data = param
|
|
121
|
+
|
|
122
|
+
return model
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 the LlamaFactory team.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .model_utils.misc import find_all_linear_modules
|
|
16
|
+
from .tokenizer_loader import load_config, load_tokenizer
|
|
File without changes
|