birder-clip 0.0.2.dev3__tar.gz → 0.0.2.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/PKG-INFO +2 -2
  2. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/lib.py +10 -2
  3. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/training_cli.py +103 -102
  4. birder_clip-0.0.2.dev4/birder_clip/common/training_utils.py +61 -0
  5. birder_clip-0.0.2.dev4/birder_clip/data/datasets/webdataset.py +106 -0
  6. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/loss/contrastive.py +11 -0
  7. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/clip.py +19 -0
  8. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/text/base.py +4 -0
  9. birder_clip-0.0.2.dev4/birder_clip/scripts/train.py +940 -0
  10. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/show_iterator.py +77 -11
  11. birder_clip-0.0.2.dev4/birder_clip/version.py +1 -0
  12. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/PKG-INFO +2 -2
  13. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/SOURCES.txt +3 -0
  14. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/requires.txt +1 -1
  15. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/requirements/requirements.txt +1 -1
  16. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_common.py +1 -1
  17. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_datasets.py +6 -0
  18. birder_clip-0.0.2.dev3/birder_clip/version.py +0 -1
  19. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/LICENSE +0 -0
  20. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/README.md +0 -0
  21. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/__init__.py +0 -0
  22. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/__init__.py +0 -0
  23. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/common/fs_ops.py +0 -0
  24. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/conf/__init__.py +0 -0
  25. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/conf/settings.py +0 -0
  26. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/__init__.py +0 -0
  27. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/__init__.py +0 -0
  28. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/csv.py +0 -0
  29. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/data/datasets/fake.py +0 -0
  30. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/inference/__init__.py +0 -0
  31. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/inference/zero_shot.py +0 -0
  32. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/inference/zero_shot_templates.py +0 -0
  33. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/loss/__init__.py +0 -0
  34. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/__init__.py +0 -0
  35. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/manifest.py +0 -0
  36. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/model_registry/model_registry.py +0 -0
  37. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/__init__.py +0 -0
  38. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/base.py +0 -0
  39. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/text/__init__.py +0 -0
  40. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/net/text/transformer.py +0 -0
  41. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/py.typed +0 -0
  42. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/scripts/__init__.py +0 -0
  43. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/scripts/zero_shot.py +0 -0
  44. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/__init__.py +0 -0
  45. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/base.py +0 -0
  46. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  47. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/hf.py +0 -0
  48. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/registry.py +0 -0
  49. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
  50. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/__init__.py +0 -0
  51. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/__main__.py +0 -0
  52. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip/tools/download_tokenizer.py +0 -0
  53. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/dependency_links.txt +0 -0
  54. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/birder_clip.egg-info/top_level.txt +0 -0
  55. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/pyproject.toml +0 -0
  56. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/requirements/_requirements-dev.txt +0 -0
  57. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/setup.cfg +0 -0
  58. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_loss.py +0 -0
  59. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_model_registry.py +0 -0
  60. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_net.py +0 -0
  61. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev4}/tests/test_tokenizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder_clip
3
- Version: 0.0.2.dev3
3
+ Version: 0.0.2.dev4
4
4
  Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
24
24
  Requires-Python: >=3.11
25
25
  Description-Content-Type: text/markdown
26
26
  License-File: LICENSE
27
- Requires-Dist: birder>=0.5.2
27
+ Requires-Dist: birder>=0.5.4
28
28
  Requires-Dist: ftfy>=6.3.1
29
29
  Requires-Dist: regex>=2025.7.29
30
30
  Requires-Dist: tqdm>=4.67.0
