fusion-bench 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +3 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +1 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +64 -11
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +12 -1
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +158 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +12 -5
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +14 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +263 -90
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from IPython.display import HTML, display
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def create_color_style():
|
|
6
|
+
return """
|
|
7
|
+
<style>
|
|
8
|
+
.token-container { font-family: monospace; white-space: pre; }
|
|
9
|
+
.attention { background-color: #90EE90; } /* Light green */
|
|
10
|
+
.label { background-color: #FFB6C6; } /* Light red */
|
|
11
|
+
.token { color: #0066cc; } /* Blue */
|
|
12
|
+
.stats { font-weight: bold; }
|
|
13
|
+
</style>
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def escape_special_chars(text):
|
|
18
|
+
"""Convert special characters to their string representation"""
|
|
19
|
+
return (
|
|
20
|
+
text.replace("\n", "\\n")
|
|
21
|
+
.replace("\t", "\\t")
|
|
22
|
+
.replace("\r", "\\r")
|
|
23
|
+
.replace(" ", "␣")
|
|
24
|
+
) # Optional: show spaces with visible character
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def visualize_tokens_html(input_ids, attention_mask, labels, tokenizer):
|
|
28
|
+
"""
|
|
29
|
+
Visualize model inputs using HTML colored text representation for Jupyter Notebook
|
|
30
|
+
with special characters shown as strings
|
|
31
|
+
"""
|
|
32
|
+
# Convert to numpy if tensors
|
|
33
|
+
attention_mask = np.array(attention_mask).flatten()
|
|
34
|
+
labels = np.array(labels).flatten()
|
|
35
|
+
input_ids = np.array(input_ids).flatten()
|
|
36
|
+
|
|
37
|
+
# Decode tokens and escape special characters
|
|
38
|
+
tokens = [escape_special_chars(tokenizer.decode(id_)) for id_ in input_ids]
|
|
39
|
+
|
|
40
|
+
# Create HTML output
|
|
41
|
+
html_output = [create_color_style()]
|
|
42
|
+
|
|
43
|
+
# Header
|
|
44
|
+
html_output.append("<h3>**Token Visualization**</h3>")
|
|
45
|
+
|
|
46
|
+
# Legend
|
|
47
|
+
html_output.append(
|
|
48
|
+
"""
|
|
49
|
+
<div style='margin: 10px 0;'>
|
|
50
|
+
<strong>Legend:</strong><br>
|
|
51
|
+
<span class='attention'> </span> Active Attention<br>
|
|
52
|
+
<span class='label'> </span> Label Present<br>
|
|
53
|
+
<span class='token'>Text</span> Token Text<br>
|
|
54
|
+
Special Characters: \\n (newline), \\t (tab), ␣ (space)
|
|
55
|
+
</div>
|
|
56
|
+
"""
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Token alignment
|
|
60
|
+
html_output.append("<strong>Token Alignment:</strong>")
|
|
61
|
+
html_output.append("<div class='token-container'>")
|
|
62
|
+
|
|
63
|
+
# Calculate maximum token length for better alignment
|
|
64
|
+
max_token_len = max(len(str(token)) for token in tokens)
|
|
65
|
+
|
|
66
|
+
for i, (input_id, token, mask, label) in enumerate(
|
|
67
|
+
zip(input_ids, tokens, attention_mask, labels)
|
|
68
|
+
):
|
|
69
|
+
# Pad token for alignment
|
|
70
|
+
token_text = f"{token:{max_token_len}s}"
|
|
71
|
+
|
|
72
|
+
# Create classes for styling
|
|
73
|
+
classes = []
|
|
74
|
+
if mask == 1:
|
|
75
|
+
classes.append("attention")
|
|
76
|
+
if label != -100 and label != 0:
|
|
77
|
+
classes.append("label")
|
|
78
|
+
|
|
79
|
+
class_str = f"class='{' '.join(classes)}'" if classes else ""
|
|
80
|
+
|
|
81
|
+
# Create the line
|
|
82
|
+
line = f"Position {i:3d}: <span {class_str}><span class='token'>{token_text}</span></span> "
|
|
83
|
+
line += (
|
|
84
|
+
f"(Mask: {int(mask)}, Label: {int(label)}, Inpu_id: {int(input_id)})<br>"
|
|
85
|
+
)
|
|
86
|
+
html_output.append(line)
|
|
87
|
+
|
|
88
|
+
html_output.append("</div>")
|
|
89
|
+
|
|
90
|
+
# Statistics
|
|
91
|
+
html_output.append(
|
|
92
|
+
"""
|
|
93
|
+
<div class='stats' style='margin-top: 10px;'>
|
|
94
|
+
Statistics:<br>
|
|
95
|
+
Total tokens: {}<br>
|
|
96
|
+
Active attention tokens: {}<br>
|
|
97
|
+
Labeled tokens: {}<br>
|
|
98
|
+
</div>
|
|
99
|
+
""".format(
|
|
100
|
+
len(tokens), attention_mask.sum(), sum(labels != -100)
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Display the HTML
|
|
105
|
+
display(HTML("".join(html_output)))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# Example usage:
|
|
109
|
+
"""
|
|
110
|
+
from transformers import AutoTokenizer
|
|
111
|
+
import torch
|
|
112
|
+
|
|
113
|
+
# Initialize tokenizer
|
|
114
|
+
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
115
|
+
|
|
116
|
+
# Sample input with special characters
|
|
117
|
+
text = "Hello,\nhow are\tyou?"
|
|
118
|
+
inputs = tokenizer(text, return_tensors='pt')
|
|
119
|
+
labels = torch.zeros_like(inputs['input_ids']) # dummy labels
|
|
120
|
+
|
|
121
|
+
visualize_tokens_html(
|
|
122
|
+
inputs['attention_mask'][0],
|
|
123
|
+
labels[0],
|
|
124
|
+
inputs['input_ids'][0],
|
|
125
|
+
tokenizer
|
|
126
|
+
)
|
|
127
|
+
"""
|
fusion_bench/utils/type.py
CHANGED
|
@@ -6,18 +6,29 @@ from typing_extensions import TypeAlias
|
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
8
|
import torch
|
|
9
|
-
from torch import Tensor
|
|
9
|
+
from torch import Tensor, nn
|
|
10
10
|
|
|
11
11
|
StateDictType: TypeAlias = Dict[str, Tensor]
|
|
12
|
+
TorchModelType = TypeVar("TorchModelType", bound=nn.Module)
|
|
13
|
+
|
|
12
14
|
except ImportError:
|
|
13
15
|
pass
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
|
|
18
|
+
PyModuleType = type(sys)
|
|
17
19
|
T = TypeVar("T")
|
|
18
20
|
T1 = TypeVar("T1")
|
|
19
21
|
T2 = TypeVar("T2")
|
|
20
22
|
T3 = TypeVar("T3")
|
|
21
23
|
T4 = TypeVar("T4")
|
|
22
24
|
|
|
23
|
-
__all__ = [
|
|
25
|
+
__all__ = [
|
|
26
|
+
"StateDictType",
|
|
27
|
+
"PyModuleType",
|
|
28
|
+
"TorchModelType",
|
|
29
|
+
"T",
|
|
30
|
+
"T1",
|
|
31
|
+
"T2",
|
|
32
|
+
"T3",
|
|
33
|
+
"T4",
|
|
34
|
+
]
|