fusion-bench 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/base_algorithm.py +1 -1
- fusion_bench/dataset/clip_dataset.py +3 -0
- fusion_bench/dataset/fer2013.py +12 -0
- fusion_bench/dataset/llama/preference_700k.py +1 -1
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/classification/clip_finetune.py +10 -13
- fusion_bench/method/surgery/__init__.py +1 -3
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +1 -1
- fusion_bench/method/tall_mask/__init__.py +0 -0
- fusion_bench/method/tall_mask/utils.py +234 -0
- fusion_bench/method/task_singular_vector/TSVC.py +16 -0
- fusion_bench/method/task_singular_vector/TSVM.py +63 -0
- fusion_bench/method/task_singular_vector/__init__.py +9 -0
- fusion_bench/method/task_singular_vector/utils/TSVC_utils.py +50 -0
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +642 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +7 -0
- fusion_bench/method/ties_merging/ties_merging_utils.py +7 -2
- fusion_bench/mixins/clip_classification.py +6 -6
- fusion_bench/mixins/lightning_fabric.py +3 -1
- fusion_bench/modelpool/base_pool.py +0 -1
- fusion_bench/modelpool/clip_vision/modelpool.py +92 -8
- fusion_bench/models/surgery/__init__.py +1 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +2 -1
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +1 -1
- fusion_bench/programs/fabric_fusion_program.py +7 -4
- fusion_bench/taskpool/llama/reward_model.py +1 -1
- fusion_bench/tasks/clip_classification/__init__.py +13 -45
- fusion_bench/tasks/clip_classification/clip_dataset.py +1 -16
- fusion_bench/tasks/clip_classification/cub_200_2011.py +208 -0
- fusion_bench/tasks/clip_classification/emnist_letters.py +31 -0
- fusion_bench/tasks/clip_classification/emnist_mnist.py +5 -0
- fusion_bench/tasks/clip_classification/fashion_mnist.py +18 -0
- fusion_bench/tasks/clip_classification/fer2013.py +18 -0
- fusion_bench/tasks/clip_classification/food101.py +105 -0
- fusion_bench/tasks/clip_classification/kmnist.py +17 -0
- fusion_bench/tasks/clip_classification/mongo_leaf_disease.py +19 -0
- fusion_bench/tasks/clip_classification/pcam.py +5 -0
- fusion_bench/utils/parameters.py +12 -3
- fusion_bench/utils/type.py +10 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/RECORD +195 -62
- fusion_bench_config/dataset/image_classification/README.md +6 -0
- fusion_bench_config/dataset/image_classification/test/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/test/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +5 -0
- fusion_bench_config/dataset/image_classification/test/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/test/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/test/stl10.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/TALL14.yaml +20 -0
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/cifar10.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cifar100.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/cub-200-2011.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_letters.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/emnist_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/fer2013.yaml +3 -0
- fusion_bench_config/dataset/image_classification/train/food101.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/kmnist.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/mango-leaf-disease.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford-iiit-pet.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/oxford_flowers102.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/pcam.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/rendered-sst2.yaml +4 -0
- fusion_bench_config/dataset/image_classification/train/stl10.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -0
- fusion_bench_config/model/clip-vit/README.md +38 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eight_tasks.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL14.yaml +22 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_TALL20.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_cifar100.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_dtd.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_emnist_letters.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_eurosat.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fashion_mnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_fer2013.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_food101.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_gtsrb.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_kmnist.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_mnist.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -0
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -3
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -3
- fusion_bench_config/model/clip-vit/download_TALL20_models.sh +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +15 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +9 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +15 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +9 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +15 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +27 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL14.yaml +19 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TALL20.yaml +26 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_cifar100.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_dtd.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_emnist_letters.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_eurosat.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fashion_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_fer2013.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_food101.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_gtsrb.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_kmnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_mnist.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford-iiit-pet.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_oxford_flowers102_val.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_pcam.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_rendered-sst2.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_resisc45.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stanford-cars.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_stl10.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397.yaml +3 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml +3 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.7.dist-info → fusion_bench-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -132,13 +132,13 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
132
132
|
|
|
133
133
|
# get cache directory
|
|
134
134
|
if self.modelpool.has_pretrained:
|
|
135
|
-
model_name = self.modelpool.get_model_config(
|
|
136
|
-
|
|
137
|
-
|
|
135
|
+
model_name = self.modelpool.get_model_config("_pretrained_")
|
|
136
|
+
if not isinstance(model_name, str):
|
|
137
|
+
model_name = model_name.pretrained_model_name_or_path
|
|
138
138
|
else:
|
|
139
|
-
model_name = self.modelpool.get_model_config(
|
|
140
|
-
|
|
141
|
-
|
|
139
|
+
model_name = self.modelpool.get_model_config(self.modelpool.model_names[0])
|
|
140
|
+
if not isinstance(model_name, str):
|
|
141
|
+
model_name = model_name.pretrained_model_name_or_path
|
|
142
142
|
cache_dir = os.path.join(
|
|
143
143
|
self.zeroshot_weights_cache_dir,
|
|
144
144
|
os.path.normpath(model_name.split("/")[-1]),
|
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
|
|
|
5
5
|
|
|
6
6
|
import lightning as L
|
|
7
7
|
import torch
|
|
8
|
+
from lightning.fabric.connector import _is_using_cli
|
|
8
9
|
from lightning.fabric.loggers import TensorBoardLogger
|
|
9
10
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
10
11
|
from omegaconf import DictConfig, OmegaConf
|
|
@@ -79,7 +80,8 @@ class LightningFabricMixin:
|
|
|
79
80
|
self._fabric_instance = L.Fabric()
|
|
80
81
|
else:
|
|
81
82
|
self._fabric_instance = instantiate(config.fabric)
|
|
82
|
-
|
|
83
|
+
if not _is_using_cli(): # if not using cli, launch the fabric
|
|
84
|
+
self._fabric_instance.launch()
|
|
83
85
|
# Set the log directory in config if it is not already set
|
|
84
86
|
if (
|
|
85
87
|
self.log_dir is not None
|
|
@@ -147,7 +147,6 @@ class BaseModelPool(BaseYAMLSerializableModel):
|
|
|
147
147
|
DictConfig: The configuration for the specified model.
|
|
148
148
|
"""
|
|
149
149
|
model_config = self._models[model_name]
|
|
150
|
-
assert isinstance(model_config, DictConfig), "Model config must be a DictConfig"
|
|
151
150
|
if return_copy:
|
|
152
151
|
model_config = deepcopy(model_config)
|
|
153
152
|
return model_config
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional, Union
|
|
4
4
|
|
|
5
|
+
from datasets import load_dataset
|
|
5
6
|
from omegaconf import DictConfig, open_dict
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.utils.data import Dataset
|
|
6
9
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
7
10
|
from typing_extensions import override
|
|
8
11
|
|
|
@@ -36,17 +39,29 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
36
39
|
|
|
37
40
|
def load_processor(self, *args, **kwargs) -> CLIPProcessor:
|
|
38
41
|
assert self._processor is not None, "Processor is not defined in the config"
|
|
39
|
-
|
|
42
|
+
if isinstance(self._processor, str):
|
|
43
|
+
log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
|
|
44
|
+
processor = CLIPProcessor.from_pretrained(self._processor)
|
|
45
|
+
else:
|
|
46
|
+
processor = instantiate(self._processor, *args, **kwargs)
|
|
40
47
|
return processor
|
|
41
48
|
|
|
42
49
|
def load_clip_model(self, model_name: str, *args, **kwargs) -> CLIPModel:
|
|
43
50
|
model_config = self._models[model_name]
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
51
|
+
|
|
52
|
+
if isinstance(model_config, str):
|
|
53
|
+
log.info(f"Loading `transformers.CLIPModel`: {model_config}")
|
|
54
|
+
clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
|
|
55
|
+
return clip_model
|
|
56
|
+
else:
|
|
57
|
+
assert isinstance(
|
|
58
|
+
model_config, DictConfig
|
|
59
|
+
), "Model config must be a DictConfig"
|
|
60
|
+
model_config = deepcopy(model_config)
|
|
61
|
+
with open_dict(model_config):
|
|
62
|
+
model_config._target_ = "transformers.CLIPModel.from_pretrained"
|
|
63
|
+
clip_model = instantiate(model_config, *args, **kwargs)
|
|
64
|
+
return clip_model
|
|
50
65
|
|
|
51
66
|
@override
|
|
52
67
|
def save_model(self, model: CLIPVisionModel, path: str):
|
|
@@ -59,3 +74,72 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
59
74
|
"""
|
|
60
75
|
with timeit_context(f'Saving clip vision model to "{path}"'):
|
|
61
76
|
model.save_pretrained(path)
|
|
77
|
+
|
|
78
|
+
def load_model(
|
|
79
|
+
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
80
|
+
) -> CLIPVisionModel:
|
|
81
|
+
"""
|
|
82
|
+
This method is used to load a CLIPVisionModel from the model pool.
|
|
83
|
+
|
|
84
|
+
Example configuration could be:
|
|
85
|
+
|
|
86
|
+
```yaml
|
|
87
|
+
models:
|
|
88
|
+
cifar10: tanganke/clip-vit-base-patch32_cifar10
|
|
89
|
+
sun397: tanganke/clip-vit-base-patch32_sun397
|
|
90
|
+
stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
model_name_or_config (Union[str, DictConfig]): The name of the model or the model configuration.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
CLIPVisionModel: The loaded CLIPVisionModel.
|
|
98
|
+
"""
|
|
99
|
+
if (
|
|
100
|
+
isinstance(model_name_or_config, str)
|
|
101
|
+
and model_name_or_config in self._models
|
|
102
|
+
):
|
|
103
|
+
model = self._models[model_name_or_config]
|
|
104
|
+
if isinstance(model, str):
|
|
105
|
+
log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
|
|
106
|
+
return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
|
|
107
|
+
if isinstance(model, nn.Module):
|
|
108
|
+
log.info(f"Returning existing model: {model}")
|
|
109
|
+
return model
|
|
110
|
+
|
|
111
|
+
# If the model is not a string, we use the default load_model method
|
|
112
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
113
|
+
|
|
114
|
+
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
115
|
+
dataset_config = self._train_datasets[dataset_name]
|
|
116
|
+
if isinstance(dataset_config, str):
|
|
117
|
+
log.info(
|
|
118
|
+
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
119
|
+
)
|
|
120
|
+
dataset = load_dataset(dataset_config, split="train")
|
|
121
|
+
else:
|
|
122
|
+
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
123
|
+
return dataset
|
|
124
|
+
|
|
125
|
+
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
126
|
+
dataset_config = self._val_datasets[dataset_name]
|
|
127
|
+
if isinstance(dataset_config, str):
|
|
128
|
+
log.info(
|
|
129
|
+
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
130
|
+
)
|
|
131
|
+
dataset = load_dataset(dataset_config, split="validation")
|
|
132
|
+
else:
|
|
133
|
+
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
134
|
+
return dataset
|
|
135
|
+
|
|
136
|
+
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
137
|
+
dataset_config = self._test_datasets[dataset_name]
|
|
138
|
+
if isinstance(dataset_config, str):
|
|
139
|
+
log.info(
|
|
140
|
+
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
141
|
+
)
|
|
142
|
+
dataset = load_dataset(dataset_config, split="test")
|
|
143
|
+
else:
|
|
144
|
+
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
145
|
+
return dataset
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .surgerymodelwrapper import SurgeryModelWrapper
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import TYPE_CHECKING,
|
|
2
|
+
from typing import TYPE_CHECKING, Callable, Generic, List, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
@@ -7,6 +7,7 @@ from transformers.models.clip.modeling_clip import (
|
|
|
7
7
|
CLIPVisionModel,
|
|
8
8
|
CLIPVisionTransformer,
|
|
9
9
|
)
|
|
10
|
+
|
|
10
11
|
from fusion_bench.utils.type import TorchModelType
|
|
11
12
|
|
|
12
13
|
|
|
@@ -16,7 +16,7 @@ import torch
|
|
|
16
16
|
from torch import Tensor, nn
|
|
17
17
|
from torch.func import functional_call
|
|
18
18
|
|
|
19
|
-
from fusion_bench.utils.type import
|
|
19
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
20
20
|
|
|
21
21
|
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
22
22
|
|
|
@@ -22,7 +22,7 @@ import torch
|
|
|
22
22
|
from torch import Tensor, nn
|
|
23
23
|
from torch.func import functional_call
|
|
24
24
|
|
|
25
|
-
from fusion_bench.utils.type import
|
|
25
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
26
26
|
|
|
27
27
|
log = logging.getLogger(__name__)
|
|
28
28
|
|
|
@@ -185,10 +185,13 @@ class FabricModelFusionProgram(
|
|
|
185
185
|
report = taskpool.evaluate(merged_model)
|
|
186
186
|
return report
|
|
187
187
|
elif isinstance(merged_model, Dict):
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
188
|
+
report = {}
|
|
189
|
+
for key, item in merged_model.items():
|
|
190
|
+
if isinstance(item, nn.Module):
|
|
191
|
+
report[key] = taskpool.evaluate(item)
|
|
192
|
+
else:
|
|
193
|
+
# metadata
|
|
194
|
+
report[key] = item
|
|
192
195
|
return report
|
|
193
196
|
elif isinstance(merged_model, Iterable):
|
|
194
197
|
return [
|
|
@@ -11,10 +11,10 @@ import functools
|
|
|
11
11
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
|
12
12
|
|
|
13
13
|
import lightning as L
|
|
14
|
+
import numpy as np
|
|
14
15
|
import torch
|
|
15
16
|
from omegaconf import DictConfig
|
|
16
17
|
from torch.utils.data import Subset
|
|
17
|
-
import numpy as np
|
|
18
18
|
from tqdm.auto import tqdm
|
|
19
19
|
|
|
20
20
|
from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
|
|
@@ -58,11 +58,24 @@ class CLIPTemplateFactory:
|
|
|
58
58
|
"templates": "templates",
|
|
59
59
|
},
|
|
60
60
|
"nateraw/rendered-sst2": ".rendered_sst2",
|
|
61
|
+
"rendered-sst2": ".rendered_sst2",
|
|
61
62
|
"tanganke/stl10": ".stl10",
|
|
63
|
+
"stl10": ".stl10",
|
|
62
64
|
"dpdl-benchmark/oxford_flowers102": ".flower102",
|
|
65
|
+
"oxford_flowers102": ".flower102",
|
|
63
66
|
"timm/oxford-iiit-pet": ".oxford_iiit_pet",
|
|
67
|
+
"oxford-iiit-pet": ".oxford_iiit_pet",
|
|
64
68
|
"imagenet": ".imagenet",
|
|
65
69
|
"tiny-imagenet": ".tiny_imagenet",
|
|
70
|
+
"pcam": ".pcam",
|
|
71
|
+
"fer2013": ".fer2013",
|
|
72
|
+
"emnist_mnist": ".emnist_mnist",
|
|
73
|
+
"emnist_letters": ".emnist_letters",
|
|
74
|
+
"kmnist": ".kmnist",
|
|
75
|
+
"food101": ".food101",
|
|
76
|
+
"fashion_mnist": ".fashion_mnist",
|
|
77
|
+
"cub-200-2011": ".cub_200_2011",
|
|
78
|
+
"mango-leaf-disease": ".mango_leaf_disease",
|
|
66
79
|
}
|
|
67
80
|
|
|
68
81
|
@staticmethod
|
|
@@ -168,48 +181,3 @@ class CLIPTemplateFactory:
|
|
|
168
181
|
|
|
169
182
|
def get_classnames_and_templates(dataset_name: str):
|
|
170
183
|
return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def _load_hf_dataset(dataset_name: str):
|
|
174
|
-
"""
|
|
175
|
-
Load a dataset from the Hugging Face datasets library based on the specified dataset name.
|
|
176
|
-
|
|
177
|
-
This function handles specific preprocessing steps for certain datasets to ensure consistency in dataset format.
|
|
178
|
-
For example, it renames columns, removes unnecessary columns, and specifies subsets for certain datasets.
|
|
179
|
-
|
|
180
|
-
Expected dataset format:
|
|
181
|
-
- The dataset should have an "image" column containing the image data.
|
|
182
|
-
- The dataset should have a "label" column containing the class labels.
|
|
183
|
-
|
|
184
|
-
Args:
|
|
185
|
-
dataset_name (str): The name of the dataset to load. Can be one of "svhn", "cifar10", "cifar100", "timm/oxford-iiit-pet", or any other dataset name supported by the Hugging Face datasets library. By default, the datasets have two columns: "image" and "label".
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
A dataset object loaded from the Hugging Face datasets library, with any necessary preprocessing applied.
|
|
189
|
-
"""
|
|
190
|
-
if dataset_name == "svhn":
|
|
191
|
-
return load_dataset(dataset_name, "cropped_digits")
|
|
192
|
-
elif dataset_name == "cifar10":
|
|
193
|
-
dataset = load_dataset(dataset_name)
|
|
194
|
-
dataset = dataset.rename_columns({"img": "image"})
|
|
195
|
-
return dataset
|
|
196
|
-
elif dataset_name == "cifar100":
|
|
197
|
-
dataset = load_dataset(dataset_name)
|
|
198
|
-
dataset = dataset.remove_columns(["coarse_label"]).rename_columns(
|
|
199
|
-
{"img": "image", "fine_label": "label"}
|
|
200
|
-
)
|
|
201
|
-
return dataset
|
|
202
|
-
elif dataset_name == "timm/oxford-iiit-pet":
|
|
203
|
-
dataset = load_dataset(dataset_name)
|
|
204
|
-
dataset = dataset.remove_columns(["image_id", "label_cat_dog"])
|
|
205
|
-
return dataset
|
|
206
|
-
else:
|
|
207
|
-
return load_dataset(dataset_name)
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
def load_clip_dataset(dataset: str, processor):
|
|
211
|
-
hf_dataset = _load_hf_dataset(dataset)
|
|
212
|
-
return (
|
|
213
|
-
CLIPDataset(hf_dataset["train"], processor),
|
|
214
|
-
CLIPDataset(hf_dataset["test"], processor),
|
|
215
|
-
)
|
|
@@ -1,16 +1 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class CLIPDataset(torch.utils.data.Dataset):
|
|
5
|
-
def __init__(self, dataset, processor):
|
|
6
|
-
self.dataset = dataset
|
|
7
|
-
self.processor = processor
|
|
8
|
-
|
|
9
|
-
def __len__(self):
|
|
10
|
-
return len(self.dataset)
|
|
11
|
-
|
|
12
|
-
def __getitem__(self, idx):
|
|
13
|
-
item = self.dataset[idx]
|
|
14
|
-
image = item["image"]
|
|
15
|
-
inputs = self.processor(images=[image], return_tensors="pt")["pixel_values"][0]
|
|
16
|
-
return inputs, item["label"]
|
|
1
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
classname_mapping = {
|
|
2
|
+
"0": "Black_footed_Albatross",
|
|
3
|
+
"1": "Laysan_Albatross",
|
|
4
|
+
"2": "Sooty_Albatross",
|
|
5
|
+
"3": "Groove_billed_Ani",
|
|
6
|
+
"4": "Crested_Auklet",
|
|
7
|
+
"5": "Least_Auklet",
|
|
8
|
+
"6": "Parakeet_Auklet",
|
|
9
|
+
"7": "Rhinoceros_Auklet",
|
|
10
|
+
"8": "Brewer_Blackbird",
|
|
11
|
+
"9": "Red_winged_Blackbird",
|
|
12
|
+
"10": "Rusty_Blackbird",
|
|
13
|
+
"11": "Yellow_headed_Blackbird",
|
|
14
|
+
"12": "Bobolink",
|
|
15
|
+
"13": "Indigo_Bunting",
|
|
16
|
+
"14": "Lazuli_Bunting",
|
|
17
|
+
"15": "Painted_Bunting",
|
|
18
|
+
"16": "Cardinal",
|
|
19
|
+
"17": "Spotted_Catbird",
|
|
20
|
+
"18": "Gray_Catbird",
|
|
21
|
+
"19": "Yellow_breasted_Chat",
|
|
22
|
+
"20": "Eastern_Towhee",
|
|
23
|
+
"21": "Chuck_will_Widow",
|
|
24
|
+
"22": "Brandt_Cormorant",
|
|
25
|
+
"23": "Red_faced_Cormorant",
|
|
26
|
+
"24": "Pelagic_Cormorant",
|
|
27
|
+
"25": "Bronzed_Cowbird",
|
|
28
|
+
"26": "Shiny_Cowbird",
|
|
29
|
+
"27": "Brown_Creeper",
|
|
30
|
+
"28": "American_Crow",
|
|
31
|
+
"29": "Fish_Crow",
|
|
32
|
+
"30": "Black_billed_Cuckoo",
|
|
33
|
+
"31": "Mangrove_Cuckoo",
|
|
34
|
+
"32": "Yellow_billed_Cuckoo",
|
|
35
|
+
"33": "Gray_crowned_Rosy_Finch",
|
|
36
|
+
"34": "Purple_Finch",
|
|
37
|
+
"35": "Northern_Flicker",
|
|
38
|
+
"36": "Acadian_Flycatcher",
|
|
39
|
+
"37": "Great_Crested_Flycatcher",
|
|
40
|
+
"38": "Least_Flycatcher",
|
|
41
|
+
"39": "Olive_sided_Flycatcher",
|
|
42
|
+
"40": "Scissor_tailed_Flycatcher",
|
|
43
|
+
"41": "Vermilion_Flycatcher",
|
|
44
|
+
"42": "Yellow_bellied_Flycatcher",
|
|
45
|
+
"43": "Frigatebird",
|
|
46
|
+
"44": "Northern_Fulmar",
|
|
47
|
+
"45": "Gadwall",
|
|
48
|
+
"46": "American_Goldfinch",
|
|
49
|
+
"47": "European_Goldfinch",
|
|
50
|
+
"48": "Boat_tailed_Grackle",
|
|
51
|
+
"49": "Eared_Grebe",
|
|
52
|
+
"50": "Horned_Grebe",
|
|
53
|
+
"51": "Pied_billed_Grebe",
|
|
54
|
+
"52": "Western_Grebe",
|
|
55
|
+
"53": "Blue_Grosbeak",
|
|
56
|
+
"54": "Evening_Grosbeak",
|
|
57
|
+
"55": "Pine_Grosbeak",
|
|
58
|
+
"56": "Rose_breasted_Grosbeak",
|
|
59
|
+
"57": "Pigeon_Guillemot",
|
|
60
|
+
"58": "California_Gull",
|
|
61
|
+
"59": "Glaucous_winged_Gull",
|
|
62
|
+
"60": "Heermann_Gull",
|
|
63
|
+
"61": "Herring_Gull",
|
|
64
|
+
"62": "Ivory_Gull",
|
|
65
|
+
"63": "Ring_billed_Gull",
|
|
66
|
+
"64": "Slaty_backed_Gull",
|
|
67
|
+
"65": "Western_Gull",
|
|
68
|
+
"66": "Anna_Hummingbird",
|
|
69
|
+
"67": "Ruby_throated_Hummingbird",
|
|
70
|
+
"68": "Rufous_Hummingbird",
|
|
71
|
+
"69": "Green_Violetear",
|
|
72
|
+
"70": "Long_tailed_Jaeger",
|
|
73
|
+
"71": "Pomarine_Jaeger",
|
|
74
|
+
"72": "Blue_Jay",
|
|
75
|
+
"73": "Florida_Jay",
|
|
76
|
+
"74": "Green_Jay",
|
|
77
|
+
"75": "Dark_eyed_Junco",
|
|
78
|
+
"76": "Tropical_Kingbird",
|
|
79
|
+
"77": "Gray_Kingbird",
|
|
80
|
+
"78": "Belted_Kingfisher",
|
|
81
|
+
"79": "Green_Kingfisher",
|
|
82
|
+
"80": "Pied_Kingfisher",
|
|
83
|
+
"81": "Ringed_Kingfisher",
|
|
84
|
+
"82": "White_breasted_Kingfisher",
|
|
85
|
+
"83": "Red_legged_Kittiwake",
|
|
86
|
+
"84": "Horned_Lark",
|
|
87
|
+
"85": "Pacific_Loon",
|
|
88
|
+
"86": "Mallard",
|
|
89
|
+
"87": "Western_Meadowlark",
|
|
90
|
+
"88": "Hooded_Merganser",
|
|
91
|
+
"89": "Red_breasted_Merganser",
|
|
92
|
+
"90": "Mockingbird",
|
|
93
|
+
"91": "Nighthawk",
|
|
94
|
+
"92": "Clark_Nutcracker",
|
|
95
|
+
"93": "White_breasted_Nuthatch",
|
|
96
|
+
"94": "Baltimore_Oriole",
|
|
97
|
+
"95": "Hooded_Oriole",
|
|
98
|
+
"96": "Orchard_Oriole",
|
|
99
|
+
"97": "Scott_Oriole",
|
|
100
|
+
"98": "Ovenbird",
|
|
101
|
+
"99": "Brown_Pelican",
|
|
102
|
+
"100": "White_Pelican",
|
|
103
|
+
"101": "Western_Wood_Pewee",
|
|
104
|
+
"102": "Sayornis",
|
|
105
|
+
"103": "American_Pipit",
|
|
106
|
+
"104": "Whip_poor_Will",
|
|
107
|
+
"105": "Horned_Puffin",
|
|
108
|
+
"106": "Common_Raven",
|
|
109
|
+
"107": "White_necked_Raven",
|
|
110
|
+
"108": "American_Redstart",
|
|
111
|
+
"109": "Geococcyx",
|
|
112
|
+
"110": "Loggerhead_Shrike",
|
|
113
|
+
"111": "Great_Grey_Shrike",
|
|
114
|
+
"112": "Baird_Sparrow",
|
|
115
|
+
"113": "Black_throated_Sparrow",
|
|
116
|
+
"114": "Brewer_Sparrow",
|
|
117
|
+
"115": "Chipping_Sparrow",
|
|
118
|
+
"116": "Clay_colored_Sparrow",
|
|
119
|
+
"117": "House_Sparrow",
|
|
120
|
+
"118": "Field_Sparrow",
|
|
121
|
+
"119": "Fox_Sparrow",
|
|
122
|
+
"120": "Grasshopper_Sparrow",
|
|
123
|
+
"121": "Harris_Sparrow",
|
|
124
|
+
"122": "Henslow_Sparrow",
|
|
125
|
+
"123": "Le_Conte_Sparrow",
|
|
126
|
+
"124": "Lincoln_Sparrow",
|
|
127
|
+
"125": "Nelson_Sharp_tailed_Sparrow",
|
|
128
|
+
"126": "Savannah_Sparrow",
|
|
129
|
+
"127": "Seaside_Sparrow",
|
|
130
|
+
"128": "Song_Sparrow",
|
|
131
|
+
"129": "Tree_Sparrow",
|
|
132
|
+
"130": "Vesper_Sparrow",
|
|
133
|
+
"131": "White_crowned_Sparrow",
|
|
134
|
+
"132": "White_throated_Sparrow",
|
|
135
|
+
"133": "Cape_Glossy_Starling",
|
|
136
|
+
"134": "Bank_Swallow",
|
|
137
|
+
"135": "Barn_Swallow",
|
|
138
|
+
"136": "Cliff_Swallow",
|
|
139
|
+
"137": "Tree_Swallow",
|
|
140
|
+
"138": "Scarlet_Tanager",
|
|
141
|
+
"139": "Summer_Tanager",
|
|
142
|
+
"140": "Artic_Tern",
|
|
143
|
+
"141": "Black_Tern",
|
|
144
|
+
"142": "Caspian_Tern",
|
|
145
|
+
"143": "Common_Tern",
|
|
146
|
+
"144": "Elegant_Tern",
|
|
147
|
+
"145": "Forsters_Tern",
|
|
148
|
+
"146": "Least_Tern",
|
|
149
|
+
"147": "Green_tailed_Towhee",
|
|
150
|
+
"148": "Brown_Thrasher",
|
|
151
|
+
"149": "Sage_Thrasher",
|
|
152
|
+
"150": "Black_capped_Vireo",
|
|
153
|
+
"151": "Blue_headed_Vireo",
|
|
154
|
+
"152": "Philadelphia_Vireo",
|
|
155
|
+
"153": "Red_eyed_Vireo",
|
|
156
|
+
"154": "Warbling_Vireo",
|
|
157
|
+
"155": "White_eyed_Vireo",
|
|
158
|
+
"156": "Yellow_throated_Vireo",
|
|
159
|
+
"157": "Bay_breasted_Warbler",
|
|
160
|
+
"158": "Black_and_white_Warbler",
|
|
161
|
+
"159": "Black_throated_Blue_Warbler",
|
|
162
|
+
"160": "Blue_winged_Warbler",
|
|
163
|
+
"161": "Canada_Warbler",
|
|
164
|
+
"162": "Cape_May_Warbler",
|
|
165
|
+
"163": "Cerulean_Warbler",
|
|
166
|
+
"164": "Chestnut_sided_Warbler",
|
|
167
|
+
"165": "Golden_winged_Warbler",
|
|
168
|
+
"166": "Hooded_Warbler",
|
|
169
|
+
"167": "Kentucky_Warbler",
|
|
170
|
+
"168": "Magnolia_Warbler",
|
|
171
|
+
"169": "Mourning_Warbler",
|
|
172
|
+
"170": "Myrtle_Warbler",
|
|
173
|
+
"171": "Nashville_Warbler",
|
|
174
|
+
"172": "Orange_crowned_Warbler",
|
|
175
|
+
"173": "Palm_Warbler",
|
|
176
|
+
"174": "Pine_Warbler",
|
|
177
|
+
"175": "Prairie_Warbler",
|
|
178
|
+
"176": "Prothonotary_Warbler",
|
|
179
|
+
"177": "Swainson_Warbler",
|
|
180
|
+
"178": "Tennessee_Warbler",
|
|
181
|
+
"179": "Wilson_Warbler",
|
|
182
|
+
"180": "Worm_eating_Warbler",
|
|
183
|
+
"181": "Yellow_Warbler",
|
|
184
|
+
"182": "Northern_Waterthrush",
|
|
185
|
+
"183": "Louisiana_Waterthrush",
|
|
186
|
+
"184": "Bohemian_Waxwing",
|
|
187
|
+
"185": "Cedar_Waxwing",
|
|
188
|
+
"186": "American_Three_toed_Woodpecker",
|
|
189
|
+
"187": "Pileated_Woodpecker",
|
|
190
|
+
"188": "Red_bellied_Woodpecker",
|
|
191
|
+
"189": "Red_cockaded_Woodpecker",
|
|
192
|
+
"190": "Red_headed_Woodpecker",
|
|
193
|
+
"191": "Downy_Woodpecker",
|
|
194
|
+
"192": "Bewick_Wren",
|
|
195
|
+
"193": "Cactus_Wren",
|
|
196
|
+
"194": "Carolina_Wren",
|
|
197
|
+
"195": "House_Wren",
|
|
198
|
+
"196": "Marsh_Wren",
|
|
199
|
+
"197": "Rock_Wren",
|
|
200
|
+
"198": "Winter_Wren",
|
|
201
|
+
"199": "Common_Yellowthroat",
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
classnames = [classname_mapping[str(i)] for i in range(200)]
|
|
205
|
+
templates = [
|
|
206
|
+
lambda c: f"a photo of a {c}.",
|
|
207
|
+
lambda c: f"a photo of the {c}.",
|
|
208
|
+
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
classnames_mapping = {
|
|
2
|
+
"0": "A",
|
|
3
|
+
"1": "B",
|
|
4
|
+
"2": "C",
|
|
5
|
+
"3": "D",
|
|
6
|
+
"4": "E",
|
|
7
|
+
"5": "F",
|
|
8
|
+
"6": "G",
|
|
9
|
+
"7": "H",
|
|
10
|
+
"8": "I",
|
|
11
|
+
"9": "J",
|
|
12
|
+
"10": "K",
|
|
13
|
+
"11": "L",
|
|
14
|
+
"12": "M",
|
|
15
|
+
"13": "N",
|
|
16
|
+
"14": "O",
|
|
17
|
+
"15": "P",
|
|
18
|
+
"16": "Q",
|
|
19
|
+
"17": "R",
|
|
20
|
+
"18": "S",
|
|
21
|
+
"19": "T",
|
|
22
|
+
"20": "U",
|
|
23
|
+
"21": "V",
|
|
24
|
+
"22": "W",
|
|
25
|
+
"23": "X",
|
|
26
|
+
"24": "Y",
|
|
27
|
+
"25": "Z",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
classnames = [classnames_mapping[str(i)] for i in range(26)]
|
|
31
|
+
templates = [lambda c: f'a photo of the digit character: "{c}".']
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
classname_mapping = {
|
|
2
|
+
"0": "T - shirt / top",
|
|
3
|
+
"1": "Trouser",
|
|
4
|
+
"2": "Pullover",
|
|
5
|
+
"3": "Dress",
|
|
6
|
+
"4": "Coat",
|
|
7
|
+
"5": "Sandal",
|
|
8
|
+
"6": "Shirt",
|
|
9
|
+
"7": "Sneaker",
|
|
10
|
+
"8": "Bag",
|
|
11
|
+
"9": "Ankle boot",
|
|
12
|
+
}
|
|
13
|
+
classnames = [classname_mapping[str(i)] for i in range(10)]
|
|
14
|
+
|
|
15
|
+
templates = [
|
|
16
|
+
lambda c: f"a photo of a {c}.",
|
|
17
|
+
lambda c: f"a photo of the {c}.",
|
|
18
|
+
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
classnames = [
|
|
2
|
+
"angry",
|
|
3
|
+
"disgusted",
|
|
4
|
+
"fearful",
|
|
5
|
+
"happy",
|
|
6
|
+
"neutral",
|
|
7
|
+
"sad",
|
|
8
|
+
"surprised",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
templates = [
|
|
12
|
+
lambda c: f"a photo of a {c} looking face.",
|
|
13
|
+
lambda c: f"a photo of a face showing the emotion: {c}.",
|
|
14
|
+
lambda c: f"a photo of a face looking {c}.",
|
|
15
|
+
lambda c: f"a face that looks {c}.",
|
|
16
|
+
lambda c: f"they look {c}.",
|
|
17
|
+
lambda c: f"look at how {c} they are.",
|
|
18
|
+
]
|