@@ -43,9 +43,17 @@ def get_image_text_network_name(
43
43
  parts = [network]
44
44
  if image_encoder is not None:
45
45
  parts.append(image_encoder)
46
- if text_encoder is not None:
46
+ if text_encoder is not None and text_encoder != "text_transformer":
47
47
  parts.append(text_encoder)
48
- if tokenizer is not None:
48
+
49
+ if registry.exists(network) is True:
50
+ default_tokenizer = registry.get_default_tokenizer(network)
51
+ else:
52
+ default_tokenizer = "simple_tokenizer"
53
+ if default_tokenizer is None:
54
+ default_tokenizer = "simple_tokenizer"
55
+
56
+ if tokenizer is not None and tokenizer != default_tokenizer:
49
57
  parts.append(tokenizer)
50
58
  if embed_dim is not None:
51
59
  parts.append(f"d{embed_dim}")
@@ -17,32 +17,6 @@ from birder_clip.model_registry import Task
17
17
  from birder_clip.model_registry import registry
18
18
 
19
19
 
20
- def add_compile_args(parser: argparse.ArgumentParser) -> None:
21
- group = parser.add_argument_group("Compilation parameters")
22
- group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
23
- group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
24
- group.add_argument(
25
- "--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
26
- )
27
- group.add_argument(
28
- "--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
29
- )
30
- group.add_argument(
31
- "--compile-recompile-limit",
32
- type=int,
33
- default=torch.compiler.config.recompile_limit,
34
- metavar="N",
35
- help="maximum recompilations per compiled function before eager fallback",
36
- )
37
- group.add_argument(
38
- "--compile-accumulated-recompile-limit",
39
- type=int,
40
- default=torch.compiler.config.accumulated_recompile_limit,
41
- metavar="N",
42
- help="maximum total recompilations across compiled functions",
43
- )
44
-
45
-
46
20
  def add_model_args(parser: argparse.ArgumentParser) -> None:
47
21
  parser.add_argument("-n", "--network", type=str, help="the image-text network to train")
48
22
  parser.add_argument("-t", "--tag", type=str, help="add model tag")
@@ -75,9 +49,7 @@ def add_loss_args(parser: argparse.ArgumentParser) -> None:
75
49
  def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: int = 32) -> None:
76
50
  group = parser.add_argument_group("Optimization parameters")
77
51
  group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
78
- group.add_argument(
79
- "--opt", type=str, choices=list(get_args(OptimizerType)), default="adamw", help="optimizer to use"
80
- )
52
+ group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
81
53
  group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
82
54
  group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
83
55
  group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
@@ -92,6 +64,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
92
64
  metavar="N",
93
65
  help="number of iterations to accumulate gradients per optimizer step",
94
66
  )
67
+ # NOTE: Add flag for negative sample caching in grad accum mode
95
68
 
96
69
 
97
70
  def add_lr_wd_args(parser: argparse.ArgumentParser) -> None:
@@ -129,14 +102,14 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
129
102
  "--lr-scheduler-update",
130
103
  type=str,
131
104
  choices=["epoch", "step"],
132
- default="step",
105
+ default="epoch",
133
106
  help="when to apply learning rate scheduler update: epoch (once per epoch), step (each optimizer step)",
134
107
  )
135
108
  group.add_argument(
136
109
  "--lr-scheduler",
137
110
  type=str,
138
111
  choices=list(get_args(SchedulerType)),
139
- default="cosine",
112
+ default="constant",
140
113
  help="learning rate scheduler",
141
114
  )
142
115
  group.add_argument(
@@ -175,15 +148,6 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
175
148
  )
176
149
 
177
150
 
178
- def add_input_args(parser: argparse.ArgumentParser) -> None:
179
- group = parser.add_argument_group("Input parameters")
180
- group.add_argument(
181
- "--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
182
- )
183
- group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
184
- group.add_argument("--context-length", type=int, metavar="N", help="text context length")
185
-
186
-
187
151
  def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
188
152
  group = parser.add_argument_group("Training schedule parameters")
189
153
  group.add_argument("--epochs", type=int, default=default_epochs, metavar="N", help="number of training epochs")
@@ -204,6 +168,37 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
204
168
  )
205
169
 
206
170
 
