birder-clip 0.0.2.dev5__tar.gz → 0.0.2.dev7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder_clip-0.0.2.dev7/PKG-INFO +151 -0
- birder_clip-0.0.2.dev7/README.md +89 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/common/lib.py +1 -1
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/common/training_cli.py +77 -3
- birder_clip-0.0.2.dev7/birder_clip/inference/data_parallel.py +118 -0
- birder_clip-0.0.2.dev7/birder_clip/inference/image_embeddings.py +63 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/inference/zero_shot.py +44 -10
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/loss/__init__.py +2 -0
- birder_clip-0.0.2.dev7/birder_clip/loss/coca.py +66 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/model_registry.py +40 -19
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/net/__init__.py +2 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/net/base.py +14 -2
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/net/clip.py +160 -28
- birder_clip-0.0.2.dev7/birder_clip/net/coca.py +463 -0
- birder_clip-0.0.2.dev7/birder_clip/net/text/__init__.py +9 -0
- birder_clip-0.0.2.dev7/birder_clip/net/text/base.py +142 -0
- birder_clip-0.0.2.dev7/birder_clip/net/text/conditioned_decoder.py +441 -0
- birder_clip-0.0.2.dev5/birder_clip/net/text/transformer.py → birder_clip-0.0.2.dev7/birder_clip/net/text/encoder.py +133 -19
- birder_clip-0.0.2.dev7/birder_clip/net/text/hf.py +154 -0
- birder_clip-0.0.2.dev7/birder_clip/net/text/prefix_decoder.py +1 -0
- birder_clip-0.0.2.dev7/birder_clip/scripts/__main__.py +25 -0
- birder_clip-0.0.2.dev7/birder_clip/scripts/embed_images.py +432 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/scripts/train.py +244 -50
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/scripts/zero_shot.py +47 -9
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/base.py +1 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/hf.py +2 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/openvision.py +3 -1
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/simple_tokenizer.py +1 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/convert_model.py +166 -6
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/list_models.py +7 -4
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/model_info.py +1 -1
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/show_iterator.py +14 -13
- birder_clip-0.0.2.dev7/birder_clip/version.py +1 -0
- birder_clip-0.0.2.dev7/birder_clip.egg-info/PKG-INFO +151 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/SOURCES.txt +12 -1
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/requires.txt +3 -3
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/requirements/_requirements-dev.txt +2 -2
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/requirements/requirements.txt +1 -1
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/tests/test_common.py +3 -3
- birder_clip-0.0.2.dev7/tests/test_inference.py +143 -0
- birder_clip-0.0.2.dev7/tests/test_loss.py +129 -0
- birder_clip-0.0.2.dev7/tests/test_model_registry.py +102 -0
- birder_clip-0.0.2.dev7/tests/test_net.py +572 -0
- birder_clip-0.0.2.dev7/tests/test_net_text.py +96 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/tests/test_tokenizers.py +29 -0
- birder_clip-0.0.2.dev5/PKG-INFO +0 -72
- birder_clip-0.0.2.dev5/README.md +0 -10
- birder_clip-0.0.2.dev5/birder_clip/net/text/__init__.py +0 -5
- birder_clip-0.0.2.dev5/birder_clip/net/text/base.py +0 -46
- birder_clip-0.0.2.dev5/birder_clip/version.py +0 -1
- birder_clip-0.0.2.dev5/birder_clip.egg-info/PKG-INFO +0 -72
- birder_clip-0.0.2.dev5/tests/test_loss.py +0 -47
- birder_clip-0.0.2.dev5/tests/test_model_registry.py +0 -74
- birder_clip-0.0.2.dev5/tests/test_net.py +0 -147
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/LICENSE +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/common/fs_ops.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/common/training_utils.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/conf/settings.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/data/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/csv.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/fake.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/data/datasets/webdataset.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/loss/contrastive.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/model_registry/manifest.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/py.typed +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tokenizers/registry.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/__init__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/__main__.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/download_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip/tools/stats.py +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev5 → birder_clip-0.0.2.dev7}/tests/test_datasets.py +0 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: birder_clip
|
|
3
|
+
Version: 0.0.2.dev7
|
|
4
|
+
Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
|
|
5
|
+
Author: Ofer Hasson
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://gitlab.com/birder/birder-clip
|
|
8
|
+
Project-URL: Issues, https://gitlab.com/birder/birder-clip/-/issues
|
|
9
|
+
Project-URL: Changelog, https://gitlab.com/birder/birder-clip/-/blob/main/CHANGELOG.md
|
|
10
|
+
Keywords: computer-vision,clip,image-text,pytorch,deep-learning,artificial intelligence
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Education
|
|
15
|
+
Classifier: Operating System :: OS Independent
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
20
|
+
Classifier: Topic :: Software Development
|
|
21
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
22
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
23
|
+
Classifier: Typing :: Typed
|
|
24
|
+
Requires-Python: >=3.11
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
License-File: LICENSE
|
|
27
|
+
Requires-Dist: birder>=0.6.0
|
|
28
|
+
Requires-Dist: ftfy>=6.3.1
|
|
29
|
+
Requires-Dist: regex>=2025.7.29
|
|
30
|
+
Requires-Dist: tqdm>=4.67.0
|
|
31
|
+
Requires-Dist: webdataset>=0.2.111
|
|
32
|
+
Requires-Dist: huggingface_hub
|
|
33
|
+
Requires-Dist: transformers
|
|
34
|
+
Requires-Dist: torch>=2.10.0
|
|
35
|
+
Requires-Dist: torchvision
|
|
36
|
+
Provides-Extra: dev
|
|
37
|
+
Requires-Dist: bandit~=1.9.4; extra == "dev"
|
|
38
|
+
Requires-Dist: black~=26.5.0; extra == "dev"
|
|
39
|
+
Requires-Dist: build~=1.5.0; extra == "dev"
|
|
40
|
+
Requires-Dist: bumpver~=2026.1132; extra == "dev"
|
|
41
|
+
Requires-Dist: coverage~=7.14.2; extra == "dev"
|
|
42
|
+
Requires-Dist: debugpy; extra == "dev"
|
|
43
|
+
Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
|
|
44
|
+
Requires-Dist: flake8~=7.3.0; extra == "dev"
|
|
45
|
+
Requires-Dist: invoke~=3.0.3; extra == "dev"
|
|
46
|
+
Requires-Dist: ipython; extra == "dev"
|
|
47
|
+
Requires-Dist: isort~=8.0.1; extra == "dev"
|
|
48
|
+
Requires-Dist: mkdocs~=1.6.1; extra == "dev"
|
|
49
|
+
Requires-Dist: mkdocs-exclude~=1.0.2; extra == "dev"
|
|
50
|
+
Requires-Dist: mypy~=2.1.0; extra == "dev"
|
|
51
|
+
Requires-Dist: parameterized~=0.9.0; extra == "dev"
|
|
52
|
+
Requires-Dist: pylint~=4.0.6; extra == "dev"
|
|
53
|
+
Requires-Dist: pytest; extra == "dev"
|
|
54
|
+
Requires-Dist: requests~=2.34.2; extra == "dev"
|
|
55
|
+
Requires-Dist: safetensors~=0.7.0; extra == "dev"
|
|
56
|
+
Requires-Dist: setuptools; extra == "dev"
|
|
57
|
+
Requires-Dist: twine~=6.2.0; extra == "dev"
|
|
58
|
+
Requires-Dist: types-requests~=2.33.0; extra == "dev"
|
|
59
|
+
Requires-Dist: urllib3~=2.7.0; extra == "dev"
|
|
60
|
+
Requires-Dist: wheel; extra == "dev"
|
|
61
|
+
Dynamic: license-file
|
|
62
|
+
|
|
63
|
+
# Birder CLIP
|
|
64
|
+
|
|
65
|
+
Birder CLIP is an early-stage Birder extension for CLIP-style image-text models, focused on practical inference and fine-tuning workflows.
|
|
66
|
+
|
|
67
|
+
- [Introduction](#introduction)
|
|
68
|
+
- [Setup](#setup)
|
|
69
|
+
- [Getting Started](#getting-started)
|
|
70
|
+
- [Training](#training)
|
|
71
|
+
- [Project Status and Contributions](#project-status-and-contributions)
|
|
72
|
+
- [Licenses](#licenses)
|
|
73
|
+
- [Acknowledgments](#acknowledgments)
|
|
74
|
+
|
|
75
|
+
## Introduction
|
|
76
|
+
|
|
77
|
+
Birder CLIP extends [Birder](https://gitlab.com/birder/birder) with image-text models for zero-shot classification, image-text retrieval style workflows, caption generation and related multimodal computer vision tasks.
|
|
78
|
+
|
|
79
|
+
The project is aimed at image-text modeling rather than general vision-language model (VLM) chat or instruction-following systems.
|
|
80
|
+
It currently includes CLIP-style components, tokenizers, model registry utilities, inference scripts and training code.
|
|
81
|
+
Full training is supported, but for large-scale CLIP pretraining you are probably better served by [OpenCLIP](https://github.com/mlfoundations/open_clip).
|
|
82
|
+
|
|
83
|
+
## Setup
|
|
84
|
+
|
|
85
|
+
1. Ensure your environment meets the minimum requirements:
|
|
86
|
+
- Python 3.11 or newer
|
|
87
|
+
- PyTorch 2.10 or newer (installed for your hardware/driver stack)
|
|
88
|
+
- Birder 0.6.0 or newer
|
|
89
|
+
|
|
90
|
+
1. Install the latest Birder CLIP version:
|
|
91
|
+
|
|
92
|
+
```sh
|
|
93
|
+
pip install birder-clip
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
## Getting Started
|
|
97
|
+
|
|
98
|
+
List available image-text models:
|
|
99
|
+
|
|
100
|
+
```sh
|
|
101
|
+
python -m birder_clip.tools list-models --image-text
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
List available pretrained weights:
|
|
105
|
+
|
|
106
|
+
```sh
|
|
107
|
+
python -m birder_clip.tools list-models --pretrained --verbose
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
Run zero-shot classification on a directory of images:
|
|
111
|
+
|
|
112
|
+
```sh
|
|
113
|
+
python -m birder_clip.scripts.zero_shot -n laion_clip_vit_l14 --classes eagle hawk falcon --template-set default --gpu data/images
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
For detailed options, run:
|
|
117
|
+
|
|
118
|
+
```sh
|
|
119
|
+
python -m birder_clip.scripts.zero_shot --help
|
|
120
|
+
python -m birder_clip.tools --help
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
## Training
|
|
124
|
+
|
|
125
|
+
Birder CLIP includes training support for image-text datasets in CSV and WebDataset formats, including CLIP, CoCa and LiT-style workflows.
|
|
126
|
+
|
|
127
|
+
## Project Status and Contributions
|
|
128
|
+
|
|
129
|
+
Birder CLIP is an early alpha project. APIs, model names, checkpoints, training recipes and command-line options may change without notice.
|
|
130
|
+
|
|
131
|
+
This is currently a personal project in active development. Suggestions, bug reports and feedback are welcome through the project's issue tracker, but the project is not yet stable enough for broad external contributions.
|
|
132
|
+
|
|
133
|
+
## Licenses
|
|
134
|
+
|
|
135
|
+
The code in this project is primarily licensed under Apache 2.0. See [LICENSE](LICENSE) for details.
|
|
136
|
+
|
|
137
|
+
Some model implementations, pretrained weights, tokenizers and converted artifacts may be derived from or depend on projects and datasets with their own licenses and usage restrictions.
|
|
138
|
+
|
|
139
|
+
**You are responsible for ensuring compliance with all licenses and conditions of any dependent licenses.**
|
|
140
|
+
|
|
141
|
+
### Disclaimer
|
|
142
|
+
|
|
143
|
+
If you intend to use Birder CLIP, its pretrained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
|
|
144
|
+
|
|
145
|
+
It's the user's responsibility to ensure that their use of this project, including any pretrained weights or datasets, complies with all applicable licenses and legal requirements.
|
|
146
|
+
|
|
147
|
+
## Acknowledgments
|
|
148
|
+
|
|
149
|
+
Birder CLIP owes much to the work of others in computer vision, image-text representation learning and open-source machine learning.
|
|
150
|
+
|
|
151
|
+
Special thanks to the [OpenCLIP](https://github.com/mlfoundations/open_clip) project, which serves as the main reference implementation and inspiration for much of the CLIP-style modeling and training work here. The same principle as in Birder applies: this project stands on the shoulders of many open-source projects, papers and datasets. If an attribution is missing, please open an issue.
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Birder CLIP
|
|
2
|
+
|
|
3
|
+
Birder CLIP is an early-stage Birder extension for CLIP-style image-text models, focused on practical inference and fine-tuning workflows.
|
|
4
|
+
|
|
5
|
+
- [Introduction](#introduction)
|
|
6
|
+
- [Setup](#setup)
|
|
7
|
+
- [Getting Started](#getting-started)
|
|
8
|
+
- [Training](#training)
|
|
9
|
+
- [Project Status and Contributions](#project-status-and-contributions)
|
|
10
|
+
- [Licenses](#licenses)
|
|
11
|
+
- [Acknowledgments](#acknowledgments)
|
|
12
|
+
|
|
13
|
+
## Introduction
|
|
14
|
+
|
|
15
|
+
Birder CLIP extends [Birder](https://gitlab.com/birder/birder) with image-text models for zero-shot classification, image-text retrieval style workflows, caption generation and related multimodal computer vision tasks.
|
|
16
|
+
|
|
17
|
+
The project is aimed at image-text modeling rather than general vision-language model (VLM) chat or instruction-following systems.
|
|
18
|
+
It currently includes CLIP-style components, tokenizers, model registry utilities, inference scripts and training code.
|
|
19
|
+
Full training is supported, but for large-scale CLIP pretraining you are probably better served by [OpenCLIP](https://github.com/mlfoundations/open_clip).
|
|
20
|
+
|
|
21
|
+
## Setup
|
|
22
|
+
|
|
23
|
+
1. Ensure your environment meets the minimum requirements:
|
|
24
|
+
- Python 3.11 or newer
|
|
25
|
+
- PyTorch 2.10 or newer (installed for your hardware/driver stack)
|
|
26
|
+
- Birder 0.6.0 or newer
|
|
27
|
+
|
|
28
|
+
1. Install the latest Birder CLIP version:
|
|
29
|
+
|
|
30
|
+
```sh
|
|
31
|
+
pip install birder-clip
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
## Getting Started
|
|
35
|
+
|
|
36
|
+
List available image-text models:
|
|
37
|
+
|
|
38
|
+
```sh
|
|
39
|
+
python -m birder_clip.tools list-models --image-text
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
List available pretrained weights:
|
|
43
|
+
|
|
44
|
+
```sh
|
|
45
|
+
python -m birder_clip.tools list-models --pretrained --verbose
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Run zero-shot classification on a directory of images:
|
|
49
|
+
|
|
50
|
+
```sh
|
|
51
|
+
python -m birder_clip.scripts.zero_shot -n laion_clip_vit_l14 --classes eagle hawk falcon --template-set default --gpu data/images
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
For detailed options, run:
|
|
55
|
+
|
|
56
|
+
```sh
|
|
57
|
+
python -m birder_clip.scripts.zero_shot --help
|
|
58
|
+
python -m birder_clip.tools --help
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Training
|
|
62
|
+
|
|
63
|
+
Birder CLIP includes training support for image-text datasets in CSV and WebDataset formats, including CLIP, CoCa and LiT-style workflows.
|
|
64
|
+
|
|
65
|
+
## Project Status and Contributions
|
|
66
|
+
|
|
67
|
+
Birder CLIP is an early alpha project. APIs, model names, checkpoints, training recipes and command-line options may change without notice.
|
|
68
|
+
|
|
69
|
+
This is currently a personal project in active development. Suggestions, bug reports and feedback are welcome through the project's issue tracker, but the project is not yet stable enough for broad external contributions.
|
|
70
|
+
|
|
71
|
+
## Licenses
|
|
72
|
+
|
|
73
|
+
The code in this project is primarily licensed under Apache 2.0. See [LICENSE](LICENSE) for details.
|
|
74
|
+
|
|
75
|
+
Some model implementations, pretrained weights, tokenizers and converted artifacts may be derived from or depend on projects and datasets with their own licenses and usage restrictions.
|
|
76
|
+
|
|
77
|
+
**You are responsible for ensuring compliance with all licenses and conditions of any dependent licenses.**
|
|
78
|
+
|
|
79
|
+
### Disclaimer
|
|
80
|
+
|
|
81
|
+
If you intend to use Birder CLIP, its pretrained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
|
|
82
|
+
|
|
83
|
+
It's the user's responsibility to ensure that their use of this project, including any pretrained weights or datasets, complies with all applicable licenses and legal requirements.
|
|
84
|
+
|
|
85
|
+
## Acknowledgments
|
|
86
|
+
|
|
87
|
+
Birder CLIP owes much to the work of others in computer vision, image-text representation learning and open-source machine learning.
|
|
88
|
+
|
|
89
|
+
Special thanks to the [OpenCLIP](https://github.com/mlfoundations/open_clip) project, which serves as the main reference implementation and inspiration for much of the CLIP-style modeling and training work here. The same principle as in Birder applies: this project stands on the shoulders of many open-source projects, papers and datasets. If an attribution is missing, please open an issue.
|
|
@@ -43,7 +43,7 @@ def get_image_text_network_name(
|
|
|
43
43
|
parts = [network]
|
|
44
44
|
if image_encoder is not None:
|
|
45
45
|
parts.append(image_encoder)
|
|
46
|
-
if text_encoder is not None and text_encoder != "
|
|
46
|
+
if text_encoder is not None and text_encoder != "transformer_encoder":
|
|
47
47
|
parts.append(text_encoder)
|
|
48
48
|
|
|
49
49
|
if registry.exists(network) is True:
|
|
@@ -21,6 +21,11 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
|
21
21
|
parser.add_argument("-n", "--network", type=str, help="the image-text network to train")
|
|
22
22
|
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
23
23
|
parser.add_argument("--image-encoder", type=str, help="the image encoder to use")
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"--image-encoder-pretrained",
|
|
26
|
+
type=str,
|
|
27
|
+
help="pretrained Birder image model weights path to load into the image encoder",
|
|
28
|
+
)
|
|
24
29
|
parser.add_argument("--text-encoder", type=str, help="the text encoder to use")
|
|
25
30
|
parser.add_argument("--embed-dim", type=int, metavar="N", help="shared image-text embedding dimension")
|
|
26
31
|
parser.add_argument("--tokenizer", type=str, help="the tokenizer to use")
|
|
@@ -43,7 +48,23 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
|
43
48
|
|
|
44
49
|
def add_loss_args(parser: argparse.ArgumentParser) -> None:
|
|
45
50
|
group = parser.add_argument_group("Loss parameters")
|
|
46
|
-
group.add_argument("--loss", type=str, choices=["clip"], default="clip", help="loss function to use")
|
|
51
|
+
group.add_argument("--loss", type=str, choices=["clip", "coca"], default="clip", help="loss function to use")
|
|
52
|
+
group.add_argument(
|
|
53
|
+
"--coca-caption-loss-weight", type=float, default=1.0, help="weight assigned to CoCa caption loss"
|
|
54
|
+
)
|
|
55
|
+
group.add_argument(
|
|
56
|
+
"--coca-contrastive-loss-weight", type=float, default=1.0, help="weight assigned to CoCa contrastive loss"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def add_freeze_args(parser: argparse.ArgumentParser) -> None:
|
|
61
|
+
group = parser.add_argument_group("Freeze parameters")
|
|
62
|
+
group.add_argument(
|
|
63
|
+
"--freeze-image-encoder",
|
|
64
|
+
default=False,
|
|
65
|
+
action="store_true",
|
|
66
|
+
help="freeze image encoder body, leaving the projection head trainable",
|
|
67
|
+
)
|
|
47
68
|
|
|
48
69
|
|
|
49
70
|
def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: int = 32) -> None:
|
|
@@ -66,7 +87,12 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
|
|
|
66
87
|
metavar="N",
|
|
67
88
|
help="number of iterations to accumulate gradients per optimizer step",
|
|
68
89
|
)
|
|
69
|
-
|
|
90
|
+
group.add_argument(
|
|
91
|
+
"--grad-accum-cache-negatives",
|
|
92
|
+
default=False,
|
|
93
|
+
action="store_true",
|
|
94
|
+
help="cache features so CLIP loss uses all accumulated microbatches as negatives",
|
|
95
|
+
)
|
|
70
96
|
|
|
71
97
|
|
|
72
98
|
def add_lr_wd_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -250,6 +276,8 @@ def add_data_aug_args(
|
|
|
250
276
|
)
|
|
251
277
|
group.add_argument("--ra-magnitude", type=int, default=9, help="magnitude for all the RandAugment transformations")
|
|
252
278
|
group.add_argument("--augmix-severity", type=int, default=3, help="severity of AugMix policy")
|
|
279
|
+
group.add_argument("--clip-color-jitter-prob", type=float, default=0.0, help="CLIP color jitter probability")
|
|
280
|
+
group.add_argument("--clip-gray-prob", type=float, default=0.0, help="CLIP grayscale probability")
|
|
253
281
|
group.add_argument("--resize-min-scale", type=float, default=default_min_scale, help="random resize min scale")
|
|
254
282
|
group.add_argument(
|
|
255
283
|
"--re-prob",
|
|
@@ -356,6 +384,29 @@ def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
|
356
384
|
)
|
|
357
385
|
|
|
358
386
|
|
|
387
|
+
def add_grad_checkpointing_args(parser: argparse.ArgumentParser) -> None:
|
|
388
|
+
group = parser.add_argument_group("Gradient checkpointing parameters")
|
|
389
|
+
group.add_argument(
|
|
390
|
+
"--grad-checkpointing",
|
|
391
|
+
default=False,
|
|
392
|
+
action="store_true",
|
|
393
|
+
help="enable gradient checkpointing for supported models",
|
|
394
|
+
)
|
|
395
|
+
group.add_argument(
|
|
396
|
+
"--grad-checkpointing-segments",
|
|
397
|
+
type=int,
|
|
398
|
+
metavar="N",
|
|
399
|
+
help="number of checkpoint segments to request from supported models",
|
|
400
|
+
)
|
|
401
|
+
group.add_argument(
|
|
402
|
+
"--no-grad-checkpointing-preserve-rng-state",
|
|
403
|
+
dest="grad_checkpointing_preserve_rng_state",
|
|
404
|
+
default=True,
|
|
405
|
+
action="store_false",
|
|
406
|
+
help="disable RNG state preservation during gradient checkpointing recomputation",
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
|
|
359
410
|
def add_compile_args(parser: argparse.ArgumentParser) -> None:
|
|
360
411
|
group = parser.add_argument_group("Compilation parameters")
|
|
361
412
|
group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
@@ -578,8 +629,12 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
578
629
|
raise cli.ValidationError("--load-states requires --resume-epoch to be set")
|
|
579
630
|
if args.load_scheduler is True and args.resume_epoch is None:
|
|
580
631
|
raise cli.ValidationError("--load-scheduler requires --resume-epoch to be set")
|
|
581
|
-
if
|
|
632
|
+
if args.pretrained is True and args.resume_epoch is not None:
|
|
582
633
|
raise cli.ValidationError("--pretrained cannot be used with --resume-epoch")
|
|
634
|
+
if args.image_encoder_pretrained is not None and args.resume_epoch is not None:
|
|
635
|
+
raise cli.ValidationError("--image-encoder-pretrained cannot be used with --resume-epoch")
|
|
636
|
+
if args.pretrained is True and args.image_encoder_pretrained is not None:
|
|
637
|
+
raise cli.ValidationError("--image-encoder-pretrained cannot be used with --pretrained")
|
|
583
638
|
|
|
584
639
|
if args.freeze_bn is True and args.sync_bn is True:
|
|
585
640
|
raise cli.ValidationError("--freeze-bn cannot be used with --sync-bn")
|
|
@@ -605,10 +660,29 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
605
660
|
raise cli.ValidationError("--context-length must be positive")
|
|
606
661
|
if args.grad_accum_steps < 1:
|
|
607
662
|
raise cli.ValidationError("--grad-accum-steps must be >= 1")
|
|
663
|
+
if args.grad_accum_cache_negatives is True and args.grad_accum_steps == 1:
|
|
664
|
+
raise cli.ValidationError("--grad-accum-cache-negatives requires --grad-accum-steps greater than 1")
|
|
665
|
+
if args.grad_accum_cache_negatives is True and args.loss == "coca":
|
|
666
|
+
raise cli.ValidationError("--grad-accum-cache-negatives is only supported with --loss clip")
|
|
667
|
+
|
|
668
|
+
if args.coca_caption_loss_weight < 0.0:
|
|
669
|
+
raise cli.ValidationError("--coca-caption-loss-weight must be non-negative")
|
|
670
|
+
if args.coca_contrastive_loss_weight < 0.0:
|
|
671
|
+
raise cli.ValidationError("--coca-contrastive-loss-weight must be non-negative")
|
|
672
|
+
|
|
673
|
+
# EMA
|
|
608
674
|
if args.model_ema_steps < 1:
|
|
609
675
|
raise cli.ValidationError("--model-ema-steps must be >= 1")
|
|
610
676
|
|
|
677
|
+
# Gradient checkpointing args
|
|
678
|
+
if args.grad_checkpointing_segments is not None and args.grad_checkpointing_segments < 1:
|
|
679
|
+
raise cli.ValidationError("--grad-checkpointing-segments must be >= 1")
|
|
680
|
+
if args.grad_checkpointing_segments is not None and args.grad_checkpointing is False:
|
|
681
|
+
raise cli.ValidationError("--grad-checkpointing-segments requires --grad-checkpointing")
|
|
682
|
+
|
|
611
683
|
if args.distributed_mode == "fsdp":
|
|
684
|
+
if args.grad_checkpointing is True:
|
|
685
|
+
raise cli.ValidationError("--grad-checkpointing cannot be used with --distributed-mode fsdp")
|
|
612
686
|
if args.sync_bn is True:
|
|
613
687
|
raise cli.ValidationError("--sync-bn cannot be used with --distributed-mode fsdp")
|
|
614
688
|
if args.find_unused_parameters is True:
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Inference-optimized multi-GPU parallelization for image-text models
|
|
3
|
+
|
|
4
|
+
This module provides ZeroShotInferenceDataParallel, a CLIP-style zero-shot
|
|
5
|
+
specialization of Birder's InferenceDataParallel.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from birder.inference.data_parallel import InferenceDataParallel
|
|
12
|
+
|
|
13
|
+
from birder_clip.inference.zero_shot import ZeroShotInference
|
|
14
|
+
from birder_clip.net.base import BaseNet
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ZeroShotInferenceDataParallel(InferenceDataParallel):
|
|
18
|
+
"""
|
|
19
|
+
Distributes zero-shot image inference batches across multiple GPUs
|
|
20
|
+
|
|
21
|
+
This wrapper scatters the image batch across devices and keeps a replicated
|
|
22
|
+
copy of the zero-shot text embeddings on each device. Each replica computes
|
|
23
|
+
image embeddings and zero-shot logits locally before outputs are gathered.
|
|
24
|
+
|
|
25
|
+
Important
|
|
26
|
+
---------
|
|
27
|
+
This class assumes the model is already configured for inference mode
|
|
28
|
+
(i.e., loaded with inference=True in load_model or manually set to eval mode
|
|
29
|
+
with requires_grad=False on all parameters).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
module: BaseNet,
|
|
35
|
+
text_embeddings: torch.Tensor,
|
|
36
|
+
device_ids: Optional[list[int]] = None,
|
|
37
|
+
output_device: Optional[int | str | torch.device] = None,
|
|
38
|
+
compile_replicas: bool = False,
|
|
39
|
+
compile_methods: Optional[list[str]] = None,
|
|
40
|
+
compile_mode: Optional[str] = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
if compile_methods is None:
|
|
43
|
+
compile_methods = ["encode_image", "forward_logits"]
|
|
44
|
+
|
|
45
|
+
super().__init__(
|
|
46
|
+
module,
|
|
47
|
+
device_ids=device_ids,
|
|
48
|
+
output_device=output_device,
|
|
49
|
+
compile_replicas=compile_replicas,
|
|
50
|
+
compile_methods=compile_methods,
|
|
51
|
+
compile_mode=compile_mode,
|
|
52
|
+
)
|
|
53
|
+
self.set_text_embeddings(text_embeddings)
|
|
54
|
+
|
|
55
|
+
def set_text_embeddings(self, text_embeddings: torch.Tensor) -> None:
|
|
56
|
+
if text_embeddings.ndim != 2:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"text_embeddings must be a 2D tensor of shape (num_classes, embedding_size), "
|
|
59
|
+
f"got shape {text_embeddings.size()}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.text_embeddings = [
|
|
63
|
+
text_embeddings.to(f"cuda:{device_id}", non_blocking=True) for device_id in self.device_ids
|
|
64
|
+
]
|
|
65
|
+
self.inference_modules = [
|
|
66
|
+
ZeroShotInference(replica, embeddings) for replica, embeddings in zip(self.replicas, self.text_embeddings)
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
def forward( # type: ignore[override] # pylint: disable=arguments-differ
|
|
70
|
+
self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Run zero-shot inference distributed across GPUs
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
inputs
|
|
78
|
+
Input image batch to process.
|
|
79
|
+
tta
|
|
80
|
+
Run inference with oversampling.
|
|
81
|
+
return_logits
|
|
82
|
+
If True, return raw logits instead of probabilities after softmax.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
if len(self.device_ids) == 1:
|
|
86
|
+
output = self.inference_modules[0](
|
|
87
|
+
inputs,
|
|
88
|
+
tta=tta,
|
|
89
|
+
return_logits=return_logits,
|
|
90
|
+
)
|
|
91
|
+
return self._gather([output])
|
|
92
|
+
|
|
93
|
+
scattered = self._scatter(inputs, {})
|
|
94
|
+
|
|
95
|
+
outputs = []
|
|
96
|
+
for inference, (input_chunk, _), device_id in zip(self.inference_modules, scattered, self.device_ids):
|
|
97
|
+
if input_chunk is not None and input_chunk.size(0) > 0:
|
|
98
|
+
with torch.cuda.device(device_id):
|
|
99
|
+
output = inference(
|
|
100
|
+
input_chunk,
|
|
101
|
+
tta=tta,
|
|
102
|
+
return_logits=return_logits,
|
|
103
|
+
)
|
|
104
|
+
outputs.append(output)
|
|
105
|
+
else:
|
|
106
|
+
outputs.append(None)
|
|
107
|
+
|
|
108
|
+
return self._gather(outputs)
|
|
109
|
+
|
|
110
|
+
def __repr__(self) -> str:
|
|
111
|
+
return (
|
|
112
|
+
f"ZeroShotInferenceDataParallel(\n"
|
|
113
|
+
f" devices={self.device_ids},\n"
|
|
114
|
+
f" output_device={self.output_device},\n"
|
|
115
|
+
f" src_device={self.src_device},\n"
|
|
116
|
+
f" text_embeddings_shape={tuple(self.text_embeddings[0].shape)}\n"
|
|
117
|
+
f")"
|
|
118
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import numpy.typing as npt
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from birder_clip.net.base import BaseNet
|
|
12
|
+
|
|
13
|
+
DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def infer_dataloader_iter(
|
|
17
|
+
device: torch.device,
|
|
18
|
+
net: BaseNet,
|
|
19
|
+
dataloader: DataLoader,
|
|
20
|
+
model_dtype: torch.dtype = torch.float32,
|
|
21
|
+
amp: bool = False,
|
|
22
|
+
amp_dtype: Optional[torch.dtype] = None,
|
|
23
|
+
num_samples: Optional[int] = None,
|
|
24
|
+
chunk_size: Optional[float] = None,
|
|
25
|
+
) -> Iterator[DataloaderInferenceResult]:
|
|
26
|
+
if chunk_size is None:
|
|
27
|
+
chunk_size = float("inf")
|
|
28
|
+
|
|
29
|
+
net.to(device, dtype=model_dtype)
|
|
30
|
+
embeddings_list: list[npt.NDArray[np.float32]] = []
|
|
31
|
+
sample_paths: list[str] = []
|
|
32
|
+
sample_count = 0
|
|
33
|
+
with tqdm(total=num_samples, initial=0, unit="images", unit_scale=True, leave=False) as progress:
|
|
34
|
+
for file_paths, inputs, _targets in dataloader:
|
|
35
|
+
batch_size = inputs.size(0)
|
|
36
|
+
|
|
37
|
+
# Inference
|
|
38
|
+
inputs = inputs.to(device, dtype=model_dtype)
|
|
39
|
+
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
40
|
+
embeddings = net.encode_image(inputs, normalize=True)
|
|
41
|
+
embeddings = embeddings.cpu().float().numpy()
|
|
42
|
+
|
|
43
|
+
embeddings_list.append(embeddings)
|
|
44
|
+
|
|
45
|
+
# Set sample list
|
|
46
|
+
sample_paths.extend(file_paths)
|
|
47
|
+
|
|
48
|
+
# Update progress bar
|
|
49
|
+
progress.update(n=batch_size)
|
|
50
|
+
|
|
51
|
+
# Yield results when we reach chunk_size
|
|
52
|
+
sample_count += batch_size
|
|
53
|
+
if sample_count >= chunk_size:
|
|
54
|
+
with tqdm.external_write_mode(file=sys.stderr):
|
|
55
|
+
yield (sample_paths, np.concatenate(embeddings_list, axis=0))
|
|
56
|
+
|
|
57
|
+
# Reset for next chunk
|
|
58
|
+
embeddings_list = []
|
|
59
|
+
sample_paths = []
|
|
60
|
+
sample_count = 0
|
|
61
|
+
|
|
62
|
+
if len(embeddings_list) > 0:
|
|
63
|
+
yield (sample_paths, np.concatenate(embeddings_list, axis=0))
|
|
@@ -13,18 +13,57 @@ from collections.abc import Callable
|
|
|
13
13
|
from collections.abc import Iterator
|
|
14
14
|
from collections.abc import Sequence
|
|
15
15
|
from typing import Optional
|
|
16
|
+
from typing import Protocol
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import numpy.typing as npt
|
|
19
20
|
import torch
|
|
20
21
|
import torch.nn.functional as F
|
|
21
22
|
from torch.utils.data import DataLoader
|
|
23
|
+
from torchvision.transforms import v2
|
|
24
|
+
from torchvision.transforms.v2.functional import five_crop
|
|
22
25
|
from tqdm import tqdm
|
|
23
26
|
|
|
24
27
|
from birder_clip.net.base import BaseNet
|
|
25
28
|
from birder_clip.tokenizers.base import Tokenizer
|
|
26
29
|
|
|
27
30
|
|
|
31
|
+
class ZeroShotInferenceModule(Protocol):
|
|
32
|
+
def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ZeroShotInference:
|
|
36
|
+
def __init__(self, net: BaseNet, text_embeddings: torch.Tensor) -> None:
|
|
37
|
+
self.net = net
|
|
38
|
+
self.text_embeddings = text_embeddings
|
|
39
|
+
|
|
40
|
+
def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor:
|
|
41
|
+
inputs = inputs.to(self.text_embeddings.device, non_blocking=True)
|
|
42
|
+
if tta is True:
|
|
43
|
+
_, _, H, W = inputs.size()
|
|
44
|
+
crop_h = int(H * 0.8)
|
|
45
|
+
crop_w = int(W * 0.8)
|
|
46
|
+
tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
|
|
47
|
+
t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
|
|
48
|
+
outs = []
|
|
49
|
+
for tta_input in tta_inputs:
|
|
50
|
+
image_embeddings = self.net.encode_image(t(tta_input), normalize=True)
|
|
51
|
+
logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
|
|
52
|
+
if return_logits is True:
|
|
53
|
+
outs.append(logits)
|
|
54
|
+
else:
|
|
55
|
+
outs.append(F.softmax(logits, dim=-1))
|
|
56
|
+
|
|
57
|
+
return torch.stack(outs).mean(dim=0)
|
|
58
|
+
|
|
59
|
+
image_embeddings = self.net.encode_image(inputs, normalize=True)
|
|
60
|
+
logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
|
|
61
|
+
if return_logits is True:
|
|
62
|
+
return logits
|
|
63
|
+
|
|
64
|
+
return F.softmax(logits, dim=-1)
|
|
65
|
+
|
|
66
|
+
|
|
28
67
|
def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
|
|
29
68
|
return [template.format(class_name) for class_name in class_names for template in templates]
|
|
30
69
|
|
|
@@ -66,9 +105,9 @@ DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArra
|
|
|
66
105
|
|
|
67
106
|
def infer_dataloader_iter(
|
|
68
107
|
device: torch.device,
|
|
69
|
-
net:
|
|
108
|
+
net: ZeroShotInferenceModule,
|
|
70
109
|
dataloader: DataLoader,
|
|
71
|
-
|
|
110
|
+
tta: bool = False,
|
|
72
111
|
return_logits: bool = False,
|
|
73
112
|
model_dtype: torch.dtype = torch.float32,
|
|
74
113
|
amp: bool = False,
|
|
@@ -80,7 +119,6 @@ def infer_dataloader_iter(
|
|
|
80
119
|
if chunk_size is None:
|
|
81
120
|
chunk_size = float("inf")
|
|
82
121
|
|
|
83
|
-
net.to(device, dtype=model_dtype)
|
|
84
122
|
out_list: list[npt.NDArray[np.float32]] = []
|
|
85
123
|
labels_list: list[npt.NDArray[np.int64]] = []
|
|
86
124
|
sample_paths: list[str] = []
|
|
@@ -90,14 +128,10 @@ def infer_dataloader_iter(
|
|
|
90
128
|
batch_size = inputs.size(0)
|
|
91
129
|
|
|
92
130
|
# Inference
|
|
93
|
-
inputs = inputs.to(
|
|
131
|
+
inputs = inputs.to(dtype=model_dtype)
|
|
94
132
|
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if return_logits is True:
|
|
98
|
-
out = logits.cpu().float().numpy()
|
|
99
|
-
else:
|
|
100
|
-
out = F.softmax(logits, dim=-1).cpu().float().numpy()
|
|
133
|
+
out = net(inputs, return_logits=return_logits, tta=tta)
|
|
134
|
+
out = out.cpu().float().numpy()
|
|
101
135
|
|
|
102
136
|
out_list.append(out)
|
|
103
137
|
|