fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +3 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +64 -11
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +12 -1
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +12 -5
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +14 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -11,7 +11,7 @@ _target_: fusion_bench.programs.FabricModelFusionProgram
|
|
|
11
11
|
_recursive_: false
|
|
12
12
|
fast_dev_run: false # Run a single batch of data to test the model or method
|
|
13
13
|
# Run the script without actually running the experiment, use with `print_config=true`.
|
|
14
|
-
# You can also use `--cfg` or `-c` to
|
|
14
|
+
# You can also use `--cfg` or `-c` to show the configuration instead of running.
|
|
15
15
|
dry_run: false
|
|
16
16
|
print_config: true # Print the configuration to the console
|
|
17
17
|
merged_model_save_path: null # path to save the merged model, use "{log_dir}" to refer to the logger directory, for example `merged_model_save_path=\{log_dir\}/merged_model`
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
# Image Classification Dataset Configurations
|
|
2
|
+
|
|
3
|
+
This folder contains the dataset configuration for image classification tasks.
|
|
4
|
+
|
|
5
|
+
- Each dataset should have 'image' and 'label' columns.
|
|
6
|
+
- If a dataset has no test split, we will use the validation split as the test split and create the validation set from the training set.
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# The 14 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# eight tasks in the task arithmetic paper
|
|
6
|
+
- sun397
|
|
7
|
+
- stanford-cars
|
|
8
|
+
- resisc45
|
|
9
|
+
- eurosat
|
|
10
|
+
- svhn
|
|
11
|
+
- gtsrb
|
|
12
|
+
- mnist
|
|
13
|
+
- dtd
|
|
14
|
+
# additional 6 tasks in the TALL mask paper
|
|
15
|
+
- oxford_flowers102
|
|
16
|
+
- pcam
|
|
17
|
+
- fer2013
|
|
18
|
+
- oxford-iiit-pet
|
|
19
|
+
- stl10
|
|
20
|
+
- cifar100
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# eight tasks in the task arithmetic paper
|
|
6
|
+
- sun397
|
|
7
|
+
- stanford-cars
|
|
8
|
+
- resisc45
|
|
9
|
+
- eurosat
|
|
10
|
+
- svhn
|
|
11
|
+
- gtsrb
|
|
12
|
+
- mnist
|
|
13
|
+
- dtd
|
|
14
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
15
|
+
- oxford_flowers102
|
|
16
|
+
- pcam
|
|
17
|
+
- fer2013
|
|
18
|
+
- oxford-iiit-pet
|
|
19
|
+
- stl10
|
|
20
|
+
- cifar100
|
|
21
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
22
|
+
- cifar10
|
|
23
|
+
- food101
|
|
24
|
+
- fashion_mnist
|
|
25
|
+
- emnist_letters
|
|
26
|
+
- kmnist
|
|
27
|
+
- rendered-sst2
|
|
28
|
+
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# The 14 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# eight tasks in the task arithmetic paper
|
|
6
|
+
- sun397
|
|
7
|
+
- stanford-cars
|
|
8
|
+
- resisc45
|
|
9
|
+
- eurosat
|
|
10
|
+
- svhn
|
|
11
|
+
- gtsrb
|
|
12
|
+
- mnist
|
|
13
|
+
- dtd
|
|
14
|
+
# additional 6 tasks in the TALL mask paper
|
|
15
|
+
- oxford_flowers102
|
|
16
|
+
- pcam
|
|
17
|
+
- fer2013
|
|
18
|
+
- oxford-iiit-pet
|
|
19
|
+
- stl10
|
|
20
|
+
- cifar100
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# eight tasks in the task arithmetic paper
|
|
6
|
+
- sun397
|
|
7
|
+
- stanford-cars
|
|
8
|
+
- resisc45
|
|
9
|
+
- eurosat
|
|
10
|
+
- svhn
|
|
11
|
+
- gtsrb
|
|
12
|
+
- mnist
|
|
13
|
+
- dtd
|
|
14
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
15
|
+
- oxford_flowers102
|
|
16
|
+
- pcam
|
|
17
|
+
- fer2013
|
|
18
|
+
- oxford-iiit-pet
|
|
19
|
+
- stl10
|
|
20
|
+
- cifar100
|
|
21
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
22
|
+
- cifar10
|
|
23
|
+
- food101
|
|
24
|
+
- fashion_mnist
|
|
25
|
+
- emnist_letters
|
|
26
|
+
- kmnist
|
|
27
|
+
- rendered-sst2
|
|
28
|
+
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- loggers: tensorboard_logger
|
|
3
|
+
- strategy: llama_peft_fsdp
|
|
4
|
+
- _self_
|
|
5
|
+
|
|
6
|
+
_target_: lightning.Fabric
|
|
7
|
+
_recursive_: true
|
|
8
|
+
# Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
|
|
9
|
+
# The value applies per node.
|
|
10
|
+
devices: auto
|
|
11
|
+
# The hardware to run on. Possible choices are:
|
|
12
|
+
# ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
|
13
|
+
# for example: fabric.accelerator=cpu
|
|
14
|
+
accelerator: auto
|
|
15
|
+
# reference to the precision policy: https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision
|
|
16
|
+
precision: bf16-true
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# https://lightning.ai/docs/fabric/2.4.0/api/generated/lightning.fabric.strategies.DeepSpeedStrategy.html#deepspeedstrategy
|
|
2
|
+
_target_: lightning.fabric.strategies.DeepSpeedStrategy
|
|
3
|
+
|
|
4
|
+
accelerator: null
|
|
5
|
+
zero_optimization: true
|
|
6
|
+
stage: 2
|
|
7
|
+
offload_optimizer: false
|
|
8
|
+
offload_parameters: false
|
|
9
|
+
offload_params_device: "cpu"
|
|
10
|
+
offload_optimizer_device: "cpu"
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
_target_: lightning.fabric.strategies.FSDPStrategy
|
|
2
|
+
sharding_strategy: FULL_SHARD
|
|
3
|
+
state_dict_type: full # Save a single, consolidated checkpoint file
|
|
4
|
+
cpu_offload: false
|
|
5
|
+
auto_wrap_policy:
|
|
6
|
+
_target_: fusion_bench.mixins.lightning_fabric.get_size_based_auto_wrap_policy
|
|
7
|
+
activation_checkpointing_policy: ${.auto_wrap_policy}
|
|
8
|
+
# limit_all_gathers: true
|
|
9
|
+
|
|
@@ -11,7 +11,7 @@ _target_: fusion_bench.programs.FabricModelFusionProgram
|
|
|
11
11
|
_recursive_: false
|
|
12
12
|
fast_dev_run: false # Run a single batch of data to test the model or method
|
|
13
13
|
# Run the script without actually running the experiment, use with `print_config=true`.
|
|
14
|
-
# You can also use `--cfg` or `-c` to
|
|
14
|
+
# You can also use `--cfg` or `-c` to show the configuration instead of running.
|
|
15
15
|
dry_run: false
|
|
16
16
|
print_config: true # Print the configuration to the console
|
|
17
17
|
merged_model_save_path: null # path to save the merged model, use "{log_dir}" to refer to the logger directory, for example `merged_model_save_path=\{log_dir\}/merged_model`
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- hydra: default
|
|
3
|
+
- fabric: llama_fsdp
|
|
4
|
+
# --- Model, Method, Task ---
|
|
5
|
+
- method: lm_finetune/fullfinetune_sft.yaml
|
|
6
|
+
- modelpool: CausalLMPool/llama_alpaca_cleaned.yaml
|
|
7
|
+
- taskpool: dummy
|
|
8
|
+
- _self_
|
|
9
|
+
|
|
10
|
+
_target_: fusion_bench.programs.FabricModelFusionProgram
|
|
11
|
+
_recursive_: false
|
|
12
|
+
|
|
13
|
+
fast_dev_run: false # Run a single batch of data to test the model or method
|
|
14
|
+
# Run the script without actually running the experiment, use with `print_config=true`.
|
|
15
|
+
# You can also use `--cfg` or `-c` to show the configuration instead of running.
|
|
16
|
+
dry_run: false
|
|
17
|
+
print_config: true # Print the configuration to the console
|
|
18
|
+
report_save_path: null # path to save the result report
|
|
19
|
+
print_function_call: true # set to false if you don't want to print the details of instantiate calls
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
_target_: fusion_bench.method.BradleyTerryRewardModeling
|
|
2
|
+
_recursive_: False
|
|
3
|
+
|
|
4
|
+
optimizer:
|
|
5
|
+
_target_: torch.optim.AdamW
|
|
6
|
+
lr: 1e-5
|
|
7
|
+
weight_decay: 0.01
|
|
8
|
+
fused: null
|
|
9
|
+
|
|
10
|
+
lr_scheduler:
|
|
11
|
+
_target_: fusion_bench.optim.lr_scheduler.CosineDecayWithWarmup
|
|
12
|
+
T_max: _T_max_ # this will be replaced by the expected number of training steps
|
|
13
|
+
init_lr: 0
|
|
14
|
+
warmup_steps: 100
|
|
15
|
+
max_lr: ${..optimizer.lr}
|
|
16
|
+
min_lr: 1e-6
|
|
17
|
+
|
|
18
|
+
dataloader_kwargs:
|
|
19
|
+
# per-gpu batch size
|
|
20
|
+
batch_size: 1
|
|
21
|
+
num_workers: 0
|
|
22
|
+
pin_memory: True
|
|
23
|
+
|
|
24
|
+
# Training hyperparameters
|
|
25
|
+
# if max_epochs=-1, max_steps will be used to determine the number of training steps
|
|
26
|
+
max_epochs: 3
|
|
27
|
+
max_steps: -1
|
|
28
|
+
max_steps_per_epoch: -1
|
|
29
|
+
accumulate_grad_batches: 1
|
|
30
|
+
lr_scheduler_interval: step
|
|
31
|
+
lr_scheduler_frequency: 1
|
|
32
|
+
# Checkpointing may be done by epoch or step, and at the end of training
|
|
33
|
+
# `checkpoint_save_interval` can be 'epoch' or 'step'
|
|
34
|
+
checkpoint_save_interval: epoch
|
|
35
|
+
checkpoint_save_frequency: 1
|
|
36
|
+
# Whether to use gradient clipping, and if so, the value and algorithm
|
|
37
|
+
gradient_clip_val: null
|
|
38
|
+
gradient_clip_algorithm: norm
|
|
39
|
+
save_optimizer_state: false
|
|
40
|
+
# save_full_model must be true when using shared FSDP
|
|
41
|
+
save_full_model: true
|
|
42
|
+
# save_ckpt_type can be 'hf' or 'lightning'
|
|
43
|
+
save_ckpt_type: lightning
|
|
44
|
+
# Path to checkpoint to load from, used for resuming training
|
|
45
|
+
ckpt_path: null
|
|
46
|
+
max_length: 4096
|
|
47
|
+
fix_token_embedding: true
|
|
@@ -3,14 +3,17 @@ _recursive_: False
|
|
|
3
3
|
|
|
4
4
|
optimizer:
|
|
5
5
|
_target_: torch.optim.AdamW
|
|
6
|
-
|
|
6
|
+
lr: 1e-5
|
|
7
7
|
weight_decay: 0.01
|
|
8
|
-
|
|
8
|
+
fused: null
|
|
9
9
|
|
|
10
10
|
lr_scheduler:
|
|
11
|
-
_target_:
|
|
11
|
+
_target_: fusion_bench.optim.lr_scheduler.CosineDecayWithWarmup
|
|
12
12
|
T_max: _T_max_ # this will be replaced by the expected number of training steps
|
|
13
|
-
|
|
13
|
+
init_lr: 0
|
|
14
|
+
warmup_steps: 100
|
|
15
|
+
max_lr: ${..optimizer.lr}
|
|
16
|
+
min_lr: 1e-6
|
|
14
17
|
|
|
15
18
|
dataloader_kwargs:
|
|
16
19
|
# per-gpu batch size
|
|
@@ -36,5 +39,9 @@ gradient_clip_algorithm: norm
|
|
|
36
39
|
save_optimizer_state: false
|
|
37
40
|
# save_full_model must be true when using shared FSDP
|
|
38
41
|
save_full_model: true
|
|
42
|
+
# save_ckpt_type can be 'hf' or 'lightning'
|
|
43
|
+
save_ckpt_type: lightning
|
|
39
44
|
# Path to checkpoint to load from, used for resuming training
|
|
40
45
|
ckpt_path: null
|
|
46
|
+
max_length: 4096
|
|
47
|
+
fix_token_embedding: true
|
|
@@ -3,9 +3,9 @@ _recursive_: False
|
|
|
3
3
|
|
|
4
4
|
optimizer:
|
|
5
5
|
_target_: torch.optim.AdamW
|
|
6
|
-
|
|
6
|
+
lr: 1e-4
|
|
7
7
|
weight_decay: 0.01
|
|
8
|
-
|
|
8
|
+
fused: null
|
|
9
9
|
|
|
10
10
|
lr_scheduler:
|
|
11
11
|
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
|
|
@@ -56,6 +56,8 @@ gradient_clip_algorithm: norm
|
|
|
56
56
|
save_optimizer_state: false
|
|
57
57
|
# save_full_model must be true when using shared FSDP
|
|
58
58
|
save_full_model: false
|
|
59
|
+
# save_ckpt_type can be 'peft' or 'lightning'
|
|
60
|
+
save_ckpt_type: lightning
|
|
59
61
|
# Path to checkpoint to load from, used for resuming training
|
|
60
62
|
ckpt_path: null
|
|
61
63
|
max_length: 4096
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# this option can be "clip_task_wise_adamerging"
|
|
2
|
+
name: clip_layer_wise_adamerging_surgery
|
|
3
|
+
# this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
|
|
4
|
+
# if weights is specified, skip the test-time adaptation training
|
|
5
|
+
weights: null
|
|
6
|
+
# learning rate
|
|
7
|
+
optimizer: adam
|
|
8
|
+
lr: 1e-3
|
|
9
|
+
init_values: 0.3
|
|
10
|
+
# if `clamp_weights` is true, the weights will be clamped to [0, 1]
|
|
11
|
+
clamp_weights: false
|
|
12
|
+
# arguments of `functional_call`
|
|
13
|
+
tie_weights: true
|
|
14
|
+
strict: false
|
|
15
|
+
# this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
|
|
16
|
+
devices: 1
|
|
17
|
+
batch_size: 16
|
|
18
|
+
num_workers: 8
|
|
19
|
+
max_steps: 1000
|
|
20
|
+
fast_dev_run: ${fast_dev_run}
|
|
21
|
+
# the path for saving the merging weights
|
|
22
|
+
save_merging_weights: 'merging_weights.pt'
|
|
23
|
+
cache_dir: outputs
|
|
24
|
+
|
|
25
|
+
# parameters of Surgery
|
|
26
|
+
eval_iterations: 200
|
|
27
|
+
surgery_steps: 1000
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
This folder contains the configuration for the CLIP-ViT models (managed by `fusion_bench.modelpool.CLIPVisionModelPool`).
|
|
2
|
+
|
|
3
|
+
## Expected Configuration
|
|
4
|
+
|
|
5
|
+
### Detailed Configuration
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
```yaml
|
|
9
|
+
${name_of_model}:
|
|
10
|
+
_target_: ${function_to_load_model}
|
|
11
|
+
... # arguments to pass to the function
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
For example, to load the pre-trained CLIP-ViT-B/16 model, you can use the following configuration:
|
|
15
|
+
|
|
16
|
+
```yaml
|
|
17
|
+
_pretrained_: # `_pretrained_` is a special key in FusionBench that indicates the model is pre-trained
|
|
18
|
+
_target_: transformers.CLIPVisionModel.from_pretrained
|
|
19
|
+
pretrained_model_name_or_path: openai/clip-vit-base-patch16
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
In this case, calling `modelpool.load_model("_pretrained_")` will return a `transformers.CLIPVisionModel` instance, which is equivalent to call `transformers.CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")`.
|
|
23
|
+
|
|
24
|
+
The detailed configuration is more flexible and can be used when you need to pass additional arguments to the `from_pretrained` function or call custom functions to load and preprocess the model.
|
|
25
|
+
|
|
26
|
+
### Simplified Configuration
|
|
27
|
+
|
|
28
|
+
```yaml
|
|
29
|
+
${name_of_model}: ${pretrained_model_name_or_path}
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
This is a simplified configuration that is equivalent to the detailed configuration.
|
|
33
|
+
|
|
34
|
+
For example, to load the pre-trained CLIP-ViT-B/16 model, you can use the following configuration:
|
|
35
|
+
|
|
36
|
+
```yaml
|
|
37
|
+
_pretrained_: openai/clip-vit-base-patch16
|
|
38
|
+
```
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# The 14 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# pre-trained model
|
|
6
|
+
- clip-vit-base-patch16
|
|
7
|
+
# eight tasks in the task arithmetic paper
|
|
8
|
+
- clip-vit-base-patch16_sun397
|
|
9
|
+
- clip-vit-base-patch16_stanford-cars
|
|
10
|
+
- clip-vit-base-patch16_resisc45
|
|
11
|
+
- clip-vit-base-patch16_eurosat
|
|
12
|
+
- clip-vit-base-patch16_svhn
|
|
13
|
+
- clip-vit-base-patch16_gtsrb
|
|
14
|
+
- clip-vit-base-patch16_mnist
|
|
15
|
+
- clip-vit-base-patch16_dtd
|
|
16
|
+
# additional 6 tasks in the TALL mask paper
|
|
17
|
+
- clip-vit-base-patch16_oxford_flowers102
|
|
18
|
+
- clip-vit-base-patch16_pcam
|
|
19
|
+
- clip-vit-base-patch16_fer2013
|
|
20
|
+
- clip-vit-base-patch16_oxford-iiit-pet
|
|
21
|
+
- clip-vit-base-patch16_stl10
|
|
22
|
+
- clip-vit-base-patch16_cifar100
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# pre-trained model
|
|
6
|
+
- clip-vit-base-patch16
|
|
7
|
+
# eight tasks in the task arithmetic paper
|
|
8
|
+
- clip-vit-base-patch16_sun397
|
|
9
|
+
- clip-vit-base-patch16_stanford-cars
|
|
10
|
+
- clip-vit-base-patch16_resisc45
|
|
11
|
+
- clip-vit-base-patch16_eurosat
|
|
12
|
+
- clip-vit-base-patch16_svhn
|
|
13
|
+
- clip-vit-base-patch16_gtsrb
|
|
14
|
+
- clip-vit-base-patch16_mnist
|
|
15
|
+
- clip-vit-base-patch16_dtd
|
|
16
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
17
|
+
- clip-vit-base-patch16_oxford_flowers102
|
|
18
|
+
- clip-vit-base-patch16_pcam
|
|
19
|
+
- clip-vit-base-patch16_fer2013
|
|
20
|
+
- clip-vit-base-patch16_oxford-iiit-pet
|
|
21
|
+
- clip-vit-base-patch16_stl10
|
|
22
|
+
- clip-vit-base-patch16_cifar100
|
|
23
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
24
|
+
- clip-vit-base-patch16_cifar10
|
|
25
|
+
- clip-vit-base-patch16_food101
|
|
26
|
+
- clip-vit-base-patch16_fashion_mnist
|
|
27
|
+
- clip-vit-base-patch16_emnist_letters
|
|
28
|
+
- clip-vit-base-patch16_kmnist
|
|
29
|
+
- clip-vit-base-patch16_rendered-sst2
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
cifar10: tanganke/clip-vit-base-patch16_cifar10
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
cifar100: tanganke/clip-vit-base-patch16_cifar100
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
emnist_letters: tanganke/clip-vit-base-patch16_emnist_letters
|