birder-clip 0.0.2.dev6__tar.gz → 0.0.2.dev7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/PKG-INFO +2 -2
  2. birder_clip-0.0.2.dev7/birder_clip/inference/data_parallel.py +118 -0
  3. birder_clip-0.0.2.dev7/birder_clip/inference/image_embeddings.py +63 -0
  4. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/inference/zero_shot.py +40 -33
  5. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/clip.py +23 -0
  6. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/coca.py +66 -0
  7. birder_clip-0.0.2.dev7/birder_clip/scripts/__main__.py +25 -0
  8. birder_clip-0.0.2.dev7/birder_clip/scripts/embed_images.py +432 -0
  9. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/scripts/zero_shot.py +28 -7
  10. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/convert_model.py +165 -5
  11. birder_clip-0.0.2.dev7/birder_clip/version.py +1 -0
  12. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/PKG-INFO +2 -2
  13. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/SOURCES.txt +5 -0
  14. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/requires.txt +1 -1
  15. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/requirements/_requirements-dev.txt +1 -1
  16. birder_clip-0.0.2.dev7/tests/test_inference.py +143 -0
  17. birder_clip-0.0.2.dev6/birder_clip/version.py +0 -1
  18. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/LICENSE +0 -0
  19. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/README.md +0 -0
  20. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/__init__.py +0 -0
  21. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/__init__.py +0 -0
  22. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/fs_ops.py +0 -0
  23. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/lib.py +0 -0
  24. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/training_cli.py +0 -0
  25. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/common/training_utils.py +0 -0
  26. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/conf/__init__.py +0 -0
  27. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/conf/settings.py +0 -0
  28. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/__init__.py +0 -0
  29. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/__init__.py +0 -0
  30. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/csv.py +0 -0
  31. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/fake.py +0 -0
  32. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/webdataset.py +0 -0
  33. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/inference/__init__.py +0 -0
  34. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/inference/zero_shot_templates.py +0 -0
  35. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/loss/__init__.py +0 -0
  36. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/loss/coca.py +0 -0
  37. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/loss/contrastive.py +0 -0
  38. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/__init__.py +0 -0
  39. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/manifest.py +0 -0
  40. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/model_registry.py +0 -0
  41. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/__init__.py +0 -0
  42. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/base.py +0 -0
  43. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/__init__.py +0 -0
  44. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/base.py +0 -0
  45. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/conditioned_decoder.py +0 -0
  46. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/encoder.py +0 -0
  47. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/hf.py +0 -0
  48. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/net/text/prefix_decoder.py +0 -0
  49. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/py.typed +0 -0
  50. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/scripts/__init__.py +0 -0
  51. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/scripts/train.py +0 -0
  52. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/__init__.py +0 -0
  53. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/base.py +0 -0
  54. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  55. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/hf.py +0 -0
  56. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/openvision.py +0 -0
  57. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/registry.py +0 -0
  58. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
  59. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/__init__.py +0 -0
  60. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/__main__.py +0 -0
  61. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/download_tokenizer.py +0 -0
  62. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/list_models.py +0 -0
  63. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/model_info.py +0 -0
  64. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/show_iterator.py +0 -0
  65. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip/tools/stats.py +0 -0
  66. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/dependency_links.txt +0 -0
  67. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/top_level.txt +0 -0
  68. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/pyproject.toml +0 -0
  69. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/requirements/requirements.txt +0 -0
  70. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/setup.cfg +0 -0
  71. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_common.py +0 -0
  72. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_datasets.py +0 -0
  73. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_loss.py +0 -0
  74. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_model_registry.py +0 -0
  75. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_net.py +0 -0
  76. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_net_text.py +0 -0
  77. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev7}/tests/test_tokenizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder_clip
3
- Version: 0.0.2.dev6
3
+ Version: 0.0.2.dev7
4
4
  Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -38,7 +38,7 @@ Requires-Dist: bandit~=1.9.4; extra == "dev"
38
38
  Requires-Dist: black~=26.5.0; extra == "dev"