171
+ def add_ema_args(
172
+ parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
173
+ ) -> None:
174
+ group = parser.add_argument_group("Exponential moving average parameters")
175
+ group.add_argument(
176
+ "--model-ema",
177
+ default=False,
178
+ action="store_true",
179
+ help="enable tracking exponential moving average of model parameters",
180
+ )
181
+ group.add_argument(
182
+ "--model-ema-steps",
183
+ type=int,
184
+ default=default_ema_steps,
185
+ metavar="N",
186
+ help="number of optimizer steps between EMA updates",
187
+ )
188
+ group.add_argument(
189
+ "--model-ema-decay",
190
+ type=float,
191
+ default=default_ema_decay,
192
+ help="decay factor for exponential moving average of model parameters",
193
+ )
194
+ group.add_argument(
195
+ "--model-ema-warmup",
196
+ type=int,
197
+ metavar="N",
198
+ help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
199
+ )
200
+
201
+
207
202
  def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
208
203
  group = parser.add_argument_group("Batch normalization parameters")
209
204
  group.add_argument(
@@ -215,6 +210,15 @@ def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
215
210
  group.add_argument("--sync-bn", default=False, action="store_true", help="use synchronized BatchNorm")
216
211
 
217
212
 
213
+ def add_input_args(parser: argparse.ArgumentParser) -> None:
214
+ group = parser.add_argument_group("Input parameters")
215
+ group.add_argument(
216
+ "--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
217
+ )
218
+ group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
219
+ group.add_argument("--context-length", type=int, metavar="N", help="text context length")
220
+
221
+
218
222
  def add_data_aug_args(
219
223
  parser: argparse.ArgumentParser,
220
224
  default_level: int = 4,
@@ -260,7 +264,7 @@ def add_data_aug_args(
260
264
  "--rgb-mode",
261
265
  type=str,
262
266
  choices=list(typing.get_args(RGBMode)),
263
- default="clip",
267
+ default="birder",
264
268
  help="RGB mean and std to use for normalization",
265
269
  )
266
270
  group.add_argument(
@@ -279,67 +283,6 @@ def add_data_aug_args(
279
283
  )
280
284
 
281
285
 
282
- def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
283
- group = parser.add_argument_group("Checkpoint parameters")
284
- group.add_argument(
285
- "--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
286
- )
287
- group.add_argument(
288
- "--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
289
- )
290
- group.add_argument(
291
- "--pretrained",
292
- default=False,
293
- action="store_true",
294
- help="start with pretrained version of specified network (will download if not found locally)",
295
- )
296
- group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
297
- group.add_argument(
298
- "--non-strict-weights",
299
- default=False,
300
- action="store_true",
301
- help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
302
- )
303
- group.add_argument(
304
- "--load-states",
305
- default=False,
306
- action="store_true",
307
- help="load optimizer, scheduler and scaler states when resuming",
308
- )
309
- group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
310
-
311
-
312
- def add_ema_args(
313
- parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
314
- ) -> None:
315
- group = parser.add_argument_group("Exponential moving average parameters")
316
- group.add_argument(
317
- "--model-ema",
318
- default=False,
319
- action="store_true",
320
- help="enable tracking exponential moving average of model parameters",
321
- )
322
- group.add_argument(
323
- "--model-ema-steps",
324
- type=int,
325
- default=default_ema_steps,
326
- metavar="N",
327
- help="number of optimizer steps between EMA updates",
328
- )
329
- group.add_argument(
330
- "--model-ema-decay",
331
- type=float,
332
- default=default_ema_decay,
333
- help="decay factor for exponential moving average of model parameters",
334
- )
335
- group.add_argument(
336
- "--model-ema-warmup",
337
- type=int,
338
- metavar="N",
339
- help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
340
- )
341
-
342
-
343
286
  def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
344
287
  group = parser.add_argument_group("Dataloader parameters")
345
288
  group.add_argument(
@@ -405,6 +348,62 @@ def add_precision_args(parser: argparse.ArgumentParser) -> None:
405
348
  )
406
349
 
407
350
 
351
+ def add_compile_args(parser: argparse.ArgumentParser) -> None:
352
+ group = parser.add_argument_group("Compilation parameters")
353
+ group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
354
+ group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
355
+ group.add_argument(
356
+ "--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
357
+ )
358
+ group.add_argument(
359
+ "--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
360
+ )
361
+ group.add_argument(
362
+ "--compile-recompile-limit",
363
+ type=int,
364
+ default=torch.compiler.config.recompile_limit,
365
+ metavar="N",
366
+ help="maximum recompilations per compiled function before eager fallback",
367
+ )
368
+ group.add_argument(
369
+ "--compile-accumulated-recompile-limit",
370
+ type=int,
371
+ default=torch.compiler.config.accumulated_recompile_limit,
372
+ metavar="N",
373
+ help="maximum total recompilations across compiled functions",
374
+ )
375
+
376
+
377
+ def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
378
+ group = parser.add_argument_group("Checkpoint parameters")
379
+ group.add_argument(
380
+ "--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
381
+ )
382
+ group.add_argument(
383
+ "--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
384
+ )
385
+ group.add_argument(
386
+ "--pretrained",
387
+ default=False,
388
+ action="store_true",
389
+ help="start with pretrained version of specified network (will download if not found locally)",
390
+ )
391
+ group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
392
+ group.add_argument(
393
+ "--non-strict-weights",
394
+ default=False,
395
+ action="store_true",
396
+ help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
397
+ )
398
+ group.add_argument(
399
+ "--load-states",
400
+ default=False,
401
+ action="store_true",
402
+ help="load optimizer, scheduler and scaler states when resuming",
403
+ )
404
+ group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
405
+
406
+
408
407
  def add_distributed_args(parser: argparse.ArgumentParser) -> None:
409
408
  group = parser.add_argument_group("Distributed training parameters")
410
409
  group.add_argument("--world-size", type=int, default=1, metavar="N", help="number of distributed processes")
@@ -558,5 +557,7 @@ def common_args_validation(args: argparse.Namespace) -> None:
558
557
  raise cli.ValidationError("--embed-dim must be positive")
559
558
  if args.context_length is not None and args.context_length <= 0:
560
559
  raise cli.ValidationError("--context-length must be positive")
560
+ if args.grad_accum_steps < 1:
561
+ raise cli.ValidationError("--grad-accum-steps must be >= 1")
561
562
  if args.model_ema_steps < 1:
562
563
  raise cli.ValidationError("--model-ema-steps must be >= 1")
@@ -0,0 +1,61 @@
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from birder.common import training_utils as birder_training_utils
10
+
11
+ from birder_clip.common import fs_ops
12
+ from birder_clip.conf import settings
13
+
14
+
15
+ def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
16
+ file_handler = logging.FileHandler(log_file_path)
17
+ formatter = logging.Formatter(
18
+ fmt="{message}",
19
+ style="{",
20
+ )
21
+ file_handler.setFormatter(formatter)
22
+ file_handler.setLevel(settings.LOG_LEVEL)
23
+
24
+ logging.getLogger("birder").addHandler(file_handler)
25
+ logging.getLogger("birder_clip").addHandler(file_handler)
26
+
27
+ return file_handler
28
+
29
+
30
+ def save_training_checkpoint(
31
+ args: argparse.Namespace,
32
+ network_name: str,
33
+ epoch: int,
34
+ net: torch.nn.Module,
35
+ signature: Any,
36
+ rgb_stats: Any,
37
+ optimizer: Optional[torch.optim.Optimizer],
38
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
39
+ scaler: Optional[torch.amp.grad_scaler.GradScaler],
40
+ model_base: Optional[torch.nn.Module],
41
+ *,
42
+ external_config: Optional[dict[str, Any]] = None,
43
+ **extra_states: Optional[dict[str, Any]],
44
+ ) -> None:
45
+ if birder_training_utils.is_global_primary(args) is True:
46
+ fs_ops.checkpoint_model(
47
+ network_name,
48
+ epoch,
49
+ net,
50
+ signature,
51
+ rgb_stats,
52
+ optimizer,
53
+ scheduler,
54
+ scaler,
55
+ model_base,
56
+ external_config=external_config,
57
+ **extra_states,
58
+ )
59
+
60
+ if birder_training_utils.is_dist_available_and_initialized() is True:
61
+ dist.barrier()
@@ -0,0 +1,106 @@
1
+ import logging
2
+ from collections.abc import Callable
3
+ from functools import partial
4
+ from typing import Any
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import webdataset as wds
9
+ from birder.conf import settings
10
+ from birder.data.datasets import webdataset as birder_wds
11
+
12
+ from birder_clip.tokenizers import Tokenizer
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def decode_caption(caption: Any, caption_json_key: str = "caption") -> str:
18
+ if isinstance(caption, dict):
19
+ if caption_json_key not in caption:
20
+ raise ValueError(f"WebDataset JSON sample missing '{caption_json_key}' key")
21
+
22
+ caption = caption[caption_json_key]
23
+
24
+ if isinstance(caption, bytes):
25
+ caption = caption.decode("utf-8")
26
+
27
+ if isinstance(caption, str) is False:
28
+ raise TypeError(f"WebDataset caption must be a string, got {type(caption).__name__}")
29
+
30
+ return caption # type: ignore[no-any-return]
31
+
32
+
33
+ def tokenize_caption(caption: str, tokenizer: Tokenizer) -> torch.Tensor:
34
+ return tokenizer([caption])[0]
35
+
36
+
37
+ def make_wds_dataset(
38
+ wds_path: str | list[str],
39
+ dataset_size: int,
40
+ shuffle: bool,
41
+ samples_names: bool,
42
+ transform: Callable[..., torch.Tensor],
43
+ image_decoder: birder_wds.WDSImageDecoderSpec = "tv",
44
+ channels: int = settings.DEFAULT_NUM_CHANNELS,
45
+ tokenizer: Optional[Tokenizer] = None,
46
+ *,
47
+ caption_key: str = "txt;json", # WebDataset picks the first present key, so txt takes precedence over json
48
+ caption_json_key: str = "caption",
49
+ cache_dir: Optional[str] = None,
50
+ shuffle_buffer_size: Optional[int] = None,
51
+ shuffle_initial_size: Optional[int] = None,
52
+ ) -> torch.utils.data.IterableDataset:
53
+ if shuffle is True:
54
+ shardshuffle = 500
55
+ else:
56
+ shardshuffle = False
57
+
58
+ dataset = wds.WebDataset(
59
+ wds_path, shardshuffle=shardshuffle, nodesplitter=wds.split_by_node, cache_dir=cache_dir, empty_check=False
60
+ )
61
+ if shuffle is True:
62
+ if shuffle_buffer_size is None:
63
+ shuffle_buffer_size = birder_wds.WDS_SHUFFLE_SIZE
64
+ if shuffle_initial_size is None:
65
+ shuffle_initial_size = birder_wds.WDS_INITIAL_SIZE
66
+
67
+ logger.debug(f"Using buffer size of {shuffle_buffer_size} for shuffle with {shuffle_initial_size} initial size")
68
+ dataset = dataset.shuffle(shuffle_buffer_size, initial=shuffle_initial_size)
69
+
70
+ return_keys = ["jpeg;jpg;png;webp"]
71
+ return_keys = return_keys + [caption_key]
72
+ if samples_names is True:
73
+ return_keys = ["__url__", "__key__"] + return_keys
74
+
75
+ if isinstance(image_decoder, str):
76
+ decoder = birder_wds.get_wds_image_decoder(image_decoder, channels)
77
+ else:
78
+ decoder = image_decoder
79
+
80
+ dataset = dataset.with_length(dataset_size, silent=True).decode(decoder).to_tuple(*return_keys)
81
+
82
+ caption_decoder = partial(decode_caption, caption_json_key=caption_json_key)
83
+ if samples_names is True:
84
+ dataset = dataset.map(birder_wds.decode_sample_name)
85
+ dataset = dataset.map_tuple(birder_wds.identity, transform, caption_decoder)
86
+ else:
87
+ dataset = dataset.map_tuple(transform, caption_decoder)
88
+
89
+ if tokenizer is not None:
90
+ text_transform = partial(tokenize_caption, tokenizer=tokenizer)
91
+ if samples_names is True:
92
+ dataset = dataset.map_tuple(birder_wds.identity, birder_wds.identity, text_transform)
93
+ else:
94
+ dataset = dataset.map_tuple(birder_wds.identity, text_transform)
95
+
96
+ return dataset
97
+
98
+
99
+ def wds_size(wds_path: str, device: torch.device, select_suffix: str | tuple[str, ...] = ("txt", "json")) -> int:
100
+ return birder_wds.wds_size(wds_path, device, select_suffix=select_suffix)
101
+
102
+
103
+ def prepare_wds_args(
104
+ data_path: str, size: Optional[int], device: torch.device, select_suffix: str | tuple[str, ...] = ("txt", "json")
105
+ ) -> tuple[str, int]:
106
+ return birder_wds.prepare_wds_args(data_path, size, device, select_suffix=select_suffix)
@@ -1,6 +1,9 @@
1
1
  """
2
2
  CLIP loss, adapted from
3
3
  https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py
4
+
5
+ Paper "Learning Transferable Visual Models From Natural Language Supervision",
6
+ https://arxiv.org/abs/2103.00020
4
7
  """
5
8
 
6
9
  # Reference license: MIT
@@ -22,6 +25,14 @@ def gather_features(features: torch.Tensor) -> torch.Tensor:
22
25
 
23
26
 
24
27
  class CLIPLoss(torch.nn.Module):
28
+ """
29
+ CLIP symmetric contrastive loss
30
+
31
+ Implements the bidirectional InfoNCE objective from CLIP: image features
32
+ classify their matching text features, and text features classify their
33
+ matching image features, using the batch as negatives.
34
+ """
35
+
25
36
  def forward(
26
37
  self,
27
38
  image_features: torch.Tensor,
@@ -246,6 +246,10 @@ registry.register_model_config(
246
246
  },
247
247
  )
248
248
 
249
+
250
+ # Weights
251
+ ####################
252
+
249
253
  registry.register_weights(
250
254
  "openai_clip_vit_l14",
251
255
  {
@@ -261,3 +265,18 @@ registry.register_weights(
261
265
  "net": {"network": "openai_clip_vit_l14"},
262
266
  },
263
267
  )
268
+ registry.register_weights(
269
+ "pe_core_b16",
270
+ {
271
+ "description": "RoPEi ViT b16 image encoder pretrained by Meta FAIR using CLIP",
272
+ "resolution": (224, 224),
273
+ "context_length": 32,
274
+ "formats": {
275
+ "pt": {
276
+ "file_size": 1707.8,
277
+ "sha256": "11453d4a36fad6dbd802ec9fa35375ce0ae8b7b156a5ca45c0e87587df05290f",
278
+ }
279
+ },
280
+ "net": {"network": "pe_core_b16"},
281
+ },
282
+ )
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import logging
2
3
  from typing import Any
3
4
  from typing import Optional
4
5
 
@@ -7,6 +8,8 @@ from torch import nn
7
8
 
8
9
  from birder_clip.model_registry import Task
9
10
 
11
+ logger = logging.getLogger(__name__)
12
+
10
13
 
11
14
  class TextBaseNet(nn.Module):
12
15
  default_context_length = 77
@@ -38,4 +41,5 @@ class TextBaseNet(nn.Module):
38
41
  if new_context_length == self.context_length:
39
42
  return
40
43
 
44
+ logger.info(f"Adjusting model context length from {self.context_length} to {new_context_length}")
41
45
  self.context_length = new_context_length