birder-clip 0.0.2.dev5__tar.gz → 0.0.2.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. birder_clip-0.0.2.dev6/PKG-INFO +151 -0
  2. birder_clip-0.0.2.dev6/README.md +89 -0
  3. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/common/lib.py +1 -1
  4. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/common/training_cli.py +77 -3
  5. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/inference/zero_shot.py +34 -7
  6. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/loss/__init__.py +2 -0
  7. birder_clip-0.0.2.dev6/birder_clip/loss/coca.py +66 -0
  8. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/model_registry/model_registry.py +40 -19
  9. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/net/__init__.py +2 -0
  10. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/net/base.py +14 -2
  11. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/net/clip.py +137 -28
  12. birder_clip-0.0.2.dev6/birder_clip/net/coca.py +397 -0
  13. birder_clip-0.0.2.dev6/birder_clip/net/text/__init__.py +9 -0
  14. birder_clip-0.0.2.dev6/birder_clip/net/text/base.py +142 -0
  15. birder_clip-0.0.2.dev6/birder_clip/net/text/conditioned_decoder.py +441 -0
  16. birder_clip-0.0.2.dev5/birder_clip/net/text/transformer.py → birder_clip-0.0.2.dev6/birder_clip/net/text/encoder.py +133 -19
  17. birder_clip-0.0.2.dev6/birder_clip/net/text/hf.py +154 -0
  18. birder_clip-0.0.2.dev6/birder_clip/net/text/prefix_decoder.py +1 -0
  19. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/scripts/train.py +244 -50
  20. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/scripts/zero_shot.py +19 -2
  21. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tokenizers/base.py +1 -0
  22. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tokenizers/hf.py +2 -0
  23. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tokenizers/openvision.py +3 -1
  24. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tokenizers/simple_tokenizer.py +1 -0
  25. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/convert_model.py +1 -1
  26. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/list_models.py +7 -4
  27. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/model_info.py +1 -1
  28. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/show_iterator.py +14 -13
  29. birder_clip-0.0.2.dev6/birder_clip/version.py +1 -0
  30. birder_clip-0.0.2.dev6/birder_clip.egg-info/PKG-INFO +151 -0
  31. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip.egg-info/SOURCES.txt +7 -1
  32. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip.egg-info/requires.txt +2 -2
  33. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/requirements/_requirements-dev.txt +1 -1
  34. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/requirements/requirements.txt +1 -1
  35. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/tests/test_common.py +3 -3
  36. birder_clip-0.0.2.dev6/tests/test_loss.py +129 -0
  37. birder_clip-0.0.2.dev6/tests/test_model_registry.py +102 -0
  38. birder_clip-0.0.2.dev6/tests/test_net.py +572 -0
  39. birder_clip-0.0.2.dev6/tests/test_net_text.py +96 -0
  40. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/tests/test_tokenizers.py +29 -0
  41. birder_clip-0.0.2.dev5/PKG-INFO +0 -72
  42. birder_clip-0.0.2.dev5/README.md +0 -10
  43. birder_clip-0.0.2.dev5/birder_clip/net/text/__init__.py +0 -5
  44. birder_clip-0.0.2.dev5/birder_clip/net/text/base.py +0 -46
  45. birder_clip-0.0.2.dev5/birder_clip/version.py +0 -1
  46. birder_clip-0.0.2.dev5/birder_clip.egg-info/PKG-INFO +0 -72
  47. birder_clip-0.0.2.dev5/tests/test_loss.py +0 -47
  48. birder_clip-0.0.2.dev5/tests/test_model_registry.py +0 -74
  49. birder_clip-0.0.2.dev5/tests/test_net.py +0 -147
  50. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/LICENSE +0 -0
  51. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/__init__.py +0 -0
  52. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/common/__init__.py +0 -0
  53. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/common/fs_ops.py +0 -0
  54. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/common/training_utils.py +0 -0
  55. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/conf/__init__.py +0 -0
  56. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/conf/settings.py +0 -0
  57. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/data/__init__.py +0 -0
  58. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/data/datasets/__init__.py +0 -0
  59. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/data/datasets/csv.py +0 -0
  60. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/data/datasets/fake.py +0 -0
  61. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/data/datasets/webdataset.py +0 -0
  62. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/inference/__init__.py +0 -0
  63. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/inference/zero_shot_templates.py +0 -0
  64. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/loss/contrastive.py +0 -0
  65. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/model_registry/__init__.py +0 -0
  66. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/model_registry/manifest.py +0 -0
  67. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/py.typed +0 -0
  68. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/scripts/__init__.py +0 -0
  69. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tokenizers/__init__.py +0 -0
  70. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  71. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tokenizers/registry.py +0 -0
  72. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/__init__.py +0 -0
  73. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/__main__.py +0 -0
  74. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/download_tokenizer.py +0 -0
  75. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip/tools/stats.py +0 -0
  76. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip.egg-info/dependency_links.txt +0 -0
  77. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/birder_clip.egg-info/top_level.txt +0 -0
  78. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/pyproject.toml +0 -0
  79. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/setup.cfg +0 -0
  80. {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev6}/tests/test_datasets.py +0 -0