39
39
  Requires-Dist: build~=1.5.0; extra == "dev"
40
40
  Requires-Dist: bumpver~=2026.1132; extra == "dev"
41
- Requires-Dist: coverage~=7.14.1; extra == "dev"
41
+ Requires-Dist: coverage~=7.14.2; extra == "dev"
42
42
  Requires-Dist: debugpy; extra == "dev"
43
43
  Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
44
44
  Requires-Dist: flake8~=7.3.0; extra == "dev"
@@ -0,0 +1,118 @@
1
+ """
2
+ Inference-optimized multi-GPU parallelization for image-text models
3
+
4
+ This module provides ZeroShotInferenceDataParallel, a CLIP-style zero-shot
5
+ specialization of Birder's InferenceDataParallel.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from birder.inference.data_parallel import InferenceDataParallel
12
+
13
+ from birder_clip.inference.zero_shot import ZeroShotInference
14
+ from birder_clip.net.base import BaseNet
15
+
16
+
17
+ class ZeroShotInferenceDataParallel(InferenceDataParallel):
18
+ """
19
+ Distributes zero-shot image inference batches across multiple GPUs
20
+
21
+ This wrapper scatters the image batch across devices and keeps a replicated
22
+ copy of the zero-shot text embeddings on each device. Each replica computes
23
+ image embeddings and zero-shot logits locally before outputs are gathered.
24
+
25
+ Important
26
+ ---------
27
+ This class assumes the model is already configured for inference mode
28
+ (i.e., loaded with inference=True in load_model or manually set to eval mode
29
+ with requires_grad=False on all parameters).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ module: BaseNet,
35
+ text_embeddings: torch.Tensor,
36
+ device_ids: Optional[list[int]] = None,
37
+ output_device: Optional[int | str | torch.device] = None,
38
+ compile_replicas: bool = False,
39
+ compile_methods: Optional[list[str]] = None,
40
+ compile_mode: Optional[str] = None,
41
+ ) -> None:
42
+ if compile_methods is None:
43
+ compile_methods = ["encode_image", "forward_logits"]
44
+
45
+ super().__init__(
46
+ module,
47
+ device_ids=device_ids,
48
+ output_device=output_device,
49
+ compile_replicas=compile_replicas,
50
+ compile_methods=compile_methods,
51
+ compile_mode=compile_mode,
52
+ )
53
+ self.set_text_embeddings(text_embeddings)
54
+
55
+ def set_text_embeddings(self, text_embeddings: torch.Tensor) -> None:
56
+ if text_embeddings.ndim != 2:
57
+ raise ValueError(
58
+ f"text_embeddings must be a 2D tensor of shape (num_classes, embedding_size), "
59
+ f"got shape {text_embeddings.size()}"
60
+ )
61
+
62
+ self.text_embeddings = [
63
+ text_embeddings.to(f"cuda:{device_id}", non_blocking=True) for device_id in self.device_ids
64
+ ]
65
+ self.inference_modules = [
66
+ ZeroShotInference(replica, embeddings) for replica, embeddings in zip(self.replicas, self.text_embeddings)
67
+ ]
68
+
69
+ def forward( # type: ignore[override] # pylint: disable=arguments-differ
70
+ self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False
71
+ ) -> torch.Tensor:
72
+ """
73
+ Run zero-shot inference distributed across GPUs
74
+
75
+ Parameters
76
+ ----------
77
+ inputs
78
+ Input image batch to process.
79
+ tta
80
+ Run inference with oversampling.
81
+ return_logits
82
+ If True, return raw logits instead of probabilities after softmax.
83
+ """
84
+
85
+ if len(self.device_ids) == 1:
86
+ output = self.inference_modules[0](
87
+ inputs,
88
+ tta=tta,
89
+ return_logits=return_logits,
90
+ )
91
+ return self._gather([output])
92
+
93
+ scattered = self._scatter(inputs, {})
94
+
95
+ outputs = []
96
+ for inference, (input_chunk, _), device_id in zip(self.inference_modules, scattered, self.device_ids):
97
+ if input_chunk is not None and input_chunk.size(0) > 0:
98
+ with torch.cuda.device(device_id):
99
+ output = inference(
100
+ input_chunk,
101
+ tta=tta,
102
+ return_logits=return_logits,
103
+ )
104
+ outputs.append(output)
105
+ else:
106
+ outputs.append(None)
107
+
108
+ return self._gather(outputs)
109
+
110
+ def __repr__(self) -> str:
111
+ return (
112
+ f"ZeroShotInferenceDataParallel(\n"
113
+ f" devices={self.device_ids},\n"
114
+ f" output_device={self.output_device},\n"
115
+ f" src_device={self.src_device},\n"
116
+ f" text_embeddings_shape={tuple(self.text_embeddings[0].shape)}\n"
117
+ f")"
118
+ )
@@ -0,0 +1,63 @@
1
+ import sys
2
+ from collections.abc import Iterator
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from birder_clip.net.base import BaseNet
12
+
13
+ DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32]]
14
+
15
+
16
+ def infer_dataloader_iter(
17
+ device: torch.device,
18
+ net: BaseNet,
19
+ dataloader: DataLoader,
20
+ model_dtype: torch.dtype = torch.float32,
21
+ amp: bool = False,
22
+ amp_dtype: Optional[torch.dtype] = None,
23
+ num_samples: Optional[int] = None,
24
+ chunk_size: Optional[float] = None,
25
+ ) -> Iterator[DataloaderInferenceResult]:
26
+ if chunk_size is None:
27
+ chunk_size = float("inf")
28
+
29
+ net.to(device, dtype=model_dtype)
30
+ embeddings_list: list[npt.NDArray[np.float32]] = []
31
+ sample_paths: list[str] = []
32
+ sample_count = 0
33
+ with tqdm(total=num_samples, initial=0, unit="images", unit_scale=True, leave=False) as progress:
34
+ for file_paths, inputs, _targets in dataloader:
35
+ batch_size = inputs.size(0)
36
+
37
+ # Inference
38
+ inputs = inputs.to(device, dtype=model_dtype)
39
+ with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
40
+ embeddings = net.encode_image(inputs, normalize=True)
41
+ embeddings = embeddings.cpu().float().numpy()
42
+
43
+ embeddings_list.append(embeddings)
44
+
45
+ # Set sample list
46
+ sample_paths.extend(file_paths)
47
+
48
+ # Update progress bar
49
+ progress.update(n=batch_size)
50
+
51
+ # Yield results when we reach chunk_size
52
+ sample_count += batch_size
53
+ if sample_count >= chunk_size:
54
+ with tqdm.external_write_mode(file=sys.stderr):
55
+ yield (sample_paths, np.concatenate(embeddings_list, axis=0))
56
+
57
+ # Reset for next chunk
58
+ embeddings_list = []
59
+ sample_paths = []
60
+ sample_count = 0
61
+
62
+ if len(embeddings_list) > 0:
63
+ yield (sample_paths, np.concatenate(embeddings_list, axis=0))
@@ -13,6 +13,7 @@ from collections.abc import Callable
13
13
  from collections.abc import Iterator
