fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +3 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +64 -11
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +12 -1
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +12 -5
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +14 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@ class AlgorithmFactory:
|
|
|
21
21
|
"clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
|
|
22
22
|
"clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
|
|
23
23
|
"singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
|
|
24
|
+
"clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
|
|
24
25
|
# plug-and-play model merging methods
|
|
25
26
|
"clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
26
27
|
"clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
3
|
|
|
4
4
|
from omegaconf import DictConfig
|
|
5
5
|
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from fusion_bench.programs.base_program import BaseHydraProgram
|
|
8
|
+
|
|
6
9
|
__all__ = ["ModelFusionAlgorithm"]
|
|
7
10
|
|
|
8
11
|
|
|
@@ -18,6 +21,9 @@ class ModelFusionAlgorithm(ABC):
|
|
|
18
21
|
config (DictConfig): Configuration for the algorithm.
|
|
19
22
|
"""
|
|
20
23
|
|
|
24
|
+
_program: "BaseHydraProgram" = None
|
|
25
|
+
"""A reference to the program that is running the algorithm."""
|
|
26
|
+
|
|
21
27
|
def __init__(self, algorithm_config: Optional[DictConfig] = None):
|
|
22
28
|
"""
|
|
23
29
|
Initialize the model fusion algorithm with the given configuration.
|
|
@@ -22,7 +22,7 @@ class ModelPoolFactory:
|
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
24
|
_modelpool = {
|
|
25
|
-
"NYUv2ModelPool": ".nyuv2_modelpool.NYUv2ModelPool",
|
|
25
|
+
"NYUv2ModelPool": "fusion_bench.modelpool.nyuv2_modelpool.NYUv2ModelPool",
|
|
26
26
|
"huggingface_clip_vision": HuggingFaceClipVisionPool,
|
|
27
27
|
"HF_GPT2ForSequenceClassification": GPT2ForSequenceClassificationPool,
|
|
28
28
|
"AutoModelPool": ".huggingface_automodel.AutoModelPool",
|
|
@@ -20,7 +20,7 @@ class TaskPoolFactory:
|
|
|
20
20
|
"dummy": DummyTaskPool,
|
|
21
21
|
"clip_vit_classification": ".clip_image_classification.CLIPImageClassificationTaskPool",
|
|
22
22
|
"FlanT5GLUETextGenerationTaskPool": ".flan_t5_glue_text_generation.FlanT5GLUETextGenerationTaskPool",
|
|
23
|
-
"NYUv2TaskPool": ".nyuv2_taskpool.NYUv2TaskPool",
|
|
23
|
+
"NYUv2TaskPool": "fusion_bench.taskpool.nyuv2_taskpool.NYUv2TaskPool",
|
|
24
24
|
}
|
|
25
25
|
|
|
26
26
|
@staticmethod
|
|
@@ -152,6 +152,11 @@ class Task:
|
|
|
152
152
|
tasks = []
|
|
153
153
|
for test_data in data["test"]:
|
|
154
154
|
task = cls.deserialize(
|
|
155
|
+
{
|
|
156
|
+
"train": data["train"],
|
|
157
|
+
"test": [test_data],
|
|
158
|
+
"name": data.get("name", ""),
|
|
159
|
+
},
|
|
155
160
|
{
|
|
156
161
|
"train": data["train"],
|
|
157
162
|
"test": [test_data],
|
|
@@ -143,7 +143,7 @@ def format_and_filter(
|
|
|
143
143
|
data = {
|
|
144
144
|
"input_ids": prompt_tokens + output_tokens,
|
|
145
145
|
"attention_mask": [1] * len(prompt_tokens) + [1] * len(output_tokens),
|
|
146
|
-
"labels":
|
|
146
|
+
"labels": prompt_tokens + output_tokens,
|
|
147
147
|
"task_id": task_id,
|
|
148
148
|
"num_prompt_tokens": len(prompt_tokens),
|
|
149
149
|
"num_output_tokens": len(output_tokens),
|
|
@@ -65,4 +65,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
65
65
|
else:
|
|
66
66
|
# if processor is None, return the raw image directly
|
|
67
67
|
inputs = image
|
|
68
|
+
# convert boolean label to int, this is for the case when the label is a binary classification task
|
|
69
|
+
if isinstance(item["label"], bool):
|
|
70
|
+
item["label"] = 1 if item["label"] else 0
|
|
68
71
|
return inputs, item["label"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from datasets import load_dataset
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def load_fer2013(path: str = "clip-benchmark/wds_fer2013", split: str = "train"):
|
|
5
|
+
dataset = load_dataset(path, split=split)
|
|
6
|
+
dataset = dataset.remove_columns(["__key__", "__url__"])
|
|
7
|
+
dataset = dataset.rename_columns({"jpg": "image", "cls": "label"})
|
|
8
|
+
return dataset
|
|
9
|
+
|
|
10
|
+
if __name__ == "__main__":
|
|
11
|
+
dataset = load_fer2013(split="test")
|
|
12
|
+
print(dataset)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import collate
|
|
@@ -1,16 +1,100 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any, Dict, List, Optional
|
|
4
5
|
|
|
5
6
|
from datasets import Dataset, load_dataset, load_from_disk
|
|
7
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
8
|
+
from tqdm.auto import tqdm
|
|
6
9
|
from transformers import PreTrainedTokenizer
|
|
7
10
|
|
|
8
11
|
import fusion_bench
|
|
12
|
+
from fusion_bench.utils import timeit_context
|
|
9
13
|
|
|
10
14
|
log = logging.getLogger(__name__)
|
|
11
15
|
|
|
12
16
|
|
|
13
|
-
def
|
|
17
|
+
def convert_alpaca_to_conversation(alpaca_data: List[Dict[str, str]]):
|
|
18
|
+
"""
|
|
19
|
+
Convert Alpaca format data to conversation format.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
alpaca_data (list): List of dictionaries in Alpaca format with
|
|
23
|
+
'instruction', 'input', and 'output' keys
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
list: List of conversations in ChatML format
|
|
27
|
+
"""
|
|
28
|
+
conversations = []
|
|
29
|
+
|
|
30
|
+
for item in tqdm(
|
|
31
|
+
alpaca_data,
|
|
32
|
+
"Converting Alpaca to conversations",
|
|
33
|
+
disable=not rank_zero_only.rank == 0,
|
|
34
|
+
):
|
|
35
|
+
# Skip if required fields are missing
|
|
36
|
+
if not item.get("instruction") or not item.get("output"):
|
|
37
|
+
continue
|
|
38
|
+
|
|
39
|
+
conversation = []
|
|
40
|
+
|
|
41
|
+
# Create user message
|
|
42
|
+
user_content = item["instruction"]
|
|
43
|
+
if item.get("input") and item["input"].strip():
|
|
44
|
+
user_content += f"\n\n{item['input']}"
|
|
45
|
+
|
|
46
|
+
conversation.append({"role": "user", "content": user_content})
|
|
47
|
+
|
|
48
|
+
# Create assistant message
|
|
49
|
+
conversation.append({"role": "assistant", "content": item["output"]})
|
|
50
|
+
|
|
51
|
+
conversations.append(conversation)
|
|
52
|
+
|
|
53
|
+
return conversations
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def load_tokenized_alpaca_dataset(
|
|
57
|
+
tokenizer: PreTrainedTokenizer,
|
|
58
|
+
path: str = "yahma/alpaca-cleaned",
|
|
59
|
+
split: str = "train",
|
|
60
|
+
cache_path: Optional[str] = None,
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Load and tokenized Alpaca dataset and Alpaca-like dataset.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the dataset.
|
|
67
|
+
path (str, optional): The path to the Alpaca dataset. Defaults to "yahma/alpaca-cleaned".
|
|
68
|
+
split (str, optional): The dataset split to load (e.g., "train", "test"). Defaults to "train".
|
|
69
|
+
cache_path (Optional[str], optional): The path to cache the tokenized dataset. If provided and the cache exists,
|
|
70
|
+
the dataset will be loaded from the cache. Defaults to None.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Dataset: The tokenized dataset.
|
|
74
|
+
"""
|
|
75
|
+
if cache_path is not None and os.path.exists(cache_path):
|
|
76
|
+
dataset = load_from_disk(cache_path)
|
|
77
|
+
if split is not None and split in dataset:
|
|
78
|
+
return dataset[split]
|
|
79
|
+
else:
|
|
80
|
+
return dataset
|
|
81
|
+
|
|
82
|
+
dataset = load_dataset(path, split=split)
|
|
83
|
+
|
|
84
|
+
alpaca_data = dataset.to_list()
|
|
85
|
+
conversations = convert_alpaca_to_conversation(alpaca_data)
|
|
86
|
+
with timeit_context("Tokenizing dataset"):
|
|
87
|
+
tokenized_dataset = tokenizer.apply_chat_template(
|
|
88
|
+
conversations, return_dict=True
|
|
89
|
+
)
|
|
90
|
+
tokenized_dataset = Dataset.from_dict(tokenized_dataset)
|
|
91
|
+
|
|
92
|
+
if cache_path is not None and rank_zero_only.rank == 0:
|
|
93
|
+
tokenized_dataset.save_to_disk(cache_path)
|
|
94
|
+
return tokenized_dataset
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _tokenize_alpaca_dataset_with_template(
|
|
14
98
|
dataset: Dataset,
|
|
15
99
|
tokenizer: PreTrainedTokenizer,
|
|
16
100
|
max_length: int = 2048,
|
|
@@ -32,6 +116,10 @@ def tokenize_alpaca_dataset(
|
|
|
32
116
|
Returns:
|
|
33
117
|
Tokenized dataset
|
|
34
118
|
"""
|
|
119
|
+
warnings.warn(
|
|
120
|
+
"This function is deprecated. Use `apply_chat_template` from `transformers` instead.",
|
|
121
|
+
DeprecationWarning,
|
|
122
|
+
)
|
|
35
123
|
|
|
36
124
|
def prepare_samples(samples: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:
|
|
37
125
|
# Format prompts based on whether input field exists
|
|
@@ -115,7 +203,7 @@ def tokenize_alpaca_dataset(
|
|
|
115
203
|
return tokenized_dataset
|
|
116
204
|
|
|
117
205
|
|
|
118
|
-
def
|
|
206
|
+
def load_tokenized_alpaca_dataset_from_json_with_prompt(
|
|
119
207
|
data_files: str,
|
|
120
208
|
tokenizer: PreTrainedTokenizer,
|
|
121
209
|
max_length: int,
|
|
@@ -138,5 +226,7 @@ def load_tokenized_alpaca_dataset_from_json(
|
|
|
138
226
|
dataset = load_dataset("json", data_files=data_files)
|
|
139
227
|
if split is not None:
|
|
140
228
|
dataset = dataset[split]
|
|
141
|
-
dataset =
|
|
229
|
+
dataset = _tokenize_alpaca_dataset_with_template(
|
|
230
|
+
dataset, tokenizer, max_length=max_length
|
|
231
|
+
)
|
|
142
232
|
return dataset
|
|
@@ -7,7 +7,7 @@ from torch.nn.utils.rnn import pad_sequence
|
|
|
7
7
|
|
|
8
8
|
def padded_collate_sft(
|
|
9
9
|
batch: List[Dict[str, List[int]]],
|
|
10
|
-
|
|
10
|
+
pad_token_id: int = 0,
|
|
11
11
|
input_ids_key: str = "input_ids",
|
|
12
12
|
attention_mask_key: Optional[str] = "attention_mask",
|
|
13
13
|
labels_key: Optional[str] = "labels",
|
|
@@ -28,7 +28,7 @@ def padded_collate_sft(
|
|
|
28
28
|
input_ids = pad_sequence(
|
|
29
29
|
[torch.tensor(x[input_ids_key]) for x in batch],
|
|
30
30
|
batch_first=True,
|
|
31
|
-
padding_value=
|
|
31
|
+
padding_value=pad_token_id,
|
|
32
32
|
)
|
|
33
33
|
if attention_mask_key is not None and attention_mask_key in batch[0]:
|
|
34
34
|
attention_mask = pad_sequence(
|
|
@@ -38,6 +38,12 @@ def padded_collate_sft(
|
|
|
38
38
|
)
|
|
39
39
|
else:
|
|
40
40
|
attention_mask = None
|
|
41
|
+
|
|
42
|
+
for i, item in enumerate(batch):
|
|
43
|
+
# if labels_key not in item, copy input_ids to labels_key
|
|
44
|
+
if labels_key not in item:
|
|
45
|
+
item[labels_key] = item[input_ids_key]
|
|
46
|
+
|
|
41
47
|
labels = pad_sequence(
|
|
42
48
|
[torch.tensor(x[labels_key]) for x in batch],
|
|
43
49
|
batch_first=True,
|
|
@@ -58,3 +64,57 @@ def padded_collate_sft(
|
|
|
58
64
|
collated_batch[key] = [x[key] for x in batch]
|
|
59
65
|
|
|
60
66
|
return collated_batch
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def bradley_terry_rm_collate(
|
|
70
|
+
batch: List[Dict[str, List[int]]],
|
|
71
|
+
pad_token_id: int = 0,
|
|
72
|
+
padding_side="right",
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Collate function for Bradley-Terry reward modeling.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
|
|
79
|
+
pad_token_id (int): Padding index for input ids. Defaults to 0.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Dict[str, torch.Tensor]: Collated input and label tensors. The first half of the batch is the winner, and the second half is the loser.
|
|
83
|
+
"""
|
|
84
|
+
converted_batch = []
|
|
85
|
+
for item in batch:
|
|
86
|
+
new_item = {
|
|
87
|
+
"input_ids": item["chosen_input_ids"],
|
|
88
|
+
"attention_mask": item["chosen_attention_mask"],
|
|
89
|
+
}
|
|
90
|
+
converted_batch.append(new_item)
|
|
91
|
+
for item in batch:
|
|
92
|
+
new_item = {
|
|
93
|
+
"input_ids": item["rejected_input_ids"],
|
|
94
|
+
"attention_mask": item["rejected_attention_mask"],
|
|
95
|
+
}
|
|
96
|
+
converted_batch.append(new_item)
|
|
97
|
+
|
|
98
|
+
input_ids = pad_sequence(
|
|
99
|
+
[torch.tensor(x["input_ids"]) for x in converted_batch],
|
|
100
|
+
batch_first=True,
|
|
101
|
+
padding_value=pad_token_id,
|
|
102
|
+
padding_side=padding_side,
|
|
103
|
+
)
|
|
104
|
+
attention_mask = pad_sequence(
|
|
105
|
+
[torch.tensor(x["attention_mask"]) for x in converted_batch],
|
|
106
|
+
batch_first=True,
|
|
107
|
+
padding_value=0,
|
|
108
|
+
padding_side=padding_side,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
collated_batch = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
112
|
+
for key in batch[0]:
|
|
113
|
+
if key not in [
|
|
114
|
+
"chosen_input_ids",
|
|
115
|
+
"chosen_attention_mask",
|
|
116
|
+
"rejected_input_ids",
|
|
117
|
+
"rejected_attention_mask",
|
|
118
|
+
]:
|
|
119
|
+
collated_batch[key] = [x[key] for x in batch]
|
|
120
|
+
return collated_batch
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
5
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
6
|
+
from tqdm.auto import tqdm
|
|
7
|
+
|
|
8
|
+
from fusion_bench.utils import timeit_context
|
|
9
|
+
|
|
10
|
+
from .alpaca import convert_alpaca_to_conversation
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from transformers import PreTrainedTokenizer
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def load_tokenized_metamathqa(
|
|
17
|
+
tokenizer: "PreTrainedTokenizer",
|
|
18
|
+
path: str = "meta-math/MetaMathQA",
|
|
19
|
+
split: str = "train",
|
|
20
|
+
cache_path: Optional[str] = None,
|
|
21
|
+
):
|
|
22
|
+
if cache_path is not None and os.path.exists(cache_path):
|
|
23
|
+
dataset = load_from_disk(cache_path)
|
|
24
|
+
if split is not None and split in dataset:
|
|
25
|
+
return dataset[split]
|
|
26
|
+
else:
|
|
27
|
+
return dataset
|
|
28
|
+
|
|
29
|
+
dataset = load_dataset(path, split=split)
|
|
30
|
+
|
|
31
|
+
# convert dataset to alpaca format and save to ../data/MetaMathQA.json
|
|
32
|
+
alpaca_dataset = []
|
|
33
|
+
for example in tqdm(dataset, disable=not rank_zero_only.rank == 0):
|
|
34
|
+
alpaca_example = {
|
|
35
|
+
"instruction": example["query"],
|
|
36
|
+
"input": "",
|
|
37
|
+
"output": example["response"],
|
|
38
|
+
}
|
|
39
|
+
alpaca_dataset.append(alpaca_example)
|
|
40
|
+
|
|
41
|
+
conversations = convert_alpaca_to_conversation(alpaca_dataset)
|
|
42
|
+
with timeit_context("Tokenizing dataset"):
|
|
43
|
+
tokenized_dataset = tokenizer.apply_chat_template(
|
|
44
|
+
conversations, return_dict=True
|
|
45
|
+
)
|
|
46
|
+
tokenized_dataset = Dataset.from_dict(tokenized_dataset)
|
|
47
|
+
|
|
48
|
+
if cache_path is not None and rank_zero_only.rank == 0:
|
|
49
|
+
tokenized_dataset.save_to_disk(cache_path)
|
|
50
|
+
return tokenized_dataset
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
7
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench.utils import timeit_context
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from transformers import PreTrainedTokenizer
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def load_tokenized_preference_700k_for_rlhf(
|
|
19
|
+
tokenizer: "PreTrainedTokenizer",
|
|
20
|
+
path: str = "hendrydong/preference_700K",
|
|
21
|
+
split: str = "train",
|
|
22
|
+
num_proc: int = 8,
|
|
23
|
+
cache_path: Optional[str] = None,
|
|
24
|
+
):
|
|
25
|
+
R"""
|
|
26
|
+
Load and tokenized Preference 700k dataset for Bradley-Terry ranking model.
|
|
27
|
+
|
|
28
|
+
The returned dataset contains the following fields:
|
|
29
|
+
|
|
30
|
+
- chosen_input_ids: The input token ids for the winner.
|
|
31
|
+
- chosen_attention_mask: The attention mask for the winner.
|
|
32
|
+
- rejected_input_ids: The input token ids for the loser.
|
|
33
|
+
- rejected_attention_mask: The attention mask for the loser.
|
|
34
|
+
"""
|
|
35
|
+
if cache_path is not None and os.path.exists(cache_path):
|
|
36
|
+
dataset = load_from_disk(cache_path)
|
|
37
|
+
return dataset
|
|
38
|
+
|
|
39
|
+
dataset = load_dataset(path, split=split)
|
|
40
|
+
|
|
41
|
+
def tokenize(sample):
|
|
42
|
+
sample["chosen_chat"] = tokenizer.apply_chat_template(
|
|
43
|
+
sample["chosen"], tokenize=False, add_generation_prompt=False
|
|
44
|
+
)
|
|
45
|
+
sample["rejected_chat"] = tokenizer.apply_chat_template(
|
|
46
|
+
sample["rejected"], tokenize=False, add_generation_prompt=False
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
tokenized_pos = tokenizer(sample["chosen_chat"], truncation=True)
|
|
50
|
+
tokenized_neg = tokenizer(sample["rejected_chat"], truncation=True)
|
|
51
|
+
|
|
52
|
+
# Ensure that the chosen response does not contain an PAD token
|
|
53
|
+
sample["chosen_input_ids"] = tokenized_pos["input_ids"]
|
|
54
|
+
sample["chosen_attention_mask"] = tokenized_pos["attention_mask"]
|
|
55
|
+
if tokenizer.pad_token_id in tokenized_pos["input_ids"]:
|
|
56
|
+
log.warning(f"Prompt contains PAD token: {sample['chosen_chat']}")
|
|
57
|
+
|
|
58
|
+
sample["rejected_input_ids"] = tokenized_neg["input_ids"]
|
|
59
|
+
sample["rejected_attention_mask"] = tokenized_neg["attention_mask"]
|
|
60
|
+
# Ensure that the rejected response does not contain an PAD token
|
|
61
|
+
if tokenizer.pad_token_id in tokenized_neg["input_ids"]:
|
|
62
|
+
log.warning(f"Prompt contains PAD token: {sample['rejected_chat']}")
|
|
63
|
+
|
|
64
|
+
return sample
|
|
65
|
+
|
|
66
|
+
dataset = dataset.map(tokenize, num_proc=num_proc)
|
|
67
|
+
|
|
68
|
+
if cache_path is not None and rank_zero_only.rank == 0:
|
|
69
|
+
dataset.save_to_disk(cache_path)
|
|
70
|
+
return dataset
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
6
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
7
|
+
from tqdm.auto import tqdm
|
|
8
|
+
|
|
9
|
+
from fusion_bench.utils import timeit_context
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from transformers import PreTrainedTokenizer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_tokenized_stanford_shp_for_rlhf(
|
|
16
|
+
tokenizer: "PreTrainedTokenizer",
|
|
17
|
+
path: str = "stanfordnlp/SHP",
|
|
18
|
+
split: str = "train",
|
|
19
|
+
num_proc: int = 8,
|
|
20
|
+
cache_path: Optional[str] = None,
|
|
21
|
+
):
|
|
22
|
+
if cache_path is not None and os.path.isdir(cache_path):
|
|
23
|
+
dataset = load_from_disk(cache_path)
|
|
24
|
+
return dataset
|
|
25
|
+
|
|
26
|
+
dataset = load_dataset(path, split=split)
|
|
27
|
+
|
|
28
|
+
def tokenize(sample):
|
|
29
|
+
"""
|
|
30
|
+
- history: the post title concatented to the post body (string)
|
|
31
|
+
- human_ref_A: text of comment A (string)
|
|
32
|
+
- human_ref_B: text of comment B (string)
|
|
33
|
+
- labels: the preference label -- it is 1 if A is preferred to B; 0 if B is preferred to A. This was randomized such that the label distribution is roughly 50/50. (integer)
|
|
34
|
+
"""
|
|
35
|
+
# Create a conversation with the post title and body, followed by comments
|
|
36
|
+
conversation = [{"role": "user", "content": sample["history"]}]
|
|
37
|
+
if sample["labels"] == 0:
|
|
38
|
+
sample["chosen"] = deepcopy(conversation).append(
|
|
39
|
+
{"role": "assistant", "content": sample["human_ref_B"]}
|
|
40
|
+
)
|
|
41
|
+
sample["rejected"] = deepcopy(conversation).append(
|
|
42
|
+
{"role": "assistant", "content": sample["human_ref_A"]}
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
sample["chosen"] = deepcopy(conversation).append(
|
|
46
|
+
{"role": "assistant", "content": sample["human_ref_A"]}
|
|
47
|
+
)
|
|
48
|
+
sample["rejected"] = deepcopy(conversation).append(
|
|
49
|
+
{"role": "assistant", "content": sample["human_ref_B"]}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# apply chat template
|
|
53
|
+
sample["chosen_chat"] = tokenizer.apply_chat_template(
|
|
54
|
+
sample["chosen"], tokenize=False, add_generation_prompt=False
|
|
55
|
+
)
|
|
56
|
+
sample["rejected_chat"] = tokenizer.apply_chat_template(
|
|
57
|
+
sample["rejected"], tokenize=False, add_generation_prompt=False
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# tokenize the conversation
|
|
61
|
+
tokenized_pos = tokenizer(sample["chosen_chat"], truncation=True)
|
|
62
|
+
tokenized_neg = tokenizer(sample["rejected_chat"], truncation=True)
|
|
63
|
+
|
|
64
|
+
# Ensure that the chosen response does not contain an EOS token
|
|
65
|
+
sample["chosen_input_ids"] = tokenized_pos["input_ids"]
|
|
66
|
+
sample["chosen_attention_mask"] = tokenized_pos["attention_mask"]
|
|
67
|
+
assert (
|
|
68
|
+
tokenizer.eos_token_id not in tokenized_pos["input_ids"][:-1]
|
|
69
|
+
), f"Prompt contains EOS token: {sample['positive']}"
|
|
70
|
+
if sample["chosen_input_ids"][-1] != tokenizer.eos_token_id:
|
|
71
|
+
sample["chosen_input_ids"].append(tokenizer.eos_token_id)
|
|
72
|
+
sample["chosen_attention_mask"].append(1)
|
|
73
|
+
|
|
74
|
+
sample["rejected_input_ids"] = tokenized_neg["input_ids"]
|
|
75
|
+
sample["rejected_attention_mask"] = tokenized_neg["attention_mask"]
|
|
76
|
+
# Ensure that the rejected response does not contain an EOS token
|
|
77
|
+
assert (
|
|
78
|
+
tokenizer.eos_token_id not in tokenized_neg["input_ids"][:-1]
|
|
79
|
+
), f"Prompt contains EOS token: {sample['rejected']}"
|
|
80
|
+
if sample["rejected_input_ids"][-1] != tokenizer.eos_token_id:
|
|
81
|
+
sample["rejected_input_ids"].append(tokenizer.eos_token_id)
|
|
82
|
+
sample["rejected_attention_mask"].append(1)
|
|
83
|
+
|
|
84
|
+
return sample
|
|
85
|
+
|
|
86
|
+
dataset = dataset.map(tokenize, num_proc=num_proc)
|
|
87
|
+
|
|
88
|
+
if cache_path is not None and rank_zero_only.rank == 0:
|
|
89
|
+
dataset.save_to_disk(cache_path)
|
|
90
|
+
return dataset
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
5
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
6
|
+
from tqdm.auto import tqdm
|
|
7
|
+
|
|
8
|
+
from fusion_bench.utils import timeit_context
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from transformers import PreTrainedTokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_tokenized_ultrachat_200k(
|
|
15
|
+
tokenizer: "PreTrainedTokenizer",
|
|
16
|
+
path: str = "HuggingFaceH4/ultrachat_200k",
|
|
17
|
+
split: str = "train_sft",
|
|
18
|
+
num_proc: int = 8,
|
|
19
|
+
cache_path: Optional[str] = None,
|
|
20
|
+
):
|
|
21
|
+
R"""
|
|
22
|
+
Load and tokenized Ultrachat 200k dataset for Bradley-Terry ranking model.
|
|
23
|
+
|
|
24
|
+
The returned dataset contains the following fields:
|
|
25
|
+
|
|
26
|
+
- input_ids: The input token ids for the winner.
|
|
27
|
+
- attention_mask: The attention mask for the winner.
|
|
28
|
+
"""
|
|
29
|
+
if cache_path is not None and os.path.exists(cache_path):
|
|
30
|
+
dataset = load_from_disk(cache_path)
|
|
31
|
+
return dataset
|
|
32
|
+
|
|
33
|
+
dataset = load_dataset(path, split=split)
|
|
34
|
+
|
|
35
|
+
def tokenize(sample):
|
|
36
|
+
|
|
37
|
+
# ? is it necessary to `.replace(tokenizer.bos_token, "")`?
|
|
38
|
+
sample["input_ids"] = tokenizer.apply_chat_template(
|
|
39
|
+
sample["messages"], tokenize=True, add_generation_prompt=False
|
|
40
|
+
)
|
|
41
|
+
sample["attention_mask"] = [1] * len(sample["input_ids"])
|
|
42
|
+
|
|
43
|
+
return sample
|
|
44
|
+
|
|
45
|
+
dataset = dataset.map(tokenize, num_proc=num_proc)
|
|
46
|
+
|
|
47
|
+
if cache_path is not None and rank_zero_only.rank == 0:
|
|
48
|
+
dataset.save_to_disk(cache_path)
|
|
49
|
+
return dataset
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if __name__ == "__main__":
|
|
53
|
+
# Example usage and testing
|
|
54
|
+
from transformers import AutoTokenizer
|
|
55
|
+
|
|
56
|
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
|
|
57
|
+
dataset = load_tokenized_ultrachat_200k(tokenizer)
|
|
58
|
+
print(dataset)
|
|
File without changes
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -10,7 +10,7 @@ _import_structure = {
|
|
|
10
10
|
"dummy": ["DummyAlgorithm"],
|
|
11
11
|
# single task learning (fine-tuning)
|
|
12
12
|
"classification": ["ImageClassificationFineTuningForCLIP"],
|
|
13
|
-
"lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT"],
|
|
13
|
+
"lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
|
|
14
14
|
# analysis
|
|
15
15
|
"analysis": ["TaskVectorCosSimilarity", "TaskVectorViolinPlot"],
|
|
16
16
|
# model ensemble methods
|
|
@@ -49,6 +49,7 @@ _import_structure = {
|
|
|
49
49
|
"PWEMoExactParetoOptimalForCLIP",
|
|
50
50
|
],
|
|
51
51
|
"ada_svd": ["AdaSVDMergingForCLIPVisionModel"],
|
|
52
|
+
"task_singular_vector": ["TaskSingularVectorMerging"],
|
|
52
53
|
# plug-and-play model merging methods
|
|
53
54
|
"concrete_subspace": [
|
|
54
55
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -153,6 +154,7 @@ if TYPE_CHECKING:
|
|
|
153
154
|
SparseLoForLlama,
|
|
154
155
|
)
|
|
155
156
|
from .task_arithmetic import TaskArithmeticAlgorithm
|
|
157
|
+
from .task_singular_vector import TaskSingularVectorMerging
|
|
156
158
|
from .ties_merging import TiesMergingAlgorithm
|
|
157
159
|
from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
|
|
158
160
|
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
from abc import abstractmethod
|
|
4
|
-
from typing import Any, List, Mapping, Union, cast # noqa: F401
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, cast # noqa: F401
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
8
8
|
from omegaconf import DictConfig
|
|
9
|
-
from torch import Tensor
|
|
9
|
+
from torch import Tensor, nn
|
|
10
10
|
from torch.utils.data import DataLoader
|
|
11
11
|
from tqdm.autonotebook import tqdm
|
|
12
12
|
|
|
@@ -19,10 +19,14 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
|
19
19
|
get_layer_wise_weights,
|
|
20
20
|
)
|
|
21
21
|
from fusion_bench.utils.data import load_tensor_from_file
|
|
22
|
+
from fusion_bench.utils.type import TorchModelType
|
|
22
23
|
|
|
23
24
|
from .entropy_loss import entropy_loss
|
|
24
25
|
from .utils import get_memory_usage
|
|
25
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
|
|
29
|
+
|
|
26
30
|
log = logging.getLogger(__name__)
|
|
27
31
|
|
|
28
32
|
|
|
@@ -31,6 +35,9 @@ class LayerWiseAdaMergingAlgorithm(
|
|
|
31
35
|
LightningFabricMixin,
|
|
32
36
|
SimpleProfilerMixin,
|
|
33
37
|
):
|
|
38
|
+
_program: "FabricModelFusionProgram"
|
|
39
|
+
"""The program that this algorithm is running on."""
|
|
40
|
+
|
|
34
41
|
"""
|
|
35
42
|
Implements the Layer-Wise AdaMerging Algorithm.
|
|
36
43
|
|
|
@@ -48,7 +55,7 @@ class LayerWiseAdaMergingAlgorithm(
|
|
|
48
55
|
super().__init__(algorithm_config)
|
|
49
56
|
|
|
50
57
|
@torch.no_grad()
|
|
51
|
-
def construct_layer_wise_merged_model(self, modelpool: ModelPool):
|
|
58
|
+
def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
|
|
52
59
|
"""
|
|
53
60
|
Constructs a wrapped layer-wise merged model from model pool.
|
|
54
61
|
|
|
@@ -183,7 +190,7 @@ class LayerWiseAdaMergingAlgorithm(
|
|
|
183
190
|
"""
|
|
184
191
|
pass
|
|
185
192
|
|
|
186
|
-
def test_time_adaptation(self, module: LayerWiseMergedModel):
|
|
193
|
+
def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
|
|
187
194
|
"""
|
|
188
195
|
Perform test-time adaptation on the merged model.
|
|
189
196
|
|