birder-clip 0.0.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ from birder_clip.model_registry.model_registry import list_pretrained_models
2
+ from birder_clip.version import __version__
3
+
4
+ __all__ = [
5
+ "list_pretrained_models",
6
+ "__version__",
7
+ ]
File without changes
@@ -0,0 +1,111 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Any
4
+ from typing import NamedTuple
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from birder.conf import settings
9
+ from birder.data.transforms.classification import RGBType
10
+
11
+ from birder_clip.common import lib
12
+ from birder_clip.model_registry import registry
13
+ from birder_clip.net.base import BaseNet
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ModelInfo(NamedTuple):
19
+ signature: dict[str, Any]
20
+ rgb_stats: RGBType
21
+ custom_config: Optional[dict[str, Any]] = None
22
+
23
+
24
+ def model_path(network_name: str, *, epoch: Optional[int | str] = None) -> Path:
25
+ if epoch is not None:
26
+ file_name = f"{network_name}_{epoch}.pt"
27
+ else:
28
+ file_name = f"{network_name}.pt"
29
+
30
+ return settings.MODELS_DIR.joinpath(file_name)
31
+
32
+
33
+ def load_model(
34
+ device: torch.device,
35
+ network: str,
36
+ *,
37
+ path: Optional[str | Path] = None,
38
+ config: Optional[dict[str, Any]] = None,
39
+ tag: Optional[str] = None,
40
+ image_encoder: Optional[str] = None,
41
+ text_encoder: Optional[str] = None,
42
+ embed_dim: Optional[int] = None,
43
+ tokenizer: Optional[str] = None,
44
+ image_encoder_config: Optional[dict[str, Any]] = None,
45
+ text_encoder_config: Optional[dict[str, Any]] = None,
46
+ epoch: Optional[int] = None,
47
+ new_size: Optional[tuple[int, int]] = None,
48
+ inference: bool,
49
+ dtype: Optional[torch.dtype] = None,
50
+ ) -> tuple[BaseNet, ModelInfo]:
51
+ if path is None:
52
+ _network_name = lib.get_image_text_network_name(
53
+ network,
54
+ tag=tag,
55
+ image_encoder=image_encoder,
56
+ text_encoder=text_encoder,
57
+ embed_dim=embed_dim,
58
+ tokenizer=tokenizer,
59
+ )
60
+ path = model_path(_network_name, epoch=epoch)
61
+
62
+ logger.info(f"Loading model from {path} on device {device}...")
63
+ model_dict: dict[str, Any] = torch.load(path, map_location=device, weights_only=True)
64
+
65
+ loaded_config: dict[str, Any] = model_dict.get("config", {})
66
+ merged_config = {**loaded_config}
67
+ if tokenizer is not None:
68
+ tokenizer_name: Optional[str] = tokenizer
69
+ merged_config["tokenizer"] = tokenizer
70
+ else:
71
+ tokenizer_name = loaded_config.get("tokenizer")
72
+
73
+ if image_encoder is not None and text_encoder is not None and embed_dim is not None and tokenizer_name is not None:
74
+ image_size = loaded_config.get("image", {}).get("size")
75
+ merged_config.update(
76
+ lib.get_image_text_network_config(
77
+ image_encoder,
78
+ text_encoder,
79
+ embed_dim,
80
+ tokenizer_name,
81
+ image_size=image_size,
82
+ image_config=image_encoder_config,
83
+ text_config=text_encoder_config,
84
+ )
85
+ )
86
+ if config is not None:
87
+ merged_config.update(config)
88
+ if len(merged_config) == 0:
89
+ merged_config = None # type: ignore[assignment]
90
+
91
+ net = registry.net_factory(network, config=merged_config)
92
+ net.load_state_dict(model_dict["state"])
93
+ if new_size is not None:
94
+ net.adjust_image_size(new_size)
95
+
96
+ net.to(device)
97
+ if dtype is not None:
98
+ net.to(dtype)
99
+ if inference is True:
100
+ for param in net.parameters():
101
+ param.requires_grad_(False)
102
+
103
+ net.eval()
104
+
105
+ if len(loaded_config) == 0:
106
+ custom_config = None
107
+ else:
108
+ custom_config = loaded_config
109
+ logger.debug(f"Model loaded with custom config: {custom_config}")
110
+
111
+ return (net, ModelInfo(model_dict["signature"], model_dict["rgb_stats"], custom_config))
@@ -0,0 +1,65 @@
1
+ from typing import Any
2
+ from typing import Optional
3
+
4
+
5
+ def get_size_from_signature(signature: dict[str, Any]) -> tuple[int, int]:
6
+ return tuple(signature["inputs"][0]["data_shape"][2:4])
7
+
8
+
9
+ def get_channels_from_signature(signature: dict[str, Any]) -> int:
10
+ return signature["inputs"][0]["data_shape"][1] # type: ignore[no-any-return]
11
+
12
+
13
+ def get_image_text_network_name(
14
+ network: str,
15
+ tag: Optional[str] = None,
16
+ image_encoder: Optional[str] = None,
17
+ text_encoder: Optional[str] = None,
18
+ embed_dim: Optional[int] = None,
19
+ tokenizer: Optional[str] = None,
20
+ ) -> str:
21
+ parts = [network]
22
+ if image_encoder is not None:
23
+ parts.append(image_encoder)
24
+ if text_encoder is not None:
25
+ parts.append(text_encoder)
26
+ if tokenizer is not None:
27
+ parts.append(tokenizer)
28
+ if embed_dim is not None:
29
+ parts.append(f"d{embed_dim}")
30
+
31
+ network_name = "_".join(parts)
32
+ if tag is not None:
33
+ network_name = f"{network_name}_{tag}"
34
+
35
+ return network_name
36
+
37
+
38
+ def get_image_text_network_config(
39
+ image_encoder: str,
40
+ text_encoder: str,
41
+ embed_dim: int,
42
+ tokenizer: Optional[str],
43
+ *,
44
+ image_size: Optional[tuple[int, int]] = None,
45
+ image_config: Optional[dict[str, Any]] = None,
46
+ text_config: Optional[dict[str, Any]] = None,
47
+ ) -> dict[str, Any]:
48
+ image: dict[str, Any] = {"network": image_encoder}
49
+ if image_size is not None:
50
+ image["size"] = image_size
51
+ if image_config is not None:
52
+ image["config"] = image_config
53
+
54
+ text: dict[str, Any] = {"network": text_encoder}
55
+ if text_config is not None:
56
+ text["config"] = text_config
57
+
58
+ network_config = {
59
+ "image": image,
60
+ "text": text,
61
+ "embed_dim": embed_dim,
62
+ "tokenizer": tokenizer,
63
+ }
64
+
65
+ return network_config
File without changes
@@ -0,0 +1,27 @@
1
+ import logging.config
2
+ import os
3
+ from typing import Any
4
+
5
+ # Logging
6
+ # https://docs.python.org/3/library/logging.config.html
7
+ LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
8
+ LOGGING: dict[str, Any] = {
9
+ "version": 1,
10
+ "disable_existing_loggers": False,
11
+ "formatters": {
12
+ "verbose": {
13
+ "format": "[{asctime}.{msecs:04.0f} {levelname} {filename}:{lineno:<4d}] {message}",
14
+ "style": "{",
15
+ "datefmt": "%d/%b/%Y %H:%M:%S",
16
+ },
17
+ "simple": {"format": "[{asctime} {levelname}] {message}", "style": "{"},
18
+ },
19
+ "handlers": {
20
+ "console": {"class": "logging.StreamHandler", "level": "DEBUG", "formatter": "verbose"},
21
+ },
22
+ "loggers": {
23
+ "birder_clip": {"handlers": ["console"], "level": LOG_LEVEL, "propagate": False},
24
+ },
25
+ }
26
+
27
+ logging.config.dictConfig(LOGGING)
File without changes
@@ -0,0 +1,50 @@
1
+ """
2
+ Zero-shot text embedding helpers
3
+
4
+ Zero-shot classification compares image features against one text feature per
5
+ candidate class. When multiple prompt templates are used, this module follows
6
+ the OpenCLIP/OpenAI CLIP convention: encode every class/template prompt,
7
+ normalize prompt embeddings, average them per class and normalize the averaged
8
+ class embedding again.
9
+ """
10
+
11
+ from collections.abc import Sequence
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from birder_clip.net.base import BaseNet
17
+ from birder_clip.tokenizers.base import Tokenizer
18
+
19
+
20
+ def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
21
+ return [template.format(class_name) for class_name in class_names for template in templates]
22
+
23
+
24
+ def build_class_text_embeddings(
25
+ model: BaseNet,
26
+ tokenizer: Tokenizer,
27
+ class_names: Sequence[str],
28
+ templates: Sequence[str],
29
+ *,
30
+ device: torch.device,
31
+ context_length: int | None = None,
32
+ batch_size: int | None = None,
33
+ ) -> torch.Tensor:
34
+ num_templates = len(templates)
35
+ if batch_size is None:
36
+ batch_size = len(class_names)
37
+
38
+ class_text_embeddings = []
39
+ with torch.inference_mode():
40
+ for start in range(0, len(class_names), batch_size):
41
+ batch_class_names = class_names[start : start + batch_size]
42
+ prompts = render_prompts(batch_class_names, templates)
43
+ tokens = tokenizer(prompts, context_length=context_length).to(device)
44
+ class_embeddings = model.encode_text(tokens, normalize=True)
45
+
46
+ class_embeddings = class_embeddings.reshape(len(batch_class_names), num_templates, -1).mean(dim=1)
47
+ class_embeddings = F.normalize(class_embeddings, dim=-1)
48
+ class_text_embeddings.append(class_embeddings)
49
+
50
+ return torch.concat(class_text_embeddings, dim=0)
@@ -0,0 +1,229 @@
1
+ """
2
+ Zero-shot classification prompt templates
3
+
4
+ OpenAI simple ImageNet templates adapted from OpenCLIP:
5
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/zero_shot_metadata.py
6
+
7
+ OpenAI templates adapted from:
8
+ https://github.com/openai/CLIP/blob/main/data/prompts.md
9
+ """
10
+
11
+ # Reference license: MIT
12
+
13
+ DEFAULT_TEMPLATES = ("a photo of a {}.",)
14
+
15
+ IDENTITY_TEMPLATES = ("{}",)
16
+
17
+ SIMPLE_IMAGENET_TEMPLATES = (
18
+ "itap of a {}.",
19
+ "a bad photo of the {}.",
20
+ "a origami {}.",
21
+ "a photo of the large {}.",
22
+ "a {} in a video game.",
23
+ "art of the {}.",
24
+ "a photo of the small {}.",
25
+ )
26
+
27
+ OPENAI_IMAGENET_TEMPLATES = (
28
+ "a bad photo of a {}.",
29
+ "a photo of many {}.",
30
+ "a sculpture of a {}.",
31
+ "a photo of the hard to see {}.",
32
+ "a low resolution photo of the {}.",
33
+ "a rendering of a {}.",
34
+ "graffiti of a {}.",
35
+ "a bad photo of the {}.",
36
+ "a cropped photo of the {}.",
37
+ "a tattoo of a {}.",
38
+ "the embroidered {}.",
39
+ "a photo of a hard to see {}.",
40
+ "a bright photo of a {}.",
41
+ "a photo of a clean {}.",
42
+ "a photo of a dirty {}.",
43
+ "a dark photo of the {}.",
44
+ "a drawing of a {}.",
45
+ "a photo of my {}.",
46
+ "the plastic {}.",
47
+ "a photo of the cool {}.",
48
+ "a close-up photo of a {}.",
49
+ "a black and white photo of the {}.",
50
+ "a painting of the {}.",
51
+ "a painting of a {}.",
52
+ "a pixelated photo of the {}.",
53
+ "a sculpture of the {}.",
54
+ "a bright photo of the {}.",
55
+ "a cropped photo of a {}.",
56
+ "a plastic {}.",
57
+ "a photo of the dirty {}.",
58
+ "a jpeg corrupted photo of a {}.",
59
+ "a blurry photo of the {}.",
60
+ "a photo of the {}.",
61
+ "a good photo of the {}.",
62
+ "a rendering of the {}.",
63
+ "a {} in a video game.",
64
+ "a photo of one {}.",
65
+ "a doodle of a {}.",
66
+ "a close-up photo of the {}.",
67
+ "a photo of a {}.",
68
+ "the origami {}.",
69
+ "the {} in a video game.",
70
+ "a sketch of a {}.",
71
+ "a doodle of the {}.",
72
+ "a origami {}.",
73
+ "a low resolution photo of a {}.",
74
+ "the toy {}.",
75
+ "a rendition of the {}.",
76
+ "a photo of the clean {}.",
77
+ "a photo of a large {}.",
78
+ "a rendition of a {}.",
79
+ "a photo of a nice {}.",
80
+ "a photo of a weird {}.",
81
+ "a blurry photo of a {}.",
82
+ "a cartoon {}.",
83
+ "art of a {}.",
84
+ "a sketch of the {}.",
85
+ "a embroidered {}.",
86
+ "a pixelated photo of a {}.",
87
+ "itap of the {}.",
88
+ "a jpeg corrupted photo of the {}.",
89
+ "a good photo of a {}.",
90
+ "a plushie {}.",
91
+ "a photo of the nice {}.",
92
+ "a photo of the small {}.",
93
+ "a photo of the weird {}.",
94
+ "the cartoon {}.",
95
+ "art of the {}.",
96
+ "a drawing of the {}.",
97
+ "a photo of the large {}.",
98
+ "a black and white photo of a {}.",
99
+ "the plushie {}.",
100
+ "a dark photo of a {}.",
101
+ "itap of a {}.",
102
+ "graffiti of the {}.",
103
+ "a toy {}.",
104
+ "itap of my {}.",
105
+ "a photo of a cool {}.",
106
+ "a photo of a small {}.",
107
+ "a tattoo of the {}.",
108
+ )
109
+
110
+ OPENAI_CIFAR10_TEMPLATES = (
111
+ "a photo of a {}.",
112
+ "a blurry photo of a {}.",
113
+ "a black and white photo of a {}.",
114
+ "a low contrast photo of a {}.",
115
+ "a high contrast photo of a {}.",
116
+ "a bad photo of a {}.",
117
+ "a good photo of a {}.",
118
+ "a photo of a small {}.",
119
+ "a photo of a big {}.",
120
+ "a photo of the {}.",
121
+ "a blurry photo of the {}.",
122
+ "a black and white photo of the {}.",
123
+ "a low contrast photo of the {}.",
124
+ "a high contrast photo of the {}.",
125
+ "a bad photo of the {}.",
126
+ "a good photo of the {}.",
127
+ "a photo of the small {}.",
128
+ "a photo of the big {}.",
129
+ )
130
+
131
+ OPENAI_CALTECH101_TEMPLATES = (
132
+ "a photo of a {}.",
133
+ "a painting of a {}.",
134
+ "a plastic {}.",
135
+ "a sculpture of a {}.",
136
+ "a sketch of a {}.",
137
+ "a tattoo of a {}.",
138
+ "a toy {}.",
139
+ "a rendition of a {}.",
140
+ "a embroidered {}.",
141
+ "a cartoon {}.",
142
+ "a {} in a video game.",
143
+ "a plushie {}.",
144
+ "a origami {}.",
145
+ "art of a {}.",
146
+ "graffiti of a {}.",
147
+ "a drawing of a {}.",
148
+ "a doodle of a {}.",
149
+ "a photo of the {}.",
150
+ "a painting of the {}.",
151
+ "the plastic {}.",
152
+ "a sculpture of the {}.",
153
+ "a sketch of the {}.",
154
+ "a tattoo of the {}.",
155
+ "the toy {}.",
156
+ "a rendition of the {}.",
157
+ "the embroidered {}.",
158
+ "the cartoon {}.",
159
+ "the {} in a video game.",
160
+ "the plushie {}.",
161
+ "the origami {}.",
162
+ "art of the {}.",
163
+ "graffiti of the {}.",
164
+ "a drawing of the {}.",
165
+ "a doodle of the {}.",
166
+ )
167
+
168
+ OPENAI_COUNTRY211_TEMPLATES = (
169
+ "a photo i took in {}.",
170
+ "a photo i took while visiting {}.",
171
+ "a photo from my home country of {}.",
172
+ "a photo from my visit to {}.",
173
+ "a photo showing the country of {}.",
174
+ )
175
+
176
+ OPENAI_EUROSAT_TEMPLATES = (
177
+ "a centered satellite photo of {}.",
178
+ "a centered satellite photo of a {}.",
179
+ "a centered satellite photo of the {}.",
180
+ )
181
+
182
+ OPENAI_FGVC_AIRCRAFT_TEMPLATES = (
183
+ "a photo of a {}, a type of aircraft.",
184
+ "a photo of the {}, a type of aircraft.",
185
+ )
186
+
187
+ OPENAI_KINETICS_TEMPLATES = (
188
+ "a photo of {}.",
189
+ "a photo of a person {}.",
190
+ "a photo of a person using {}.",
191
+ "a photo of a person doing {}.",
192
+ "a photo of a person during {}.",
193
+ "a photo of a person performing {}.",
194
+ "a photo of a person practicing {}.",
195
+ "a video of {}.",
196
+ "a video of a person {}.",
197
+ "a video of a person using {}.",
198
+ "a video of a person doing {}.",
199
+ "a video of a person during {}.",
200
+ "a video of a person performing {}.",
201
+ "a video of a person practicing {}.",
202
+ "a example of {}.",
203
+ "a example of a person {}.",
204
+ "a example of a person using {}.",
205
+ "a example of a person doing {}.",
206
+ "a example of a person during {}.",
207
+ "a example of a person performing {}.",
208
+ "a example of a person practicing {}.",
209
+ "a demonstration of {}.",
210
+ "a demonstration of a person {}.",
211
+ "a demonstration of a person using {}.",
212
+ "a demonstration of a person doing {}.",
213
+ "a demonstration of a person during {}.",
214
+ "a demonstration of a person performing {}.",
215
+ "a demonstration of a person practicing {}.",
216
+ )
217
+
218
+ TEMPLATE_SETS = {
219
+ "default": DEFAULT_TEMPLATES,
220
+ "identity": IDENTITY_TEMPLATES,
221
+ "simple_imagenet": SIMPLE_IMAGENET_TEMPLATES,
222
+ "openai_imagenet": OPENAI_IMAGENET_TEMPLATES,
223
+ "openai_cifar10": OPENAI_CIFAR10_TEMPLATES,
224
+ "openai_caltech101": OPENAI_CALTECH101_TEMPLATES,
225
+ "openai_country211": OPENAI_COUNTRY211_TEMPLATES,
226
+ "openai_eurosat": OPENAI_EUROSAT_TEMPLATES,
227
+ "openai_fgvc_aircraft": OPENAI_FGVC_AIRCRAFT_TEMPLATES,
228
+ "openai_kinetics": OPENAI_KINETICS_TEMPLATES,
229
+ }
@@ -0,0 +1,9 @@
1
+ from birder_clip.model_registry.model_registry import Task
2
+ from birder_clip.model_registry.model_registry import list_pretrained_models
3
+ from birder_clip.model_registry.model_registry import registry
4
+
5
+ __all__ = [
6
+ "Task",
7
+ "list_pretrained_models",
8
+ "registry",
9
+ ]
@@ -0,0 +1,133 @@
1
+ import fnmatch
2
+ import warnings
3
+ from enum import StrEnum
4
+ from typing import TYPE_CHECKING
5
+ from typing import Any
6
+ from typing import Literal
7
+ from typing import Optional
8
+ from typing import TypeAlias
9
+
10
+ from birder.model_registry.model_registry import group_sort
11
+
12
+ if TYPE_CHECKING is True:
13
+ from birder_clip.net.base import BaseNet # pylint: disable=cyclic-import
14
+ from birder_clip.net.text.base import TextBaseNet # pylint: disable=cyclic-import
15
+
16
+ NetType: TypeAlias = type[BaseNet] | type[TextBaseNet]
17
+
18
+
19
+ class Task(StrEnum):
20
+ IMAGE_TEXT = "image_text"
21
+ TEXT = "text"
22
+
23
+
24
+ class ModelRegistry:
25
+ def __init__(self) -> None:
26
+ self.registered_configs: dict[str, "NetType"] = {}
27
+ self._image_text_nets: dict[str, type["BaseNet"]] = {}
28
+ self._text_nets: dict[str, type[TextBaseNet]] = {}
29
+ self._pretrained_nets: dict[str, dict[str, Any]] = {}
30
+
31
+ @property
32
+ def all_nets(self) -> dict[str, "NetType"]:
33
+ return {**self._image_text_nets, **self._text_nets}
34
+
35
+ def _get_models_for_task(self, task: Task) -> dict[str, "NetType"]:
36
+ if task == Task.IMAGE_TEXT:
37
+ nets: dict[str, "NetType"] = self._image_text_nets
38
+ elif task == Task.TEXT:
39
+ nets = self._text_nets
40
+ else:
41
+ raise ValueError(f"Unsupported model task: {task}")
42
+
43
+ return nets
44
+
45
+ def _register_model(self, name: str, net_type: "NetType") -> None:
46
+ name_key = name.lower()
47
+ task = Task(net_type.task)
48
+ nets = self._get_models_for_task(task)
49
+ if name_key in self.all_nets and name_key not in nets:
50
+ raise ValueError(f"Registered model name '{name}' collides with an existing registered model name")
51
+ if name_key in nets:
52
+ warnings.warn(f"Network '{name}' is already registered and will be overwritten", UserWarning)
53
+
54
+ nets[name_key] = net_type
55
+
56
+ def register_model_config(self, name: str, net_type: "NetType", *, config: Optional[dict[str, Any]] = None) -> None:
57
+ name_key = name.lower()
58
+ registered_net_type = type(name, (net_type,), {"config": config})
59
+ self._register_model(name_key, registered_net_type)
60
+ if name_key in self.registered_configs:
61
+ warnings.warn(f"Registered config '{name}' is already registered and will be overwritten", UserWarning)
62
+
63
+ self.registered_configs[name_key] = registered_net_type
64
+
65
+ def register_weights(self, name: str, weights_info: dict[str, Any]) -> None:
66
+ if name in self._pretrained_nets:
67
+ warnings.warn(f"Weights '{name}' are already registered and will be overwritten", UserWarning)
68
+
69
+ self._pretrained_nets[name] = weights_info
70
+
71
+ def list_models(
72
+ self,
73
+ include_filter: Optional[str] = None,
74
+ *,
75
+ task: Optional[Task] = None,
76
+ net_type: Optional[type | tuple[type, ...]] = None,
77
+ net_type_op: Literal["AND", "OR"] = "AND",
78
+ ) -> list[str]:
79
+ nets = self.all_nets
80
+ if task is not None:
81
+ nets = self._get_models_for_task(task)
82
+
83
+ if net_type is not None:
84
+ if not isinstance(net_type, tuple):
85
+ net_type = (net_type,)
86
+
87
+ if net_type_op == "OR":
88
+ nets = {name: t for name, t in nets.items() if issubclass(t, net_type) is True}
89
+ elif net_type_op == "AND":
90
+ nets = {name: t for name, t in nets.items() if all(issubclass(t, nt) for nt in net_type)}
91
+ else:
92
+ raise ValueError(f"Unknown op {net_type_op}")
93
+
94
+ model_list = list(nets.keys())
95
+ if include_filter is not None:
96
+ model_list = fnmatch.filter(model_list, include_filter)
97
+
98
+ return group_sort(model_list)
99
+
100
+ def list_pretrained_models(self, include_filter: Optional[str] = None) -> list[str]:
101
+ model_list = list(self._pretrained_nets.keys())
102
+ if include_filter is not None:
103
+ model_list = fnmatch.filter(model_list, include_filter)
104
+
105
+ return group_sort(model_list)
106
+
107
+ def exists(self, name: str, task: Optional[Task] = None, net_type: Optional[type] = None) -> bool:
108
+ nets = self.all_nets
109
+ if task is not None:
110
+ nets = self._get_models_for_task(task)
111
+
112
+ if net_type is not None:
113
+ nets = {name: t for name, t in nets.items() if issubclass(t, net_type) is True}
114
+
115
+ return name.lower() in nets
116
+
117
+ def pretrained_exists(self, name: str) -> bool:
118
+ return name in self._pretrained_nets
119
+
120
+ def get_pretrained_metadata(self, name: str) -> dict[str, Any]:
121
+ return self._pretrained_nets[name]
122
+
123
+ def text_factory(self, name: str, *, config: Optional[dict[str, Any]] = None) -> "TextBaseNet":
124
+ name_key = name.lower()
125
+ return self._text_nets[name_key](config=config)
126
+
127
+ def net_factory(self, name: str, *, config: Optional[dict[str, Any]] = None) -> "BaseNet":
128
+ name_key = name.lower()
129
+ return self._image_text_nets[name_key](config=config)
130
+
131
+
132
+ registry = ModelRegistry()
133
+ list_pretrained_models = registry.list_pretrained_models
@@ -0,0 +1,5 @@
1
+ from birder_clip.net.clip import CLIP
2
+
3
+ __all__ = [
4
+ "CLIP",
5
+ ]
@@ -0,0 +1,44 @@
1
+ import copy
2
+ from typing import Any
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from birder.net.base import BaseNet as ImageEncoder
7
+ from torch import nn
8
+
9
+ from birder_clip.model_registry import Task
10
+
11
+
12
+ class BaseNet(nn.Module):
13
+ task = str(Task.IMAGE_TEXT)
14
+
15
+ def __init__(self, *, config: Optional[dict[str, Any]] = None) -> None:
16
+ super().__init__()
17
+ if hasattr(self, "config") is False:
18
+ self.config = config
19
+ else:
20
+ if self.config is not None:
21
+ self.config = copy.deepcopy(self.config) # Avoid mutating registered configs
22
+
23
+ if config is not None:
24
+ assert self.config is not None
25
+ self.config.update(config) # Override with custom config
26
+
27
+ self.image_encoder: ImageEncoder
28
+ self.embedding_size: int
29
+ self.tokenizer_name: str
30
+
31
+ def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
32
+ raise NotImplementedError
33
+
34
+ def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
35
+ raise NotImplementedError
36
+
37
+ def forward_logits(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
38
+ raise NotImplementedError
39
+
40
+ def forward(self, image: torch.Tensor, text: torch.Tensor) -> torch.Tensor:
41
+ raise NotImplementedError
42
+
43
+ def adjust_image_size(self, new_size: tuple[int, int]) -> None:
44
+ self.image_encoder.adjust_size(new_size)