14
14
  from collections.abc import Sequence
15
15
  from typing import Optional
16
+ from typing import Protocol
16
17
 
17
18
  import numpy as np
18
19
  import numpy.typing as npt
@@ -27,6 +28,42 @@ from birder_clip.net.base import BaseNet
27
28
  from birder_clip.tokenizers.base import Tokenizer
28
29
 
29
30
 
31
+ class ZeroShotInferenceModule(Protocol):
32
+ def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor: ...
33
+
34
+
35
+ class ZeroShotInference:
36
+ def __init__(self, net: BaseNet, text_embeddings: torch.Tensor) -> None:
37
+ self.net = net
38
+ self.text_embeddings = text_embeddings
39
+
40
+ def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor:
41
+ inputs = inputs.to(self.text_embeddings.device, non_blocking=True)
42
+ if tta is True:
43
+ _, _, H, W = inputs.size()
44
+ crop_h = int(H * 0.8)
45
+ crop_w = int(W * 0.8)
46
+ tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
47
+ t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
48
+ outs = []
49
+ for tta_input in tta_inputs:
50
+ image_embeddings = self.net.encode_image(t(tta_input), normalize=True)
51
+ logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
52
+ if return_logits is True:
53
+ outs.append(logits)
54
+ else:
55
+ outs.append(F.softmax(logits, dim=-1))
56
+
57
+ return torch.stack(outs).mean(dim=0)
58
+
59
+ image_embeddings = self.net.encode_image(inputs, normalize=True)
60
+ logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
61
+ if return_logits is True:
62
+ return logits
63
+
64
+ return F.softmax(logits, dim=-1)
65
+
66
+
30
67
  def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
