birder-clip 0.0.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder_clip/__init__.py +7 -0
- birder_clip/common/__init__.py +0 -0
- birder_clip/common/fs_ops.py +111 -0
- birder_clip/common/lib.py +65 -0
- birder_clip/conf/__init__.py +0 -0
- birder_clip/conf/settings.py +27 -0
- birder_clip/inference/__init__.py +0 -0
- birder_clip/inference/zero_shot.py +50 -0
- birder_clip/inference/zero_shot_templates.py +229 -0
- birder_clip/model_registry/__init__.py +9 -0
- birder_clip/model_registry/model_registry.py +133 -0
- birder_clip/net/__init__.py +5 -0
- birder_clip/net/base.py +44 -0
- birder_clip/net/clip.py +91 -0
- birder_clip/net/text/__init__.py +7 -0
- birder_clip/net/text/base.py +29 -0
- birder_clip/net/text/transformer.py +325 -0
- birder_clip/py.typed +0 -0
- birder_clip/scripts/__init__.py +0 -0
- birder_clip/scripts/zero_shot.py +281 -0
- birder_clip/tokenizers/__init__.py +11 -0
- birder_clip/tokenizers/base.py +8 -0
- birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- birder_clip/tokenizers/openai_clip_bpe.py +225 -0
- birder_clip/tokenizers/registry.py +29 -0
- birder_clip/version.py +1 -0
- birder_clip-0.0.2.dev0.dist-info/METADATA +70 -0
- birder_clip-0.0.2.dev0.dist-info/RECORD +31 -0
- birder_clip-0.0.2.dev0.dist-info/WHEEL +5 -0
- birder_clip-0.0.2.dev0.dist-info/licenses/LICENSE +201 -0
- birder_clip-0.0.2.dev0.dist-info/top_level.txt +1 -0
birder_clip/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from birder.conf import settings
|
|
9
|
+
from birder.data.transforms.classification import RGBType
|
|
10
|
+
|
|
11
|
+
from birder_clip.common import lib
|
|
12
|
+
from birder_clip.model_registry import registry
|
|
13
|
+
from birder_clip.net.base import BaseNet
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelInfo(NamedTuple):
|
|
19
|
+
signature: dict[str, Any]
|
|
20
|
+
rgb_stats: RGBType
|
|
21
|
+
custom_config: Optional[dict[str, Any]] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def model_path(network_name: str, *, epoch: Optional[int | str] = None) -> Path:
|
|
25
|
+
if epoch is not None:
|
|
26
|
+
file_name = f"{network_name}_{epoch}.pt"
|
|
27
|
+
else:
|
|
28
|
+
file_name = f"{network_name}.pt"
|
|
29
|
+
|
|
30
|
+
return settings.MODELS_DIR.joinpath(file_name)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def load_model(
|
|
34
|
+
device: torch.device,
|
|
35
|
+
network: str,
|
|
36
|
+
*,
|
|
37
|
+
path: Optional[str | Path] = None,
|
|
38
|
+
config: Optional[dict[str, Any]] = None,
|
|
39
|
+
tag: Optional[str] = None,
|
|
40
|
+
image_encoder: Optional[str] = None,
|
|
41
|
+
text_encoder: Optional[str] = None,
|
|
42
|
+
embed_dim: Optional[int] = None,
|
|
43
|
+
tokenizer: Optional[str] = None,
|
|
44
|
+
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
45
|
+
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
46
|
+
epoch: Optional[int] = None,
|
|
47
|
+
new_size: Optional[tuple[int, int]] = None,
|
|
48
|
+
inference: bool,
|
|
49
|
+
dtype: Optional[torch.dtype] = None,
|
|
50
|
+
) -> tuple[BaseNet, ModelInfo]:
|
|
51
|
+
if path is None:
|
|
52
|
+
_network_name = lib.get_image_text_network_name(
|
|
53
|
+
network,
|
|
54
|
+
tag=tag,
|
|
55
|
+
image_encoder=image_encoder,
|
|
56
|
+
text_encoder=text_encoder,
|
|
57
|
+
embed_dim=embed_dim,
|
|
58
|
+
tokenizer=tokenizer,
|
|
59
|
+
)
|
|
60
|
+
path = model_path(_network_name, epoch=epoch)
|
|
61
|
+
|
|
62
|
+
logger.info(f"Loading model from {path} on device {device}...")
|
|
63
|
+
model_dict: dict[str, Any] = torch.load(path, map_location=device, weights_only=True)
|
|
64
|
+
|
|
65
|
+
loaded_config: dict[str, Any] = model_dict.get("config", {})
|
|
66
|
+
merged_config = {**loaded_config}
|
|
67
|
+
if tokenizer is not None:
|
|
68
|
+
tokenizer_name: Optional[str] = tokenizer
|
|
69
|
+
merged_config["tokenizer"] = tokenizer
|
|
70
|
+
else:
|
|
71
|
+
tokenizer_name = loaded_config.get("tokenizer")
|
|
72
|
+
|
|
73
|
+
if image_encoder is not None and text_encoder is not None and embed_dim is not None and tokenizer_name is not None:
|
|
74
|
+
image_size = loaded_config.get("image", {}).get("size")
|
|
75
|
+
merged_config.update(
|
|
76
|
+
lib.get_image_text_network_config(
|
|
77
|
+
image_encoder,
|
|
78
|
+
text_encoder,
|
|
79
|
+
embed_dim,
|
|
80
|
+
tokenizer_name,
|
|
81
|
+
image_size=image_size,
|
|
82
|
+
image_config=image_encoder_config,
|
|
83
|
+
text_config=text_encoder_config,
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
if config is not None:
|
|
87
|
+
merged_config.update(config)
|
|
88
|
+
if len(merged_config) == 0:
|
|
89
|
+
merged_config = None # type: ignore[assignment]
|
|
90
|
+
|
|
91
|
+
net = registry.net_factory(network, config=merged_config)
|
|
92
|
+
net.load_state_dict(model_dict["state"])
|
|
93
|
+
if new_size is not None:
|
|
94
|
+
net.adjust_image_size(new_size)
|
|
95
|
+
|
|
96
|
+
net.to(device)
|
|
97
|
+
if dtype is not None:
|
|
98
|
+
net.to(dtype)
|
|
99
|
+
if inference is True:
|
|
100
|
+
for param in net.parameters():
|
|
101
|
+
param.requires_grad_(False)
|
|
102
|
+
|
|
103
|
+
net.eval()
|
|
104
|
+
|
|
105
|
+
if len(loaded_config) == 0:
|
|
106
|
+
custom_config = None
|
|
107
|
+
else:
|
|
108
|
+
custom_config = loaded_config
|
|
109
|
+
logger.debug(f"Model loaded with custom config: {custom_config}")
|
|
110
|
+
|
|
111
|
+
return (net, ModelInfo(model_dict["signature"], model_dict["rgb_stats"], custom_config))
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_size_from_signature(signature: dict[str, Any]) -> tuple[int, int]:
|
|
6
|
+
return tuple(signature["inputs"][0]["data_shape"][2:4])
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_channels_from_signature(signature: dict[str, Any]) -> int:
|
|
10
|
+
return signature["inputs"][0]["data_shape"][1] # type: ignore[no-any-return]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_image_text_network_name(
|
|
14
|
+
network: str,
|
|
15
|
+
tag: Optional[str] = None,
|
|
16
|
+
image_encoder: Optional[str] = None,
|
|
17
|
+
text_encoder: Optional[str] = None,
|
|
18
|
+
embed_dim: Optional[int] = None,
|
|
19
|
+
tokenizer: Optional[str] = None,
|
|
20
|
+
) -> str:
|
|
21
|
+
parts = [network]
|
|
22
|
+
if image_encoder is not None:
|
|
23
|
+
parts.append(image_encoder)
|
|
24
|
+
if text_encoder is not None:
|
|
25
|
+
parts.append(text_encoder)
|
|
26
|
+
if tokenizer is not None:
|
|
27
|
+
parts.append(tokenizer)
|
|
28
|
+
if embed_dim is not None:
|
|
29
|
+
parts.append(f"d{embed_dim}")
|
|
30
|
+
|
|
31
|
+
network_name = "_".join(parts)
|
|
32
|
+
if tag is not None:
|
|
33
|
+
network_name = f"{network_name}_{tag}"
|
|
34
|
+
|
|
35
|
+
return network_name
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_image_text_network_config(
|
|
39
|
+
image_encoder: str,
|
|
40
|
+
text_encoder: str,
|
|
41
|
+
embed_dim: int,
|
|
42
|
+
tokenizer: Optional[str],
|
|
43
|
+
*,
|
|
44
|
+
image_size: Optional[tuple[int, int]] = None,
|
|
45
|
+
image_config: Optional[dict[str, Any]] = None,
|
|
46
|
+
text_config: Optional[dict[str, Any]] = None,
|
|
47
|
+
) -> dict[str, Any]:
|
|
48
|
+
image: dict[str, Any] = {"network": image_encoder}
|
|
49
|
+
if image_size is not None:
|
|
50
|
+
image["size"] = image_size
|
|
51
|
+
if image_config is not None:
|
|
52
|
+
image["config"] = image_config
|
|
53
|
+
|
|
54
|
+
text: dict[str, Any] = {"network": text_encoder}
|
|
55
|
+
if text_config is not None:
|
|
56
|
+
text["config"] = text_config
|
|
57
|
+
|
|
58
|
+
network_config = {
|
|
59
|
+
"image": image,
|
|
60
|
+
"text": text,
|
|
61
|
+
"embed_dim": embed_dim,
|
|
62
|
+
"tokenizer": tokenizer,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
return network_config
|
|
File without changes
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import logging.config
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
# Logging
|
|
6
|
+
# https://docs.python.org/3/library/logging.config.html
|
|
7
|
+
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
|
|
8
|
+
LOGGING: dict[str, Any] = {
|
|
9
|
+
"version": 1,
|
|
10
|
+
"disable_existing_loggers": False,
|
|
11
|
+
"formatters": {
|
|
12
|
+
"verbose": {
|
|
13
|
+
"format": "[{asctime}.{msecs:04.0f} {levelname} {filename}:{lineno:<4d}] {message}",
|
|
14
|
+
"style": "{",
|
|
15
|
+
"datefmt": "%d/%b/%Y %H:%M:%S",
|
|
16
|
+
},
|
|
17
|
+
"simple": {"format": "[{asctime} {levelname}] {message}", "style": "{"},
|
|
18
|
+
},
|
|
19
|
+
"handlers": {
|
|
20
|
+
"console": {"class": "logging.StreamHandler", "level": "DEBUG", "formatter": "verbose"},
|
|
21
|
+
},
|
|
22
|
+
"loggers": {
|
|
23
|
+
"birder_clip": {"handlers": ["console"], "level": LOG_LEVEL, "propagate": False},
|
|
24
|
+
},
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
logging.config.dictConfig(LOGGING)
|
|
File without changes
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Zero-shot text embedding helpers
|
|
3
|
+
|
|
4
|
+
Zero-shot classification compares image features against one text feature per
|
|
5
|
+
candidate class. When multiple prompt templates are used, this module follows
|
|
6
|
+
the OpenCLIP/OpenAI CLIP convention: encode every class/template prompt,
|
|
7
|
+
normalize prompt embeddings, average them per class and normalize the averaged
|
|
8
|
+
class embedding again.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from collections.abc import Sequence
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
from birder_clip.net.base import BaseNet
|
|
17
|
+
from birder_clip.tokenizers.base import Tokenizer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
|
|
21
|
+
return [template.format(class_name) for class_name in class_names for template in templates]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_class_text_embeddings(
|
|
25
|
+
model: BaseNet,
|
|
26
|
+
tokenizer: Tokenizer,
|
|
27
|
+
class_names: Sequence[str],
|
|
28
|
+
templates: Sequence[str],
|
|
29
|
+
*,
|
|
30
|
+
device: torch.device,
|
|
31
|
+
context_length: int | None = None,
|
|
32
|
+
batch_size: int | None = None,
|
|
33
|
+
) -> torch.Tensor:
|
|
34
|
+
num_templates = len(templates)
|
|
35
|
+
if batch_size is None:
|
|
36
|
+
batch_size = len(class_names)
|
|
37
|
+
|
|
38
|
+
class_text_embeddings = []
|
|
39
|
+
with torch.inference_mode():
|
|
40
|
+
for start in range(0, len(class_names), batch_size):
|
|
41
|
+
batch_class_names = class_names[start : start + batch_size]
|
|
42
|
+
prompts = render_prompts(batch_class_names, templates)
|
|
43
|
+
tokens = tokenizer(prompts, context_length=context_length).to(device)
|
|
44
|
+
class_embeddings = model.encode_text(tokens, normalize=True)
|
|
45
|
+
|
|
46
|
+
class_embeddings = class_embeddings.reshape(len(batch_class_names), num_templates, -1).mean(dim=1)
|
|
47
|
+
class_embeddings = F.normalize(class_embeddings, dim=-1)
|
|
48
|
+
class_text_embeddings.append(class_embeddings)
|
|
49
|
+
|
|
50
|
+
return torch.concat(class_text_embeddings, dim=0)
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Zero-shot classification prompt templates
|
|
3
|
+
|
|
4
|
+
OpenAI simple ImageNet templates adapted from OpenCLIP:
|
|
5
|
+
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/zero_shot_metadata.py
|
|
6
|
+
|
|
7
|
+
OpenAI templates adapted from:
|
|
8
|
+
https://github.com/openai/CLIP/blob/main/data/prompts.md
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# Reference license: MIT
|
|
12
|
+
|
|
13
|
+
DEFAULT_TEMPLATES = ("a photo of a {}.",)
|
|
14
|
+
|
|
15
|
+
IDENTITY_TEMPLATES = ("{}",)
|
|
16
|
+
|
|
17
|
+
SIMPLE_IMAGENET_TEMPLATES = (
|
|
18
|
+
"itap of a {}.",
|
|
19
|
+
"a bad photo of the {}.",
|
|
20
|
+
"a origami {}.",
|
|
21
|
+
"a photo of the large {}.",
|
|
22
|
+
"a {} in a video game.",
|
|
23
|
+
"art of the {}.",
|
|
24
|
+
"a photo of the small {}.",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
OPENAI_IMAGENET_TEMPLATES = (
|
|
28
|
+
"a bad photo of a {}.",
|
|
29
|
+
"a photo of many {}.",
|
|
30
|
+
"a sculpture of a {}.",
|
|
31
|
+
"a photo of the hard to see {}.",
|
|
32
|
+
"a low resolution photo of the {}.",
|
|
33
|
+
"a rendering of a {}.",
|
|
34
|
+
"graffiti of a {}.",
|
|
35
|
+
"a bad photo of the {}.",
|
|
36
|
+
"a cropped photo of the {}.",
|
|
37
|
+
"a tattoo of a {}.",
|
|
38
|
+
"the embroidered {}.",
|
|
39
|
+
"a photo of a hard to see {}.",
|
|
40
|
+
"a bright photo of a {}.",
|
|
41
|
+
"a photo of a clean {}.",
|
|
42
|
+
"a photo of a dirty {}.",
|
|
43
|
+
"a dark photo of the {}.",
|
|
44
|
+
"a drawing of a {}.",
|
|
45
|
+
"a photo of my {}.",
|
|
46
|
+
"the plastic {}.",
|
|
47
|
+
"a photo of the cool {}.",
|
|
48
|
+
"a close-up photo of a {}.",
|
|
49
|
+
"a black and white photo of the {}.",
|
|
50
|
+
"a painting of the {}.",
|
|
51
|
+
"a painting of a {}.",
|
|
52
|
+
"a pixelated photo of the {}.",
|
|
53
|
+
"a sculpture of the {}.",
|
|
54
|
+
"a bright photo of the {}.",
|
|
55
|
+
"a cropped photo of a {}.",
|
|
56
|
+
"a plastic {}.",
|
|
57
|
+
"a photo of the dirty {}.",
|
|
58
|
+
"a jpeg corrupted photo of a {}.",
|
|
59
|
+
"a blurry photo of the {}.",
|
|
60
|
+
"a photo of the {}.",
|
|
61
|
+
"a good photo of the {}.",
|
|
62
|
+
"a rendering of the {}.",
|
|
63
|
+
"a {} in a video game.",
|
|
64
|
+
"a photo of one {}.",
|
|
65
|
+
"a doodle of a {}.",
|
|
66
|
+
"a close-up photo of the {}.",
|
|
67
|
+
"a photo of a {}.",
|
|
68
|
+
"the origami {}.",
|
|
69
|
+
"the {} in a video game.",
|
|
70
|
+
"a sketch of a {}.",
|
|
71
|
+
"a doodle of the {}.",
|
|
72
|
+
"a origami {}.",
|
|
73
|
+
"a low resolution photo of a {}.",
|
|
74
|
+
"the toy {}.",
|
|
75
|
+
"a rendition of the {}.",
|
|
76
|
+
"a photo of the clean {}.",
|
|
77
|
+
"a photo of a large {}.",
|
|
78
|
+
"a rendition of a {}.",
|
|
79
|
+
"a photo of a nice {}.",
|
|
80
|
+
"a photo of a weird {}.",
|
|
81
|
+
"a blurry photo of a {}.",
|
|
82
|
+
"a cartoon {}.",
|
|
83
|
+
"art of a {}.",
|
|
84
|
+
"a sketch of the {}.",
|
|
85
|
+
"a embroidered {}.",
|
|
86
|
+
"a pixelated photo of a {}.",
|
|
87
|
+
"itap of the {}.",
|
|
88
|
+
"a jpeg corrupted photo of the {}.",
|
|
89
|
+
"a good photo of a {}.",
|
|
90
|
+
"a plushie {}.",
|
|
91
|
+
"a photo of the nice {}.",
|
|
92
|
+
"a photo of the small {}.",
|
|
93
|
+
"a photo of the weird {}.",
|
|
94
|
+
"the cartoon {}.",
|
|
95
|
+
"art of the {}.",
|
|
96
|
+
"a drawing of the {}.",
|
|
97
|
+
"a photo of the large {}.",
|
|
98
|
+
"a black and white photo of a {}.",
|
|
99
|
+
"the plushie {}.",
|
|
100
|
+
"a dark photo of a {}.",
|
|
101
|
+
"itap of a {}.",
|
|
102
|
+
"graffiti of the {}.",
|
|
103
|
+
"a toy {}.",
|
|
104
|
+
"itap of my {}.",
|
|
105
|
+
"a photo of a cool {}.",
|
|
106
|
+
"a photo of a small {}.",
|
|
107
|
+
"a tattoo of the {}.",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
OPENAI_CIFAR10_TEMPLATES = (
|
|
111
|
+
"a photo of a {}.",
|
|
112
|
+
"a blurry photo of a {}.",
|
|
113
|
+
"a black and white photo of a {}.",
|
|
114
|
+
"a low contrast photo of a {}.",
|
|
115
|
+
"a high contrast photo of a {}.",
|
|
116
|
+
"a bad photo of a {}.",
|
|
117
|
+
"a good photo of a {}.",
|
|
118
|
+
"a photo of a small {}.",
|
|
119
|
+
"a photo of a big {}.",
|
|
120
|
+
"a photo of the {}.",
|
|
121
|
+
"a blurry photo of the {}.",
|
|
122
|
+
"a black and white photo of the {}.",
|
|
123
|
+
"a low contrast photo of the {}.",
|
|
124
|
+
"a high contrast photo of the {}.",
|
|
125
|
+
"a bad photo of the {}.",
|
|
126
|
+
"a good photo of the {}.",
|
|
127
|
+
"a photo of the small {}.",
|
|
128
|
+
"a photo of the big {}.",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
OPENAI_CALTECH101_TEMPLATES = (
|
|
132
|
+
"a photo of a {}.",
|
|
133
|
+
"a painting of a {}.",
|
|
134
|
+
"a plastic {}.",
|
|
135
|
+
"a sculpture of a {}.",
|
|
136
|
+
"a sketch of a {}.",
|
|
137
|
+
"a tattoo of a {}.",
|
|
138
|
+
"a toy {}.",
|
|
139
|
+
"a rendition of a {}.",
|
|
140
|
+
"a embroidered {}.",
|
|
141
|
+
"a cartoon {}.",
|
|
142
|
+
"a {} in a video game.",
|
|
143
|
+
"a plushie {}.",
|
|
144
|
+
"a origami {}.",
|
|
145
|
+
"art of a {}.",
|
|
146
|
+
"graffiti of a {}.",
|
|
147
|
+
"a drawing of a {}.",
|
|
148
|
+
"a doodle of a {}.",
|
|
149
|
+
"a photo of the {}.",
|
|
150
|
+
"a painting of the {}.",
|
|
151
|
+
"the plastic {}.",
|
|
152
|
+
"a sculpture of the {}.",
|
|
153
|
+
"a sketch of the {}.",
|
|
154
|
+
"a tattoo of the {}.",
|
|
155
|
+
"the toy {}.",
|
|
156
|
+
"a rendition of the {}.",
|
|
157
|
+
"the embroidered {}.",
|
|
158
|
+
"the cartoon {}.",
|
|
159
|
+
"the {} in a video game.",
|
|
160
|
+
"the plushie {}.",
|
|
161
|
+
"the origami {}.",
|
|
162
|
+
"art of the {}.",
|
|
163
|
+
"graffiti of the {}.",
|
|
164
|
+
"a drawing of the {}.",
|
|
165
|
+
"a doodle of the {}.",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
OPENAI_COUNTRY211_TEMPLATES = (
|
|
169
|
+
"a photo i took in {}.",
|
|
170
|
+
"a photo i took while visiting {}.",
|
|
171
|
+
"a photo from my home country of {}.",
|
|
172
|
+
"a photo from my visit to {}.",
|
|
173
|
+
"a photo showing the country of {}.",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
OPENAI_EUROSAT_TEMPLATES = (
|
|
177
|
+
"a centered satellite photo of {}.",
|
|
178
|
+
"a centered satellite photo of a {}.",
|
|
179
|
+
"a centered satellite photo of the {}.",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
OPENAI_FGVC_AIRCRAFT_TEMPLATES = (
|
|
183
|
+
"a photo of a {}, a type of aircraft.",
|
|
184
|
+
"a photo of the {}, a type of aircraft.",
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
OPENAI_KINETICS_TEMPLATES = (
|
|
188
|
+
"a photo of {}.",
|
|
189
|
+
"a photo of a person {}.",
|
|
190
|
+
"a photo of a person using {}.",
|
|
191
|
+
"a photo of a person doing {}.",
|
|
192
|
+
"a photo of a person during {}.",
|
|
193
|
+
"a photo of a person performing {}.",
|
|
194
|
+
"a photo of a person practicing {}.",
|
|
195
|
+
"a video of {}.",
|
|
196
|
+
"a video of a person {}.",
|
|
197
|
+
"a video of a person using {}.",
|
|
198
|
+
"a video of a person doing {}.",
|
|
199
|
+
"a video of a person during {}.",
|
|
200
|
+
"a video of a person performing {}.",
|
|
201
|
+
"a video of a person practicing {}.",
|
|
202
|
+
"a example of {}.",
|
|
203
|
+
"a example of a person {}.",
|
|
204
|
+
"a example of a person using {}.",
|
|
205
|
+
"a example of a person doing {}.",
|
|
206
|
+
"a example of a person during {}.",
|
|
207
|
+
"a example of a person performing {}.",
|
|
208
|
+
"a example of a person practicing {}.",
|
|
209
|
+
"a demonstration of {}.",
|
|
210
|
+
"a demonstration of a person {}.",
|
|
211
|
+
"a demonstration of a person using {}.",
|
|
212
|
+
"a demonstration of a person doing {}.",
|
|
213
|
+
"a demonstration of a person during {}.",
|
|
214
|
+
"a demonstration of a person performing {}.",
|
|
215
|
+
"a demonstration of a person practicing {}.",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
TEMPLATE_SETS = {
|
|
219
|
+
"default": DEFAULT_TEMPLATES,
|
|
220
|
+
"identity": IDENTITY_TEMPLATES,
|
|
221
|
+
"simple_imagenet": SIMPLE_IMAGENET_TEMPLATES,
|
|
222
|
+
"openai_imagenet": OPENAI_IMAGENET_TEMPLATES,
|
|
223
|
+
"openai_cifar10": OPENAI_CIFAR10_TEMPLATES,
|
|
224
|
+
"openai_caltech101": OPENAI_CALTECH101_TEMPLATES,
|
|
225
|
+
"openai_country211": OPENAI_COUNTRY211_TEMPLATES,
|
|
226
|
+
"openai_eurosat": OPENAI_EUROSAT_TEMPLATES,
|
|
227
|
+
"openai_fgvc_aircraft": OPENAI_FGVC_AIRCRAFT_TEMPLATES,
|
|
228
|
+
"openai_kinetics": OPENAI_KINETICS_TEMPLATES,
|
|
229
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from birder_clip.model_registry.model_registry import Task
|
|
2
|
+
from birder_clip.model_registry.model_registry import list_pretrained_models
|
|
3
|
+
from birder_clip.model_registry.model_registry import registry
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"Task",
|
|
7
|
+
"list_pretrained_models",
|
|
8
|
+
"registry",
|
|
9
|
+
]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import fnmatch
|
|
2
|
+
import warnings
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing import Literal
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from typing import TypeAlias
|
|
9
|
+
|
|
10
|
+
from birder.model_registry.model_registry import group_sort
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING is True:
|
|
13
|
+
from birder_clip.net.base import BaseNet # pylint: disable=cyclic-import
|
|
14
|
+
from birder_clip.net.text.base import TextBaseNet # pylint: disable=cyclic-import
|
|
15
|
+
|
|
16
|
+
NetType: TypeAlias = type[BaseNet] | type[TextBaseNet]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Task(StrEnum):
|
|
20
|
+
IMAGE_TEXT = "image_text"
|
|
21
|
+
TEXT = "text"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ModelRegistry:
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self.registered_configs: dict[str, "NetType"] = {}
|
|
27
|
+
self._image_text_nets: dict[str, type["BaseNet"]] = {}
|
|
28
|
+
self._text_nets: dict[str, type[TextBaseNet]] = {}
|
|
29
|
+
self._pretrained_nets: dict[str, dict[str, Any]] = {}
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def all_nets(self) -> dict[str, "NetType"]:
|
|
33
|
+
return {**self._image_text_nets, **self._text_nets}
|
|
34
|
+
|
|
35
|
+
def _get_models_for_task(self, task: Task) -> dict[str, "NetType"]:
|
|
36
|
+
if task == Task.IMAGE_TEXT:
|
|
37
|
+
nets: dict[str, "NetType"] = self._image_text_nets
|
|
38
|
+
elif task == Task.TEXT:
|
|
39
|
+
nets = self._text_nets
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError(f"Unsupported model task: {task}")
|
|
42
|
+
|
|
43
|
+
return nets
|
|
44
|
+
|
|
45
|
+
def _register_model(self, name: str, net_type: "NetType") -> None:
|
|
46
|
+
name_key = name.lower()
|
|
47
|
+
task = Task(net_type.task)
|
|
48
|
+
nets = self._get_models_for_task(task)
|
|
49
|
+
if name_key in self.all_nets and name_key not in nets:
|
|
50
|
+
raise ValueError(f"Registered model name '{name}' collides with an existing registered model name")
|
|
51
|
+
if name_key in nets:
|
|
52
|
+
warnings.warn(f"Network '{name}' is already registered and will be overwritten", UserWarning)
|
|
53
|
+
|
|
54
|
+
nets[name_key] = net_type
|
|
55
|
+
|
|
56
|
+
def register_model_config(self, name: str, net_type: "NetType", *, config: Optional[dict[str, Any]] = None) -> None:
|
|
57
|
+
name_key = name.lower()
|
|
58
|
+
registered_net_type = type(name, (net_type,), {"config": config})
|
|
59
|
+
self._register_model(name_key, registered_net_type)
|
|
60
|
+
if name_key in self.registered_configs:
|
|
61
|
+
warnings.warn(f"Registered config '{name}' is already registered and will be overwritten", UserWarning)
|
|
62
|
+
|
|
63
|
+
self.registered_configs[name_key] = registered_net_type
|
|
64
|
+
|
|
65
|
+
def register_weights(self, name: str, weights_info: dict[str, Any]) -> None:
|
|
66
|
+
if name in self._pretrained_nets:
|
|
67
|
+
warnings.warn(f"Weights '{name}' are already registered and will be overwritten", UserWarning)
|
|
68
|
+
|
|
69
|
+
self._pretrained_nets[name] = weights_info
|
|
70
|
+
|
|
71
|
+
def list_models(
|
|
72
|
+
self,
|
|
73
|
+
include_filter: Optional[str] = None,
|
|
74
|
+
*,
|
|
75
|
+
task: Optional[Task] = None,
|
|
76
|
+
net_type: Optional[type | tuple[type, ...]] = None,
|
|
77
|
+
net_type_op: Literal["AND", "OR"] = "AND",
|
|
78
|
+
) -> list[str]:
|
|
79
|
+
nets = self.all_nets
|
|
80
|
+
if task is not None:
|
|
81
|
+
nets = self._get_models_for_task(task)
|
|
82
|
+
|
|
83
|
+
if net_type is not None:
|
|
84
|
+
if not isinstance(net_type, tuple):
|
|
85
|
+
net_type = (net_type,)
|
|
86
|
+
|
|
87
|
+
if net_type_op == "OR":
|
|
88
|
+
nets = {name: t for name, t in nets.items() if issubclass(t, net_type) is True}
|
|
89
|
+
elif net_type_op == "AND":
|
|
90
|
+
nets = {name: t for name, t in nets.items() if all(issubclass(t, nt) for nt in net_type)}
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(f"Unknown op {net_type_op}")
|
|
93
|
+
|
|
94
|
+
model_list = list(nets.keys())
|
|
95
|
+
if include_filter is not None:
|
|
96
|
+
model_list = fnmatch.filter(model_list, include_filter)
|
|
97
|
+
|
|
98
|
+
return group_sort(model_list)
|
|
99
|
+
|
|
100
|
+
def list_pretrained_models(self, include_filter: Optional[str] = None) -> list[str]:
|
|
101
|
+
model_list = list(self._pretrained_nets.keys())
|
|
102
|
+
if include_filter is not None:
|
|
103
|
+
model_list = fnmatch.filter(model_list, include_filter)
|
|
104
|
+
|
|
105
|
+
return group_sort(model_list)
|
|
106
|
+
|
|
107
|
+
def exists(self, name: str, task: Optional[Task] = None, net_type: Optional[type] = None) -> bool:
|
|
108
|
+
nets = self.all_nets
|
|
109
|
+
if task is not None:
|
|
110
|
+
nets = self._get_models_for_task(task)
|
|
111
|
+
|
|
112
|
+
if net_type is not None:
|
|
113
|
+
nets = {name: t for name, t in nets.items() if issubclass(t, net_type) is True}
|
|
114
|
+
|
|
115
|
+
return name.lower() in nets
|
|
116
|
+
|
|
117
|
+
def pretrained_exists(self, name: str) -> bool:
|
|
118
|
+
return name in self._pretrained_nets
|
|
119
|
+
|
|
120
|
+
def get_pretrained_metadata(self, name: str) -> dict[str, Any]:
|
|
121
|
+
return self._pretrained_nets[name]
|
|
122
|
+
|
|
123
|
+
def text_factory(self, name: str, *, config: Optional[dict[str, Any]] = None) -> "TextBaseNet":
|
|
124
|
+
name_key = name.lower()
|
|
125
|
+
return self._text_nets[name_key](config=config)
|
|
126
|
+
|
|
127
|
+
def net_factory(self, name: str, *, config: Optional[dict[str, Any]] = None) -> "BaseNet":
|
|
128
|
+
name_key = name.lower()
|
|
129
|
+
return self._image_text_nets[name_key](config=config)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
registry = ModelRegistry()
|
|
133
|
+
list_pretrained_models = registry.list_pretrained_models
|
birder_clip/net/base.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from typing import Any
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from birder.net.base import BaseNet as ImageEncoder
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from birder_clip.model_registry import Task
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseNet(nn.Module):
|
|
13
|
+
task = str(Task.IMAGE_TEXT)
|
|
14
|
+
|
|
15
|
+
def __init__(self, *, config: Optional[dict[str, Any]] = None) -> None:
|
|
16
|
+
super().__init__()
|
|
17
|
+
if hasattr(self, "config") is False:
|
|
18
|
+
self.config = config
|
|
19
|
+
else:
|
|
20
|
+
if self.config is not None:
|
|
21
|
+
self.config = copy.deepcopy(self.config) # Avoid mutating registered configs
|
|
22
|
+
|
|
23
|
+
if config is not None:
|
|
24
|
+
assert self.config is not None
|
|
25
|
+
self.config.update(config) # Override with custom config
|
|
26
|
+
|
|
27
|
+
self.image_encoder: ImageEncoder
|
|
28
|
+
self.embedding_size: int
|
|
29
|
+
self.tokenizer_name: str
|
|
30
|
+
|
|
31
|
+
def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
def forward_logits(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
def forward(self, image: torch.Tensor, text: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
def adjust_image_size(self, new_size: tuple[int, int]) -> None:
|
|
44
|
+
self.image_encoder.adjust_size(new_size)
|