birder-clip 0.0.2.dev3__tar.gz → 0.0.2.dev4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/PKG-INFO +2 -2
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/lib.py +10 -2
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/training_cli.py +103 -102
- birder_clip-0.0.2.dev4/birder_clip/common/training_utils.py +61 -0
- birder_clip-0.0.2.dev4/birder_clip/data/datasets/webdataset.py +106 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/loss/contrastive.py +11 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/clip.py +19 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/text/base.py +4 -0
- birder_clip-0.0.2.dev4/birder_clip/scripts/train.py +940 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/show_iterator.py +77 -11
- birder_clip-0.0.2.dev4/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/PKG-INFO +2 -2
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/SOURCES.txt +3 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/requires.txt +1 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/requirements/requirements.txt +1 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_common.py +1 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_datasets.py +6 -0
- birder_clip-0.0.2.dev3/birder_clip/version.py +0 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/LICENSE +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/README.md +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/fs_ops.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/conf/settings.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/csv.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/fake.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/inference/zero_shot.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/loss/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/manifest.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/model_registry.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/base.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/text/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/text/transformer.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/py.typed +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/scripts/zero_shot.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/base.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/hf.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/registry.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/__main__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/download_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/requirements/_requirements-dev.txt +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_loss.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_model_registry.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_net.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_tokenizers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder_clip
|
|
3
|
-
Version: 0.0.2.
|
|
3
|
+
Version: 0.0.2.dev4
|
|
4
4
|
Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
|
|
|
24
24
|
Requires-Python: >=3.11
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
License-File: LICENSE
|
|
27
|
-
Requires-Dist: birder>=0.5.
|
|
27
|
+
Requires-Dist: birder>=0.5.4
|
|
28
28
|
Requires-Dist: ftfy>=6.3.1
|
|
29
29
|
Requires-Dist: regex>=2025.7.29
|
|
30
30
|
Requires-Dist: tqdm>=4.67.0
|
|
@@ -43,9 +43,17 @@ def get_image_text_network_name(
|
|
|
43
43
|
parts = [network]
|
|
44
44
|
if image_encoder is not None:
|
|
45
45
|
parts.append(image_encoder)
|
|
46
|
-
if text_encoder is not None:
|
|
46
|
+
if text_encoder is not None and text_encoder != "text_transformer":
|
|
47
47
|
parts.append(text_encoder)
|
|
48
|
-
|
|
48
|
+
|
|
49
|
+
if registry.exists(network) is True:
|
|
50
|
+
default_tokenizer = registry.get_default_tokenizer(network)
|
|
51
|
+
else:
|
|
52
|
+
default_tokenizer = "simple_tokenizer"
|
|
53
|
+
if default_tokenizer is None:
|
|
54
|
+
default_tokenizer = "simple_tokenizer"
|
|
55
|
+
|
|
56
|
+
if tokenizer is not None and tokenizer != default_tokenizer:
|
|
49
57
|
parts.append(tokenizer)
|
|
50
58
|
if embed_dim is not None:
|
|
51
59
|
parts.append(f"d{embed_dim}")
|
|
@@ -17,32 +17,6 @@ from birder_clip.model_registry import Task
|
|
|
17
17
|
from birder_clip.model_registry import registry
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
def add_compile_args(parser: argparse.ArgumentParser) -> None:
|
|
21
|
-
group = parser.add_argument_group("Compilation parameters")
|
|
22
|
-
group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
23
|
-
group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
|
|
24
|
-
group.add_argument(
|
|
25
|
-
"--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
|
|
26
|
-
)
|
|
27
|
-
group.add_argument(
|
|
28
|
-
"--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
|
|
29
|
-
)
|
|
30
|
-
group.add_argument(
|
|
31
|
-
"--compile-recompile-limit",
|
|
32
|
-
type=int,
|
|
33
|
-
default=torch.compiler.config.recompile_limit,
|
|
34
|
-
metavar="N",
|
|
35
|
-
help="maximum recompilations per compiled function before eager fallback",
|
|
36
|
-
)
|
|
37
|
-
group.add_argument(
|
|
38
|
-
"--compile-accumulated-recompile-limit",
|
|
39
|
-
type=int,
|
|
40
|
-
default=torch.compiler.config.accumulated_recompile_limit,
|
|
41
|
-
metavar="N",
|
|
42
|
-
help="maximum total recompilations across compiled functions",
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
20
|
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
47
21
|
parser.add_argument("-n", "--network", type=str, help="the image-text network to train")
|
|
48
22
|
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
@@ -75,9 +49,7 @@ def add_loss_args(parser: argparse.ArgumentParser) -> None:
|
|
|
75
49
|
def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: int = 32) -> None:
|
|
76
50
|
group = parser.add_argument_group("Optimization parameters")
|
|
77
51
|
group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
|
|
78
|
-
group.add_argument(
|
|
79
|
-
"--opt", type=str, choices=list(get_args(OptimizerType)), default="adamw", help="optimizer to use"
|
|
80
|
-
)
|
|
52
|
+
group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
|
|
81
53
|
group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
|
|
82
54
|
group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
|
|
83
55
|
group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
|
|
@@ -92,6 +64,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
|
|
|
92
64
|
metavar="N",
|
|
93
65
|
help="number of iterations to accumulate gradients per optimizer step",
|
|
94
66
|
)
|
|
67
|
+
# NOTE: Add flag for negative sample caching in grad accum mode
|
|
95
68
|
|
|
96
69
|
|
|
97
70
|
def add_lr_wd_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -129,14 +102,14 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
|
129
102
|
"--lr-scheduler-update",
|
|
130
103
|
type=str,
|
|
131
104
|
choices=["epoch", "step"],
|
|
132
|
-
default="
|
|
105
|
+
default="epoch",
|
|
133
106
|
help="when to apply learning rate scheduler update: epoch (once per epoch), step (each optimizer step)",
|
|
134
107
|
)
|
|
135
108
|
group.add_argument(
|
|
136
109
|
"--lr-scheduler",
|
|
137
110
|
type=str,
|
|
138
111
|
choices=list(get_args(SchedulerType)),
|
|
139
|
-
default="
|
|
112
|
+
default="constant",
|
|
140
113
|
help="learning rate scheduler",
|
|
141
114
|
)
|
|
142
115
|
group.add_argument(
|
|
@@ -175,15 +148,6 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
|
175
148
|
)
|
|
176
149
|
|
|
177
150
|
|
|
178
|
-
def add_input_args(parser: argparse.ArgumentParser) -> None:
|
|
179
|
-
group = parser.add_argument_group("Input parameters")
|
|
180
|
-
group.add_argument(
|
|
181
|
-
"--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
|
|
182
|
-
)
|
|
183
|
-
group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
|
|
184
|
-
group.add_argument("--context-length", type=int, metavar="N", help="text context length")
|
|
185
|
-
|
|
186
|
-
|
|
187
151
|
def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
|
|
188
152
|
group = parser.add_argument_group("Training schedule parameters")
|
|
189
153
|
group.add_argument("--epochs", type=int, default=default_epochs, metavar="N", help="number of training epochs")
|
|
@@ -204,6 +168,37 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
|
|
|
204
168
|
)
|
|
205
169
|
|
|
206
170
|
|
|
171
|
+
def add_ema_args(
|
|
172
|
+
parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
|
|
173
|
+
) -> None:
|
|
174
|
+
group = parser.add_argument_group("Exponential moving average parameters")
|
|
175
|
+
group.add_argument(
|
|
176
|
+
"--model-ema",
|
|
177
|
+
default=False,
|
|
178
|
+
action="store_true",
|
|
179
|
+
help="enable tracking exponential moving average of model parameters",
|
|
180
|
+
)
|
|
181
|
+
group.add_argument(
|
|
182
|
+
"--model-ema-steps",
|
|
183
|
+
type=int,
|
|
184
|
+
default=default_ema_steps,
|
|
185
|
+
metavar="N",
|
|
186
|
+
help="number of optimizer steps between EMA updates",
|
|
187
|
+
)
|
|
188
|
+
group.add_argument(
|
|
189
|
+
"--model-ema-decay",
|
|
190
|
+
type=float,
|
|
191
|
+
default=default_ema_decay,
|
|
192
|
+
help="decay factor for exponential moving average of model parameters",
|
|
193
|
+
)
|
|
194
|
+
group.add_argument(
|
|
195
|
+
"--model-ema-warmup",
|
|
196
|
+
type=int,
|
|
197
|
+
metavar="N",
|
|
198
|
+
help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
207
202
|
def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
|
|
208
203
|
group = parser.add_argument_group("Batch normalization parameters")
|
|
209
204
|
group.add_argument(
|
|
@@ -215,6 +210,15 @@ def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
|
|
|
215
210
|
group.add_argument("--sync-bn", default=False, action="store_true", help="use synchronized BatchNorm")
|
|
216
211
|
|
|
217
212
|
|
|
213
|
+
def add_input_args(parser: argparse.ArgumentParser) -> None:
|
|
214
|
+
group = parser.add_argument_group("Input parameters")
|
|
215
|
+
group.add_argument(
|
|
216
|
+
"--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
|
|
217
|
+
)
|
|
218
|
+
group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
|
|
219
|
+
group.add_argument("--context-length", type=int, metavar="N", help="text context length")
|
|
220
|
+
|
|
221
|
+
|
|
218
222
|
def add_data_aug_args(
|
|
219
223
|
parser: argparse.ArgumentParser,
|
|
220
224
|
default_level: int = 4,
|
|
@@ -260,7 +264,7 @@ def add_data_aug_args(
|
|
|
260
264
|
"--rgb-mode",
|
|
261
265
|
type=str,
|
|
262
266
|
choices=list(typing.get_args(RGBMode)),
|
|
263
|
-
default="
|
|
267
|
+
default="birder",
|
|
264
268
|
help="RGB mean and std to use for normalization",
|
|
265
269
|
)
|
|
266
270
|
group.add_argument(
|
|
@@ -279,67 +283,6 @@ def add_data_aug_args(
|
|
|
279
283
|
)
|
|
280
284
|
|
|
281
285
|
|
|
282
|
-
def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
|
|
283
|
-
group = parser.add_argument_group("Checkpoint parameters")
|
|
284
|
-
group.add_argument(
|
|
285
|
-
"--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
|
|
286
|
-
)
|
|
287
|
-
group.add_argument(
|
|
288
|
-
"--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
|
|
289
|
-
)
|
|
290
|
-
group.add_argument(
|
|
291
|
-
"--pretrained",
|
|
292
|
-
default=False,
|
|
293
|
-
action="store_true",
|
|
294
|
-
help="start with pretrained version of specified network (will download if not found locally)",
|
|
295
|
-
)
|
|
296
|
-
group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
|
|
297
|
-
group.add_argument(
|
|
298
|
-
"--non-strict-weights",
|
|
299
|
-
default=False,
|
|
300
|
-
action="store_true",
|
|
301
|
-
help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
|
|
302
|
-
)
|
|
303
|
-
group.add_argument(
|
|
304
|
-
"--load-states",
|
|
305
|
-
default=False,
|
|
306
|
-
action="store_true",
|
|
307
|
-
help="load optimizer, scheduler and scaler states when resuming",
|
|
308
|
-
)
|
|
309
|
-
group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
def add_ema_args(
|
|
313
|
-
parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
|
|
314
|
-
) -> None:
|
|
315
|
-
group = parser.add_argument_group("Exponential moving average parameters")
|
|
316
|
-
group.add_argument(
|
|
317
|
-
"--model-ema",
|
|
318
|
-
default=False,
|
|
319
|
-
action="store_true",
|
|
320
|
-
help="enable tracking exponential moving average of model parameters",
|
|
321
|
-
)
|
|
322
|
-
group.add_argument(
|
|
323
|
-
"--model-ema-steps",
|
|
324
|
-
type=int,
|
|
325
|
-
default=default_ema_steps,
|
|
326
|
-
metavar="N",
|
|
327
|
-
help="number of optimizer steps between EMA updates",
|
|
328
|
-
)
|
|
329
|
-
group.add_argument(
|
|
330
|
-
"--model-ema-decay",
|
|
331
|
-
type=float,
|
|
332
|
-
default=default_ema_decay,
|
|
333
|
-
help="decay factor for exponential moving average of model parameters",
|
|
334
|
-
)
|
|
335
|
-
group.add_argument(
|
|
336
|
-
"--model-ema-warmup",
|
|
337
|
-
type=int,
|
|
338
|
-
metavar="N",
|
|
339
|
-
help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
|
|
343
286
|
def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
|
|
344
287
|
group = parser.add_argument_group("Dataloader parameters")
|
|
345
288
|
group.add_argument(
|
|
@@ -405,6 +348,62 @@ def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
|
405
348
|
)
|
|
406
349
|
|
|
407
350
|
|
|
351
|
+
def add_compile_args(parser: argparse.ArgumentParser) -> None:
|
|
352
|
+
group = parser.add_argument_group("Compilation parameters")
|
|
353
|
+
group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
354
|
+
group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
|
|
355
|
+
group.add_argument(
|
|
356
|
+
"--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
|
|
357
|
+
)
|
|
358
|
+
group.add_argument(
|
|
359
|
+
"--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
|
|
360
|
+
)
|
|
361
|
+
group.add_argument(
|
|
362
|
+
"--compile-recompile-limit",
|
|
363
|
+
type=int,
|
|
364
|
+
default=torch.compiler.config.recompile_limit,
|
|
365
|
+
metavar="N",
|
|
366
|
+
help="maximum recompilations per compiled function before eager fallback",
|
|
367
|
+
)
|
|
368
|
+
group.add_argument(
|
|
369
|
+
"--compile-accumulated-recompile-limit",
|
|
370
|
+
type=int,
|
|
371
|
+
default=torch.compiler.config.accumulated_recompile_limit,
|
|
372
|
+
metavar="N",
|
|
373
|
+
help="maximum total recompilations across compiled functions",
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
|
|
378
|
+
group = parser.add_argument_group("Checkpoint parameters")
|
|
379
|
+
group.add_argument(
|
|
380
|
+
"--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
|
|
381
|
+
)
|
|
382
|
+
group.add_argument(
|
|
383
|
+
"--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
|
|
384
|
+
)
|
|
385
|
+
group.add_argument(
|
|
386
|
+
"--pretrained",
|
|
387
|
+
default=False,
|
|
388
|
+
action="store_true",
|
|
389
|
+
help="start with pretrained version of specified network (will download if not found locally)",
|
|
390
|
+
)
|
|
391
|
+
group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
|
|
392
|
+
group.add_argument(
|
|
393
|
+
"--non-strict-weights",
|
|
394
|
+
default=False,
|
|
395
|
+
action="store_true",
|
|
396
|
+
help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
|
|
397
|
+
)
|
|
398
|
+
group.add_argument(
|
|
399
|
+
"--load-states",
|
|
400
|
+
default=False,
|
|
401
|
+
action="store_true",
|
|
402
|
+
help="load optimizer, scheduler and scaler states when resuming",
|
|
403
|
+
)
|
|
404
|
+
group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
|
|
405
|
+
|
|
406
|
+
|
|
408
407
|
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
|
409
408
|
group = parser.add_argument_group("Distributed training parameters")
|
|
410
409
|
group.add_argument("--world-size", type=int, default=1, metavar="N", help="number of distributed processes")
|
|
@@ -558,5 +557,7 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
558
557
|
raise cli.ValidationError("--embed-dim must be positive")
|
|
559
558
|
if args.context_length is not None and args.context_length <= 0:
|
|
560
559
|
raise cli.ValidationError("--context-length must be positive")
|
|
560
|
+
if args.grad_accum_steps < 1:
|
|
561
|
+
raise cli.ValidationError("--grad-accum-steps must be >= 1")
|
|
561
562
|
if args.model_ema_steps < 1:
|
|
562
563
|
raise cli.ValidationError("--model-ema-steps must be >= 1")
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from birder.common import training_utils as birder_training_utils
|
|
10
|
+
|
|
11
|
+
from birder_clip.common import fs_ops
|
|
12
|
+
from birder_clip.conf import settings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
|
|
16
|
+
file_handler = logging.FileHandler(log_file_path)
|
|
17
|
+
formatter = logging.Formatter(
|
|
18
|
+
fmt="{message}",
|
|
19
|
+
style="{",
|
|
20
|
+
)
|
|
21
|
+
file_handler.setFormatter(formatter)
|
|
22
|
+
file_handler.setLevel(settings.LOG_LEVEL)
|
|
23
|
+
|
|
24
|
+
logging.getLogger("birder").addHandler(file_handler)
|
|
25
|
+
logging.getLogger("birder_clip").addHandler(file_handler)
|
|
26
|
+
|
|
27
|
+
return file_handler
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def save_training_checkpoint(
|
|
31
|
+
args: argparse.Namespace,
|
|
32
|
+
network_name: str,
|
|
33
|
+
epoch: int,
|
|
34
|
+
net: torch.nn.Module,
|
|
35
|
+
signature: Any,
|
|
36
|
+
rgb_stats: Any,
|
|
37
|
+
optimizer: Optional[torch.optim.Optimizer],
|
|
38
|
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
|
|
39
|
+
scaler: Optional[torch.amp.grad_scaler.GradScaler],
|
|
40
|
+
model_base: Optional[torch.nn.Module],
|
|
41
|
+
*,
|
|
42
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
43
|
+
**extra_states: Optional[dict[str, Any]],
|
|
44
|
+
) -> None:
|
|
45
|
+
if birder_training_utils.is_global_primary(args) is True:
|
|
46
|
+
fs_ops.checkpoint_model(
|
|
47
|
+
network_name,
|
|
48
|
+
epoch,
|
|
49
|
+
net,
|
|
50
|
+
signature,
|
|
51
|
+
rgb_stats,
|
|
52
|
+
optimizer,
|
|
53
|
+
scheduler,
|
|
54
|
+
scaler,
|
|
55
|
+
model_base,
|
|
56
|
+
external_config=external_config,
|
|
57
|
+
**extra_states,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if birder_training_utils.is_dist_available_and_initialized() is True:
|
|
61
|
+
dist.barrier()
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import webdataset as wds
|
|
9
|
+
from birder.conf import settings
|
|
10
|
+
from birder.data.datasets import webdataset as birder_wds
|
|
11
|
+
|
|
12
|
+
from birder_clip.tokenizers import Tokenizer
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def decode_caption(caption: Any, caption_json_key: str = "caption") -> str:
|
|
18
|
+
if isinstance(caption, dict):
|
|
19
|
+
if caption_json_key not in caption:
|
|
20
|
+
raise ValueError(f"WebDataset JSON sample missing '{caption_json_key}' key")
|
|
21
|
+
|
|
22
|
+
caption = caption[caption_json_key]
|
|
23
|
+
|
|
24
|
+
if isinstance(caption, bytes):
|
|
25
|
+
caption = caption.decode("utf-8")
|
|
26
|
+
|
|
27
|
+
if isinstance(caption, str) is False:
|
|
28
|
+
raise TypeError(f"WebDataset caption must be a string, got {type(caption).__name__}")
|
|
29
|
+
|
|
30
|
+
return caption # type: ignore[no-any-return]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def tokenize_caption(caption: str, tokenizer: Tokenizer) -> torch.Tensor:
|
|
34
|
+
return tokenizer([caption])[0]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def make_wds_dataset(
|
|
38
|
+
wds_path: str | list[str],
|
|
39
|
+
dataset_size: int,
|
|
40
|
+
shuffle: bool,
|
|
41
|
+
samples_names: bool,
|
|
42
|
+
transform: Callable[..., torch.Tensor],
|
|
43
|
+
image_decoder: birder_wds.WDSImageDecoderSpec = "tv",
|
|
44
|
+
channels: int = settings.DEFAULT_NUM_CHANNELS,
|
|
45
|
+
tokenizer: Optional[Tokenizer] = None,
|
|
46
|
+
*,
|
|
47
|
+
caption_key: str = "txt;json", # WebDataset picks the first present key, so txt takes precedence over json
|
|
48
|
+
caption_json_key: str = "caption",
|
|
49
|
+
cache_dir: Optional[str] = None,
|
|
50
|
+
shuffle_buffer_size: Optional[int] = None,
|
|
51
|
+
shuffle_initial_size: Optional[int] = None,
|
|
52
|
+
) -> torch.utils.data.IterableDataset:
|
|
53
|
+
if shuffle is True:
|
|
54
|
+
shardshuffle = 500
|
|
55
|
+
else:
|
|
56
|
+
shardshuffle = False
|
|
57
|
+
|
|
58
|
+
dataset = wds.WebDataset(
|
|
59
|
+
wds_path, shardshuffle=shardshuffle, nodesplitter=wds.split_by_node, cache_dir=cache_dir, empty_check=False
|
|
60
|
+
)
|
|
61
|
+
if shuffle is True:
|
|
62
|
+
if shuffle_buffer_size is None:
|
|
63
|
+
shuffle_buffer_size = birder_wds.WDS_SHUFFLE_SIZE
|
|
64
|
+
if shuffle_initial_size is None:
|
|
65
|
+
shuffle_initial_size = birder_wds.WDS_INITIAL_SIZE
|
|
66
|
+
|
|
67
|
+
logger.debug(f"Using buffer size of {shuffle_buffer_size} for shuffle with {shuffle_initial_size} initial size")
|
|
68
|
+
dataset = dataset.shuffle(shuffle_buffer_size, initial=shuffle_initial_size)
|
|
69
|
+
|
|
70
|
+
return_keys = ["jpeg;jpg;png;webp"]
|
|
71
|
+
return_keys = return_keys + [caption_key]
|
|
72
|
+
if samples_names is True:
|
|
73
|
+
return_keys = ["__url__", "__key__"] + return_keys
|
|
74
|
+
|
|
75
|
+
if isinstance(image_decoder, str):
|
|
76
|
+
decoder = birder_wds.get_wds_image_decoder(image_decoder, channels)
|
|
77
|
+
else:
|
|
78
|
+
decoder = image_decoder
|
|
79
|
+
|
|
80
|
+
dataset = dataset.with_length(dataset_size, silent=True).decode(decoder).to_tuple(*return_keys)
|
|
81
|
+
|
|
82
|
+
caption_decoder = partial(decode_caption, caption_json_key=caption_json_key)
|
|
83
|
+
if samples_names is True:
|
|
84
|
+
dataset = dataset.map(birder_wds.decode_sample_name)
|
|
85
|
+
dataset = dataset.map_tuple(birder_wds.identity, transform, caption_decoder)
|
|
86
|
+
else:
|
|
87
|
+
dataset = dataset.map_tuple(transform, caption_decoder)
|
|
88
|
+
|
|
89
|
+
if tokenizer is not None:
|
|
90
|
+
text_transform = partial(tokenize_caption, tokenizer=tokenizer)
|
|
91
|
+
if samples_names is True:
|
|
92
|
+
dataset = dataset.map_tuple(birder_wds.identity, birder_wds.identity, text_transform)
|
|
93
|
+
else:
|
|
94
|
+
dataset = dataset.map_tuple(birder_wds.identity, text_transform)
|
|
95
|
+
|
|
96
|
+
return dataset
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def wds_size(wds_path: str, device: torch.device, select_suffix: str | tuple[str, ...] = ("txt", "json")) -> int:
|
|
100
|
+
return birder_wds.wds_size(wds_path, device, select_suffix=select_suffix)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def prepare_wds_args(
|
|
104
|
+
data_path: str, size: Optional[int], device: torch.device, select_suffix: str | tuple[str, ...] = ("txt", "json")
|
|
105
|
+
) -> tuple[str, int]:
|
|
106
|
+
return birder_wds.prepare_wds_args(data_path, size, device, select_suffix=select_suffix)
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
"""
|
|
2
2
|
CLIP loss, adapted from
|
|
3
3
|
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py
|
|
4
|
+
|
|
5
|
+
Paper "Learning Transferable Visual Models From Natural Language Supervision",
|
|
6
|
+
https://arxiv.org/abs/2103.00020
|
|
4
7
|
"""
|
|
5
8
|
|
|
6
9
|
# Reference license: MIT
|
|
@@ -22,6 +25,14 @@ def gather_features(features: torch.Tensor) -> torch.Tensor:
|
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class CLIPLoss(torch.nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
CLIP symmetric contrastive loss
|
|
30
|
+
|
|
31
|
+
Implements the bidirectional InfoNCE objective from CLIP: image features
|
|
32
|
+
classify their matching text features, and text features classify their
|
|
33
|
+
matching image features, using the batch as negatives.
|
|
34
|
+
"""
|
|
35
|
+
|
|
25
36
|
def forward(
|
|
26
37
|
self,
|
|
27
38
|
image_features: torch.Tensor,
|
|
@@ -246,6 +246,10 @@ registry.register_model_config(
|
|
|
246
246
|
},
|
|
247
247
|
)
|
|
248
248
|
|
|
249
|
+
|
|
250
|
+
# Weights
|
|
251
|
+
####################
|
|
252
|
+
|
|
249
253
|
registry.register_weights(
|
|
250
254
|
"openai_clip_vit_l14",
|
|
251
255
|
{
|
|
@@ -261,3 +265,18 @@ registry.register_weights(
|
|
|
261
265
|
"net": {"network": "openai_clip_vit_l14"},
|
|
262
266
|
},
|
|
263
267
|
)
|
|
268
|
+
registry.register_weights(
|
|
269
|
+
"pe_core_b16",
|
|
270
|
+
{
|
|
271
|
+
"description": "RoPEi ViT b16 image encoder pretrained by Meta FAIR using CLIP",
|
|
272
|
+
"resolution": (224, 224),
|
|
273
|
+
"context_length": 32,
|
|
274
|
+
"formats": {
|
|
275
|
+
"pt": {
|
|
276
|
+
"file_size": 1707.8,
|
|
277
|
+
"sha256": "11453d4a36fad6dbd802ec9fa35375ce0ae8b7b156a5ca45c0e87587df05290f",
|
|
278
|
+
}
|
|
279
|
+
},
|
|
280
|
+
"net": {"network": "pe_core_b16"},
|
|
281
|
+
},
|
|
282
|
+
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import logging
|
|
2
3
|
from typing import Any
|
|
3
4
|
from typing import Optional
|
|
4
5
|
|
|
@@ -7,6 +8,8 @@ from torch import nn
|
|
|
7
8
|
|
|
8
9
|
from birder_clip.model_registry import Task
|
|
9
10
|
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class TextBaseNet(nn.Module):
|
|
12
15
|
default_context_length = 77
|
|
@@ -38,4 +41,5 @@ class TextBaseNet(nn.Module):
|
|
|
38
41
|
if new_context_length == self.context_length:
|
|
39
42
|
return
|
|
40
43
|
|
|
44
|
+
logger.info(f"Adjusting model context length from {self.context_length} to {new_context_length}")
|
|
41
45
|
self.context_length = new_context_length
|