fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +3 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +64 -11
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +12 -1
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +12 -5
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +14 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
|
|
11
|
+
tokenizer:
|
|
12
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
13
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
14
|
+
|
|
15
|
+
train_datasets:
|
|
16
|
+
alpaca-cleaned:
|
|
17
|
+
_target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset
|
|
18
|
+
tokenizer: ${...tokenizer}
|
|
19
|
+
path: "yahma/alpaca-cleaned"
|
|
20
|
+
split: train
|
|
21
|
+
cache_path: null
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
|
|
11
|
+
tokenizer:
|
|
12
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
13
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
14
|
+
|
|
15
|
+
train_datasets:
|
|
16
|
+
codealpaca:
|
|
17
|
+
_target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset
|
|
18
|
+
tokenizer: ${...tokenizer}
|
|
19
|
+
path: sahil2801/CodeAlpaca-20k
|
|
20
|
+
split: train
|
|
21
|
+
cache_path: null
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
|
|
11
|
+
tokenizer:
|
|
12
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
13
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
14
|
+
|
|
15
|
+
train_datasets:
|
|
16
|
+
metamathqa:
|
|
17
|
+
_target_: fusion_bench.dataset.llama.metamathqa.load_tokenized_metamathqa
|
|
18
|
+
tokenizer: ${...tokenizer}
|
|
19
|
+
cache_path: null
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: meta-llama/Llama-3-1B-Instruct
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
|
|
11
|
+
tokenizer:
|
|
12
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
13
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
14
|
+
|
|
15
|
+
train_datasets:
|
|
16
|
+
ultrachat-200k:
|
|
17
|
+
_target_: fusion_bench.dataset.llama.ultrachat.load_tokenized_ultrachat_200k
|
|
18
|
+
tokenizer: ${...tokenizer}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.SeqenceClassificationModelPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: fusion_bench.modelpool.seq_classification_lm.create_reward_model_from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
use_flash_attention_2: true
|
|
11
|
+
|
|
12
|
+
tokenizer:
|
|
13
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
14
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
15
|
+
pad_token: <|end_of_text|> # do not use eos token (<|eos_id|>) as padding token because it is used as the end of each content
|
|
16
|
+
|
|
17
|
+
train_datasets:
|
|
18
|
+
preference_700k:
|
|
19
|
+
_target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_rlhf
|
|
20
|
+
tokenizer: ${...tokenizer}
|
|
21
|
+
path: hendrydong/preference_700K
|
|
22
|
+
split: train
|
|
23
|
+
cache_path: null
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.SeqenceClassificationModelPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: fusion-bench/Llama-3.2-1B-Instruct_Bradly-Terry-RM_Preference-700k
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: transformers.AutoModelForSequenceClassification.from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
|
|
11
|
+
tokenizer:
|
|
12
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
13
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
14
|
+
pad_token: <|end_of_text|> # do not use eos token (<|eos_id|>) as padding token because it is used as the end of each content
|
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
defaults:
|
|
2
2
|
- hydra: default
|
|
3
|
+
- fabric: auto
|
|
3
4
|
- modelpool: nyuv2_modelpool
|
|
4
5
|
- method: simple_average
|
|
5
6
|
- taskpool: nyuv2_taskpool
|
|
6
7
|
- _self_
|
|
8
|
+
|
|
9
|
+
_target_: fusion_bench.programs.FabricModelFusionProgram
|
|
10
|
+
_recursive_: false
|
|
11
|
+
|
|
7
12
|
fast_dev_run: false # Run a single batch of data to test the model or method
|
|
8
13
|
use_lightning: true # Use the fabric to run the experiment
|
|
9
14
|
print_config: true # Print the configuration to the console
|
|
10
15
|
save_report: false # path to save the result report
|
|
11
|
-
fabric: null
|
|
12
16
|
trainer:
|
|
13
17
|
devices: 1
|
fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
type: clip_vit_classification
|
|
2
|
+
name: clip-vit-robustness_clean
|
|
3
|
+
# corrption can be one of:
|
|
4
|
+
# contrast, gaussian_noise, impulse_noise, jpeg_compression, motion_blur, pixelate, spatter
|
|
5
|
+
corruption: ${corruption}
|
|
6
|
+
dataset_type: huggingface_image_classification
|
|
7
|
+
tasks:
|
|
8
|
+
- name: stanford_cars
|
|
9
|
+
dataset:
|
|
10
|
+
name: tanganke/stanford_cars
|
|
11
|
+
split: ${taskpool.corruption}
|
|
12
|
+
- name: eurosat
|
|
13
|
+
dataset:
|
|
14
|
+
name: tanganke/eurosat
|
|
15
|
+
split: ${taskpool.corruption}
|
|
16
|
+
- name: resisc45
|
|
17
|
+
dataset:
|
|
18
|
+
name: tanganke/resisc45
|
|
19
|
+
split: ${taskpool.corruption}
|
|
20
|
+
- name: gtsrb
|
|
21
|
+
dataset:
|
|
22
|
+
name: tanganke/gtsrb
|
|
23
|
+
split: ${taskpool.corruption}
|
|
24
|
+
clip_model: openai/clip-vit-base-patch32
|
|
25
|
+
batch_size: 128
|
|
26
|
+
num_workers: 16
|
|
27
|
+
fast_dev_run: ${fast_dev_run}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- CLIPVisionModelTaskPool@: _template
|
|
3
|
+
- /dataset/image_classification/test@test_datasets:
|
|
4
|
+
# eight tasks in the task arithmetic paper
|
|
5
|
+
- sun397
|
|
6
|
+
- stanford-cars
|
|
7
|
+
- resisc45
|
|
8
|
+
- eurosat
|
|
9
|
+
- svhn
|
|
10
|
+
- gtsrb
|
|
11
|
+
- mnist
|
|
12
|
+
- dtd
|
|
13
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
14
|
+
- oxford_flowers102
|
|
15
|
+
- pcam
|
|
16
|
+
- fer2013
|
|
17
|
+
- oxford-iiit-pet
|
|
18
|
+
- stl10
|
|
19
|
+
- cifar100
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- CLIPVisionModelTaskPool@: _template
|
|
3
|
+
- /dataset/image_classification/test@test_datasets:
|
|
4
|
+
# eight tasks in the task arithmetic paper
|
|
5
|
+
- sun397
|
|
6
|
+
- stanford-cars
|
|
7
|
+
- resisc45
|
|
8
|
+
- eurosat
|
|
9
|
+
- svhn
|
|
10
|
+
- gtsrb
|
|
11
|
+
- mnist
|
|
12
|
+
- dtd
|
|
13
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
14
|
+
- oxford_flowers102
|
|
15
|
+
- pcam
|
|
16
|
+
- fer2013
|
|
17
|
+
- oxford-iiit-pet
|
|
18
|
+
- stl10
|
|
19
|
+
- cifar100
|
|
20
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
21
|
+
- cifar10
|
|
22
|
+
- food101
|
|
23
|
+
- fashion_mnist
|
|
24
|
+
- emnist_letters
|
|
25
|
+
- kmnist
|
|
26
|
+
- rendered-sst2
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
_target_: fusion_bench.taskpool.llama.reward_model.RewardModelEvaluationTaskPool
|
|
2
|
+
|
|
3
|
+
test_datasets:
|
|
4
|
+
preference_700k:
|
|
5
|
+
_target_: fusion_bench.dataset.llama.preference_700k.load_tokenized_preference_700k_for_rlhf
|
|
6
|
+
tokenizer: ${...tokenizer}
|
|
7
|
+
path: hendrydong/preference_700K
|
|
8
|
+
split: train
|
|
9
|
+
cache_path: null
|
|
10
|
+
|
|
11
|
+
dataloader_kwargs:
|
|
12
|
+
shuffle: False
|
|
13
|
+
batch_size: 16
|
|
14
|
+
|
|
15
|
+
tokenizer: ${..modelpool.tokenizer}
|
|
16
|
+
|
|
17
|
+
max_num_samples: 1000
|
|
18
|
+
seed: 42
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
defaults:
|
|
2
|
-
- example_config
|
|
3
|
-
- override method: weighted_average_for_llama
|
|
4
|
-
- override modelpool: llama_for_causallm
|
|
5
|
-
- _self_
|
|
6
|
-
modelpool:
|
|
7
|
-
models:
|
|
8
|
-
# the pre-trained model (base model) is optional
|
|
9
|
-
# if not provided, the first model will be used as the base model
|
|
10
|
-
- name: _pretrained_
|
|
11
|
-
path: meta-llama/Meta-Llama-3-8B
|
|
12
|
-
- name: expert_1
|
|
13
|
-
path: meta-llama/Meta-Llama-3-8B
|
|
14
|
-
- name: expert_2
|
|
15
|
-
path: meta-llama/Meta-Llama-3-8B-Instruct
|
|
16
|
-
method:
|
|
17
|
-
normalize: true # if true, the weights will be normalized before merging
|
|
18
|
-
weights: # List of weights for each model
|
|
19
|
-
- 0.5
|
|
20
|
-
- 0.5
|
|
21
|
-
# if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
|
|
22
|
-
# if false, the whole model will be merged
|
|
23
|
-
backbone_only: true
|
|
24
|
-
merged_model_save_path: null
|
|
25
|
-
save_tokenizer: true
|
|
26
|
-
push_to_hub: false
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|