birder-clip 0.0.2.dev1__tar.gz → 0.0.2.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/PKG-INFO +1 -1
- birder_clip-0.0.2.dev3/birder_clip/__init__.py +20 -0
- birder_clip-0.0.2.dev3/birder_clip/common/fs_ops.py +610 -0
- birder_clip-0.0.2.dev3/birder_clip/common/lib.py +154 -0
- birder_clip-0.0.2.dev3/birder_clip/common/training_cli.py +562 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/conf/settings.py +8 -0
- birder_clip-0.0.2.dev3/birder_clip/data/datasets/csv.py +90 -0
- birder_clip-0.0.2.dev3/birder_clip/data/datasets/fake.py +23 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/inference/zero_shot.py +8 -4
- birder_clip-0.0.2.dev3/birder_clip/loss/__init__.py +5 -0
- birder_clip-0.0.2.dev3/birder_clip/loss/contrastive.py +46 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/model_registry/__init__.py +2 -2
- birder_clip-0.0.2.dev3/birder_clip/model_registry/manifest.py +52 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/model_registry/model_registry.py +92 -8
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/net/base.py +22 -1
- birder_clip-0.0.2.dev3/birder_clip/net/clip.py +263 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/net/text/__init__.py +0 -2
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/net/text/base.py +13 -1
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/net/text/transformer.py +19 -4
- birder_clip-0.0.2.dev3/birder_clip/py.typed +0 -0
- birder_clip-0.0.2.dev3/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/scripts/zero_shot.py +246 -13
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/tokenizers/__init__.py +6 -0
- birder_clip-0.0.2.dev3/birder_clip/tokenizers/base.py +54 -0
- birder_clip-0.0.2.dev3/birder_clip/tokenizers/hf.py +86 -0
- birder_clip-0.0.2.dev3/birder_clip/tokenizers/registry.py +56 -0
- birder_clip-0.0.2.dev1/birder_clip/tokenizers/openai_clip_bpe.py → birder_clip-0.0.2.dev3/birder_clip/tokenizers/simple_tokenizer.py +8 -44
- birder_clip-0.0.2.dev3/birder_clip/tools/__init__.py +0 -0
- birder_clip-0.0.2.dev3/birder_clip/tools/__main__.py +30 -0
- birder_clip-0.0.2.dev3/birder_clip/tools/download_tokenizer.py +65 -0
- birder_clip-0.0.2.dev3/birder_clip/tools/show_iterator.py +172 -0
- birder_clip-0.0.2.dev3/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip.egg-info/PKG-INFO +1 -1
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip.egg-info/SOURCES.txt +16 -1
- birder_clip-0.0.2.dev3/tests/test_common.py +90 -0
- birder_clip-0.0.2.dev3/tests/test_datasets.py +57 -0
- birder_clip-0.0.2.dev3/tests/test_loss.py +47 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/tests/test_model_registry.py +0 -1
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/tests/test_net.py +33 -3
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/tests/test_tokenizers.py +8 -8
- birder_clip-0.0.2.dev1/birder_clip/__init__.py +0 -7
- birder_clip-0.0.2.dev1/birder_clip/common/fs_ops.py +0 -111
- birder_clip-0.0.2.dev1/birder_clip/common/lib.py +0 -65
- birder_clip-0.0.2.dev1/birder_clip/net/clip.py +0 -91
- birder_clip-0.0.2.dev1/birder_clip/tokenizers/base.py +0 -8
- birder_clip-0.0.2.dev1/birder_clip/tokenizers/registry.py +0 -29
- birder_clip-0.0.2.dev1/birder_clip/version.py +0 -1
- birder_clip-0.0.2.dev1/tests/test_common.py +0 -27
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/LICENSE +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/README.md +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev1/birder_clip/inference → birder_clip-0.0.2.dev3/birder_clip/data}/__init__.py +0 -0
- {birder_clip-0.0.2.dev1/birder_clip/scripts → birder_clip-0.0.2.dev3/birder_clip/data/datasets}/__init__.py +0 -0
- /birder_clip-0.0.2.dev1/birder_clip/py.typed → /birder_clip-0.0.2.dev3/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/net/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip.egg-info/requires.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/requirements/_requirements-dev.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/requirements/requirements.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev3}/setup.cfg +0 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from birder.data.transforms.classification import inference_preset as inference_transform
|
|
2
|
+
|
|
3
|
+
from birder_clip.common.fs_ops import load_pretrained_model
|
|
4
|
+
from birder_clip.common.fs_ops import load_pretrained_model_and_transform
|
|
5
|
+
from birder_clip.common.fs_ops import load_pretrained_tokenizer
|
|
6
|
+
from birder_clip.common.lib import get_channels_from_signature
|
|
7
|
+
from birder_clip.common.lib import get_size_from_signature
|
|
8
|
+
from birder_clip.model_registry.model_registry import list_pretrained_models
|
|
9
|
+
from birder_clip.version import __version__
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"inference_transform",
|
|
13
|
+
"get_channels_from_signature",
|
|
14
|
+
"get_size_from_signature",
|
|
15
|
+
"list_pretrained_models",
|
|
16
|
+
"load_pretrained_model",
|
|
17
|
+
"load_pretrained_model_and_transform",
|
|
18
|
+
"load_pretrained_tokenizer",
|
|
19
|
+
"__version__",
|
|
20
|
+
]
|
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
from typing import NamedTuple
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from birder.common import cli
|
|
12
|
+
from birder.conf import settings
|
|
13
|
+
from birder.data.transforms.classification import RGBType
|
|
14
|
+
from birder.data.transforms.classification import inference_preset
|
|
15
|
+
|
|
16
|
+
from birder_clip.common import lib
|
|
17
|
+
from birder_clip.model_registry import Task
|
|
18
|
+
from birder_clip.model_registry import registry
|
|
19
|
+
from birder_clip.model_registry.manifest import EncoderMetadataType
|
|
20
|
+
from birder_clip.model_registry.manifest import FileFormatType
|
|
21
|
+
from birder_clip.net.base import BaseNet
|
|
22
|
+
from birder_clip.net.base import SignatureType
|
|
23
|
+
from birder_clip.tokenizers import Tokenizer
|
|
24
|
+
from birder_clip.tokenizers import get_tokenizer
|
|
25
|
+
from birder_clip.tokenizers.hf import download_hf_tokenizer
|
|
26
|
+
from birder_clip.tokenizers.hf import get_hf_tokenizer_source
|
|
27
|
+
from birder_clip.version import __version__
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import safetensors
|
|
31
|
+
import safetensors.torch
|
|
32
|
+
|
|
33
|
+
_HAS_SAFETENSORS = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
_HAS_SAFETENSORS = False
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ModelInfo(NamedTuple):
|
|
41
|
+
signature: SignatureType
|
|
42
|
+
rgb_stats: RGBType
|
|
43
|
+
custom_config: Optional[dict[str, Any]] = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def write_config(network_name: str, net: BaseNet, signature: SignatureType, rgb_stats: RGBType) -> None:
|
|
47
|
+
model_config = lib.get_image_text_network_config(net, signature, rgb_stats)
|
|
48
|
+
config_file = settings.MODELS_DIR.joinpath(f"{network_name}.json")
|
|
49
|
+
logger.info(f"Writing {config_file}")
|
|
50
|
+
with open(config_file, "w", encoding="utf-8") as handle:
|
|
51
|
+
json.dump(model_config, handle, indent=2)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _split_encoder_metadata(encoder: Optional[EncoderMetadataType]) -> tuple[Optional[str], Optional[dict[str, Any]]]:
|
|
55
|
+
if encoder is None:
|
|
56
|
+
return (None, None)
|
|
57
|
+
if isinstance(encoder, str):
|
|
58
|
+
return (encoder, None)
|
|
59
|
+
|
|
60
|
+
if "network" not in encoder:
|
|
61
|
+
raise ValueError("Encoder metadata must include a 'network' field")
|
|
62
|
+
|
|
63
|
+
return (None, encoder) # type: ignore[return-value]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def model_path(
|
|
67
|
+
network_name: str,
|
|
68
|
+
*,
|
|
69
|
+
epoch: Optional[int | str] = None,
|
|
70
|
+
file_format: FileFormatType = "pt",
|
|
71
|
+
states: bool = False,
|
|
72
|
+
) -> Path:
|
|
73
|
+
if epoch is not None:
|
|
74
|
+
file_name = f"{network_name}_{epoch}"
|
|
75
|
+
else:
|
|
76
|
+
file_name = network_name
|
|
77
|
+
|
|
78
|
+
if states is True:
|
|
79
|
+
file_name = f"{file_name}_states.pt"
|
|
80
|
+
else:
|
|
81
|
+
file_name = f"{file_name}.{file_format}"
|
|
82
|
+
|
|
83
|
+
return settings.MODELS_DIR.joinpath(file_name)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _checkpoint_states(
|
|
87
|
+
states_path: Path,
|
|
88
|
+
optimizer: Optional[torch.optim.Optimizer],
|
|
89
|
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
|
|
90
|
+
scaler: Optional[torch.amp.grad_scaler.GradScaler],
|
|
91
|
+
model_base: Optional[torch.nn.Module],
|
|
92
|
+
**extra_states: Optional[dict[str, Any]],
|
|
93
|
+
) -> None:
|
|
94
|
+
if optimizer is None and scheduler is None and scaler is None and model_base is None and len(extra_states) == 0:
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
kwargs = {}
|
|
98
|
+
if optimizer is not None:
|
|
99
|
+
kwargs["optimizer_state"] = optimizer.state_dict()
|
|
100
|
+
if scheduler is not None:
|
|
101
|
+
kwargs["scheduler_state"] = scheduler.state_dict()
|
|
102
|
+
if scaler is not None:
|
|
103
|
+
kwargs["scaler_state"] = scaler.state_dict()
|
|
104
|
+
if model_base is not None:
|
|
105
|
+
kwargs["model_base_state"] = model_base.state_dict()
|
|
106
|
+
kwargs.update({k: v for k, v in extra_states.items() if v is not None})
|
|
107
|
+
|
|
108
|
+
logger.info(f"Saving training states {states_path}...")
|
|
109
|
+
torch.save(kwargs, states_path)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class TrainingStates(NamedTuple):
|
|
113
|
+
optimizer_state: Optional[dict[str, Any]]
|
|
114
|
+
scheduler_state: Optional[dict[str, Any]]
|
|
115
|
+
scaler_state: Optional[dict[str, Any]]
|
|
116
|
+
model_base_state: Optional[dict[str, Any]]
|
|
117
|
+
ema_model_state: Optional[dict[str, Any]] = None
|
|
118
|
+
extra_states: Optional[dict[str, Any]] = None
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def empty(cls) -> "TrainingStates":
|
|
122
|
+
return cls(None, None, None, None, None)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _load_states(states_path: Path, device: torch.device) -> TrainingStates:
|
|
126
|
+
if states_path.exists() is True:
|
|
127
|
+
logger.info(f"Loading states from {states_path} on device {device}...")
|
|
128
|
+
states_dict: dict[str, Any] = torch.load(states_path, map_location=device, weights_only=True)
|
|
129
|
+
optimizer_state = states_dict.pop("optimizer_state", None)
|
|
130
|
+
scheduler_state = states_dict.pop("scheduler_state", None)
|
|
131
|
+
scaler_state = states_dict.pop("scaler_state", None)
|
|
132
|
+
model_base_state = states_dict.pop("model_base_state", None)
|
|
133
|
+
extra_states = {}
|
|
134
|
+
for state in states_dict:
|
|
135
|
+
extra_states[state] = states_dict[state]
|
|
136
|
+
|
|
137
|
+
return TrainingStates(
|
|
138
|
+
optimizer_state=optimizer_state,
|
|
139
|
+
scheduler_state=scheduler_state,
|
|
140
|
+
scaler_state=scaler_state,
|
|
141
|
+
model_base_state=model_base_state,
|
|
142
|
+
extra_states=extra_states,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
logger.debug("Checkpoint training states not found, returning empty states")
|
|
146
|
+
return TrainingStates.empty()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def checkpoint_model(
|
|
150
|
+
network_name: str,
|
|
151
|
+
epoch: int,
|
|
152
|
+
net: torch.nn.Module,
|
|
153
|
+
signature: SignatureType,
|
|
154
|
+
rgb_stats: RGBType,
|
|
155
|
+
optimizer: Optional[torch.optim.Optimizer],
|
|
156
|
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
|
|
157
|
+
scaler: Optional[torch.amp.grad_scaler.GradScaler],
|
|
158
|
+
model_base: Optional[torch.nn.Module],
|
|
159
|
+
*,
|
|
160
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
161
|
+
**extra_states: Optional[dict[str, Any]],
|
|
162
|
+
) -> None:
|
|
163
|
+
kwargs = {}
|
|
164
|
+
if external_config is not None:
|
|
165
|
+
kwargs["config"] = external_config
|
|
166
|
+
|
|
167
|
+
path = model_path(network_name, epoch=epoch)
|
|
168
|
+
states_path = model_path(network_name, epoch=epoch, states=True)
|
|
169
|
+
logger.info(f"Saving model checkpoint {path}...")
|
|
170
|
+
torch.save(
|
|
171
|
+
{
|
|
172
|
+
"state": net.state_dict(),
|
|
173
|
+
"birder_clip_version": __version__,
|
|
174
|
+
"task": net.task,
|
|
175
|
+
"signature": signature,
|
|
176
|
+
"rgb_stats": rgb_stats,
|
|
177
|
+
**kwargs,
|
|
178
|
+
},
|
|
179
|
+
path,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
_checkpoint_states(states_path, optimizer, scheduler, scaler, model_base, **extra_states)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def clean_checkpoints(network_name: str, keep_last: int) -> None:
|
|
186
|
+
epoch = "*[0-9]"
|
|
187
|
+
models_glob = str(model_path(network_name, epoch=epoch))
|
|
188
|
+
states_glob = str(model_path(network_name, epoch=epoch, states=True))
|
|
189
|
+
model_pattern = re.compile(r".*_([1-9][0-9]*)\.pt$")
|
|
190
|
+
states_pattern = re.compile(r".*_([1-9][0-9]*)_states\.pt$")
|
|
191
|
+
|
|
192
|
+
model_paths = list(settings.BASE_DIR.glob(models_glob))
|
|
193
|
+
for p in sorted(model_paths, key=lambda p: p.stat().st_mtime)[:-keep_last]:
|
|
194
|
+
if model_pattern.search(str(p)) is not None:
|
|
195
|
+
logger.info(f"Removing checkpoint {p}...")
|
|
196
|
+
p.unlink()
|
|
197
|
+
|
|
198
|
+
state_paths = list(settings.BASE_DIR.glob(states_glob))
|
|
199
|
+
for p in sorted(state_paths, key=lambda p: p.stat().st_mtime)[:-keep_last]:
|
|
200
|
+
if states_pattern.search(str(p)) is not None:
|
|
201
|
+
logger.info(f"Removing checkpoint states {p}...")
|
|
202
|
+
p.unlink()
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class CheckpointStates(NamedTuple):
|
|
206
|
+
net: BaseNet
|
|
207
|
+
rgb_stats: RGBType
|
|
208
|
+
training_states: TrainingStates
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def load_checkpoint(
|
|
212
|
+
device: torch.device,
|
|
213
|
+
network: str,
|
|
214
|
+
*,
|
|
215
|
+
config: Optional[dict[str, Any]] = None,
|
|
216
|
+
tag: Optional[str] = None,
|
|
217
|
+
image_encoder: Optional[str] = None,
|
|
218
|
+
text_encoder: Optional[str] = None,
|
|
219
|
+
embed_dim: Optional[int] = None,
|
|
220
|
+
tokenizer: Optional[str] = None,
|
|
221
|
+
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
222
|
+
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
223
|
+
epoch: Optional[int] = None,
|
|
224
|
+
new_size: Optional[tuple[int, int]] = None,
|
|
225
|
+
new_context_length: Optional[int] = None,
|
|
226
|
+
strict: bool = True,
|
|
227
|
+
) -> CheckpointStates:
|
|
228
|
+
network_name = lib.get_image_text_network_name(
|
|
229
|
+
network,
|
|
230
|
+
tag=tag,
|
|
231
|
+
image_encoder=image_encoder,
|
|
232
|
+
text_encoder=text_encoder,
|
|
233
|
+
embed_dim=embed_dim,
|
|
234
|
+
tokenizer=tokenizer,
|
|
235
|
+
)
|
|
236
|
+
path = model_path(network_name, epoch=epoch)
|
|
237
|
+
states_path = model_path(network_name, epoch=epoch, states=True)
|
|
238
|
+
|
|
239
|
+
logger.info(f"Loading model from {path} on device {device}...")
|
|
240
|
+
model_dict: dict[str, Any] = torch.load(path, map_location=device, weights_only=True)
|
|
241
|
+
training_states = _load_states(states_path, device)
|
|
242
|
+
|
|
243
|
+
signature: SignatureType = model_dict["signature"]
|
|
244
|
+
rgb_stats: RGBType = model_dict["rgb_stats"]
|
|
245
|
+
input_channels = lib.get_channels_from_signature(signature)
|
|
246
|
+
size = lib.get_size_from_signature(signature)
|
|
247
|
+
context_length = lib.get_context_length_from_signature(signature)
|
|
248
|
+
logger.debug(f"Loaded model with RGB stats: {rgb_stats}")
|
|
249
|
+
logger.debug(f"Loaded model input size is {size}")
|
|
250
|
+
|
|
251
|
+
registered_config = registry.all_nets[network.lower()].config # type: ignore[misc]
|
|
252
|
+
loaded_config = model_dict.get("config", {})
|
|
253
|
+
checkpoint_config = {} if loaded_config is None else loaded_config.copy()
|
|
254
|
+
if config is not None:
|
|
255
|
+
checkpoint_config.update(config)
|
|
256
|
+
|
|
257
|
+
model_config = lib.get_image_text_model_config(
|
|
258
|
+
registered_config,
|
|
259
|
+
checkpoint_config,
|
|
260
|
+
image_encoder=image_encoder,
|
|
261
|
+
text_encoder=text_encoder,
|
|
262
|
+
embed_dim=embed_dim,
|
|
263
|
+
tokenizer=tokenizer,
|
|
264
|
+
image_encoder_config=image_encoder_config,
|
|
265
|
+
text_encoder_config=text_encoder_config,
|
|
266
|
+
input_channels=input_channels,
|
|
267
|
+
image_size=size,
|
|
268
|
+
context_length=context_length,
|
|
269
|
+
)
|
|
270
|
+
net = registry.net_factory(network, config=model_config)
|
|
271
|
+
|
|
272
|
+
if training_states.model_base_state is not None:
|
|
273
|
+
net.load_state_dict(training_states.model_base_state, strict=strict)
|
|
274
|
+
training_states = training_states._replace(ema_model_state=model_dict["state"])
|
|
275
|
+
else:
|
|
276
|
+
net.load_state_dict(model_dict["state"], strict=strict)
|
|
277
|
+
|
|
278
|
+
if new_size is not None:
|
|
279
|
+
net.adjust_image_size(new_size)
|
|
280
|
+
if new_context_length is not None:
|
|
281
|
+
net.adjust_context_length(new_context_length)
|
|
282
|
+
|
|
283
|
+
net.to(device)
|
|
284
|
+
|
|
285
|
+
return CheckpointStates(net, rgb_stats, training_states)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def load_model(
|
|
289
|
+
device: torch.device,
|
|
290
|
+
network: str,
|
|
291
|
+
*,
|
|
292
|
+
path: Optional[str | Path] = None,
|
|
293
|
+
config: Optional[dict[str, Any]] = None,
|
|
294
|
+
tag: Optional[str] = None,
|
|
295
|
+
image_encoder: Optional[str] = None,
|
|
296
|
+
text_encoder: Optional[str] = None,
|
|
297
|
+
embed_dim: Optional[int] = None,
|
|
298
|
+
tokenizer: Optional[str] = None,
|
|
299
|
+
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
300
|
+
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
301
|
+
epoch: Optional[int] = None,
|
|
302
|
+
new_size: Optional[tuple[int, int]] = None,
|
|
303
|
+
new_context_length: Optional[int] = None,
|
|
304
|
+
inference: bool,
|
|
305
|
+
st: bool = False,
|
|
306
|
+
dtype: Optional[torch.dtype] = None,
|
|
307
|
+
) -> tuple[BaseNet, ModelInfo]:
|
|
308
|
+
if path is None:
|
|
309
|
+
_network_name = lib.get_image_text_network_name(
|
|
310
|
+
network,
|
|
311
|
+
tag=tag,
|
|
312
|
+
image_encoder=image_encoder,
|
|
313
|
+
text_encoder=text_encoder,
|
|
314
|
+
embed_dim=embed_dim,
|
|
315
|
+
tokenizer=tokenizer,
|
|
316
|
+
)
|
|
317
|
+
path = model_path(_network_name, epoch=epoch, file_format="safetensors" if st is True else "pt")
|
|
318
|
+
|
|
319
|
+
logger.info(f"Loading model from {path} on device {device}...")
|
|
320
|
+
|
|
321
|
+
if st is True:
|
|
322
|
+
assert _HAS_SAFETENSORS, "'pip install safetensors' to use .safetensors"
|
|
323
|
+
with safetensors.safe_open(path, framework="pt", device="cpu") as handle:
|
|
324
|
+
extra_files = handle.metadata()
|
|
325
|
+
assert extra_files is not None
|
|
326
|
+
|
|
327
|
+
signature: SignatureType = json.loads(extra_files["signature"])
|
|
328
|
+
rgb_stats: RGBType = json.loads(extra_files["rgb_stats"])
|
|
329
|
+
if "config" in extra_files and len(extra_files["config"]) > 0:
|
|
330
|
+
loaded_config: dict[str, Any] = json.loads(extra_files["config"])
|
|
331
|
+
else:
|
|
332
|
+
loaded_config = {}
|
|
333
|
+
|
|
334
|
+
else:
|
|
335
|
+
model_dict: dict[str, Any] = torch.load(path, map_location=device, weights_only=True)
|
|
336
|
+
signature = model_dict["signature"]
|
|
337
|
+
rgb_stats = model_dict["rgb_stats"]
|
|
338
|
+
loaded_config = model_dict.get("config", {})
|
|
339
|
+
|
|
340
|
+
size = lib.get_size_from_signature(signature)
|
|
341
|
+
input_channels = lib.get_channels_from_signature(signature)
|
|
342
|
+
context_length = lib.get_context_length_from_signature(signature)
|
|
343
|
+
registered_config = registry.all_nets[network.lower()].config # type: ignore[misc]
|
|
344
|
+
checkpoint_config = {} if loaded_config is None else loaded_config.copy()
|
|
345
|
+
if config is not None:
|
|
346
|
+
checkpoint_config.update(config)
|
|
347
|
+
|
|
348
|
+
model_config = lib.get_image_text_model_config(
|
|
349
|
+
registered_config,
|
|
350
|
+
checkpoint_config,
|
|
351
|
+
image_encoder=image_encoder,
|
|
352
|
+
text_encoder=text_encoder,
|
|
353
|
+
embed_dim=embed_dim,
|
|
354
|
+
tokenizer=tokenizer,
|
|
355
|
+
image_encoder_config=image_encoder_config,
|
|
356
|
+
text_encoder_config=text_encoder_config,
|
|
357
|
+
input_channels=input_channels,
|
|
358
|
+
image_size=size,
|
|
359
|
+
context_length=context_length,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
net = registry.net_factory(network, config=model_config)
|
|
363
|
+
if st is True:
|
|
364
|
+
model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
|
|
365
|
+
net.load_state_dict(model_state)
|
|
366
|
+
else:
|
|
367
|
+
net.load_state_dict(model_dict["state"])
|
|
368
|
+
|
|
369
|
+
if new_size is not None:
|
|
370
|
+
net.adjust_image_size(new_size)
|
|
371
|
+
if new_context_length is not None:
|
|
372
|
+
net.adjust_context_length(new_context_length)
|
|
373
|
+
|
|
374
|
+
net.to(device)
|
|
375
|
+
if dtype is not None:
|
|
376
|
+
net.to(dtype)
|
|
377
|
+
if inference is True:
|
|
378
|
+
for param in net.parameters():
|
|
379
|
+
param.requires_grad_(False)
|
|
380
|
+
|
|
381
|
+
net.eval()
|
|
382
|
+
|
|
383
|
+
if len(loaded_config) == 0:
|
|
384
|
+
custom_config = None
|
|
385
|
+
else:
|
|
386
|
+
custom_config = loaded_config
|
|
387
|
+
logger.debug(f"Model loaded with custom config: {custom_config}")
|
|
388
|
+
|
|
389
|
+
return (net, ModelInfo(signature, rgb_stats, custom_config))
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def load_pretrained_model(
|
|
393
|
+
weights: str,
|
|
394
|
+
*,
|
|
395
|
+
dst: Optional[str | Path] = None,
|
|
396
|
+
file_format: FileFormatType = "pt",
|
|
397
|
+
inference: bool = False,
|
|
398
|
+
device: Optional[torch.device] = None,
|
|
399
|
+
dtype: Optional[torch.dtype] = None,
|
|
400
|
+
custom_config: Optional[dict[str, Any]] = None,
|
|
401
|
+
progress_bar: bool = True,
|
|
402
|
+
) -> tuple[BaseNet, ModelInfo]:
|
|
403
|
+
"""
|
|
404
|
+
Load a pretrained model
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
weights
|
|
409
|
+
Name of the pretrained weights to load from the model registry.
|
|
410
|
+
dst
|
|
411
|
+
Destination path where the model weights will be downloaded or loaded from.
|
|
412
|
+
If None, the model will be saved in the default models directory.
|
|
413
|
+
file_format
|
|
414
|
+
Model format.
|
|
415
|
+
inference
|
|
416
|
+
Whether to prepare the model for inference mode.
|
|
417
|
+
device
|
|
418
|
+
Device to load the model on.
|
|
419
|
+
dtype
|
|
420
|
+
Data type for model parameters and computations.
|
|
421
|
+
custom_config
|
|
422
|
+
Additional model configuration that overrides or extends the predefined configuration.
|
|
423
|
+
progress_bar
|
|
424
|
+
Whether to display a progress bar during file download.
|
|
425
|
+
|
|
426
|
+
Returns
|
|
427
|
+
-------
|
|
428
|
+
A tuple containing two elements:
|
|
429
|
+
- A PyTorch module (neural network model) loaded with pretrained weights.
|
|
430
|
+
- Model info containing signature and RGB stats.
|
|
431
|
+
|
|
432
|
+
Notes
|
|
433
|
+
-----
|
|
434
|
+
- Creates the models directory if it does not exist.
|
|
435
|
+
- Downloads the model weights if not already present locally.
|
|
436
|
+
- When inference is True, the model is set to evaluation mode with gradient calculation disabled.
|
|
437
|
+
- If device is None, it will default to CPU.
|
|
438
|
+
|
|
439
|
+
Examples
|
|
440
|
+
--------
|
|
441
|
+
>>> net, model_info = load_pretrained_model("openai_clip_vit_l14")
|
|
442
|
+
>>> net, model_info = load_pretrained_model(
|
|
443
|
+
... "openai_clip_vit_l14", inference=True, device=torch.device("cuda"))
|
|
444
|
+
"""
|
|
445
|
+
|
|
446
|
+
download_model_by_weights(weights, dst=dst, file_format=file_format, progress_bar=progress_bar)
|
|
447
|
+
model_metadata = registry.get_pretrained_metadata(weights)
|
|
448
|
+
model_file, _ = lib.get_pretrained_model_url(weights, file_format)
|
|
449
|
+
if dst is None:
|
|
450
|
+
dst = settings.MODELS_DIR.joinpath(model_file)
|
|
451
|
+
|
|
452
|
+
if device is None:
|
|
453
|
+
device = torch.device("cpu")
|
|
454
|
+
|
|
455
|
+
if model_metadata["task"] != Task.IMAGE_TEXT:
|
|
456
|
+
raise ValueError(f"Unknown model type: {model_metadata['task']}")
|
|
457
|
+
|
|
458
|
+
image_encoder, image_config = _split_encoder_metadata(model_metadata["net"].get("image_encoder", None))
|
|
459
|
+
text_encoder, text_config = _split_encoder_metadata(model_metadata["net"].get("text_encoder", None))
|
|
460
|
+
|
|
461
|
+
pretrained_config: dict[str, Any] = {}
|
|
462
|
+
if image_config is not None:
|
|
463
|
+
pretrained_config["image"] = image_config
|
|
464
|
+
|
|
465
|
+
if text_config is not None:
|
|
466
|
+
pretrained_config["text"] = text_config
|
|
467
|
+
|
|
468
|
+
if custom_config is not None:
|
|
469
|
+
pretrained_config.update(custom_config)
|
|
470
|
+
if len(pretrained_config) > 0:
|
|
471
|
+
config = pretrained_config
|
|
472
|
+
else:
|
|
473
|
+
config = None
|
|
474
|
+
|
|
475
|
+
return load_model(
|
|
476
|
+
device,
|
|
477
|
+
model_metadata["net"]["network"],
|
|
478
|
+
path=dst,
|
|
479
|
+
config=config,
|
|
480
|
+
tag=model_metadata["net"].get("tag", None),
|
|
481
|
+
image_encoder=image_encoder,
|
|
482
|
+
text_encoder=text_encoder,
|
|
483
|
+
embed_dim=model_metadata["net"].get("embed_dim", None),
|
|
484
|
+
tokenizer=model_metadata["net"].get("tokenizer", None),
|
|
485
|
+
inference=inference,
|
|
486
|
+
st=file_format == "safetensors",
|
|
487
|
+
dtype=dtype,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def load_pretrained_model_and_transform(
|
|
492
|
+
weights: str,
|
|
493
|
+
*,
|
|
494
|
+
dst: Optional[str | Path] = None,
|
|
495
|
+
file_format: FileFormatType = "pt",
|
|
496
|
+
inference: bool = True,
|
|
497
|
+
device: Optional[torch.device] = None,
|
|
498
|
+
dtype: Optional[torch.dtype] = None,
|
|
499
|
+
custom_config: Optional[dict[str, Any]] = None,
|
|
500
|
+
progress_bar: bool = True,
|
|
501
|
+
classification_kwargs: Optional[dict[str, Any]] = None,
|
|
502
|
+
) -> tuple[BaseNet, ModelInfo, Callable[..., torch.Tensor]]:
|
|
503
|
+
"""
|
|
504
|
+
Load a pretrained model and build the matching inference transform
|
|
505
|
+
|
|
506
|
+
This is a convenience helper for the common inference path where the model and
|
|
507
|
+
its default preprocessing are needed together. Image-text models use inference_preset.
|
|
508
|
+
|
|
509
|
+
Parameters
|
|
510
|
+
----------
|
|
511
|
+
weights
|
|
512
|
+
Name of the pretrained weights to load from the model registry.
|
|
513
|
+
dst
|
|
514
|
+
Destination path where the model weights will be downloaded or loaded from.
|
|
515
|
+
file_format
|
|
516
|
+
Model format.
|
|
517
|
+
inference
|
|
518
|
+
Whether to prepare the model for inference mode.
|
|
519
|
+
device
|
|
520
|
+
Device to load the model on.
|
|
521
|
+
dtype
|
|
522
|
+
Data type for model parameters and computations.
|
|
523
|
+
custom_config
|
|
524
|
+
Additional model configuration that overrides or extends the predefined configuration.
|
|
525
|
+
progress_bar
|
|
526
|
+
Whether to display a progress bar during file download.
|
|
527
|
+
classification_kwargs
|
|
528
|
+
Optional keyword arguments forwarded to inference_preset.
|
|
529
|
+
|
|
530
|
+
Returns
|
|
531
|
+
-------
|
|
532
|
+
A tuple containing three elements:
|
|
533
|
+
- A PyTorch module (neural network model) loaded with pretrained weights.
|
|
534
|
+
- Model info containing signature and RGB stats.
|
|
535
|
+
- An inference transform matching the model task.
|
|
536
|
+
"""
|
|
537
|
+
|
|
538
|
+
net, model_info = load_pretrained_model(
|
|
539
|
+
weights,
|
|
540
|
+
dst=dst,
|
|
541
|
+
file_format=file_format,
|
|
542
|
+
inference=inference,
|
|
543
|
+
device=device,
|
|
544
|
+
dtype=dtype,
|
|
545
|
+
custom_config=custom_config,
|
|
546
|
+
progress_bar=progress_bar,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
size = lib.get_size_from_signature(model_info.signature)
|
|
550
|
+
classification_args = {} if classification_kwargs is None else dict(classification_kwargs)
|
|
551
|
+
transform = inference_preset(size, model_info.rgb_stats, **classification_args)
|
|
552
|
+
|
|
553
|
+
return (net, model_info, transform)
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
def load_pretrained_tokenizer(weights: str, *, download: bool = True, **kwargs: Any) -> Tokenizer:
|
|
557
|
+
"""
|
|
558
|
+
Load the tokenizer matching pretrained weights
|
|
559
|
+
|
|
560
|
+
Parameters
|
|
561
|
+
----------
|
|
562
|
+
weights
|
|
563
|
+
Name of the pretrained weights to load from the model registry.
|
|
564
|
+
download
|
|
565
|
+
Whether to download tokenizer files when needed.
|
|
566
|
+
kwargs
|
|
567
|
+
Additional tokenizer keyword arguments that override or extend the predefined configuration.
|
|
568
|
+
|
|
569
|
+
Returns
|
|
570
|
+
-------
|
|
571
|
+
A tokenizer configured for the pretrained weights.
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
model_metadata = registry.get_pretrained_metadata(weights)
|
|
575
|
+
tokenizer_name = model_metadata["net"].get("tokenizer", None)
|
|
576
|
+
if tokenizer_name is None:
|
|
577
|
+
tokenizer_name = registry.get_default_tokenizer(model_metadata["net"]["network"])
|
|
578
|
+
if tokenizer_name is None:
|
|
579
|
+
raise ValueError(f"Tokenizer is not available for {weights}")
|
|
580
|
+
|
|
581
|
+
tokenizer_kwargs = {"context_length": model_metadata["context_length"], **kwargs}
|
|
582
|
+
if download is True:
|
|
583
|
+
hf_source = get_hf_tokenizer_source(tokenizer_name)
|
|
584
|
+
if hf_source is not None:
|
|
585
|
+
# Match the source that get_tokenizer will use after applying caller overrides
|
|
586
|
+
hf_source = tokenizer_kwargs.get("source", hf_source)
|
|
587
|
+
download_hf_tokenizer(hf_source)
|
|
588
|
+
|
|
589
|
+
return get_tokenizer(tokenizer_name, **tokenizer_kwargs)
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def download_model_by_weights(
|
|
593
|
+
weights: str, *, dst: Optional[str | Path] = None, file_format: FileFormatType = "pt", progress_bar: bool = True
|
|
594
|
+
) -> None:
|
|
595
|
+
if settings.MODELS_DIR.exists() is False:
|
|
596
|
+
logger.info(f"Creating {settings.MODELS_DIR} directory...")
|
|
597
|
+
settings.MODELS_DIR.mkdir(parents=True)
|
|
598
|
+
|
|
599
|
+
model_metadata = registry.get_pretrained_metadata(weights)
|
|
600
|
+
if file_format not in model_metadata["formats"]:
|
|
601
|
+
available_formats = ", ".join(model_metadata["formats"].keys())
|
|
602
|
+
raise ValueError(
|
|
603
|
+
f"Requested format '{file_format}' not available for {weights}, available formats are: {available_formats}"
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
model_file, url = lib.get_pretrained_model_url(weights, file_format)
|
|
607
|
+
if dst is None:
|
|
608
|
+
dst = settings.MODELS_DIR.joinpath(model_file)
|
|
609
|
+
|
|
610
|
+
cli.download_file(url, dst, model_metadata["formats"][file_format]["sha256"], progress_bar=progress_bar)
|