birder-clip 0.0.2.dev1__tar.gz → 0.0.2.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/PKG-INFO +1 -1
- birder_clip-0.0.2.dev2/birder_clip/common/training_cli.py +450 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/conf/settings.py +5 -0
- birder_clip-0.0.2.dev2/birder_clip/data/datasets/csv.py +90 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/inference/zero_shot.py +6 -2
- birder_clip-0.0.2.dev2/birder_clip/net/clip.py +238 -0
- birder_clip-0.0.2.dev2/birder_clip/py.typed +0 -0
- birder_clip-0.0.2.dev2/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/scripts/zero_shot.py +233 -11
- birder_clip-0.0.2.dev2/birder_clip/tokenizers/hf.py +43 -0
- birder_clip-0.0.2.dev2/birder_clip/tokenizers/registry.py +53 -0
- birder_clip-0.0.2.dev2/birder_clip/tools/__init__.py +0 -0
- birder_clip-0.0.2.dev2/birder_clip/tools/__main__.py +30 -0
- birder_clip-0.0.2.dev2/birder_clip/tools/download_tokenizer.py +82 -0
- birder_clip-0.0.2.dev2/birder_clip/tools/show_iterator.py +172 -0
- birder_clip-0.0.2.dev2/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip.egg-info/PKG-INFO +1 -1
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip.egg-info/SOURCES.txt +9 -0
- birder_clip-0.0.2.dev1/birder_clip/net/clip.py +0 -91
- birder_clip-0.0.2.dev1/birder_clip/tokenizers/registry.py +0 -29
- birder_clip-0.0.2.dev1/birder_clip/version.py +0 -1
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/LICENSE +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/README.md +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/common/fs_ops.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/common/lib.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev1/birder_clip/inference → birder_clip-0.0.2.dev2/birder_clip/data}/__init__.py +0 -0
- {birder_clip-0.0.2.dev1/birder_clip/scripts → birder_clip-0.0.2.dev2/birder_clip/data/datasets}/__init__.py +0 -0
- /birder_clip-0.0.2.dev1/birder_clip/py.typed → /birder_clip-0.0.2.dev2/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/model_registry/model_registry.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/net/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/net/base.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/net/text/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/net/text/base.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/net/text/transformer.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/tokenizers/__init__.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/tokenizers/base.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip/tokenizers/openai_clip_bpe.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip.egg-info/requires.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/requirements/_requirements-dev.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/requirements/requirements.txt +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/tests/test_common.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/tests/test_model_registry.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/tests/test_net.py +0 -0
- {birder_clip-0.0.2.dev1 → birder_clip-0.0.2.dev2}/tests/test_tokenizers.py +0 -0
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
from typing import get_args
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from birder.common import cli
|
|
7
|
+
from birder.common.training_utils import OptimizerType
|
|
8
|
+
from birder.common.training_utils import SchedulerType
|
|
9
|
+
from birder.conf import settings
|
|
10
|
+
|
|
11
|
+
from birder_clip.model_registry import Task
|
|
12
|
+
from birder_clip.model_registry import registry
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
16
|
+
parser.add_argument("-n", "--network", type=str, help="the image-text network to train")
|
|
17
|
+
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
18
|
+
parser.add_argument("--image-encoder", type=str, help="the image encoder to use")
|
|
19
|
+
parser.add_argument("--text-encoder", type=str, help="the text encoder to use")
|
|
20
|
+
parser.add_argument("--embed-dim", type=int, metavar="N", help="shared image-text embedding dimension")
|
|
21
|
+
parser.add_argument("--tokenizer", type=str, help="the tokenizer to use")
|
|
22
|
+
parser.add_argument(
|
|
23
|
+
"--model-config",
|
|
24
|
+
action=cli.FlexibleDictAction,
|
|
25
|
+
help="override the model default configuration, accepts key-value pairs or JSON",
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"--image-encoder-config",
|
|
29
|
+
action=cli.FlexibleDictAction,
|
|
30
|
+
help="override the image encoder configuration, accepts key-value pairs or JSON",
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--text-encoder-config",
|
|
34
|
+
action=cli.FlexibleDictAction,
|
|
35
|
+
help="override the text encoder configuration, accepts key-value pairs or JSON",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def add_loss_args(parser: argparse.ArgumentParser) -> None:
|
|
40
|
+
group = parser.add_argument_group("Loss parameters")
|
|
41
|
+
group.add_argument(
|
|
42
|
+
"--local-loss",
|
|
43
|
+
default=False,
|
|
44
|
+
action="store_true",
|
|
45
|
+
help="calculate loss with local features against gathered global features",
|
|
46
|
+
)
|
|
47
|
+
group.add_argument(
|
|
48
|
+
"--gather-with-grad",
|
|
49
|
+
default=False,
|
|
50
|
+
action="store_true",
|
|
51
|
+
help="enable gradient-preserving distributed feature gather",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: int = 64) -> None:
|
|
56
|
+
group = parser.add_argument_group("Optimization parameters")
|
|
57
|
+
group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
|
|
58
|
+
group.add_argument(
|
|
59
|
+
"--opt", type=str, choices=list(get_args(OptimizerType)), default="adamw", help="optimizer to use"
|
|
60
|
+
)
|
|
61
|
+
group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
|
|
62
|
+
group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
|
|
63
|
+
group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
|
|
64
|
+
group.add_argument("--opt-eps", type=float, help="optimizer epsilon (None to use the optimizer default)")
|
|
65
|
+
group.add_argument("--opt-betas", type=float, nargs="+", help="optimizer betas (None to use the optimizer default)")
|
|
66
|
+
group.add_argument("--opt-alpha", type=float, help="optimizer alpha (None to use the optimizer default)")
|
|
67
|
+
group.add_argument("--clip-grad-norm", type=float, help="the maximum gradient norm")
|
|
68
|
+
group.add_argument(
|
|
69
|
+
"--grad-accum-steps",
|
|
70
|
+
type=int,
|
|
71
|
+
default=1,
|
|
72
|
+
metavar="N",
|
|
73
|
+
help="number of iterations to accumulate gradients per optimizer step",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def add_lr_wd_args(parser: argparse.ArgumentParser) -> None:
|
|
78
|
+
group = parser.add_argument_group("Learning rate and regularization parameters")
|
|
79
|
+
group.add_argument("--lr", type=float, default=5.0e-4, metavar="LR", help="base learning rate")
|
|
80
|
+
group.add_argument("--bias-lr", type=float, metavar="LR", help="learning rate of biases")
|
|
81
|
+
group.add_argument(
|
|
82
|
+
"--lr-scale", type=int, help="reference batch size for LR scaling, if provided, LR will be scaled accordingly"
|
|
83
|
+
)
|
|
84
|
+
group.add_argument(
|
|
85
|
+
"--lr-scale-type", type=str, choices=["linear", "sqrt"], default="linear", help="learning rate scaling type"
|
|
86
|
+
)
|
|
87
|
+
group.add_argument("--wd", type=float, default=0.2, metavar="WD", help="weight decay")
|
|
88
|
+
group.add_argument("--norm-wd", type=float, metavar="WD", help="weight decay for Normalization layers")
|
|
89
|
+
group.add_argument(
|
|
90
|
+
"--bias-weight-decay", type=float, metavar="WD", help="weight decay for bias parameters of all layers"
|
|
91
|
+
)
|
|
92
|
+
group.add_argument(
|
|
93
|
+
"--transformer-embedding-decay",
|
|
94
|
+
type=float,
|
|
95
|
+
metavar="WD",
|
|
96
|
+
help="weight decay for embedding parameters for vision transformer models",
|
|
97
|
+
)
|
|
98
|
+
group.add_argument(
|
|
99
|
+
"--custom-layer-wd",
|
|
100
|
+
action=cli.FlexibleDictAction,
|
|
101
|
+
metavar="LAYER=WD",
|
|
102
|
+
help="custom weight decay for specific layers by name (e.g., logit_scale=0.0)",
|
|
103
|
+
)
|
|
104
|
+
group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
|
|
105
|
+
group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
|
|
106
|
+
group.add_argument(
|
|
107
|
+
"--layer-decay-no-opt-scale",
|
|
108
|
+
type=float,
|
|
109
|
+
help="layer scale threshold below which parameters are frozen",
|
|
110
|
+
)
|
|
111
|
+
group.add_argument(
|
|
112
|
+
"--custom-layer-lr-scale",
|
|
113
|
+
action=cli.FlexibleDictAction,
|
|
114
|
+
metavar="LAYER=SCALE",
|
|
115
|
+
help="custom lr_scale for specific layers by name (e.g., image_encoder=0.5,text_encoder=1.0)",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
120
|
+
group = parser.add_argument_group("Learning rate scheduler parameters")
|
|
121
|
+
group.add_argument(
|
|
122
|
+
"--lr-scheduler-update",
|
|
123
|
+
type=str,
|
|
124
|
+
choices=["epoch", "step"],
|
|
125
|
+
default="step",
|
|
126
|
+
help="when to apply learning rate scheduler update: epoch (once per epoch), step (each optimizer step)",
|
|
127
|
+
)
|
|
128
|
+
group.add_argument(
|
|
129
|
+
"--lr-scheduler",
|
|
130
|
+
type=str,
|
|
131
|
+
choices=list(get_args(SchedulerType)),
|
|
132
|
+
default="cosine",
|
|
133
|
+
help="learning rate scheduler",
|
|
134
|
+
)
|
|
135
|
+
group.add_argument(
|
|
136
|
+
"--lr-step-size",
|
|
137
|
+
type=int,
|
|
138
|
+
default=40,
|
|
139
|
+
metavar="N",
|
|
140
|
+
help="decrease lr every N epochs/steps (relative to after warmup, step scheduler only)",
|
|
141
|
+
)
|
|
142
|
+
group.add_argument(
|
|
143
|
+
"--lr-steps",
|
|
144
|
+
type=int,
|
|
145
|
+
nargs="+",
|
|
146
|
+
help="absolute epoch/step milestones when to decrease lr (multistep scheduler only)",
|
|
147
|
+
)
|
|
148
|
+
group.add_argument(
|
|
149
|
+
"--lr-step-gamma",
|
|
150
|
+
type=float,
|
|
151
|
+
default=0.75,
|
|
152
|
+
help="multiplicative factor of learning rate decay (for step scheduler only)",
|
|
153
|
+
)
|
|
154
|
+
group.add_argument(
|
|
155
|
+
"--lr-cosine-min",
|
|
156
|
+
type=float,
|
|
157
|
+
default=0.0,
|
|
158
|
+
help="minimum learning rate (for cosine annealing scheduler only)",
|
|
159
|
+
)
|
|
160
|
+
group.add_argument(
|
|
161
|
+
"--lr-power", type=float, default=1.0, help="power of the polynomial (for polynomial scheduler only)"
|
|
162
|
+
)
|
|
163
|
+
group.add_argument(
|
|
164
|
+
"--lr-warmup-decay",
|
|
165
|
+
type=float,
|
|
166
|
+
default=0.01,
|
|
167
|
+
help="multiplicative factor for learning rate at the start of warmup",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 32) -> None:
|
|
172
|
+
group = parser.add_argument_group("Training schedule parameters")
|
|
173
|
+
group.add_argument("--epochs", type=int, default=default_epochs, metavar="N", help="number of training epochs")
|
|
174
|
+
group.add_argument(
|
|
175
|
+
"--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
|
|
176
|
+
)
|
|
177
|
+
group.add_argument(
|
|
178
|
+
"--steps-per-epoch",
|
|
179
|
+
type=int,
|
|
180
|
+
metavar="N",
|
|
181
|
+
help="virtual epoch length in steps, leave unset to use the full dataset",
|
|
182
|
+
)
|
|
183
|
+
group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
|
|
184
|
+
group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
|
|
185
|
+
group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
|
|
186
|
+
group.add_argument(
|
|
187
|
+
"--cooldown-steps", type=int, metavar="N", help="number of cooldown optimizer steps (linear to zero)"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def add_input_args(parser: argparse.ArgumentParser) -> None:
|
|
192
|
+
group = parser.add_argument_group("Input parameters")
|
|
193
|
+
group.add_argument(
|
|
194
|
+
"--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
|
|
195
|
+
)
|
|
196
|
+
group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
|
|
197
|
+
group.add_argument("--force-context-length", type=int, metavar="N", help="override text context length")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
|
|
201
|
+
group = parser.add_argument_group("Dataloader parameters")
|
|
202
|
+
default_num_workers = min(12, max(os.cpu_count() // 4, 4)) # type: ignore[operator]
|
|
203
|
+
group.add_argument(
|
|
204
|
+
"-j",
|
|
205
|
+
"--num-workers",
|
|
206
|
+
type=int,
|
|
207
|
+
default=default_num_workers,
|
|
208
|
+
metavar="N",
|
|
209
|
+
help="number of preprocessing workers",
|
|
210
|
+
)
|
|
211
|
+
group.add_argument(
|
|
212
|
+
"--prefetch-factor", type=int, metavar="N", help="number of batches loaded in advance by each worker"
|
|
213
|
+
)
|
|
214
|
+
group.add_argument(
|
|
215
|
+
"--no-pin-memory",
|
|
216
|
+
dest="pin_memory",
|
|
217
|
+
default=True,
|
|
218
|
+
action="store_false",
|
|
219
|
+
help="disable memory pinning in dataloaders",
|
|
220
|
+
)
|
|
221
|
+
group.add_argument(
|
|
222
|
+
"--persistent-workers",
|
|
223
|
+
default=False,
|
|
224
|
+
action="store_true",
|
|
225
|
+
help="keep dataloader worker processes alive between epochs",
|
|
226
|
+
)
|
|
227
|
+
group.add_argument("--drop-last", default=False, action="store_true", help="drop the last incomplete batch")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
231
|
+
group = parser.add_argument_group("Precision parameters")
|
|
232
|
+
group.add_argument(
|
|
233
|
+
"--model-dtype",
|
|
234
|
+
type=str,
|
|
235
|
+
choices=["float32", "float16", "bfloat16"],
|
|
236
|
+
default="float32",
|
|
237
|
+
help="model dtype to use",
|
|
238
|
+
)
|
|
239
|
+
group.add_argument(
|
|
240
|
+
"--amp",
|
|
241
|
+
default=False,
|
|
242
|
+
action="store_true",
|
|
243
|
+
help="enable automatic mixed precision (AMP) training via torch.amp",
|
|
244
|
+
)
|
|
245
|
+
group.add_argument(
|
|
246
|
+
"--amp-dtype",
|
|
247
|
+
type=str,
|
|
248
|
+
choices=["float16", "bfloat16"],
|
|
249
|
+
default="float16",
|
|
250
|
+
help="whether to use float16 or bfloat16 for mixed precision",
|
|
251
|
+
)
|
|
252
|
+
group.add_argument(
|
|
253
|
+
"--fast-matmul", default=False, action="store_true", help="use fast matrix multiplication (affects precision)"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def add_compile_args(parser: argparse.ArgumentParser) -> None:
|
|
258
|
+
group = parser.add_argument_group("Compilation parameters")
|
|
259
|
+
group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
260
|
+
group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
|
|
261
|
+
group.add_argument(
|
|
262
|
+
"--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
|
|
263
|
+
)
|
|
264
|
+
group.add_argument(
|
|
265
|
+
"--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
|
|
266
|
+
)
|
|
267
|
+
group.add_argument(
|
|
268
|
+
"--compile-recompile-limit",
|
|
269
|
+
type=int,
|
|
270
|
+
default=torch.compiler.config.recompile_limit,
|
|
271
|
+
metavar="N",
|
|
272
|
+
help="maximum recompilations per compiled function before eager fallback",
|
|
273
|
+
)
|
|
274
|
+
group.add_argument(
|
|
275
|
+
"--compile-accumulated-recompile-limit",
|
|
276
|
+
type=int,
|
|
277
|
+
default=torch.compiler.config.accumulated_recompile_limit,
|
|
278
|
+
metavar="N",
|
|
279
|
+
help="maximum total recompilations across compiled functions",
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
|
|
284
|
+
group = parser.add_argument_group("Checkpoint parameters")
|
|
285
|
+
group.add_argument(
|
|
286
|
+
"--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
|
|
287
|
+
)
|
|
288
|
+
group.add_argument(
|
|
289
|
+
"--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
|
|
290
|
+
)
|
|
291
|
+
group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
|
|
292
|
+
group.add_argument(
|
|
293
|
+
"--non-strict-weights",
|
|
294
|
+
default=False,
|
|
295
|
+
action="store_true",
|
|
296
|
+
help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
|
|
297
|
+
)
|
|
298
|
+
group.add_argument(
|
|
299
|
+
"--load-states",
|
|
300
|
+
default=False,
|
|
301
|
+
action="store_true",
|
|
302
|
+
help="load optimizer, scheduler and scaler states when resuming",
|
|
303
|
+
)
|
|
304
|
+
group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
|
308
|
+
group = parser.add_argument_group("Distributed training parameters")
|
|
309
|
+
group.add_argument("--world-size", type=int, default=1, metavar="N", help="number of distributed processes")
|
|
310
|
+
group.add_argument("--local-rank", type=int, metavar="N", help="local rank")
|
|
311
|
+
group.add_argument("--dist-url", type=str, default="env://", help="URL used to initialize distributed training")
|
|
312
|
+
group.add_argument("--dist-backend", type=str, default="nccl", help="distributed backend")
|
|
313
|
+
group.add_argument(
|
|
314
|
+
"--find-unused-parameters",
|
|
315
|
+
default=False,
|
|
316
|
+
action="store_true",
|
|
317
|
+
help="enable searching for unused parameters in DistributedDataParallel (may impact performance)",
|
|
318
|
+
)
|
|
319
|
+
group.add_argument(
|
|
320
|
+
"--no-broadcast-buffers",
|
|
321
|
+
default=False,
|
|
322
|
+
action="store_true",
|
|
323
|
+
help="disable broadcasting of buffers from rank 0 in distributed training",
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def add_logging_and_debug_args(parser: argparse.ArgumentParser, default_log_interval: int = 50) -> None:
|
|
328
|
+
group = parser.add_argument_group("Logging and debugging parameters")
|
|
329
|
+
group.add_argument(
|
|
330
|
+
"--experiment",
|
|
331
|
+
"--exp",
|
|
332
|
+
type=str,
|
|
333
|
+
metavar="NAME",
|
|
334
|
+
help="experiment name for logging (creates dedicated directory for the run)",
|
|
335
|
+
)
|
|
336
|
+
group.add_argument(
|
|
337
|
+
"--log-interval",
|
|
338
|
+
type=int,
|
|
339
|
+
default=default_log_interval,
|
|
340
|
+
metavar="N",
|
|
341
|
+
help="how many iterations between summary writes",
|
|
342
|
+
)
|
|
343
|
+
group.add_argument(
|
|
344
|
+
"--grad-anomaly-detection",
|
|
345
|
+
default=False,
|
|
346
|
+
action="store_true",
|
|
347
|
+
help="enable the autograd anomaly detection (for debugging)",
|
|
348
|
+
)
|
|
349
|
+
group.add_argument(
|
|
350
|
+
"--use-deterministic-algorithms", default=False, action="store_true", help="use only deterministic algorithms"
|
|
351
|
+
)
|
|
352
|
+
group.add_argument(
|
|
353
|
+
"--plot-lr", default=False, action="store_true", help="plot learning rate and exit (skip training)"
|
|
354
|
+
)
|
|
355
|
+
group.add_argument("--no-summary", default=False, action="store_true", help="don't print model summary")
|
|
356
|
+
group.add_argument(
|
|
357
|
+
"--non-interactive",
|
|
358
|
+
default=False,
|
|
359
|
+
action="store_true",
|
|
360
|
+
help="force non-interactive mode (disables progress bars)",
|
|
361
|
+
)
|
|
362
|
+
group.add_argument(
|
|
363
|
+
"--seed", type=int, help="set random seed for better reproducibility (affects torch, numpy and random)"
|
|
364
|
+
)
|
|
365
|
+
group.add_argument("--cpu", default=False, action="store_true", help="use cpu (mostly for testing)")
|
|
366
|
+
group.add_argument(
|
|
367
|
+
"--use-fake-data",
|
|
368
|
+
default=False,
|
|
369
|
+
action="store_true",
|
|
370
|
+
help="use fake data instead of real dataset",
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def add_training_data_args(parser: argparse.ArgumentParser) -> None:
|
|
375
|
+
group = parser.add_argument_group("Training data parameters", description="WebDataset")
|
|
376
|
+
group.add_argument("--wds", default=False, action="store_true", help="use webdataset for training")
|
|
377
|
+
group.add_argument("--wds-info", type=str, action="append", metavar="FILE", help="one or more wds info file paths")
|
|
378
|
+
group.add_argument("--wds-cache-dir", type=str, metavar="DIR", help="webdataset cache directory")
|
|
379
|
+
group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
|
|
380
|
+
group.add_argument("--wds-val-size", type=int, metavar="N", help="size of the wds validation set")
|
|
381
|
+
group.add_argument(
|
|
382
|
+
"--wds-training-split", type=str, default="training", metavar="NAME", help="wds dataset train split"
|
|
383
|
+
)
|
|
384
|
+
group.add_argument(
|
|
385
|
+
"--wds-val-split", type=str, default="validation", metavar="NAME", help="wds dataset validation split"
|
|
386
|
+
)
|
|
387
|
+
group.add_argument(
|
|
388
|
+
"--wds-extra-shuffle",
|
|
389
|
+
default=False,
|
|
390
|
+
action="store_true",
|
|
391
|
+
help="enable cross-worker batch shuffling after batching",
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
group = parser.add_argument_group(description="CSV")
|
|
395
|
+
group.add_argument("--data-path", nargs="*", help="training CSV file paths (required columns: image_path, caption)")
|
|
396
|
+
group.add_argument(
|
|
397
|
+
"--val-path", nargs="*", help="validation CSV file paths (required columns: image_path, caption)"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def common_args_validation(args: argparse.Namespace) -> None:
|
|
402
|
+
if args.network is None:
|
|
403
|
+
raise cli.ValidationError("--network is required")
|
|
404
|
+
if registry.exists(args.network, task=Task.IMAGE_TEXT) is False:
|
|
405
|
+
raise cli.ValidationError(f"--network {args.network} not supported, see list-models tool for available options")
|
|
406
|
+
|
|
407
|
+
if args.stop_epoch is not None and args.stop_epoch > args.epochs:
|
|
408
|
+
raise cli.ValidationError(
|
|
409
|
+
f"--stop-epoch must be smaller than the total number of epochs ({args.epochs}), got {args.stop_epoch}"
|
|
410
|
+
)
|
|
411
|
+
if args.warmup_epochs is not None and args.warmup_steps is not None:
|
|
412
|
+
raise cli.ValidationError("--warmup-epochs cannot be used with --warmup-steps")
|
|
413
|
+
if args.cooldown_epochs is not None and args.cooldown_steps is not None:
|
|
414
|
+
raise cli.ValidationError("--cooldown-epochs cannot be used with --cooldown-steps")
|
|
415
|
+
if args.lr_scheduler_update != "step" and args.warmup_steps is not None:
|
|
416
|
+
raise cli.ValidationError(
|
|
417
|
+
"--warmup-steps can only be used when --lr-scheduler-update is 'step', "
|
|
418
|
+
f"but it is set to '{args.lr_scheduler_update}'"
|
|
419
|
+
)
|
|
420
|
+
if args.lr_scheduler_update != "step" and args.cooldown_steps is not None:
|
|
421
|
+
raise cli.ValidationError(
|
|
422
|
+
"--cooldown-steps can only be used when --lr-scheduler-update is 'step', "
|
|
423
|
+
f"but it is set to '{args.lr_scheduler_update}'"
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
if args.compile_fullgraph is True and args.compile is False:
|
|
427
|
+
raise cli.ValidationError("--compile-fullgraph requires --compile")
|
|
428
|
+
if args.compile_mode is not None and args.compile is False:
|
|
429
|
+
raise cli.ValidationError("--compile-mode requires --compile")
|
|
430
|
+
|
|
431
|
+
if args.load_states is True and args.resume_epoch is None:
|
|
432
|
+
raise cli.ValidationError("--load-states requires --resume-epoch to be set")
|
|
433
|
+
if args.load_scheduler is True and args.resume_epoch is None:
|
|
434
|
+
raise cli.ValidationError("--load-scheduler requires --resume-epoch to be set")
|
|
435
|
+
|
|
436
|
+
if args.wds is False and args.data_path is None and args.use_fake_data is False:
|
|
437
|
+
raise cli.ValidationError("Must provide at least one data source, --data-path or --wds")
|
|
438
|
+
if args.wds is True and args.data_path is not None and len(args.data_path) > 1:
|
|
439
|
+
raise cli.ValidationError(f"--wds can have at most 1 --data-path, got {len(args.data_path)}")
|
|
440
|
+
if args.use_fake_data is True and args.wds is True:
|
|
441
|
+
raise cli.ValidationError("--use-fake-data cannot be used with --wds")
|
|
442
|
+
if args.persistent_workers is True and args.num_workers == 0:
|
|
443
|
+
raise cli.ValidationError("--persistent-workers requires --num-workers to be greater than 0")
|
|
444
|
+
|
|
445
|
+
if args.amp is True and args.model_dtype != "float32":
|
|
446
|
+
raise cli.ValidationError("--amp can only be used with --model-dtype float32")
|
|
447
|
+
if args.embed_dim is not None and args.embed_dim <= 0:
|
|
448
|
+
raise cli.ValidationError("--embed-dim must be positive")
|
|
449
|
+
if args.force_context_length is not None and args.force_context_length <= 0:
|
|
450
|
+
raise cli.ValidationError("--force-context-length must be positive")
|
|
@@ -2,6 +2,11 @@ import logging.config
|
|
|
2
2
|
import os
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from birder.conf import settings as birder_settings
|
|
6
|
+
|
|
7
|
+
# Paths
|
|
8
|
+
TOKENIZERS_DIR = birder_settings.MODELS_DIR.joinpath("tokenizers")
|
|
9
|
+
|
|
5
10
|
# Logging
|
|
6
11
|
# https://docs.python.org/3/library/logging.config.html
|
|
7
12
|
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from bisect import bisect_right
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import polars as pl
|
|
8
|
+
import torch
|
|
9
|
+
from birder.data.datasets.directory import tv_rgb_loader
|
|
10
|
+
|
|
11
|
+
from birder_clip.tokenizers import Tokenizer
|
|
12
|
+
|
|
13
|
+
IMAGE_PATH_COLUMN = "image_path"
|
|
14
|
+
CAPTION_COLUMN = "caption"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _resolve_image_path(path: str, csv_dir: Path) -> str:
|
|
18
|
+
image_path = Path(path)
|
|
19
|
+
if image_path.is_absolute() is True:
|
|
20
|
+
return path
|
|
21
|
+
|
|
22
|
+
return str(csv_dir / image_path)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ImageTextCsvDataset(torch.utils.data.Dataset):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
csv_paths: list[str | Path],
|
|
29
|
+
transforms: Optional[Callable[..., Any]] = None,
|
|
30
|
+
tokenizer: Optional[Tokenizer] = None,
|
|
31
|
+
loader: Callable[[str], Any] = tv_rgb_loader,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.transforms = transforms
|
|
35
|
+
self.tokenizer = tokenizer
|
|
36
|
+
self.loader = loader
|
|
37
|
+
self.csv_dirs: list[Path] = []
|
|
38
|
+
self.csv_offsets: list[int] = []
|
|
39
|
+
|
|
40
|
+
frames = []
|
|
41
|
+
length = 0
|
|
42
|
+
for csv_path in csv_paths:
|
|
43
|
+
csv_path = Path(csv_path).expanduser().absolute()
|
|
44
|
+
frame = pl.read_csv(
|
|
45
|
+
csv_path,
|
|
46
|
+
columns=[IMAGE_PATH_COLUMN, CAPTION_COLUMN],
|
|
47
|
+
schema_overrides={
|
|
48
|
+
IMAGE_PATH_COLUMN: pl.String,
|
|
49
|
+
CAPTION_COLUMN: pl.String,
|
|
50
|
+
},
|
|
51
|
+
)
|
|
52
|
+
frames.append(frame)
|
|
53
|
+
length += len(frame)
|
|
54
|
+
self.csv_offsets.append(length)
|
|
55
|
+
self.csv_dirs.append(csv_path.parent)
|
|
56
|
+
|
|
57
|
+
frame = pl.concat(frames)
|
|
58
|
+
self.paths = frame.get_column(IMAGE_PATH_COLUMN)
|
|
59
|
+
self.captions = frame.get_column(CAPTION_COLUMN)
|
|
60
|
+
|
|
61
|
+
def __getitem__(self, index: int) -> tuple[str, Any, Any]:
|
|
62
|
+
path = self.paths[index]
|
|
63
|
+
csv_idx = bisect_right(self.csv_offsets, index)
|
|
64
|
+
image_path = _resolve_image_path(path, self.csv_dirs[csv_idx])
|
|
65
|
+
caption = self.captions[index]
|
|
66
|
+
image = self.loader(image_path)
|
|
67
|
+
if self.transforms is not None:
|
|
68
|
+
image = self.transforms(image)
|
|
69
|
+
|
|
70
|
+
if self.tokenizer is not None:
|
|
71
|
+
text = self.tokenizer([caption])[0]
|
|
72
|
+
else:
|
|
73
|
+
text = caption
|
|
74
|
+
|
|
75
|
+
return (path, image, text)
|
|
76
|
+
|
|
77
|
+
def __len__(self) -> int:
|
|
78
|
+
return len(self.paths)
|
|
79
|
+
|
|
80
|
+
def __repr__(self) -> str:
|
|
81
|
+
head = "Dataset " + self.__class__.__name__
|
|
82
|
+
body = [f"Number of data points: {self.__len__()}"]
|
|
83
|
+
if self.transforms is not None:
|
|
84
|
+
body += [repr(self.transforms)]
|
|
85
|
+
if self.tokenizer is not None:
|
|
86
|
+
body += [repr(self.tokenizer)]
|
|
87
|
+
|
|
88
|
+
lines = [head] + [" " + line for line in body]
|
|
89
|
+
|
|
90
|
+
return "\n".join(lines)
|
|
@@ -9,6 +9,7 @@ class embedding again.
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
from collections.abc import Sequence
|
|
12
|
+
from typing import Optional
|
|
12
13
|
|
|
13
14
|
import torch
|
|
14
15
|
import torch.nn.functional as F
|
|
@@ -22,7 +23,7 @@ def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
def build_class_text_embeddings(
|
|
25
|
-
|
|
26
|
+
net: BaseNet,
|
|
26
27
|
tokenizer: Tokenizer,
|
|
27
28
|
class_names: Sequence[str],
|
|
28
29
|
templates: Sequence[str],
|
|
@@ -30,6 +31,8 @@ def build_class_text_embeddings(
|
|
|
30
31
|
device: torch.device,
|
|
31
32
|
context_length: int | None = None,
|
|
32
33
|
batch_size: int | None = None,
|
|
34
|
+
amp: bool = False,
|
|
35
|
+
amp_dtype: Optional[torch.dtype] = None,
|
|
33
36
|
) -> torch.Tensor:
|
|
34
37
|
num_templates = len(templates)
|
|
35
38
|
if batch_size is None:
|
|
@@ -41,7 +44,8 @@ def build_class_text_embeddings(
|
|
|
41
44
|
batch_class_names = class_names[start : start + batch_size]
|
|
42
45
|
prompts = render_prompts(batch_class_names, templates)
|
|
43
46
|
tokens = tokenizer(prompts, context_length=context_length).to(device)
|
|
44
|
-
|
|
47
|
+
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
48
|
+
class_embeddings = net.encode_text(tokens, normalize=True)
|
|
45
49
|
|
|
46
50
|
class_embeddings = class_embeddings.reshape(len(batch_class_names), num_templates, -1).mean(dim=1)
|
|
47
51
|
class_embeddings = F.normalize(class_embeddings, dim=-1)
|