@@ -0,0 +1,151 @@
1
+ Metadata-Version: 2.4
2
+ Name: birder_clip
3
+ Version: 0.0.2.dev6
4
+ Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
+ Author: Ofer Hasson
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://gitlab.com/birder/birder-clip
8
+ Project-URL: Issues, https://gitlab.com/birder/birder-clip/-/issues
9
+ Project-URL: Changelog, https://gitlab.com/birder/birder-clip/-/blob/main/CHANGELOG.md
10
+ Keywords: computer-vision,clip,image-text,pytorch,deep-learning,artificial intelligence
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Education
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Topic :: Scientific/Engineering
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Topic :: Scientific/Engineering :: Image Recognition
20
+ Classifier: Topic :: Software Development
21
+ Classifier: Topic :: Software Development :: Libraries
22
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
+ Classifier: Typing :: Typed
24
+ Requires-Python: >=3.11
25
+ Description-Content-Type: text/markdown
26
+ License-File: LICENSE
27
+ Requires-Dist: birder>=0.6.0
28
+ Requires-Dist: ftfy>=6.3.1
29
+ Requires-Dist: regex>=2025.7.29
30
+ Requires-Dist: tqdm>=4.67.0
31
+ Requires-Dist: webdataset>=0.2.111
32
+ Requires-Dist: huggingface_hub
33
+ Requires-Dist: transformers
34
+ Requires-Dist: torch>=2.10.0
35
+ Requires-Dist: torchvision
36
+ Provides-Extra: dev
37
+ Requires-Dist: bandit~=1.9.4; extra == "dev"
38
+ Requires-Dist: black~=26.5.0; extra == "dev"
39
+ Requires-Dist: build~=1.5.0; extra == "dev"
40
+ Requires-Dist: bumpver~=2026.1132; extra == "dev"
41
+ Requires-Dist: coverage~=7.14.1; extra == "dev"
42
+ Requires-Dist: debugpy; extra == "dev"
43
+ Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
44
+ Requires-Dist: flake8~=7.3.0; extra == "dev"
45
+ Requires-Dist: invoke~=3.0.3; extra == "dev"
46
+ Requires-Dist: ipython; extra == "dev"
47
+ Requires-Dist: isort~=8.0.1; extra == "dev"
48
+ Requires-Dist: mkdocs~=1.6.1; extra == "dev"
49
+ Requires-Dist: mkdocs-exclude~=1.0.2; extra == "dev"
50
+ Requires-Dist: mypy~=2.1.0; extra == "dev"
51
+ Requires-Dist: parameterized~=0.9.0; extra == "dev"
52
+ Requires-Dist: pylint~=4.0.6; extra == "dev"
53
+ Requires-Dist: pytest; extra == "dev"
54
+ Requires-Dist: requests~=2.34.2; extra == "dev"
55
+ Requires-Dist: safetensors~=0.7.0; extra == "dev"
56
+ Requires-Dist: setuptools; extra == "dev"
57
+ Requires-Dist: twine~=6.2.0; extra == "dev"
58
+ Requires-Dist: types-requests~=2.33.0; extra == "dev"
59
+ Requires-Dist: urllib3~=2.7.0; extra == "dev"
60
+ Requires-Dist: wheel; extra == "dev"
61
+ Dynamic: license-file
62
+
63
+ # Birder CLIP
64
+
65
+ Birder CLIP is an early-stage Birder extension for CLIP-style image-text models, focused on practical inference and fine-tuning workflows.
66
+
67
+ - [Introduction](#introduction)
68
+ - [Setup](#setup)
69
+ - [Getting Started](#getting-started)
70
+ - [Training](#training)
71
+ - [Project Status and Contributions](#project-status-and-contributions)
72
+ - [Licenses](#licenses)
73
+ - [Acknowledgments](#acknowledgments)
74
+
75
+ ## Introduction
76
+
77
+ Birder CLIP extends [Birder](https://gitlab.com/birder/birder) with image-text models for zero-shot classification, image-text retrieval style workflows, caption generation and related multimodal computer vision tasks.
78
+
79
+ The project is aimed at image-text modeling rather than general vision-language model (VLM) chat or instruction-following systems.
80
+ It currently includes CLIP-style components, tokenizers, model registry utilities, inference scripts and training code.
81
+ Full training is supported, but for large-scale CLIP pretraining you are probably better served by [OpenCLIP](https://github.com/mlfoundations/open_clip).
82
+
83
+ ## Setup
84
+
85
+ 1. Ensure your environment meets the minimum requirements:
86
+ - Python 3.11 or newer
87
+ - PyTorch 2.10 or newer (installed for your hardware/driver stack)
88
+ - Birder 0.6.0 or newer
89
+
90
+ 1. Install the latest Birder CLIP version:
91
+
92
+ ```sh
93
+ pip install birder-clip
94
+ ```
95
+
96
+ ## Getting Started
97
+
98
+ List available image-text models:
99
+
100
+ ```sh
101
+ python -m birder_clip.tools list-models --image-text
102
+ ```
103
+
104
+ List available pretrained weights:
105
+
106
+ ```sh
107
+ python -m birder_clip.tools list-models --pretrained --verbose
108
+ ```
109
+
110
+ Run zero-shot classification on a directory of images:
111
+
112
+ ```sh
113
+ python -m birder_clip.scripts.zero_shot -n laion_clip_vit_l14 --classes eagle hawk falcon --template-set default --gpu data/images
114
+ ```
115
+
116
+ For detailed options, run:
117
+
118
+ ```sh
119
+ python -m birder_clip.scripts.zero_shot --help
120
+ python -m birder_clip.tools --help
121
+ ```
122
+
123
+ ## Training
124
+
125
+ Birder CLIP includes training support for image-text datasets in CSV and WebDataset formats, including CLIP, CoCa and LiT-style workflows.
126
+
127
+ ## Project Status and Contributions
128
+
129
+ Birder CLIP is an early alpha project. APIs, model names, checkpoints, training recipes and command-line options may change without notice.
130
+
131
+ This is currently a personal project in active development. Suggestions, bug reports and feedback are welcome through the project's issue tracker, but the project is not yet stable enough for broad external contributions.
132
+
133
+ ## Licenses
134
+
135
+ The code in this project is primarily licensed under Apache 2.0. See [LICENSE](LICENSE) for details.
136
+
137
+ Some model implementations, pretrained weights, tokenizers and converted artifacts may be derived from or depend on projects and datasets with their own licenses and usage restrictions.
138
+
139
+ **You are responsible for ensuring compliance with all licenses and conditions of any dependent licenses.**
140
+
141
+ ### Disclaimer
142
+
143
+ If you intend to use Birder CLIP, its pretrained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
144
+
145
+ It's the user's responsibility to ensure that their use of this project, including any pretrained weights or datasets, complies with all applicable licenses and legal requirements.
146
+
147
+ ## Acknowledgments
148
+
149
+ Birder CLIP owes much to the work of others in computer vision, image-text representation learning and open-source machine learning.
150
+
151
+ Special thanks to the [OpenCLIP](https://github.com/mlfoundations/open_clip) project, which serves as the main reference implementation and inspiration for much of the CLIP-style modeling and training work here. The same principle as in Birder applies: this project stands on the shoulders of many open-source projects, papers and datasets. If an attribution is missing, please open an issue.
@@ -0,0 +1,89 @@
1
+ # Birder CLIP
2
+
3
+ Birder CLIP is an early-stage Birder extension for CLIP-style image-text models, focused on practical inference and fine-tuning workflows.
4
+
5
+ - [Introduction](#introduction)
6
+ - [Setup](#setup)
7
+ - [Getting Started](#getting-started)
8
+ - [Training](#training)
9
+ - [Project Status and Contributions](#project-status-and-contributions)
10
+ - [Licenses](#licenses)
11
+ - [Acknowledgments](#acknowledgments)
12
+
13
+ ## Introduction
14
+
15
+ Birder CLIP extends [Birder](https://gitlab.com/birder/birder) with image-text models for zero-shot classification, image-text retrieval style workflows, caption generation and related multimodal computer vision tasks.
16
+
17
+ The project is aimed at image-text modeling rather than general vision-language model (VLM) chat or instruction-following systems.
18
+ It currently includes CLIP-style components, tokenizers, model registry utilities, inference scripts and training code.
19
+ Full training is supported, but for large-scale CLIP pretraining you are probably better served by [OpenCLIP](https://github.com/mlfoundations/open_clip).
20
+
21
+ ## Setup
22
+
23
+ 1. Ensure your environment meets the minimum requirements:
24
+ - Python 3.11 or newer
25
+ - PyTorch 2.10 or newer (installed for your hardware/driver stack)
26
+ - Birder 0.6.0 or newer
27
+
28
+ 1. Install the latest Birder CLIP version:
29
+
30
+ ```sh
31
+ pip install birder-clip
32
+ ```
33
+
34
+ ## Getting Started
35
+
36
+ List available image-text models:
37
+
38
+ ```sh
39
+ python -m birder_clip.tools list-models --image-text
40
+ ```
41
+
42
+ List available pretrained weights:
43
+
44
+ ```sh
45
+ python -m birder_clip.tools list-models --pretrained --verbose
46
+ ```
47
+
48
+ Run zero-shot classification on a directory of images:
49
+
50
+ ```sh
51
+ python -m birder_clip.scripts.zero_shot -n laion_clip_vit_l14 --classes eagle hawk falcon --template-set default --gpu data/images
52
+ ```
53
+
54
+ For detailed options, run:
55
+
56
+ ```sh
57
+ python -m birder_clip.scripts.zero_shot --help
58
+ python -m birder_clip.tools --help
59
+ ```
60
+
61
+ ## Training
62
+
63
+ Birder CLIP includes training support for image-text datasets in CSV and WebDataset formats, including CLIP, CoCa and LiT-style workflows.
64
+
65
+ ## Project Status and Contributions
66
+
67
+ Birder CLIP is an early alpha project. APIs, model names, checkpoints, training recipes and command-line options may change without notice.
68
+
69
+ This is currently a personal project in active development. Suggestions, bug reports and feedback are welcome through the project's issue tracker, but the project is not yet stable enough for broad external contributions.
70
+
71
+ ## Licenses
72
+
73
+ The code in this project is primarily licensed under Apache 2.0. See [LICENSE](LICENSE) for details.
74
+
75
+ Some model implementations, pretrained weights, tokenizers and converted artifacts may be derived from or depend on projects and datasets with their own licenses and usage restrictions.
76
+
77
+ **You are responsible for ensuring compliance with all licenses and conditions of any dependent licenses.**
78
+
79
+ ### Disclaimer
80
+
81
+ If you intend to use Birder CLIP, its pretrained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
82
+
83
+ It's the user's responsibility to ensure that their use of this project, including any pretrained weights or datasets, complies with all applicable licenses and legal requirements.
84
+
85
+ ## Acknowledgments
86
+
87
+ Birder CLIP owes much to the work of others in computer vision, image-text representation learning and open-source machine learning.
88
+
89
+ Special thanks to the [OpenCLIP](https://github.com/mlfoundations/open_clip) project, which serves as the main reference implementation and inspiration for much of the CLIP-style modeling and training work here. The same principle as in Birder applies: this project stands on the shoulders of many open-source projects, papers and datasets. If an attribution is missing, please open an issue.
@@ -43,7 +43,7 @@ def get_image_text_network_name(
43
43
  parts = [network]
44
44
  if image_encoder is not None:
45
45
  parts.append(image_encoder)
46
- if text_encoder is not None and text_encoder != "text_transformer":
46
+ if text_encoder is not None and text_encoder != "transformer_encoder":
47
47
  parts.append(text_encoder)
48
48
 
49
49
  if registry.exists(network) is True:
@@ -21,6 +21,11 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
21
21
  parser.add_argument("-n", "--network", type=str, help="the image-text network to train")
22
22
  parser.add_argument("-t", "--tag", type=str, help="add model tag")
23
23
  parser.add_argument("--image-encoder", type=str, help="the image encoder to use")
24
+ parser.add_argument(
25
+ "--image-encoder-pretrained",
26
+ type=str,
27
+ help="pretrained Birder image model weights path to load into the image encoder",
28
+ )
24
29
  parser.add_argument("--text-encoder", type=str, help="the text encoder to use")
25
30
  parser.add_argument("--embed-dim", type=int, metavar="N", help="shared image-text embedding dimension")
26
31
  parser.add_argument("--tokenizer", type=str, help="the tokenizer to use")
@@ -43,7 +48,23 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
43
48
 
44
49
  def add_loss_args(parser: argparse.ArgumentParser) -> None:
45
50
  group = parser.add_argument_group("Loss parameters")
46
- group.add_argument("--loss", type=str, choices=["clip"], default="clip", help="loss function to use")
51
+ group.add_argument("--loss", type=str, choices=["clip", "coca"], default="clip", help="loss function to use")
52
+ group.add_argument(
53
+ "--coca-caption-loss-weight", type=float, default=1.0, help="weight assigned to CoCa caption loss"
54
+ )
55
+ group.add_argument(
56
+ "--coca-contrastive-loss-weight", type=float, default=1.0, help="weight assigned to CoCa contrastive loss"
57
+ )
58
+
59
+
60
+ def add_freeze_args(parser: argparse.ArgumentParser) -> None:
61
+ group = parser.add_argument_group("Freeze parameters")
62
+ group.add_argument(
63
+ "--freeze-image-encoder",
64
+ default=False,
65
+ action="store_true",
66
+ help="freeze image encoder body, leaving the projection head trainable",
67
+ )
47
68
 
48
69
 
49
70
  def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: int = 32) -> None:
@@ -66,7 +87,12 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
66
87
  metavar="N",
67
88
  help="number of iterations to accumulate gradients per optimizer step",
68
89
  )
69
- # NOTE: Add flag for negative sample caching in grad accum mode
90
+ group.add_argument(
91
+ "--grad-accum-cache-negatives",
92
+ default=False,
93
+ action="store_true",
94
+ help="cache features so CLIP loss uses all accumulated microbatches as negatives",
95
+ )
70
96
 
71
97
 
72
98
  def add_lr_wd_args(parser: argparse.ArgumentParser) -> None:
@@ -250,6 +276,8 @@ def add_data_aug_args(
250
276
  )
251
277
  group.add_argument("--ra-magnitude", type=int, default=9, help="magnitude for all the RandAugment transformations")
252
278
  group.add_argument("--augmix-severity", type=int, default=3, help="severity of AugMix policy")
279
+ group.add_argument("--clip-color-jitter-prob", type=float, default=0.0, help="CLIP color jitter probability")
280
+ group.add_argument("--clip-gray-prob", type=float, default=0.0, help="CLIP grayscale probability")
253
281
  group.add_argument("--resize-min-scale", type=float, default=default_min_scale, help="random resize min scale")
254
282
  group.add_argument(
255
283
  "--re-prob",
@@ -356,6 +384,29 @@ def add_precision_args(parser: argparse.ArgumentParser) -> None:
356
384
  )
357
385
 
358
386
 
387
+ def add_grad_checkpointing_args(parser: argparse.ArgumentParser) -> None:
388
+ group = parser.add_argument_group("Gradient checkpointing parameters")
389
+ group.add_argument(
390
+ "--grad-checkpointing",
391
+ default=False,
392
+ action="store_true",
393
+ help="enable gradient checkpointing for supported models",
394
+ )
395
+ group.add_argument(
396
+ "--grad-checkpointing-segments",
397
+ type=int,
398
+ metavar="N",
399
+ help="number of checkpoint segments to request from supported models",
400
+ )
401
+ group.add_argument(
402
+ "--no-grad-checkpointing-preserve-rng-state",
403
+ dest="grad_checkpointing_preserve_rng_state",
404
+ default=True,
405
+ action="store_false",
406
+ help="disable RNG state preservation during gradient checkpointing recomputation",
407
+ )
408
+
409
+
359
410
  def add_compile_args(parser: argparse.ArgumentParser) -> None:
360
411
  group = parser.add_argument_group("Compilation parameters")
361
412
  group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
@@ -578,8 +629,12 @@ def common_args_validation(args: argparse.Namespace) -> None:
578
629
  raise cli.ValidationError("--load-states requires --resume-epoch to be set")
579
630
  if args.load_scheduler is True and args.resume_epoch is None:
580
631
  raise cli.ValidationError("--load-scheduler requires --resume-epoch to be set")
581
- if hasattr(args, "pretrained") is True and args.pretrained is True and args.resume_epoch is not None:
632
+ if args.pretrained is True and args.resume_epoch is not None:
582
633
  raise cli.ValidationError("--pretrained cannot be used with --resume-epoch")
634
+ if args.image_encoder_pretrained is not None and args.resume_epoch is not None:
635
+ raise cli.ValidationError("--image-encoder-pretrained cannot be used with --resume-epoch")
636
+ if args.pretrained is True and args.image_encoder_pretrained is not None:
637
+ raise cli.ValidationError("--image-encoder-pretrained cannot be used with --pretrained")
583
638
 
584
639
  if args.freeze_bn is True and args.sync_bn is True:
585
640
  raise cli.ValidationError("--freeze-bn cannot be used with --sync-bn")
@@ -605,10 +660,29 @@ def common_args_validation(args: argparse.Namespace) -> None:
605
660
  raise cli.ValidationError("--context-length must be positive")
606
661
  if args.grad_accum_steps < 1:
607
662
  raise cli.ValidationError("--grad-accum-steps must be >= 1")
663
+ if args.grad_accum_cache_negatives is True and args.grad_accum_steps == 1:
664
+ raise cli.ValidationError("--grad-accum-cache-negatives requires --grad-accum-steps greater than 1")
665
+ if args.grad_accum_cache_negatives is True and args.loss == "coca":
666
+ raise cli.ValidationError("--grad-accum-cache-negatives is only supported with --loss clip")
667
+
668
+ if args.coca_caption_loss_weight < 0.0:
669
+ raise cli.ValidationError("--coca-caption-loss-weight must be non-negative")
670
+ if args.coca_contrastive_loss_weight < 0.0:
671
+ raise cli.ValidationError("--coca-contrastive-loss-weight must be non-negative")
672
+
673
+ # EMA
608
674
  if args.model_ema_steps < 1:
609
675
  raise cli.ValidationError("--model-ema-steps must be >= 1")
610
676
 
677
+ # Gradient checkpointing args
678
+ if args.grad_checkpointing_segments is not None and args.grad_checkpointing_segments < 1:
679
+ raise cli.ValidationError("--grad-checkpointing-segments must be >= 1")
680
+ if args.grad_checkpointing_segments is not None and args.grad_checkpointing is False:
681
+ raise cli.ValidationError("--grad-checkpointing-segments requires --grad-checkpointing")
682
+
611
683
  if args.distributed_mode == "fsdp":
684
+ if args.grad_checkpointing is True:
685
+ raise cli.ValidationError("--grad-checkpointing cannot be used with --distributed-mode fsdp")
612
686
  if args.sync_bn is True:
613
687
  raise cli.ValidationError("--sync-bn cannot be used with --distributed-mode fsdp")
614
688
  if args.find_unused_parameters is True:
@@ -19,6 +19,8 @@ import numpy.typing as npt
19
19
  import torch
20
20
  import torch.nn.functional as F
21
21
  from torch.utils.data import DataLoader
22
+ from torchvision.transforms import v2
23
+ from torchvision.transforms.v2.functional import five_crop
22
24
  from tqdm import tqdm
23
25
 
24
26
  from birder_clip.net.base import BaseNet
@@ -64,11 +66,40 @@ def build_class_text_embeddings(
64
66
  DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]]
65
67
 
66
68
 
69
+ def infer_batch(
70
+ net: BaseNet, inputs: torch.Tensor, text_embeddings: torch.Tensor, tta: bool = False, return_logits: bool = False
71
+ ) -> torch.Tensor:
72
+ if tta is True:
73
+ _, _, H, W = inputs.size()
74
+ crop_h = int(H * 0.8)
75
+ crop_w = int(W * 0.8)
76
+ tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
77
+ t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
78
+ outs = []
79
+ for tta_input in tta_inputs:
80
+ image_embeddings = net.encode_image(t(tta_input), normalize=True)
81
+ logits = net.forward_logits(image_embeddings, text_embeddings)
82
+ if return_logits is True:
83
+ outs.append(logits)
84
+ else:
85
+ outs.append(F.softmax(logits, dim=-1))
86
+
87
+ return torch.stack(outs).mean(dim=0)
88
+
89
+ image_embeddings = net.encode_image(inputs, normalize=True)
90
+ logits = net.forward_logits(image_embeddings, text_embeddings)
91
+ if return_logits is True:
92
+ return logits
93
+
94
+ return F.softmax(logits, dim=-1)
95
+
96
+
67
97
  def infer_dataloader_iter(
68
98
  device: torch.device,
69
- net: BaseNet | torch.ScriptModule,
99
+ net: BaseNet,
70
100
  dataloader: DataLoader,
71
101
  text_embeddings: torch.Tensor,
102
+ tta: bool = False,
72
103
  return_logits: bool = False,
73
104
  model_dtype: torch.dtype = torch.float32,
74
105
  amp: bool = False,
@@ -92,12 +123,8 @@ def infer_dataloader_iter(
92
123
  # Inference
93
124
  inputs = inputs.to(device, dtype=model_dtype)
94
125
  with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
95
- image_embeddings = net.encode_image(inputs, normalize=True)
96
- logits = net.forward_logits(image_embeddings, text_embeddings)
97
- if return_logits is True:
98
- out = logits.cpu().float().numpy()
99
- else:
100
- out = F.softmax(logits, dim=-1).cpu().float().numpy()
126
+ out = infer_batch(net, inputs, text_embeddings, return_logits=return_logits, tta=tta)
127
+ out = out.cpu().float().numpy()
101
128
 
102
129
  out_list.append(out)
103
130
 
@@ -1,5 +1,7 @@
1
+ from birder_clip.loss.coca import CoCaLoss
1
2
  from birder_clip.loss.contrastive import CLIPLoss
2
3
 
3
4
  __all__ = [
5
+ "CoCaLoss",
4
6
  "CLIPLoss",
5
7
  ]
@@ -0,0 +1,66 @@
1
+ """
2
+ CoCa loss, adapted from
3
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py
4
+
5
+ Paper "CoCa: Contrastive Captioners are Image-Text Foundation Models",
6
+ https://arxiv.org/abs/2205.01917
7
+
8
+ Generated by gpt-5.5 xhigh.
9
+ """
10
+
11
+ # Reference license: MIT
12
+
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ from birder_clip.loss.contrastive import CLIPLoss
19
+
20
+
21
+ class CoCaLoss(torch.nn.Module):
22
+ """
23
+ CoCa contrastive and captioning loss
24
+
25
+ Combines the symmetric CLIP contrastive objective with autoregressive
26
+ captioning cross entropy over decoder logits.
27
+ """
28
+
29
+ def __init__(
30
+ self, *, caption_loss_weight: float = 1.0, clip_loss_weight: float = 1.0, pad_token_id: int = 0
31
+ ) -> None:
32
+ super().__init__()
33
+ self.caption_loss_weight = caption_loss_weight
34
+ self.clip_loss_weight = clip_loss_weight
35
+ self.pad_token_id = pad_token_id
36
+ self.clip_loss = CLIPLoss()
37
+
38
+ def forward(
39
+ self,
40
+ image_features: torch.Tensor,
41
+ text_features: torch.Tensor,
42
+ logits: torch.Tensor,
43
+ texts: torch.Tensor,
44
+ logit_scale: torch.Tensor,
45
+ logit_bias: Optional[torch.Tensor] = None,
46
+ ) -> dict[str, torch.Tensor]:
47
+ if self.clip_loss_weight != 0.0:
48
+ losses = self.clip_loss(image_features, text_features, logit_scale, logit_bias=logit_bias)
49
+ contrastive_loss = losses["contrastive_loss"] * self.clip_loss_weight
50
+ else:
51
+ contrastive_loss = logits.new_zeros(())
52
+
53
+ if self.caption_loss_weight != 0.0:
54
+ caption_loss = F.cross_entropy(
55
+ logits[:, :-1].permute(0, 2, 1),
56
+ texts[:, 1:],
57
+ ignore_index=self.pad_token_id,
58
+ )
59
+ caption_loss = caption_loss * self.caption_loss_weight
60
+ else:
61
+ caption_loss = logits.new_zeros(())
62
+
63
+ return {
64
+ "contrastive_loss": contrastive_loss,
65
+ "caption_loss": caption_loss,
66
+ }
@@ -15,38 +15,44 @@ from birder_clip.tokenizers.registry import get_tokenizer_info
15
15
 
16
16
  if TYPE_CHECKING is True:
17
17
  from birder_clip.net.base import BaseNet # pylint: disable=cyclic-import
18
- from birder_clip.net.text.base import TextBaseNet # pylint: disable=cyclic-import
18
+ from birder_clip.net.text.base import TextDecoderBaseNet # pylint: disable=cyclic-import
19
+ from birder_clip.net.text.base import TextEncoderBaseNet # pylint: disable=cyclic-import
19
20
 
20
- NetType: TypeAlias = type[BaseNet] | type[TextBaseNet]
21
+ BaseNetObjType: TypeAlias = BaseNet | TextEncoderBaseNet | TextDecoderBaseNet
22
+ BaseNetType: TypeAlias = type[BaseNet] | type[TextEncoderBaseNet] | type[TextDecoderBaseNet]
21
23
 
22
24
 
23
25
  class Task(StrEnum):
24
26
  IMAGE_TEXT = "image_text"
25
- TEXT = "text"
27
+ TEXT_ENCODER = "text_encoder"
28
+ TEXT_DECODER = "text_decoder"
26
29
 
27
30
 
28
31
  class ModelRegistry:
29
32
  def __init__(self) -> None:
30
- self.registered_configs: dict[str, "NetType"] = {}
33
+ self.registered_configs: dict[str, "BaseNetType"] = {}
31
34
  self._image_text_nets: dict[str, type["BaseNet"]] = {}
32
- self._text_nets: dict[str, type[TextBaseNet]] = {}
35
+ self._text_encoder_nets: dict[str, type[TextEncoderBaseNet]] = {}
36
+ self._text_decoder_nets: dict[str, type[TextDecoderBaseNet]] = {}
33
37
  self._pretrained_nets = manifest.REGISTRY_MANIFEST
34
38
 
35
39
  @property
36
- def all_nets(self) -> dict[str, "NetType"]:
37
- return {**self._image_text_nets, **self._text_nets}
40
+ def all_nets(self) -> dict[str, "BaseNetType"]:
41
+ return {**self._image_text_nets, **self._text_encoder_nets, **self._text_decoder_nets}
38
42
 
39
- def _get_models_for_task(self, task: Task) -> dict[str, "NetType"]:
43
+ def _get_models_for_task(self, task: Task) -> dict[str, "BaseNetType"]:
40
44
  if task == Task.IMAGE_TEXT:
41
- nets: dict[str, "NetType"] = self._image_text_nets
42
- elif task == Task.TEXT:
43
- nets = self._text_nets
45
+ nets: dict[str, "BaseNetType"] = self._image_text_nets
46
+ elif task == Task.TEXT_ENCODER:
47
+ nets = self._text_encoder_nets
48
+ elif task == Task.TEXT_DECODER:
49
+ nets = self._text_decoder_nets
44
50
  else:
45
51
  raise ValueError(f"Unsupported model task: {task}")
46
52
 
47
53
  return nets
48
54
 
49
- def _register_model(self, name: str, net_type: "NetType") -> None:
55
+ def _register_model(self, name: str, net_type: "BaseNetType") -> None:
50
56
  name_key = name.lower()
51
57
  task = Task(net_type.task)
52
58
  nets = self._get_models_for_task(task)
@@ -57,7 +63,9 @@ class ModelRegistry:
57
63
 
58
64
  nets[name_key] = net_type
59
65
 
60
- def register_model_config(self, name: str, net_type: "NetType", *, config: Optional[dict[str, Any]] = None) -> None:
66
+ def register_model_config(
67
+ self, name: str, net_type: "BaseNetType", *, config: Optional[dict[str, Any]] = None
68
+ ) -> None:
61
69
  name_key = name.lower()
62
70
  registered_net_type = type(name, (net_type,), {"config": config})
63
71
  self._register_model(name_key, registered_net_type)
@@ -180,17 +188,30 @@ class ModelRegistry:
180
188
 
181
189
  return metadata
182
190
 
183
- def text_factory(
191
+ def text_encoder_factory(
184
192
  self, name: str, *, config: Optional[dict[str, Any]] = None, context_length: Optional[int] = None
185
- ) -> "TextBaseNet":
193
+ ) -> "TextEncoderBaseNet":
186
194
  name_key = name.lower()
187
- return self._text_nets[name_key](config=config, context_length=context_length)
195
+ if name_key in self._text_encoder_nets:
196
+ return self._text_encoder_nets[name_key](config=config, context_length=context_length)
197
+
198
+ if name_key.startswith("hf:"):
199
+ config = {"source": name[len("hf:") :], **({} if config is None else config)}
200
+ name_key = "hf_text_encoder"
201
+
202
+ return self._text_encoder_nets[name_key](config=config, context_length=context_length)
203
+
204
+ def text_decoder_factory(
205
+ self, name: str, *, config: Optional[dict[str, Any]] = None, context_length: Optional[int] = None
206
+ ) -> "TextDecoderBaseNet":
207
+ name_key = name.lower()
208
+ return self._text_decoder_nets[name_key](config=config, context_length=context_length)
188
209
 
189
210
  def net_factory(self, name: str, *, config: Optional[dict[str, Any]] = None) -> "BaseNet":
190
211
  name_key = name.lower()
191
212
  return self._image_text_nets[name_key](config=config)
192
213
 
193
- def _metadata_type_name(self, model: "BaseNet | TextBaseNet") -> str:
214
+ def _metadata_type_name(self, model: "BaseNetObjType") -> str:
194
215
  cls = model.__class__
195
216
  bases = cls.__bases__
196
217
  if len(bases) > 1 and bases[0].__name__ == "FSDPModule":
@@ -198,14 +219,14 @@ class ModelRegistry:
198
219
 
199
220
  return cls.__name__.lower()
200
221
 
201
- def get_model_base_name(self, model: "BaseNet | TextBaseNet") -> str:
222
+ def get_model_base_name(self, model: "BaseNetObjType") -> str:
202
223
  type_name = self._metadata_type_name(model)
203
224
  if type_name in self.registered_configs:
204
225
  type_name = self.registered_configs[type_name].__bases__[0].__name__.lower()
205
226
 
206
227
  return type_name
207
228
 
208
- def get_registered_name(self, model: "BaseNet | TextBaseNet") -> Optional[str]:
229
+ def get_registered_name(self, model: "BaseNetObjType") -> Optional[str]:
209
230
  type_name = self._metadata_type_name(model)
210
231
  if type_name in self.registered_configs:
211
232
  return type_name
@@ -1,5 +1,7 @@
1
1
  from birder_clip.net.clip import CLIP
2
+ from birder_clip.net.coca import CoCa
2
3
 
3
4
  __all__ = [
4
5
  "CLIP",
6
+ "CoCa",
5
7
  ]