birder-clip 0.0.2.dev2__tar.gz → 0.0.2.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/PKG-INFO +2 -2
  2. birder_clip-0.0.2.dev4/birder_clip/__init__.py +20 -0
  3. birder_clip-0.0.2.dev4/birder_clip/common/fs_ops.py +610 -0
  4. birder_clip-0.0.2.dev4/birder_clip/common/lib.py +162 -0
  5. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/common/training_cli.py +150 -37
  6. birder_clip-0.0.2.dev4/birder_clip/common/training_utils.py +61 -0
  7. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/conf/settings.py +3 -0
  8. birder_clip-0.0.2.dev4/birder_clip/data/datasets/fake.py +23 -0
  9. birder_clip-0.0.2.dev4/birder_clip/data/datasets/webdataset.py +106 -0
  10. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/inference/zero_shot.py +2 -2
  11. birder_clip-0.0.2.dev4/birder_clip/loss/__init__.py +5 -0
  12. birder_clip-0.0.2.dev4/birder_clip/loss/contrastive.py +57 -0
  13. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/__init__.py +2 -2
  14. birder_clip-0.0.2.dev4/birder_clip/model_registry/manifest.py +52 -0
  15. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/model_registry.py +92 -8
  16. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/net/base.py +22 -1
  17. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/net/clip.py +62 -18
  18. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/net/text/__init__.py +0 -2
  19. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/net/text/base.py +17 -1
  20. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/net/text/transformer.py +19 -4
  21. birder_clip-0.0.2.dev4/birder_clip/scripts/train.py +940 -0
  22. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/scripts/zero_shot.py +13 -2
  23. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/__init__.py +6 -0
  24. birder_clip-0.0.2.dev4/birder_clip/tokenizers/base.py +54 -0
  25. birder_clip-0.0.2.dev4/birder_clip/tokenizers/hf.py +86 -0
  26. birder_clip-0.0.2.dev4/birder_clip/tokenizers/registry.py +56 -0
  27. birder_clip-0.0.2.dev2/birder_clip/tokenizers/openai_clip_bpe.py → birder_clip-0.0.2.dev4/birder_clip/tokenizers/simple_tokenizer.py +8 -44
  28. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/tools/download_tokenizer.py +4 -21
  29. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/tools/show_iterator.py +77 -11
  30. birder_clip-0.0.2.dev4/birder_clip/version.py +1 -0
  31. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/PKG-INFO +2 -2
  32. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/SOURCES.txt +10 -1
  33. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/requires.txt +1 -1
  34. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/requirements/requirements.txt +1 -1
  35. birder_clip-0.0.2.dev4/tests/test_common.py +90 -0
  36. birder_clip-0.0.2.dev4/tests/test_datasets.py +63 -0
  37. birder_clip-0.0.2.dev4/tests/test_loss.py +47 -0
  38. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/tests/test_model_registry.py +0 -1
  39. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/tests/test_net.py +33 -3
  40. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/tests/test_tokenizers.py +8 -8
  41. birder_clip-0.0.2.dev2/birder_clip/__init__.py +0 -7
  42. birder_clip-0.0.2.dev2/birder_clip/common/fs_ops.py +0 -111
  43. birder_clip-0.0.2.dev2/birder_clip/common/lib.py +0 -65
  44. birder_clip-0.0.2.dev2/birder_clip/tokenizers/base.py +0 -8
  45. birder_clip-0.0.2.dev2/birder_clip/tokenizers/hf.py +0 -43
  46. birder_clip-0.0.2.dev2/birder_clip/tokenizers/registry.py +0 -53
  47. birder_clip-0.0.2.dev2/birder_clip/version.py +0 -1
  48. birder_clip-0.0.2.dev2/tests/test_common.py +0 -27
  49. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/LICENSE +0 -0
  50. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/README.md +0 -0
  51. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/common/__init__.py +0 -0
  52. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/conf/__init__.py +0 -0
  53. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/data/__init__.py +0 -0
  54. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/__init__.py +0 -0
  55. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/csv.py +0 -0
  56. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/inference/__init__.py +0 -0
  57. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/inference/zero_shot_templates.py +0 -0
  58. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/net/__init__.py +0 -0
  59. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/py.typed +0 -0
  60. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/scripts/__init__.py +0 -0
  61. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  62. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/tools/__init__.py +0 -0
  63. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip/tools/__main__.py +0 -0
  64. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/dependency_links.txt +0 -0
  65. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/top_level.txt +0 -0
  66. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/pyproject.toml +0 -0
  67. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/requirements/_requirements-dev.txt +0 -0
  68. {birder_clip-0.0.2.dev2 → birder_clip-0.0.2.dev4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder_clip
3
- Version: 0.0.2.dev2
3
+ Version: 0.0.2.dev4
4
4
  Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
24
24
  Requires-Python: >=3.11
25
25
  Description-Content-Type: text/markdown
26
26
  License-File: LICENSE
27
- Requires-Dist: birder>=0.5.2
27
+ Requires-Dist: birder>=0.5.4
28
28
  Requires-Dist: ftfy>=6.3.1
29
29
  Requires-Dist: regex>=2025.7.29
30
30
  Requires-Dist: tqdm>=4.67.0
@@ -0,0 +1,20 @@
1
+ from birder.data.transforms.classification import inference_preset as inference_transform
2
+
3
+ from birder_clip.common.fs_ops import load_pretrained_model
4
+ from birder_clip.common.fs_ops import load_pretrained_model_and_transform
5
+ from birder_clip.common.fs_ops import load_pretrained_tokenizer
6
+ from birder_clip.common.lib import get_channels_from_signature
7
+ from birder_clip.common.lib import get_size_from_signature
8
+ from birder_clip.model_registry.model_registry import list_pretrained_models
9
+ from birder_clip.version import __version__
10
+
11
+ __all__ = [
12
+ "inference_transform",
13
+ "get_channels_from_signature",
14
+ "get_size_from_signature",
15
+ "list_pretrained_models",
16
+ "load_pretrained_model",
17
+ "load_pretrained_model_and_transform",
18
+ "load_pretrained_tokenizer",
19
+ "__version__",
20
+ ]
@@ -0,0 +1,610 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from collections.abc import Callable
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from typing import NamedTuple
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from birder.common import cli
12
+ from birder.conf import settings
13
+ from birder.data.transforms.classification import RGBType
14
+ from birder.data.transforms.classification import inference_preset
15
+
16
+ from birder_clip.common import lib
17
+ from birder_clip.model_registry import Task
18
+ from birder_clip.model_registry import registry
19
+ from birder_clip.model_registry.manifest import EncoderMetadataType
20
+ from birder_clip.model_registry.manifest import FileFormatType
21
+ from birder_clip.net.base import BaseNet
22
+ from birder_clip.net.base import SignatureType
23
+ from birder_clip.tokenizers import Tokenizer
24
+ from birder_clip.tokenizers import get_tokenizer
25
+ from birder_clip.tokenizers.hf import download_hf_tokenizer
26
+ from birder_clip.tokenizers.hf import get_hf_tokenizer_source
27
+ from birder_clip.version import __version__
28
+
29
+ try:
30
+ import safetensors
31
+ import safetensors.torch
32
+
33
+ _HAS_SAFETENSORS = True
34
+ except ImportError:
35
+ _HAS_SAFETENSORS = False
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class ModelInfo(NamedTuple):
41
+ signature: SignatureType
42
+ rgb_stats: RGBType
43
+ custom_config: Optional[dict[str, Any]] = None
44
+
45
+
46
+ def write_config(network_name: str, net: BaseNet, signature: SignatureType, rgb_stats: RGBType) -> None:
47
+ model_config = lib.get_image_text_network_config(net, signature, rgb_stats)
48
+ config_file = settings.MODELS_DIR.joinpath(f"{network_name}.json")
49
+ logger.info(f"Writing {config_file}")
50
+ with open(config_file, "w", encoding="utf-8") as handle:
51
+ json.dump(model_config, handle, indent=2)
52
+
53
+
54
+ def _split_encoder_metadata(encoder: Optional[EncoderMetadataType]) -> tuple[Optional[str], Optional[dict[str, Any]]]:
55
+ if encoder is None:
56
+ return (None, None)
57
+ if isinstance(encoder, str):
58
+ return (encoder, None)
59
+
60
+ if "network" not in encoder:
61
+ raise ValueError("Encoder metadata must include a 'network' field")
62
+
63
+ return (None, encoder) # type: ignore[return-value]
64
+
65
+
66
+ def model_path(
67
+ network_name: str,
68
+ *,
69
+ epoch: Optional[int | str] = None,
70
+ file_format: FileFormatType = "pt",
71
+ states: bool = False,
72
+ ) -> Path:
73
+ if epoch is not None:
74
+ file_name = f"{network_name}_{epoch}"
75
+ else:
76
+ file_name = network_name
77
+
78
+ if states is True:
79
+ file_name = f"{file_name}_states.pt"
80
+ else:
81
+ file_name = f"{file_name}.{file_format}"
82
+
83
+ return settings.MODELS_DIR.joinpath(file_name)
84
+
85
+
86
+ def _checkpoint_states(
87
+ states_path: Path,
88
+ optimizer: Optional[torch.optim.Optimizer],
89
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
90
+ scaler: Optional[torch.amp.grad_scaler.GradScaler],
91
+ model_base: Optional[torch.nn.Module],
92
+ **extra_states: Optional[dict[str, Any]],
93
+ ) -> None:
94
+ if optimizer is None and scheduler is None and scaler is None and model_base is None and len(extra_states) == 0:
95
+ return
96
+
97
+ kwargs = {}
98
+ if optimizer is not None:
99
+ kwargs["optimizer_state"] = optimizer.state_dict()
100
+ if scheduler is not None:
101
+ kwargs["scheduler_state"] = scheduler.state_dict()
102
+ if scaler is not None:
103
+ kwargs["scaler_state"] = scaler.state_dict()
104
+ if model_base is not None:
105
+ kwargs["model_base_state"] = model_base.state_dict()
106
+ kwargs.update({k: v for k, v in extra_states.items() if v is not None})
107
+
108
+ logger.info(f"Saving training states {states_path}...")
109
+ torch.save(kwargs, states_path)
110
+
111
+
112
+ class TrainingStates(NamedTuple):
113
+ optimizer_state: Optional[dict[str, Any]]
114
+ scheduler_state: Optional[dict[str, Any]]
115
+ scaler_state: Optional[dict[str, Any]]
116
+ model_base_state: Optional[dict[str, Any]]
117
+ ema_model_state: Optional[dict[str, Any]] = None
118
+ extra_states: Optional[dict[str, Any]] = None
119
+
120
+ @classmethod
121
+ def empty(cls) -> "TrainingStates":
122
+ return cls(None, None, None, None, None)
123
+
124
+
125
+ def _load_states(states_path: Path, device: torch.device) -> TrainingStates:
126
+ if states_path.exists() is True:
127
+ logger.info(f"Loading states from {states_path} on device {device}...")
128
+ states_dict: dict[str, Any] = torch.load(states_path, map_location=device, weights_only=True)
129
+ optimizer_state = states_dict.pop("optimizer_state", None)
130
+ scheduler_state = states_dict.pop("scheduler_state", None)
131
+ scaler_state = states_dict.pop("scaler_state", None)
132
+ model_base_state = states_dict.pop("model_base_state", None)
133
+ extra_states = {}
134
+ for state in states_dict:
135
+ extra_states[state] = states_dict[state]
136
+
137
+ return TrainingStates(
138
+ optimizer_state=optimizer_state,
139
+ scheduler_state=scheduler_state,
140
+ scaler_state=scaler_state,
141
+ model_base_state=model_base_state,
142
+ extra_states=extra_states,
143
+ )
144
+
145
+ logger.debug("Checkpoint training states not found, returning empty states")
146
+ return TrainingStates.empty()
147
+
148
+
149
+ def checkpoint_model(
150
+ network_name: str,
151
+ epoch: int,
152
+ net: torch.nn.Module,
153
+ signature: SignatureType,
154
+ rgb_stats: RGBType,
155
+ optimizer: Optional[torch.optim.Optimizer],
156
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
157
+ scaler: Optional[torch.amp.grad_scaler.GradScaler],
158
+ model_base: Optional[torch.nn.Module],
159
+ *,
160
+ external_config: Optional[dict[str, Any]] = None,
161
+ **extra_states: Optional[dict[str, Any]],
162
+ ) -> None:
163
+ kwargs = {}
164
+ if external_config is not None:
165
+ kwargs["config"] = external_config
166
+
167
+ path = model_path(network_name, epoch=epoch)
168
+ states_path = model_path(network_name, epoch=epoch, states=True)
169
+ logger.info(f"Saving model checkpoint {path}...")
170
+ torch.save(
171
+ {
172
+ "state": net.state_dict(),
173
+ "birder_clip_version": __version__,
174
+ "task": net.task,
175
+ "signature": signature,
176
+ "rgb_stats": rgb_stats,
177
+ **kwargs,
178
+ },
179
+ path,
180
+ )
181
+
182
+ _checkpoint_states(states_path, optimizer, scheduler, scaler, model_base, **extra_states)
183
+
184
+
185
+ def clean_checkpoints(network_name: str, keep_last: int) -> None:
186
+ epoch = "*[0-9]"
187
+ models_glob = str(model_path(network_name, epoch=epoch))
188
+ states_glob = str(model_path(network_name, epoch=epoch, states=True))
189
+ model_pattern = re.compile(r".*_([1-9][0-9]*)\.pt$")
190
+ states_pattern = re.compile(r".*_([1-9][0-9]*)_states\.pt$")
191
+
192
+ model_paths = list(settings.BASE_DIR.glob(models_glob))
193
+ for p in sorted(model_paths, key=lambda p: p.stat().st_mtime)[:-keep_last]:
194
+ if model_pattern.search(str(p)) is not None:
195
+ logger.info(f"Removing checkpoint {p}...")
196
+ p.unlink()
197
+
198
+ state_paths = list(settings.BASE_DIR.glob(states_glob))
199
+ for p in sorted(state_paths, key=lambda p: p.stat().st_mtime)[:-keep_last]:
200
+ if states_pattern.search(str(p)) is not None:
201
+ logger.info(f"Removing checkpoint states {p}...")
202
+ p.unlink()
203
+
204
+
205
+ class CheckpointStates(NamedTuple):
206
+ net: BaseNet
207
+ rgb_stats: RGBType
208
+ training_states: TrainingStates
209
+
210
+
211
+ def load_checkpoint(
212
+ device: torch.device,
213
+ network: str,
214
+ *,
215
+ config: Optional[dict[str, Any]] = None,
216
+ tag: Optional[str] = None,
217
+ image_encoder: Optional[str] = None,
218
+ text_encoder: Optional[str] = None,
219
+ embed_dim: Optional[int] = None,
220
+ tokenizer: Optional[str] = None,
221
+ image_encoder_config: Optional[dict[str, Any]] = None,
222
+ text_encoder_config: Optional[dict[str, Any]] = None,
223
+ epoch: Optional[int] = None,
224
+ new_size: Optional[tuple[int, int]] = None,
225
+ new_context_length: Optional[int] = None,
226
+ strict: bool = True,
227
+ ) -> CheckpointStates:
228
+ network_name = lib.get_image_text_network_name(
229
+ network,
230
+ tag=tag,
231
+ image_encoder=image_encoder,
232
+ text_encoder=text_encoder,
233
+ embed_dim=embed_dim,
234
+ tokenizer=tokenizer,
235
+ )
236
+ path = model_path(network_name, epoch=epoch)
237
+ states_path = model_path(network_name, epoch=epoch, states=True)
238
+
239
+ logger.info(f"Loading model from {path} on device {device}...")
240
+ model_dict: dict[str, Any] = torch.load(path, map_location=device, weights_only=True)
241
+ training_states = _load_states(states_path, device)
242
+
243
+ signature: SignatureType = model_dict["signature"]
244
+ rgb_stats: RGBType = model_dict["rgb_stats"]
245
+ input_channels = lib.get_channels_from_signature(signature)
246
+ size = lib.get_size_from_signature(signature)
247
+ context_length = lib.get_context_length_from_signature(signature)
248
+ logger.debug(f"Loaded model with RGB stats: {rgb_stats}")
249
+ logger.debug(f"Loaded model input size is {size}")
250
+
251
+ registered_config = registry.all_nets[network.lower()].config # type: ignore[misc]
252
+ loaded_config = model_dict.get("config", {})
253
+ checkpoint_config = {} if loaded_config is None else loaded_config.copy()
254
+ if config is not None:
255
+ checkpoint_config.update(config)
256
+
257
+ model_config = lib.get_image_text_model_config(
258
+ registered_config,
259
+ checkpoint_config,
260
+ image_encoder=image_encoder,
261
+ text_encoder=text_encoder,
262
+ embed_dim=embed_dim,
263
+ tokenizer=tokenizer,
264
+ image_encoder_config=image_encoder_config,
265
+ text_encoder_config=text_encoder_config,
266
+ input_channels=input_channels,
267
+ image_size=size,
268
+ context_length=context_length,
269
+ )
270
+ net = registry.net_factory(network, config=model_config)
271
+
272
+ if training_states.model_base_state is not None:
273
+ net.load_state_dict(training_states.model_base_state, strict=strict)
274
+ training_states = training_states._replace(ema_model_state=model_dict["state"])
275
+ else:
276
+ net.load_state_dict(model_dict["state"], strict=strict)
277
+
278
+ if new_size is not None:
279
+ net.adjust_image_size(new_size)
280
+ if new_context_length is not None:
281
+ net.adjust_context_length(new_context_length)
282
+
283
+ net.to(device)
284
+
285
+ return CheckpointStates(net, rgb_stats, training_states)
286
+
287
+
288
+ def load_model(
289
+ device: torch.device,
290
+ network: str,
291
+ *,
292
+ path: Optional[str | Path] = None,
293
+ config: Optional[dict[str, Any]] = None,
294
+ tag: Optional[str] = None,
295
+ image_encoder: Optional[str] = None,
296
+ text_encoder: Optional[str] = None,
297
+ embed_dim: Optional[int] = None,
298
+ tokenizer: Optional[str] = None,
299
+ image_encoder_config: Optional[dict[str, Any]] = None,
300
+ text_encoder_config: Optional[dict[str, Any]] = None,
301
+ epoch: Optional[int] = None,
302
+ new_size: Optional[tuple[int, int]] = None,
303
+ new_context_length: Optional[int] = None,
304
+ inference: bool,
305
+ st: bool = False,
306
+ dtype: Optional[torch.dtype] = None,
307
+ ) -> tuple[BaseNet, ModelInfo]:
308
+ if path is None:
309
+ _network_name = lib.get_image_text_network_name(
310
+ network,
311
+ tag=tag,
312
+ image_encoder=image_encoder,
313
+ text_encoder=text_encoder,
314
+ embed_dim=embed_dim,
315
+ tokenizer=tokenizer,
316
+ )
317
+ path = model_path(_network_name, epoch=epoch, file_format="safetensors" if st is True else "pt")
318
+
319
+ logger.info(f"Loading model from {path} on device {device}...")
320
+
321
+ if st is True:
322
+ assert _HAS_SAFETENSORS, "'pip install safetensors' to use .safetensors"
323
+ with safetensors.safe_open(path, framework="pt", device="cpu") as handle:
324
+ extra_files = handle.metadata()
325
+ assert extra_files is not None
326
+
327
+ signature: SignatureType = json.loads(extra_files["signature"])
328
+ rgb_stats: RGBType = json.loads(extra_files["rgb_stats"])
329
+ if "config" in extra_files and len(extra_files["config"]) > 0:
330
+ loaded_config: dict[str, Any] = json.loads(extra_files["config"])
331
+ else:
332
+ loaded_config = {}
333
+
334
+ else:
335
+ model_dict: dict[str, Any] = torch.load(path, map_location=device, weights_only=True)
336
+ signature = model_dict["signature"]
337
+ rgb_stats = model_dict["rgb_stats"]
338
+ loaded_config = model_dict.get("config", {})
339
+
340
+ size = lib.get_size_from_signature(signature)
341
+ input_channels = lib.get_channels_from_signature(signature)
342
+ context_length = lib.get_context_length_from_signature(signature)
343
+ registered_config = registry.all_nets[network.lower()].config # type: ignore[misc]
344
+ checkpoint_config = {} if loaded_config is None else loaded_config.copy()
345
+ if config is not None:
346
+ checkpoint_config.update(config)
347
+
348
+ model_config = lib.get_image_text_model_config(
349
+ registered_config,
350
+ checkpoint_config,
351
+ image_encoder=image_encoder,
352
+ text_encoder=text_encoder,
353
+ embed_dim=embed_dim,
354
+ tokenizer=tokenizer,
355
+ image_encoder_config=image_encoder_config,
356
+ text_encoder_config=text_encoder_config,
357
+ input_channels=input_channels,
358
+ image_size=size,
359
+ context_length=context_length,
360
+ )
361
+
362
+ net = registry.net_factory(network, config=model_config)
363
+ if st is True:
364
+ model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
365
+ net.load_state_dict(model_state)
366
+ else:
367
+ net.load_state_dict(model_dict["state"])
368
+
369
+ if new_size is not None:
370
+ net.adjust_image_size(new_size)
371
+ if new_context_length is not None:
372
+ net.adjust_context_length(new_context_length)
373
+
374
+ net.to(device)
375
+ if dtype is not None:
376
+ net.to(dtype)
377
+ if inference is True:
378
+ for param in net.parameters():
379
+ param.requires_grad_(False)
380
+
381
+ net.eval()
382
+
383
+ if len(loaded_config) == 0:
384
+ custom_config = None
385
+ else:
386
+ custom_config = loaded_config
387
+ logger.debug(f"Model loaded with custom config: {custom_config}")
388
+
389
+ return (net, ModelInfo(signature, rgb_stats, custom_config))
390
+
391
+
392
+ def load_pretrained_model(
393
+ weights: str,
394
+ *,
395
+ dst: Optional[str | Path] = None,
396
+ file_format: FileFormatType = "pt",
397
+ inference: bool = False,
398
+ device: Optional[torch.device] = None,
399
+ dtype: Optional[torch.dtype] = None,
400
+ custom_config: Optional[dict[str, Any]] = None,
401
+ progress_bar: bool = True,
402
+ ) -> tuple[BaseNet, ModelInfo]:
403
+ """
404
+ Load a pretrained model
405
+
406
+ Parameters
407
+ ----------
408
+ weights
409
+ Name of the pretrained weights to load from the model registry.
410
+ dst
411
+ Destination path where the model weights will be downloaded or loaded from.
412
+ If None, the model will be saved in the default models directory.
413
+ file_format
414
+ Model format.
415
+ inference
416
+ Whether to prepare the model for inference mode.
417
+ device
418
+ Device to load the model on.
419
+ dtype
420
+ Data type for model parameters and computations.
421
+ custom_config
422
+ Additional model configuration that overrides or extends the predefined configuration.
423
+ progress_bar
424
+ Whether to display a progress bar during file download.
425
+
426
+ Returns
427
+ -------
428
+ A tuple containing two elements:
429
+ - A PyTorch module (neural network model) loaded with pretrained weights.
430
+ - Model info containing signature and RGB stats.
431
+
432
+ Notes
433
+ -----
434
+ - Creates the models directory if it does not exist.
435
+ - Downloads the model weights if not already present locally.
436
+ - When inference is True, the model is set to evaluation mode with gradient calculation disabled.
437
+ - If device is None, it will default to CPU.
438
+
439
+ Examples
440
+ --------
441
+ >>> net, model_info = load_pretrained_model("openai_clip_vit_l14")
442
+ >>> net, model_info = load_pretrained_model(
443
+ ... "openai_clip_vit_l14", inference=True, device=torch.device("cuda"))
444
+ """
445
+
446
+ download_model_by_weights(weights, dst=dst, file_format=file_format, progress_bar=progress_bar)
447
+ model_metadata = registry.get_pretrained_metadata(weights)
448
+ model_file, _ = lib.get_pretrained_model_url(weights, file_format)
449
+ if dst is None:
450
+ dst = settings.MODELS_DIR.joinpath(model_file)
451
+
452
+ if device is None:
453
+ device = torch.device("cpu")
454
+
455
+ if model_metadata["task"] != Task.IMAGE_TEXT:
456
+ raise ValueError(f"Unknown model type: {model_metadata['task']}")
457
+
458
+ image_encoder, image_config = _split_encoder_metadata(model_metadata["net"].get("image_encoder", None))
459
+ text_encoder, text_config = _split_encoder_metadata(model_metadata["net"].get("text_encoder", None))
460
+
461
+ pretrained_config: dict[str, Any] = {}
462
+ if image_config is not None:
463
+ pretrained_config["image"] = image_config
464
+
465
+ if text_config is not None:
466
+ pretrained_config["text"] = text_config
467
+
468
+ if custom_config is not None:
469
+ pretrained_config.update(custom_config)
470
+ if len(pretrained_config) > 0:
471
+ config = pretrained_config
472
+ else:
473
+ config = None
474
+
475
+ return load_model(
476
+ device,
477
+ model_metadata["net"]["network"],
478
+ path=dst,
479
+ config=config,
480
+ tag=model_metadata["net"].get("tag", None),
481
+ image_encoder=image_encoder,
482
+ text_encoder=text_encoder,
483
+ embed_dim=model_metadata["net"].get("embed_dim", None),
484
+ tokenizer=model_metadata["net"].get("tokenizer", None),
485
+ inference=inference,
486
+ st=file_format == "safetensors",
487
+ dtype=dtype,
488
+ )
489
+
490
+
491
+ def load_pretrained_model_and_transform(
492
+ weights: str,
493
+ *,
494
+ dst: Optional[str | Path] = None,
495
+ file_format: FileFormatType = "pt",
496
+ inference: bool = True,
497
+ device: Optional[torch.device] = None,
498
+ dtype: Optional[torch.dtype] = None,
499
+ custom_config: Optional[dict[str, Any]] = None,
500
+ progress_bar: bool = True,
501
+ classification_kwargs: Optional[dict[str, Any]] = None,
502
+ ) -> tuple[BaseNet, ModelInfo, Callable[..., torch.Tensor]]:
503
+ """
504
+ Load a pretrained model and build the matching inference transform
505
+
506
+ This is a convenience helper for the common inference path where the model and
507
+ its default preprocessing are needed together. Image-text models use inference_preset.
508
+
509
+ Parameters
510
+ ----------
511
+ weights
512
+ Name of the pretrained weights to load from the model registry.
513
+ dst
514
+ Destination path where the model weights will be downloaded or loaded from.
515
+ file_format
516
+ Model format.
517
+ inference
518
+ Whether to prepare the model for inference mode.
519
+ device
520
+ Device to load the model on.
521
+ dtype
522
+ Data type for model parameters and computations.
523
+ custom_config
524
+ Additional model configuration that overrides or extends the predefined configuration.
525
+ progress_bar
526
+ Whether to display a progress bar during file download.
527
+ classification_kwargs
528
+ Optional keyword arguments forwarded to inference_preset.
529
+
530
+ Returns
531
+ -------
532
+ A tuple containing three elements:
533
+ - A PyTorch module (neural network model) loaded with pretrained weights.
534
+ - Model info containing signature and RGB stats.
535
+ - An inference transform matching the model task.
536
+ """
537
+
538
+ net, model_info = load_pretrained_model(
539
+ weights,
540
+ dst=dst,
541
+ file_format=file_format,
542
+ inference=inference,
543
+ device=device,
544
+ dtype=dtype,
545
+ custom_config=custom_config,
546
+ progress_bar=progress_bar,
547
+ )
548
+
549
+ size = lib.get_size_from_signature(model_info.signature)
550
+ classification_args = {} if classification_kwargs is None else dict(classification_kwargs)
551
+ transform = inference_preset(size, model_info.rgb_stats, **classification_args)
552
+
553
+ return (net, model_info, transform)
554
+
555
+
556
+ def load_pretrained_tokenizer(weights: str, *, download: bool = True, **kwargs: Any) -> Tokenizer:
557
+ """
558
+ Load the tokenizer matching pretrained weights
559
+
560
+ Parameters
561
+ ----------
562
+ weights
563
+ Name of the pretrained weights to load from the model registry.
564
+ download
565
+ Whether to download tokenizer files when needed.
566
+ kwargs
567
+ Additional tokenizer keyword arguments that override or extend the predefined configuration.
568
+
569
+ Returns
570
+ -------
571
+ A tokenizer configured for the pretrained weights.
572
+ """
573
+
574
+ model_metadata = registry.get_pretrained_metadata(weights)
575
+ tokenizer_name = model_metadata["net"].get("tokenizer", None)
576
+ if tokenizer_name is None:
577
+ tokenizer_name = registry.get_default_tokenizer(model_metadata["net"]["network"])
578
+ if tokenizer_name is None:
579
+ raise ValueError(f"Tokenizer is not available for {weights}")
580
+
581
+ tokenizer_kwargs = {"context_length": model_metadata["context_length"], **kwargs}
582
+ if download is True:
583
+ hf_source = get_hf_tokenizer_source(tokenizer_name)
584
+ if hf_source is not None:
585
+ # Match the source that get_tokenizer will use after applying caller overrides
586
+ hf_source = tokenizer_kwargs.get("source", hf_source)
587
+ download_hf_tokenizer(hf_source)
588
+
589
+ return get_tokenizer(tokenizer_name, **tokenizer_kwargs)
590
+
591
+
592
+ def download_model_by_weights(
593
+ weights: str, *, dst: Optional[str | Path] = None, file_format: FileFormatType = "pt", progress_bar: bool = True
594
+ ) -> None:
595
+ if settings.MODELS_DIR.exists() is False:
596
+ logger.info(f"Creating {settings.MODELS_DIR} directory...")
597
+ settings.MODELS_DIR.mkdir(parents=True)
598
+
599
+ model_metadata = registry.get_pretrained_metadata(weights)
600
+ if file_format not in model_metadata["formats"]:
601
+ available_formats = ", ".join(model_metadata["formats"].keys())
602
+ raise ValueError(
603
+ f"Requested format '{file_format}' not available for {weights}, available formats are: {available_formats}"
604
+ )
605
+
606
+ model_file, url = lib.get_pretrained_model_url(weights, file_format)
607
+ if dst is None:
608
+ dst = settings.MODELS_DIR.joinpath(model_file)
609
+
610
+ cli.download_file(url, dst, model_metadata["formats"][file_format]["sha256"], progress_bar=progress_bar)