birder-clip 0.0.2.dev6__tar.gz → 0.0.2.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/PKG-INFO +2 -2
- birder_clip-0.0.2.dev7/birder_clip/inference/data_parallel.py +118 -0
- birder_clip-0.0.2.dev7/birder_clip/inference/image_embeddings.py +63 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/inference/zero_shot.py +40 -33
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/clip.py +23 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/coca.py +66 -0
- birder_clip-0.0.2.dev7/birder_clip/scripts/__main__.py +25 -0
- birder_clip-0.0.2.dev7/birder_clip/scripts/embed_images.py +432 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/scripts/zero_shot.py +28 -7
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/convert_model.py +165 -5
- birder_clip-0.0.2.dev7/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/PKG-INFO +2 -2
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/SOURCES.txt +5 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/requires.txt +1 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/requirements/_requirements-dev.txt +1 -1
- birder_clip-0.0.2.dev7/tests/test_inference.py +143 -0
- birder_clip-0.0.2.dev6/birder_clip/version.py +0 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/LICENSE +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/README.md +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/fs_ops.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/lib.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/training_cli.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/training_utils.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/conf/settings.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/csv.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/fake.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/webdataset.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/loss/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/loss/coca.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/loss/contrastive.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/manifest.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/model_registry.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/base.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/base.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/conditioned_decoder.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/encoder.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/hf.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/prefix_decoder.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/py.typed +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/scripts/train.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/base.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/hf.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/openvision.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/registry.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/__main__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/download_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/list_models.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/model_info.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/show_iterator.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/stats.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/requirements/requirements.txt +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_common.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_datasets.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_loss.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_model_registry.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_net.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_net_text.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_tokenizers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder_clip
|
|
3
|
-
Version: 0.0.2.
|
|
3
|
+
Version: 0.0.2.dev7
|
|
4
4
|
Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -38,7 +38,7 @@ Requires-Dist: bandit~=1.9.4; extra == "dev"
|
|
|
38
38
|
Requires-Dist: black~=26.5.0; extra == "dev"
|
|
39
39
|
Requires-Dist: build~=1.5.0; extra == "dev"
|
|
40
40
|
Requires-Dist: bumpver~=2026.1132; extra == "dev"
|
|
41
|
-
Requires-Dist: coverage~=7.14.
|
|
41
|
+
Requires-Dist: coverage~=7.14.2; extra == "dev"
|
|
42
42
|
Requires-Dist: debugpy; extra == "dev"
|
|
43
43
|
Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
|
|
44
44
|
Requires-Dist: flake8~=7.3.0; extra == "dev"
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Inference-optimized multi-GPU parallelization for image-text models
|
|
3
|
+
|
|
4
|
+
This module provides ZeroShotInferenceDataParallel, a CLIP-style zero-shot
|
|
5
|
+
specialization of Birder's InferenceDataParallel.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from birder.inference.data_parallel import InferenceDataParallel
|
|
12
|
+
|
|
13
|
+
from birder_clip.inference.zero_shot import ZeroShotInference
|
|
14
|
+
from birder_clip.net.base import BaseNet
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ZeroShotInferenceDataParallel(InferenceDataParallel):
|
|
18
|
+
"""
|
|
19
|
+
Distributes zero-shot image inference batches across multiple GPUs
|
|
20
|
+
|
|
21
|
+
This wrapper scatters the image batch across devices and keeps a replicated
|
|
22
|
+
copy of the zero-shot text embeddings on each device. Each replica computes
|
|
23
|
+
image embeddings and zero-shot logits locally before outputs are gathered.
|
|
24
|
+
|
|
25
|
+
Important
|
|
26
|
+
---------
|
|
27
|
+
This class assumes the model is already configured for inference mode
|
|
28
|
+
(i.e., loaded with inference=True in load_model or manually set to eval mode
|
|
29
|
+
with requires_grad=False on all parameters).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
module: BaseNet,
|
|
35
|
+
text_embeddings: torch.Tensor,
|
|
36
|
+
device_ids: Optional[list[int]] = None,
|
|
37
|
+
output_device: Optional[int | str | torch.device] = None,
|
|
38
|
+
compile_replicas: bool = False,
|
|
39
|
+
compile_methods: Optional[list[str]] = None,
|
|
40
|
+
compile_mode: Optional[str] = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
if compile_methods is None:
|
|
43
|
+
compile_methods = ["encode_image", "forward_logits"]
|
|
44
|
+
|
|
45
|
+
super().__init__(
|
|
46
|
+
module,
|
|
47
|
+
device_ids=device_ids,
|
|
48
|
+
output_device=output_device,
|
|
49
|
+
compile_replicas=compile_replicas,
|
|
50
|
+
compile_methods=compile_methods,
|
|
51
|
+
compile_mode=compile_mode,
|
|
52
|
+
)
|
|
53
|
+
self.set_text_embeddings(text_embeddings)
|
|
54
|
+
|
|
55
|
+
def set_text_embeddings(self, text_embeddings: torch.Tensor) -> None:
|
|
56
|
+
if text_embeddings.ndim != 2:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"text_embeddings must be a 2D tensor of shape (num_classes, embedding_size), "
|
|
59
|
+
f"got shape {text_embeddings.size()}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.text_embeddings = [
|
|
63
|
+
text_embeddings.to(f"cuda:{device_id}", non_blocking=True) for device_id in self.device_ids
|
|
64
|
+
]
|
|
65
|
+
self.inference_modules = [
|
|
66
|
+
ZeroShotInference(replica, embeddings) for replica, embeddings in zip(self.replicas, self.text_embeddings)
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
def forward( # type: ignore[override] # pylint: disable=arguments-differ
|
|
70
|
+
self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Run zero-shot inference distributed across GPUs
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
inputs
|
|
78
|
+
Input image batch to process.
|
|
79
|
+
tta
|
|
80
|
+
Run inference with oversampling.
|
|
81
|
+
return_logits
|
|
82
|
+
If True, return raw logits instead of probabilities after softmax.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
if len(self.device_ids) == 1:
|
|
86
|
+
output = self.inference_modules[0](
|
|
87
|
+
inputs,
|
|
88
|
+
tta=tta,
|
|
89
|
+
return_logits=return_logits,
|
|
90
|
+
)
|
|
91
|
+
return self._gather([output])
|
|
92
|
+
|
|
93
|
+
scattered = self._scatter(inputs, {})
|
|
94
|
+
|
|
95
|
+
outputs = []
|
|
96
|
+
for inference, (input_chunk, _), device_id in zip(self.inference_modules, scattered, self.device_ids):
|
|
97
|
+
if input_chunk is not None and input_chunk.size(0) > 0:
|
|
98
|
+
with torch.cuda.device(device_id):
|
|
99
|
+
output = inference(
|
|
100
|
+
input_chunk,
|
|
101
|
+
tta=tta,
|
|
102
|
+
return_logits=return_logits,
|
|
103
|
+
)
|
|
104
|
+
outputs.append(output)
|
|
105
|
+
else:
|
|
106
|
+
outputs.append(None)
|
|
107
|
+
|
|
108
|
+
return self._gather(outputs)
|
|
109
|
+
|
|
110
|
+
def __repr__(self) -> str:
|
|
111
|
+
return (
|
|
112
|
+
f"ZeroShotInferenceDataParallel(\n"
|
|
113
|
+
f" devices={self.device_ids},\n"
|
|
114
|
+
f" output_device={self.output_device},\n"
|
|
115
|
+
f" src_device={self.src_device},\n"
|
|
116
|
+
f" text_embeddings_shape={tuple(self.text_embeddings[0].shape)}\n"
|
|
117
|
+
f")"
|
|
118
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import numpy.typing as npt
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from birder_clip.net.base import BaseNet
|
|
12
|
+
|
|
13
|
+
DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def infer_dataloader_iter(
|
|
17
|
+
device: torch.device,
|
|
18
|
+
net: BaseNet,
|
|
19
|
+
dataloader: DataLoader,
|
|
20
|
+
model_dtype: torch.dtype = torch.float32,
|
|
21
|
+
amp: bool = False,
|
|
22
|
+
amp_dtype: Optional[torch.dtype] = None,
|
|
23
|
+
num_samples: Optional[int] = None,
|
|
24
|
+
chunk_size: Optional[float] = None,
|
|
25
|
+
) -> Iterator[DataloaderInferenceResult]:
|
|
26
|
+
if chunk_size is None:
|
|
27
|
+
chunk_size = float("inf")
|
|
28
|
+
|
|
29
|
+
net.to(device, dtype=model_dtype)
|
|
30
|
+
embeddings_list: list[npt.NDArray[np.float32]] = []
|
|
31
|
+
sample_paths: list[str] = []
|
|
32
|
+
sample_count = 0
|
|
33
|
+
with tqdm(total=num_samples, initial=0, unit="images", unit_scale=True, leave=False) as progress:
|
|
34
|
+
for file_paths, inputs, _targets in dataloader:
|
|
35
|
+
batch_size = inputs.size(0)
|
|
36
|
+
|
|
37
|
+
# Inference
|
|
38
|
+
inputs = inputs.to(device, dtype=model_dtype)
|
|
39
|
+
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
40
|
+
embeddings = net.encode_image(inputs, normalize=True)
|
|
41
|
+
embeddings = embeddings.cpu().float().numpy()
|
|
42
|
+
|
|
43
|
+
embeddings_list.append(embeddings)
|
|
44
|
+
|
|
45
|
+
# Set sample list
|
|
46
|
+
sample_paths.extend(file_paths)
|
|
47
|
+
|
|
48
|
+
# Update progress bar
|
|
49
|
+
progress.update(n=batch_size)
|
|
50
|
+
|
|
51
|
+
# Yield results when we reach chunk_size
|
|
52
|
+
sample_count += batch_size
|
|
53
|
+
if sample_count >= chunk_size:
|
|
54
|
+
with tqdm.external_write_mode(file=sys.stderr):
|
|
55
|
+
yield (sample_paths, np.concatenate(embeddings_list, axis=0))
|
|
56
|
+
|
|
57
|
+
# Reset for next chunk
|
|
58
|
+
embeddings_list = []
|
|
59
|
+
sample_paths = []
|
|
60
|
+
sample_count = 0
|
|
61
|
+
|
|
62
|
+
if len(embeddings_list) > 0:
|
|
63
|
+
yield (sample_paths, np.concatenate(embeddings_list, axis=0))
|
|
@@ -13,6 +13,7 @@ from collections.abc import Callable
|
|
|
13
13
|
from collections.abc import Iterator
|
|
14
14
|
from collections.abc import Sequence
|
|
15
15
|
from typing import Optional
|
|
16
|
+
from typing import Protocol
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import numpy.typing as npt
|
|
@@ -27,6 +28,42 @@ from birder_clip.net.base import BaseNet
|
|
|
27
28
|
from birder_clip.tokenizers.base import Tokenizer
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
class ZeroShotInferenceModule(Protocol):
|
|
32
|
+
def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ZeroShotInference:
|
|
36
|
+
def __init__(self, net: BaseNet, text_embeddings: torch.Tensor) -> None:
|
|
37
|
+
self.net = net
|
|
38
|
+
self.text_embeddings = text_embeddings
|
|
39
|
+
|
|
40
|
+
def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor:
|
|
41
|
+
inputs = inputs.to(self.text_embeddings.device, non_blocking=True)
|
|
42
|
+
if tta is True:
|
|
43
|
+
_, _, H, W = inputs.size()
|
|
44
|
+
crop_h = int(H * 0.8)
|
|
45
|
+
crop_w = int(W * 0.8)
|
|
46
|
+
tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
|
|
47
|
+
t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
|
|
48
|
+
outs = []
|
|
49
|
+
for tta_input in tta_inputs:
|
|
50
|
+
image_embeddings = self.net.encode_image(t(tta_input), normalize=True)
|
|
51
|
+
logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
|
|
52
|
+
if return_logits is True:
|
|
53
|
+
outs.append(logits)
|
|
54
|
+
else:
|
|
55
|
+
outs.append(F.softmax(logits, dim=-1))
|
|
56
|
+
|
|
57
|
+
return torch.stack(outs).mean(dim=0)
|
|
58
|
+
|
|
59
|
+
image_embeddings = self.net.encode_image(inputs, normalize=True)
|
|
60
|
+
logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
|
|
61
|
+
if return_logits is True:
|
|
62
|
+
return logits
|
|
63
|
+
|
|
64
|
+
return F.softmax(logits, dim=-1)
|
|
65
|
+
|
|
66
|
+
|
|
30
67
|
def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
|
|
31
68
|
return [template.format(class_name) for class_name in class_names for template in templates]
|
|
32
69
|
|
|
@@ -66,39 +103,10 @@ def build_class_text_embeddings(
|
|
|
66
103
|
DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]]
|
|
67
104
|
|
|
68
105
|
|
|
69
|
-
def infer_batch(
|
|
70
|
-
net: BaseNet, inputs: torch.Tensor, text_embeddings: torch.Tensor, tta: bool = False, return_logits: bool = False
|
|
71
|
-
) -> torch.Tensor:
|
|
72
|
-
if tta is True:
|
|
73
|
-
_, _, H, W = inputs.size()
|
|
74
|
-
crop_h = int(H * 0.8)
|
|
75
|
-
crop_w = int(W * 0.8)
|
|
76
|
-
tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
|
|
77
|
-
t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
|
|
78
|
-
outs = []
|
|
79
|
-
for tta_input in tta_inputs:
|
|
80
|
-
image_embeddings = net.encode_image(t(tta_input), normalize=True)
|
|
81
|
-
logits = net.forward_logits(image_embeddings, text_embeddings)
|
|
82
|
-
if return_logits is True:
|
|
83
|
-
outs.append(logits)
|
|
84
|
-
else:
|
|
85
|
-
outs.append(F.softmax(logits, dim=-1))
|
|
86
|
-
|
|
87
|
-
return torch.stack(outs).mean(dim=0)
|
|
88
|
-
|
|
89
|
-
image_embeddings = net.encode_image(inputs, normalize=True)
|
|
90
|
-
logits = net.forward_logits(image_embeddings, text_embeddings)
|
|
91
|
-
if return_logits is True:
|
|
92
|
-
return logits
|
|
93
|
-
|
|
94
|
-
return F.softmax(logits, dim=-1)
|
|
95
|
-
|
|
96
|
-
|
|
97
106
|
def infer_dataloader_iter(
|
|
98
107
|
device: torch.device,
|
|
99
|
-
net:
|
|
108
|
+
net: ZeroShotInferenceModule,
|
|
100
109
|
dataloader: DataLoader,
|
|
101
|
-
text_embeddings: torch.Tensor,
|
|
102
110
|
tta: bool = False,
|
|
103
111
|
return_logits: bool = False,
|
|
104
112
|
model_dtype: torch.dtype = torch.float32,
|
|
@@ -111,7 +119,6 @@ def infer_dataloader_iter(
|
|
|
111
119
|
if chunk_size is None:
|
|
112
120
|
chunk_size = float("inf")
|
|
113
121
|
|
|
114
|
-
net.to(device, dtype=model_dtype)
|
|
115
122
|
out_list: list[npt.NDArray[np.float32]] = []
|
|
116
123
|
labels_list: list[npt.NDArray[np.int64]] = []
|
|
117
124
|
sample_paths: list[str] = []
|
|
@@ -121,9 +128,9 @@ def infer_dataloader_iter(
|
|
|
121
128
|
batch_size = inputs.size(0)
|
|
122
129
|
|
|
123
130
|
# Inference
|
|
124
|
-
inputs = inputs.to(
|
|
131
|
+
inputs = inputs.to(dtype=model_dtype)
|
|
125
132
|
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
126
|
-
out =
|
|
133
|
+
out = net(inputs, return_logits=return_logits, tta=tta)
|
|
127
134
|
out = out.cpu().float().numpy()
|
|
128
135
|
|
|
129
136
|
out_list.append(out)
|
|
@@ -557,6 +557,29 @@ registry.register_model_config(
|
|
|
557
557
|
},
|
|
558
558
|
)
|
|
559
559
|
|
|
560
|
+
# EVA02 CLIP - https://arxiv.org/abs/2303.15389
|
|
561
|
+
registry.register_model_config(
|
|
562
|
+
"eva02_clip_l14",
|
|
563
|
+
CLIP,
|
|
564
|
+
config={
|
|
565
|
+
"image": {
|
|
566
|
+
"network": "rope_i_vit_l14_nf_swiglu_c1",
|
|
567
|
+
"config": {"drop_path_rate": 0.0},
|
|
568
|
+
"size": (336, 336),
|
|
569
|
+
},
|
|
570
|
+
"text": {
|
|
571
|
+
"network": "transformer_encoder",
|
|
572
|
+
"config": {
|
|
573
|
+
"hidden_dim": 768,
|
|
574
|
+
"num_heads": 12,
|
|
575
|
+
"output_dim": 768,
|
|
576
|
+
},
|
|
577
|
+
},
|
|
578
|
+
"embed_dim": 768,
|
|
579
|
+
"tokenizer": "openai_clip_bpe",
|
|
580
|
+
},
|
|
581
|
+
)
|
|
582
|
+
|
|
560
583
|
# MetaCLIP - https://arxiv.org/abs/2309.16671
|
|
561
584
|
registry.register_model_config(
|
|
562
585
|
"metaclip_v1_fullcc25b_b16",
|
|
@@ -329,6 +329,38 @@ registry.register_model_config(
|
|
|
329
329
|
},
|
|
330
330
|
)
|
|
331
331
|
|
|
332
|
+
registry.register_model_config(
|
|
333
|
+
"laion_coca_vit_b32",
|
|
334
|
+
CoCa,
|
|
335
|
+
config={
|
|
336
|
+
"image": {
|
|
337
|
+
"network": "vit_b32_pn",
|
|
338
|
+
"size": (224, 224),
|
|
339
|
+
},
|
|
340
|
+
"text": {
|
|
341
|
+
"network": "transformer_encoder",
|
|
342
|
+
"config": {
|
|
343
|
+
"hidden_dim": 512,
|
|
344
|
+
"num_heads": 8,
|
|
345
|
+
"output_dim": 512,
|
|
346
|
+
"class_token": True,
|
|
347
|
+
"norm_after_pool": True,
|
|
348
|
+
"pool_type": "last",
|
|
349
|
+
},
|
|
350
|
+
"context_length": 76,
|
|
351
|
+
},
|
|
352
|
+
"decoder": {
|
|
353
|
+
"network": "conditioned_decoder",
|
|
354
|
+
"config": {
|
|
355
|
+
"num_heads": 8,
|
|
356
|
+
},
|
|
357
|
+
"context_length": 76,
|
|
358
|
+
},
|
|
359
|
+
"visual_pool": {},
|
|
360
|
+
"embed_dim": 512,
|
|
361
|
+
"tokenizer": "openai_clip_bpe",
|
|
362
|
+
},
|
|
363
|
+
)
|
|
332
364
|
registry.register_model_config(
|
|
333
365
|
"laion_coca_vit_l14",
|
|
334
366
|
CoCa,
|
|
@@ -362,6 +394,40 @@ registry.register_model_config(
|
|
|
362
394
|
},
|
|
363
395
|
)
|
|
364
396
|
|
|
397
|
+
# LAION CoCa - laion/CoCa-ViT-B-32-laion2B-s13B-b90k
|
|
398
|
+
registry.register_model_config(
|
|
399
|
+
"laion_coca_vit_b32_openclip_legacy",
|
|
400
|
+
OpenCLIPLegacyCoCa,
|
|
401
|
+
config={
|
|
402
|
+
"image": {
|
|
403
|
+
"network": "vit_b32_pn",
|
|
404
|
+
"size": (224, 224),
|
|
405
|
+
},
|
|
406
|
+
"text": {
|
|
407
|
+
"network": "transformer_encoder",
|
|
408
|
+
"config": {
|
|
409
|
+
"hidden_dim": 512,
|
|
410
|
+
"num_heads": 8,
|
|
411
|
+
"output_dim": 512,
|
|
412
|
+
"class_token": True,
|
|
413
|
+
"norm_after_pool": True,
|
|
414
|
+
"pool_type": "last",
|
|
415
|
+
},
|
|
416
|
+
"context_length": 76,
|
|
417
|
+
},
|
|
418
|
+
"decoder": {
|
|
419
|
+
"network": "conditioned_decoder",
|
|
420
|
+
"config": {
|
|
421
|
+
"num_heads": 8,
|
|
422
|
+
},
|
|
423
|
+
"context_length": 76,
|
|
424
|
+
},
|
|
425
|
+
"visual_pool": {},
|
|
426
|
+
"embed_dim": 512,
|
|
427
|
+
"tokenizer": "openai_clip_bpe",
|
|
428
|
+
},
|
|
429
|
+
)
|
|
430
|
+
|
|
365
431
|
# LAION CoCa - laion/CoCa-ViT-L-14-laion2B-s13B-b90k
|
|
366
432
|
registry.register_model_config(
|
|
367
433
|
"laion_coca_vit_l14_openclip_legacy",
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import pkgutil
|
|
2
|
+
|
|
3
|
+
import birder_clip.scripts
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def list_scripts() -> list[str]:
|
|
7
|
+
scripts = []
|
|
8
|
+
for _, name, is_pkg in pkgutil.iter_modules(birder_clip.scripts.__path__):
|
|
9
|
+
if name.startswith("_") is True: # Skip private modules
|
|
10
|
+
continue
|
|
11
|
+
|
|
12
|
+
if is_pkg is False:
|
|
13
|
+
scripts.append(name)
|
|
14
|
+
|
|
15
|
+
return scripts
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def main() -> None:
|
|
19
|
+
print("Available birder scripts:")
|
|
20
|
+
for script in list_scripts():
|
|
21
|
+
print(f" birder_clip.scripts.{script}")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if __name__ == "__main__":
|
|
25
|
+
main()
|