fusion-bench 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/base_algorithm.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/preference_700k.py +1 -1
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/surgery/__init__.py +1 -3
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +1 -1
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/clip_classification.py +6 -6
- fusion_bench/mixins/lightning_fabric.py +3 -1
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +2 -1
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +1 -1
- fusion_bench/programs/fabric_fusion_program.py +7 -4
- fusion_bench/taskpool/llama/reward_model.py +1 -1
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/type.py +10 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +195 -62
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
classnames = [
|
|
2
|
+
"apple pie",
|
|
3
|
+
"baby back ribs",
|
|
4
|
+
"baklava",
|
|
5
|
+
"beef carpaccio",
|
|
6
|
+
"beef tartare",
|
|
7
|
+
"beet salad",
|
|
8
|
+
"beignets",
|
|
9
|
+
"bibimbap",
|
|
10
|
+
"bread pudding",
|
|
11
|
+
"breakfast burrito",
|
|
12
|
+
"bruschetta",
|
|
13
|
+
"caesar salad",
|
|
14
|
+
"cannoli",
|
|
15
|
+
"caprese salad",
|
|
16
|
+
"carrot cake",
|
|
17
|
+
"ceviche",
|
|
18
|
+
"cheese plate",
|
|
19
|
+
"cheesecake",
|
|
20
|
+
"chicken curry",
|
|
21
|
+
"chicken quesadilla",
|
|
22
|
+
"chicken wings",
|
|
23
|
+
"chocolate cake",
|
|
24
|
+
"chocolate mousse",
|
|
25
|
+
"churros",
|
|
26
|
+
"clam chowder",
|
|
27
|
+
"club sandwich",
|
|
28
|
+
"crab cakes",
|
|
29
|
+
"creme brulee",
|
|
30
|
+
"croque madame",
|
|
31
|
+
"cup cakes",
|
|
32
|
+
"deviled eggs",
|
|
33
|
+
"donuts",
|
|
34
|
+
"dumplings",
|
|
35
|
+
"edamame",
|
|
36
|
+
"eggs benedict",
|
|
37
|
+
"escargots",
|
|
38
|
+
"falafel",
|
|
39
|
+
"filet mignon",
|
|
40
|
+
"fish and chips",
|
|
41
|
+
"foie gras",
|
|
42
|
+
"french fries",
|
|
43
|
+
"french onion soup",
|
|
44
|
+
"french toast",
|
|
45
|
+
"fried calamari",
|
|
46
|
+
"fried rice",
|
|
47
|
+
"frozen yogurt",
|
|
48
|
+
"garlic bread",
|
|
49
|
+
"gnocchi",
|
|
50
|
+
"greek salad",
|
|
51
|
+
"grilled cheese sandwich",
|
|
52
|
+
"grilled salmon",
|
|
53
|
+
"guacamole",
|
|
54
|
+
"gyoza",
|
|
55
|
+
"hamburger",
|
|
56
|
+
"hot and sour soup",
|
|
57
|
+
"hot dog",
|
|
58
|
+
"huevos rancheros",
|
|
59
|
+
"hummus",
|
|
60
|
+
"ice cream",
|
|
61
|
+
"lasagna",
|
|
62
|
+
"lobster bisque",
|
|
63
|
+
"lobster roll sandwich",
|
|
64
|
+
"macaroni and cheese",
|
|
65
|
+
"macarons",
|
|
66
|
+
"miso soup",
|
|
67
|
+
"mussels",
|
|
68
|
+
"nachos",
|
|
69
|
+
"omelette",
|
|
70
|
+
"onion rings",
|
|
71
|
+
"oysters",
|
|
72
|
+
"pad thai",
|
|
73
|
+
"paella",
|
|
74
|
+
"pancakes",
|
|
75
|
+
"panna cotta",
|
|
76
|
+
"peking duck",
|
|
77
|
+
"pho",
|
|
78
|
+
"pizza",
|
|
79
|
+
"pork chop",
|
|
80
|
+
"poutine",
|
|
81
|
+
"prime rib",
|
|
82
|
+
"pulled pork sandwich",
|
|
83
|
+
"ramen",
|
|
84
|
+
"ravioli",
|
|
85
|
+
"red velvet cake",
|
|
86
|
+
"risotto",
|
|
87
|
+
"samosa",
|
|
88
|
+
"sashimi",
|
|
89
|
+
"scallops",
|
|
90
|
+
"seaweed salad",
|
|
91
|
+
"shrimp and grits",
|
|
92
|
+
"spaghetti bolognese",
|
|
93
|
+
"spaghetti carbonara",
|
|
94
|
+
"spring rolls",
|
|
95
|
+
"steak",
|
|
96
|
+
"strawberry shortcake",
|
|
97
|
+
"sushi",
|
|
98
|
+
"tacos",
|
|
99
|
+
"takoyaki",
|
|
100
|
+
"tiramisu",
|
|
101
|
+
"tuna tartare",
|
|
102
|
+
"waffles",
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
templates = [lambda c: f"a photo of {c}, a type of food."]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
classnames_mapping = {
|
|
2
|
+
"0": "お",
|
|
3
|
+
"1": "き",
|
|
4
|
+
"2": "す",
|
|
5
|
+
"3": "つ",
|
|
6
|
+
"4": "な",
|
|
7
|
+
"5": "は",
|
|
8
|
+
"6": "ま",
|
|
9
|
+
"7": "や",
|
|
10
|
+
"8": "れ",
|
|
11
|
+
"9": "を",
|
|
12
|
+
}
|
|
13
|
+
classnames = [classnames_mapping[str(c)] for c in range(10)]
|
|
14
|
+
|
|
15
|
+
templates = [
|
|
16
|
+
lambda c: f"a photo of the character {c}.",
|
|
17
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
classnames = [
|
|
2
|
+
"Anthracnose",
|
|
3
|
+
"Bacterial Canker",
|
|
4
|
+
"Cutting Weevil",
|
|
5
|
+
"Die Back",
|
|
6
|
+
"Gall Midge",
|
|
7
|
+
"Healthy",
|
|
8
|
+
"Powdery Mildew",
|
|
9
|
+
"Sooty Mould",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
templates = [
|
|
13
|
+
lambda c: f"a photo of a mango leaf with {c}.",
|
|
14
|
+
lambda c: f"a mango leaf showing symptoms of {c}.",
|
|
15
|
+
lambda c: f"a close-up photo of {c} on a mango leaf.",
|
|
16
|
+
lambda c: f"this mango leaf is affected by {c}.",
|
|
17
|
+
lambda c: f"a mango leaf disease identified as {c}.",
|
|
18
|
+
lambda c: f"a {c} infection on a mango leaf.",
|
|
19
|
+
]
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from collections import OrderedDict
|
|
3
|
-
from typing import List, Mapping, Union
|
|
3
|
+
from typing import List, Mapping, Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
@@ -43,7 +43,10 @@ def trainable_state_dict(
|
|
|
43
43
|
return state_dict
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def state_dict_to_vector(
|
|
46
|
+
def state_dict_to_vector(
|
|
47
|
+
state_dict: StateDictType,
|
|
48
|
+
remove_keys: Optional[List[str]] = None,
|
|
49
|
+
):
|
|
47
50
|
"""
|
|
48
51
|
Convert a state dictionary to a vector.
|
|
49
52
|
|
|
@@ -54,6 +57,7 @@ def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
|
54
57
|
Returns:
|
|
55
58
|
torch.Tensor: The converted vector.
|
|
56
59
|
"""
|
|
60
|
+
remove_keys = remove_keys if remove_keys is not None else []
|
|
57
61
|
shared_state_dict = copy.deepcopy(state_dict)
|
|
58
62
|
for key in remove_keys:
|
|
59
63
|
if key in shared_state_dict:
|
|
@@ -64,7 +68,11 @@ def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
|
64
68
|
)
|
|
65
69
|
|
|
66
70
|
|
|
67
|
-
def vector_to_state_dict(
|
|
71
|
+
def vector_to_state_dict(
|
|
72
|
+
vector: torch.Tensor,
|
|
73
|
+
state_dict: StateDictType,
|
|
74
|
+
remove_keys: Optional[List[str]] = None,
|
|
75
|
+
):
|
|
68
76
|
"""
|
|
69
77
|
Convert a vector to a state dictionary.
|
|
70
78
|
|
|
@@ -76,6 +84,7 @@ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
|
76
84
|
Returns:
|
|
77
85
|
dict: The converted state dictionary.
|
|
78
86
|
"""
|
|
87
|
+
remove_keys = remove_keys if remove_keys is not None else []
|
|
79
88
|
# create a reference dict to define the order of the vector
|
|
80
89
|
reference_dict = copy.deepcopy(state_dict)
|
|
81
90
|
for key in remove_keys:
|
fusion_bench/utils/type.py
CHANGED
|
@@ -22,4 +22,13 @@ T2 = TypeVar("T2")
|
|
|
22
22
|
T3 = TypeVar("T3")
|
|
23
23
|
T4 = TypeVar("T4")
|
|
24
24
|
|
|
25
|
-
__all__ = [
|
|
25
|
+
__all__ = [
|
|
26
|
+
"StateDictType",
|
|
27
|
+
"PyModuleType",
|
|
28
|
+
"TorchModelType",
|
|
29
|
+
"T",
|
|
30
|
+
"T1",
|
|
31
|
+
"T2",
|
|
32
|
+
"T3",
|
|
33
|
+
"T4",
|
|
34
|
+
]
|