31
68
  return [template.format(class_name) for class_name in class_names for template in templates]
32
69
 
@@ -66,39 +103,10 @@ def build_class_text_embeddings(
66
103
  DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]]
67
104
 
68
105
 
69
- def infer_batch(
70
- net: BaseNet, inputs: torch.Tensor, text_embeddings: torch.Tensor, tta: bool = False, return_logits: bool = False
71
- ) -> torch.Tensor:
72
- if tta is True:
73
- _, _, H, W = inputs.size()
74
- crop_h = int(H * 0.8)
75
- crop_w = int(W * 0.8)
76
- tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
77
- t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
78
- outs = []
79
- for tta_input in tta_inputs:
80
- image_embeddings = net.encode_image(t(tta_input), normalize=True)
81
- logits = net.forward_logits(image_embeddings, text_embeddings)
82
- if return_logits is True:
83
- outs.append(logits)
84
- else:
85
- outs.append(F.softmax(logits, dim=-1))
86
-
87
- return torch.stack(outs).mean(dim=0)
88
-
89
- image_embeddings = net.encode_image(inputs, normalize=True)
90
- logits = net.forward_logits(image_embeddings, text_embeddings)
91
- if return_logits is True:
92
- return logits
93
-
94
- return F.softmax(logits, dim=-1)
95
-
96
-
97
106
  def infer_dataloader_iter(
98
107
  device: torch.device,
99
- net: BaseNet,
108
+ net: ZeroShotInferenceModule,
100
109
  dataloader: DataLoader,
101
- text_embeddings: torch.Tensor,
102
110
  tta: bool = False,
103
111
  return_logits: bool = False,
104
112
  model_dtype: torch.dtype = torch.float32,
@@ -111,7 +119,6 @@ def infer_dataloader_iter(
111
119
  if chunk_size is None:
112
120
  chunk_size = float("inf")
113
121
 
114
- net.to(device, dtype=model_dtype)
115
122
  out_list: list[npt.NDArray[np.float32]] = []
116
123
  labels_list: list[npt.NDArray[np.int64]] = []
117
124
  sample_paths: list[str] = []
@@ -121,9 +128,9 @@ def infer_dataloader_iter(
121
128
  batch_size = inputs.size(0)
122
129
 
123
130
  # Inference
124
- inputs = inputs.to(device, dtype=model_dtype)
131
+ inputs = inputs.to(dtype=model_dtype)
125
132
  with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
126
- out = infer_batch(net, inputs, text_embeddings, return_logits=return_logits, tta=tta)
133
+ out = net(inputs, return_logits=return_logits, tta=tta)
127
134
  out = out.cpu().float().numpy()
128
135
 
129
136
  out_list.append(out)
@@ -557,6 +557,29 @@ registry.register_model_config(
557
557
  },
558
558
  )
559
559
 
560
+ # EVA02 CLIP - https://arxiv.org/abs/2303.15389
561
+ registry.register_model_config(
562
+ "eva02_clip_l14",
563
+ CLIP,
564
+ config={
565
+ "image": {
566
+ "network": "rope_i_vit_l14_nf_swiglu_c1",
567
+ "config": {"drop_path_rate": 0.0},
568
+ "size": (336, 336),
569
+ },
570
+ "text": {
571
+ "network": "transformer_encoder",
572
+ "config": {
573
+ "hidden_dim": 768,
574
+ "num_heads": 12,
575
+ "output_dim": 768,
576
+ },
577
+ },
578
+ "embed_dim": 768,
579
+ "tokenizer": "openai_clip_bpe",
580
+ },
581
+ )
582
+
560
583
  # MetaCLIP - https://arxiv.org/abs/2309.16671
561
584
  registry.register_model_config(
562
585
  "metaclip_v1_fullcc25b_b16",
@@ -329,6 +329,38 @@ registry.register_model_config(
329
329
  },
330
330
  )
331
331
 
332
+ registry.register_model_config(
333
+ "laion_coca_vit_b32",
334
+ CoCa,
335
+ config={
336
+ "image": {
337
+ "network": "vit_b32_pn",
338
+ "size": (224, 224),
339
+ },
340
+ "text": {
341
+ "network": "transformer_encoder",
342
+ "config": {
343
+ "hidden_dim": 512,
344
+ "num_heads": 8,
345
+ "output_dim": 512,
346
+ "class_token": True,
347
+ "norm_after_pool": True,
348
+ "pool_type": "last",
349
+ },
350
+ "context_length": 76,
351
+ },
352
+ "decoder": {
353
+ "network": "conditioned_decoder",
354
+ "config": {
355
+ "num_heads": 8,
356
+ },
357
+ "context_length": 76,
358
+ },
359
+ "visual_pool": {},
360
+ "embed_dim": 512,
361
+ "tokenizer": "openai_clip_bpe",
362
+ },
363
+ )
332
364
  registry.register_model_config(
333
365
  "laion_coca_vit_l14",
334
366
  CoCa,
@@ -362,6 +394,40 @@ registry.register_model_config(
362
394
  },
363
395
  )
364
396
 
397
+ # LAION CoCa - laion/CoCa-ViT-B-32-laion2B-s13B-b90k
398
+ registry.register_model_config(
399
+ "laion_coca_vit_b32_openclip_legacy",
400
+ OpenCLIPLegacyCoCa,
401
+ config={
402
+ "image": {
403
+ "network": "vit_b32_pn",
404
+ "size": (224, 224),
405
+ },
406
+ "text": {
407
+ "network": "transformer_encoder",
408
+ "config": {
409
+ "hidden_dim": 512,
410
+ "num_heads": 8,
411
+ "output_dim": 512,
412
+ "class_token": True,
413
+ "norm_after_pool": True,
414
+ "pool_type": "last",
415
+ },
416
+ "context_length": 76,
417
+ },
418
+ "decoder": {
419
+ "network": "conditioned_decoder",
420
+ "config": {
421
+ "num_heads": 8,
422
+ },
423
+ "context_length": 76,
424
+ },
425
+ "visual_pool": {},
426
+ "embed_dim": 512,
427
+ "tokenizer": "openai_clip_bpe",
428
+ },
429
+ )
430
+
365
431
  # LAION CoCa - laion/CoCa-ViT-L-14-laion2B-s13B-b90k
366
432
  registry.register_model_config(
367
433
  "laion_coca_vit_l14_openclip_legacy",
@@ -0,0 +1,25 @@
1
+ import pkgutil
2
+
3
+ import birder_clip.scripts
4
+
5
+
6
+ def list_scripts() -> list[str]:
7
+ scripts = []
8
+ for _, name, is_pkg in pkgutil.iter_modules(birder_clip.scripts.__path__):
9
+ if name.startswith("_") is True: # Skip private modules
10
+ continue
11
+
12
+ if is_pkg is False:
13
+ scripts.append(name)
14
+
15
+ return scripts
16
+
17
+
18
+ def main() -> None:
19
+ print("Available birder scripts:")
20
+ for script in list_scripts():
21
+ print(f" birder_clip.scripts.{script}")
22
+
23
+
24
+ if __name__ == "__main__":
25
+ main()