birder-clip 0.0.2.dev4__tar.gz → 0.0.2.dev5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/PKG-INFO +3 -3
  2. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/fs_ops.py +100 -3
  3. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/lib.py +1 -1
  4. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/training_cli.py +68 -2
  5. birder_clip-0.0.2.dev5/birder_clip/common/training_utils.py +99 -0
  6. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/webdataset.py +2 -2
  7. birder_clip-0.0.2.dev5/birder_clip/inference/zero_shot.py +128 -0
  8. birder_clip-0.0.2.dev5/birder_clip/net/clip.py +610 -0
  9. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/text/base.py +1 -0
  10. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/text/transformer.py +2 -0
  11. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/scripts/train.py +63 -12
  12. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/scripts/zero_shot.py +249 -180
  13. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/__init__.py +2 -0
  14. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/base.py +3 -0
  15. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/hf.py +13 -0
  16. birder_clip-0.0.2.dev5/birder_clip/tokenizers/openvision.py +64 -0
  17. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/simple_tokenizer.py +1 -0
  18. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/__main__.py +13 -0
  19. birder_clip-0.0.2.dev5/birder_clip/tools/convert_model.py +268 -0
  20. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/download_tokenizer.py +5 -4
  21. birder_clip-0.0.2.dev5/birder_clip/tools/list_models.py +102 -0
  22. birder_clip-0.0.2.dev5/birder_clip/tools/model_info.py +145 -0
  23. birder_clip-0.0.2.dev5/birder_clip/tools/stats.py +210 -0
  24. birder_clip-0.0.2.dev5/birder_clip/version.py +1 -0
  25. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/PKG-INFO +3 -3
  26. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/SOURCES.txt +5 -0
  27. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/requires.txt +2 -2
  28. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/requirements/_requirements-dev.txt +1 -1
  29. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/requirements/requirements.txt +1 -1
  30. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_net.py +1 -1
  31. birder_clip-0.0.2.dev4/birder_clip/common/training_utils.py +0 -61
  32. birder_clip-0.0.2.dev4/birder_clip/inference/zero_shot.py +0 -54
  33. birder_clip-0.0.2.dev4/birder_clip/net/clip.py +0 -282
  34. birder_clip-0.0.2.dev4/birder_clip/version.py +0 -1
  35. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/LICENSE +0 -0
  36. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/README.md +0 -0
  37. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/__init__.py +0 -0
  38. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/__init__.py +0 -0
  39. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/conf/__init__.py +0 -0
  40. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/conf/settings.py +0 -0
  41. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/__init__.py +0 -0
  42. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/__init__.py +0 -0
  43. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/csv.py +0 -0
  44. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/fake.py +0 -0
  45. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/inference/__init__.py +0 -0
  46. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/inference/zero_shot_templates.py +0 -0
  47. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/loss/__init__.py +0 -0
  48. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/loss/contrastive.py +0 -0
  49. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/__init__.py +0 -0
  50. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/manifest.py +0 -0
  51. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/model_registry.py +0 -0
  52. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/__init__.py +0 -0
  53. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/base.py +0 -0
  54. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/text/__init__.py +0 -0
  55. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/py.typed +0 -0
  56. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/scripts/__init__.py +0 -0
  57. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  58. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/registry.py +0 -0
  59. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/__init__.py +0 -0
  60. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/show_iterator.py +0 -0
  61. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/dependency_links.txt +0 -0
  62. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/top_level.txt +0 -0
  63. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/pyproject.toml +0 -0
  64. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/setup.cfg +0 -0
  65. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_common.py +0 -0
  66. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_datasets.py +0 -0
  67. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_loss.py +0 -0
  68. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_model_registry.py +0 -0
  69. {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_tokenizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder_clip
3
- Version: 0.0.2.dev4
3
+ Version: 0.0.2.dev5
4
4
  Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
24
24
  Requires-Python: >=3.11
25
25
  Description-Content-Type: text/markdown
26
26
  License-File: LICENSE
27
- Requires-Dist: birder>=0.5.4
27
+ Requires-Dist: birder>=0.5.6
28
28
  Requires-Dist: ftfy>=6.3.1
29
29
  Requires-Dist: regex>=2025.7.29
30
30
  Requires-Dist: tqdm>=4.67.0
@@ -37,7 +37,7 @@ Provides-Extra: dev
37
37
  Requires-Dist: bandit~=1.9.4; extra == "dev"
38
38
  Requires-Dist: black~=26.5.0; extra == "dev"
39
39
  Requires-Dist: build~=1.5.0; extra == "dev"
40
- Requires-Dist: bumpver~=2025.1131; extra == "dev"
40
+ Requires-Dist: bumpver~=2026.1132; extra == "dev"
41
41
  Requires-Dist: coverage~=7.14.1; extra == "dev"
42
42
  Requires-Dist: debugpy; extra == "dev"
43
43
  Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
@@ -67,7 +67,7 @@ def model_path(
67
67
  network_name: str,
68
68
  *,
69
69
  epoch: Optional[int | str] = None,
70
- file_format: FileFormatType = "pt",
70
+ st: bool = False,
71
71
  states: bool = False,
72
72
  ) -> Path:
73
73
  if epoch is not None:
@@ -77,8 +77,10 @@ def model_path(
77
77
 
78
78
  if states is True:
79
79
  file_name = f"{file_name}_states.pt"
80
+ elif st is True:
81
+ file_name = f"{file_name}.safetensors"
80
82
  else:
81
- file_name = f"{file_name}.{file_format}"
83
+ file_name = f"{file_name}.pt"
82
84
 
83
85
  return settings.MODELS_DIR.joinpath(file_name)
84
86
 
@@ -109,6 +111,30 @@ def _checkpoint_states(
109
111
  torch.save(kwargs, states_path)
110
112
 
111
113
 
114
+ def _checkpoint_states_from_state_dicts(
115
+ states_path: Path,
116
+ optimizer_state: Optional[dict[str, Any]],
117
+ scheduler_state: Optional[dict[str, Any]],
118
+ scaler_state: Optional[dict[str, Any]],
119
+ model_base_state: Optional[dict[str, Any]],
120
+ **extra_states: Optional[dict[str, Any]],
121
+ ) -> None:
122
+ if optimizer_state is None or scheduler_state is None:
123
+ return
124
+
125
+ logger.info(f"Saving checkpoint states {states_path}...")
126
+ torch.save(
127
+ {
128
+ "optimizer_state": optimizer_state,
129
+ "scheduler_state": scheduler_state,
130
+ "scaler_state": scaler_state,
131
+ "model_base_state": model_base_state,
132
+ **extra_states,
133
+ },
134
+ states_path,
135
+ )
136
+
137
+
112
138
  class TrainingStates(NamedTuple):
113
139
  optimizer_state: Optional[dict[str, Any]]
114
140
  scheduler_state: Optional[dict[str, Any]]
@@ -182,6 +208,50 @@ def checkpoint_model(
182
208
  _checkpoint_states(states_path, optimizer, scheduler, scaler, model_base, **extra_states)
183
209
 
184
210
 
211
+ def checkpoint_model_from_state_dicts(
212
+ network_name: str,
213
+ epoch: int,
214
+ model_state: dict[str, Any],
215
+ task: Any,
216
+ signature: SignatureType,
217
+ rgb_stats: RGBType,
218
+ optimizer_state: Optional[dict[str, Any]],
219
+ scheduler_state: Optional[dict[str, Any]],
220
+ scaler_state: Optional[dict[str, Any]],
221
+ model_base_state: Optional[dict[str, Any]],
222
+ *,
223
+ external_config: Optional[dict[str, Any]] = None,
224
+ **extra_states: Optional[dict[str, Any]],
225
+ ) -> None:
226
+ kwargs = {}
227
+ if external_config is not None:
228
+ kwargs["config"] = external_config
229
+
230
+ path = model_path(network_name, epoch=epoch)
231
+ states_path = model_path(network_name, epoch=epoch, states=True)
232
+ logger.info(f"Saving model checkpoint {path}...")
233
+ torch.save(
234
+ {
235
+ "state": model_state,
236
+ "birder_clip_version": __version__,
237
+ "task": task,
238
+ "signature": signature,
239
+ "rgb_stats": rgb_stats,
240
+ **kwargs,
241
+ },
242
+ path,
243
+ )
244
+
245
+ _checkpoint_states_from_state_dicts(
246
+ states_path,
247
+ optimizer_state,
248
+ scheduler_state,
249
+ scaler_state,
250
+ model_base_state,
251
+ **extra_states,
252
+ )
253
+
254
+
185
255
  def clean_checkpoints(network_name: str, keep_last: int) -> None:
186
256
  epoch = "*[0-9]"
187
257
  models_glob = str(model_path(network_name, epoch=epoch))
@@ -314,7 +384,7 @@ def load_model(
314
384
  embed_dim=embed_dim,
315
385
  tokenizer=tokenizer,
316
386
  )
317
- path = model_path(_network_name, epoch=epoch, file_format="safetensors" if st is True else "pt")
387
+ path = model_path(_network_name, epoch=epoch, st=st)
318
388
 
319
389
  logger.info(f"Loading model from {path} on device {device}...")
320
390
 
@@ -589,6 +659,33 @@ def load_pretrained_tokenizer(weights: str, *, download: bool = True, **kwargs:
589
659
  return get_tokenizer(tokenizer_name, **tokenizer_kwargs)
590
660
 
591
661
 
662
+ def save_st(
663
+ net: torch.nn.Module,
664
+ dst: str,
665
+ task: str,
666
+ signature: SignatureType,
667
+ rgb_stats: RGBType,
668
+ *,
669
+ external_config: Optional[dict[str, Any]] = None,
670
+ ) -> None:
671
+ assert _HAS_SAFETENSORS, "'pip install safetensors' to use .safetensors"
672
+ kwargs = {}
673
+ if external_config is not None:
674
+ kwargs["config"] = json.dumps(external_config)
675
+
676
+ safetensors.torch.save_model(
677
+ net,
678
+ str(dst),
679
+ {
680
+ "birder_clip_version": __version__,
681
+ "task": task,
682
+ "signature": json.dumps(signature),
683
+ "rgb_stats": json.dumps(rgb_stats),
684
+ **kwargs,
685
+ },
686
+ )
687
+
688
+
592
689
  def download_model_by_weights(
593
690
  weights: str, *, dst: Optional[str | Path] = None, file_format: FileFormatType = "pt", progress_bar: bool = True
594
691
  ) -> None:
@@ -86,7 +86,7 @@ def get_image_text_model_config(
86
86
 
87
87
  if config is not None:
88
88
  for key, value in config.items():
89
- if key in {"image", "text"} and isinstance(value, dict) is True:
89
+ if key in {"image", "text"} and isinstance(value, dict):
90
90
  model_config[key] = {**model_config.get(key, {}), **value}
91
91
  else:
92
92
  model_config[key] = value
@@ -49,7 +49,9 @@ def add_loss_args(parser: argparse.ArgumentParser) -> None:
49
49
  def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: int = 32) -> None:
50
50
  group = parser.add_argument_group("Optimization parameters")
51
51
  group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
52
- group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
52
+ group.add_argument(
53
+ "--opt", type=str, choices=list(get_args(OptimizerType)), default="adamw", help="optimizer to use"
54
+ )
53
55
  group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
54
56
  group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
55
57
  group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
@@ -318,7 +320,13 @@ def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
318
320
  action="store_true",
319
321
  help="keep dataloader worker processes alive between epochs",
320
322
  )
321
- group.add_argument("--drop-last", default=False, action="store_true", help="drop the last incomplete batch")
323
+ group.add_argument(
324
+ "--no-drop-last",
325
+ dest="drop_last",
326
+ default=True,
327
+ action="store_false",
328
+ help="do not drop the last incomplete batch",
329
+ )
322
330
 
323
331
 
324
332
  def add_precision_args(parser: argparse.ArgumentParser) -> None:
@@ -410,6 +418,44 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
410
418
  group.add_argument("--local-rank", type=int, metavar="N", help="local rank")
411
419
  group.add_argument("--dist-url", type=str, default="env://", help="URL used to initialize distributed training")
412
420
  group.add_argument("--dist-backend", type=str, default="nccl", help="distributed backend")
421
+ group.add_argument(
422
+ "--distributed-mode", type=str, choices=["ddp", "fsdp"], default="ddp", help="distributed training mode"
423
+ )
424
+ group.add_argument(
425
+ "--fsdp-sharding-strategy",
426
+ type=str,
427
+ choices=["shard-grad-op", "full-shard"],
428
+ default="shard-grad-op",
429
+ help="FSDP sharding strategy",
430
+ )
431
+ group.add_argument(
432
+ "--fsdp-param-dtype",
433
+ type=str,
434
+ choices=["float32", "float16", "bfloat16"],
435
+ help="FSDP mixed precision parameter dtype",
436
+ )
437
+ group.add_argument(
438
+ "--fsdp-reduce-dtype",
439
+ type=str,
440
+ choices=["float32", "float16", "bfloat16"],
441
+ help="FSDP mixed precision gradient reduction dtype",
442
+ )
443
+ group.add_argument(
444
+ "--fsdp-wrap-policy",
445
+ type=str,
446
+ choices=["block-group-regex", "min-num-params"],
447
+ default="block-group-regex",
448
+ help="FSDP module wrapping policy",
449
+ )
450
+ group.add_argument(
451
+ "--fsdp-wrap-min-num-params",
452
+ type=float,
453
+ metavar="M",
454
+ help="minimum module parameter count in millions for wrapping when using --fsdp-wrap-policy min-num-params",
455
+ )
456
+ group.add_argument(
457
+ "--fsdp-offload-policy", type=str, choices=["none", "cpu"], default="none", help="FSDP parameter offload policy"
458
+ )
413
459
  group.add_argument(
414
460
  "--find-unused-parameters",
415
461
  default=False,
@@ -561,3 +607,23 @@ def common_args_validation(args: argparse.Namespace) -> None:
561
607
  raise cli.ValidationError("--grad-accum-steps must be >= 1")
562
608
  if args.model_ema_steps < 1:
563
609
  raise cli.ValidationError("--model-ema-steps must be >= 1")
610
+
611
+ if args.distributed_mode == "fsdp":
612
+ if args.sync_bn is True:
613
+ raise cli.ValidationError("--sync-bn cannot be used with --distributed-mode fsdp")
614
+ if args.find_unused_parameters is True:
615
+ raise cli.ValidationError("--find-unused-parameters cannot be used with --distributed-mode fsdp")
616
+ if args.compile_opt is True:
617
+ raise cli.ValidationError("--compile-opt cannot be used with --distributed-mode fsdp")
618
+ if args.compile_fullgraph is True:
619
+ raise cli.ValidationError("--compile-fullgraph cannot be used with --distributed-mode fsdp")
620
+ if args.cpu is True:
621
+ raise cli.ValidationError("--cpu cannot be used with --distributed-mode fsdp")
622
+ if args.model_ema is True:
623
+ raise cli.ValidationError("--model-ema cannot be used with --distributed-mode fsdp")
624
+ if args.fsdp_wrap_policy == "min-num-params" and args.fsdp_wrap_min_num_params is None:
625
+ raise cli.ValidationError(
626
+ "--fsdp-wrap-min-num-params is required when --fsdp-wrap-policy is min-num-params"
627
+ )
628
+ if args.fsdp_wrap_min_num_params is not None and args.fsdp_wrap_min_num_params <= 0:
629
+ raise cli.ValidationError("--fsdp-wrap-min-num-params must be > 0")
@@ -0,0 +1,99 @@
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from birder.common import fsdp_utils
10
+ from birder.common import training_utils as birder_training_utils
11
+
12
+ from birder_clip.common import fs_ops
13
+ from birder_clip.conf import settings
14
+
15
+
16
+ def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
17
+ file_handler = logging.FileHandler(log_file_path)
18
+ formatter = logging.Formatter(
19
+ fmt="{message}",
20
+ style="{",
21
+ )
22
+ file_handler.setFormatter(formatter)
23
+ file_handler.setLevel(settings.LOG_LEVEL)
24
+
25
+ logging.getLogger("birder").addHandler(file_handler)
26
+ logging.getLogger("birder_clip").addHandler(file_handler)
27
+
28
+ return file_handler
29
+
30
+
31
+ def save_training_checkpoint(
32
+ args: argparse.Namespace,
33
+ network_name: str,
34
+ epoch: int,
35
+ net: torch.nn.Module,
36
+ signature: Any,
37
+ rgb_stats: Any,
38
+ optimizer: Optional[torch.optim.Optimizer],
39
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
40
+ scaler: Optional[torch.amp.grad_scaler.GradScaler],
41
+ model_base: Optional[torch.nn.Module],
42
+ *,
43
+ fsdp_mode: bool = False,
44
+ fsdp_model_state: Optional[dict[str, Any]] = None,
45
+ external_config: Optional[dict[str, Any]] = None,
46
+ **extra_states: Optional[dict[str, Any]],
47
+ ) -> None:
48
+ if fsdp_mode is True:
49
+ if fsdp_model_state is not None:
50
+ model_state = fsdp_model_state
51
+ else:
52
+ model_state = fsdp_utils.gather_full_model_state_dict(net)
53
+
54
+ optimizer_state = None
55
+ scheduler_state = None
56
+ scaler_state = None
57
+ model_base_state = None
58
+ if optimizer is not None and scheduler is not None:
59
+ optimizer_state = fsdp_utils.gather_full_optimizer_state_dict(net, optimizer)
60
+ scheduler_state = scheduler.state_dict()
61
+ if scaler is not None:
62
+ scaler_state = scaler.state_dict()
63
+ if model_base is not None:
64
+ model_base_state = model_base.state_dict()
65
+
66
+ if birder_training_utils.is_global_primary(args) is True:
67
+ fs_ops.checkpoint_model_from_state_dicts(
68
+ network_name,
69
+ epoch,
70
+ model_state,
71
+ net.task,
72
+ signature,
73
+ rgb_stats,
74
+ optimizer_state,
75
+ scheduler_state,
76
+ scaler_state,
77
+ model_base_state,
78
+ external_config=external_config,
79
+ **extra_states,
80
+ )
81
+
82
+ if birder_training_utils.is_dist_available_and_initialized() is True:
83
+ dist.barrier()
84
+
85
+ else:
86
+ if birder_training_utils.is_global_primary(args) is True:
87
+ fs_ops.checkpoint_model(
88
+ network_name,
89
+ epoch,
90
+ net,
91
+ signature,
92
+ rgb_stats,
93
+ optimizer,
94
+ scheduler,
95
+ scaler,
96
+ model_base,
97
+ external_config=external_config,
98
+ **extra_states,
99
+ )
@@ -24,10 +24,10 @@ def decode_caption(caption: Any, caption_json_key: str = "caption") -> str:
24
24
  if isinstance(caption, bytes):
25
25
  caption = caption.decode("utf-8")
26
26
 
27
- if isinstance(caption, str) is False:
27
+ if not isinstance(caption, str):
28
28
  raise TypeError(f"WebDataset caption must be a string, got {type(caption).__name__}")
29
29
 
30
- return caption # type: ignore[no-any-return]
30
+ return caption
31
31
 
32
32
 
33
33
  def tokenize_caption(caption: str, tokenizer: Tokenizer) -> torch.Tensor:
@@ -0,0 +1,128 @@
1
+ """
2
+ Zero-shot text embedding helpers
3
+
4
+ Zero-shot classification compares image features against one text feature per
5
+ candidate class. When multiple prompt templates are used, this module follows
6
+ the OpenCLIP/OpenAI CLIP convention: encode every class/template prompt,
7
+ normalize prompt embeddings, average them per class and normalize the averaged
8
+ class embedding again.
9
+ """
10
+
11
+ import sys
12
+ from collections.abc import Callable
13
+ from collections.abc import Iterator
14
+ from collections.abc import Sequence
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import numpy.typing as npt
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import DataLoader
22
+ from tqdm import tqdm
23
+
24
+ from birder_clip.net.base import BaseNet
25
+ from birder_clip.tokenizers.base import Tokenizer
26
+
27
+
28
+ def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
29
+ return [template.format(class_name) for class_name in class_names for template in templates]
30
+
31
+
32
+ def build_class_text_embeddings(
33
+ net: BaseNet,
34
+ tokenizer: Tokenizer,
35
+ class_names: Sequence[str],
36
+ templates: Sequence[str],
37
+ *,
38
+ device: torch.device,
39
+ context_length: Optional[int] = None,
40
+ batch_size: Optional[int] = None,
41
+ amp: bool = False,
42
+ amp_dtype: Optional[torch.dtype] = None,
43
+ ) -> torch.Tensor:
44
+ num_templates = len(templates)
45
+ if batch_size is None:
46
+ batch_size = len(class_names)
47
+
48
+ class_text_embeddings = []
49
+ with torch.inference_mode():
50
+ for start in range(0, len(class_names), batch_size):
51
+ batch_class_names = class_names[start : start + batch_size]
52
+ prompts = render_prompts(batch_class_names, templates)
53
+ tokens = tokenizer(prompts, context_length=context_length).to(device)
54
+ with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
55
+ class_embeddings = net.encode_text(tokens, normalize=True)
56
+
57
+ class_embeddings = class_embeddings.reshape(len(batch_class_names), num_templates, -1).mean(dim=1)
58
+ class_embeddings = F.normalize(class_embeddings, dim=-1)
59
+ class_text_embeddings.append(class_embeddings)
60
+
61
+ return torch.concat(class_text_embeddings, dim=0)
62
+
63
+
64
+ DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]]
65
+
66
+
67
+ def infer_dataloader_iter(
68
+ device: torch.device,
69
+ net: BaseNet | torch.ScriptModule,
70
+ dataloader: DataLoader,
71
+ text_embeddings: torch.Tensor,
72
+ return_logits: bool = False,
73
+ model_dtype: torch.dtype = torch.float32,
74
+ amp: bool = False,
75
+ amp_dtype: Optional[torch.dtype] = None,
76
+ num_samples: Optional[int] = None,
77
+ batch_callback: Optional[Callable[[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]], None]] = None,
78
+ chunk_size: Optional[float] = None,
79
+ ) -> Iterator[DataloaderInferenceResult]:
80
+ if chunk_size is None:
81
+ chunk_size = float("inf")
82
+
83
+ net.to(device, dtype=model_dtype)
84
+ out_list: list[npt.NDArray[np.float32]] = []
85
+ labels_list: list[npt.NDArray[np.int64]] = []
86
+ sample_paths: list[str] = []
87
+ sample_count = 0
88
+ with tqdm(total=num_samples, initial=0, unit="images", unit_scale=True, leave=False) as progress:
89
+ for file_paths, inputs, targets in dataloader:
90
+ batch_size = inputs.size(0)
91
+
92
+ # Inference
93
+ inputs = inputs.to(device, dtype=model_dtype)
94
+ with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
95
+ image_embeddings = net.encode_image(inputs, normalize=True)
96
+ logits = net.forward_logits(image_embeddings, text_embeddings)
97
+ if return_logits is True:
98
+ out = logits.cpu().float().numpy()
99
+ else:
100
+ out = F.softmax(logits, dim=-1).cpu().float().numpy()
101
+
102
+ out_list.append(out)
103
+
104
+ # Set labels and sample list
105
+ batch_labels = targets.cpu().numpy()
106
+ labels_list.append(batch_labels)
107
+ sample_paths.extend(file_paths)
108
+
109
+ if batch_callback is not None:
110
+ batch_callback(file_paths, out, batch_labels)
111
+
112
+ # Update progress bar
113
+ progress.update(n=batch_size)
114
+
115
+ # Yield results when we reach chunk_size
116
+ sample_count += batch_size
117
+ if sample_count >= chunk_size:
118
+ with tqdm.external_write_mode(file=sys.stderr):
119
+ yield (sample_paths, np.concatenate(out_list, axis=0), np.concatenate(labels_list))
120
+
121
+ # Reset for next chunk
122
+ out_list = []
123
+ labels_list = []
124
+ sample_paths = []
125
+ sample_count = 0
126
+
127
+ if len(out_list) > 0:
128
+ yield (sample_paths, np.concatenate(out_list, axis=0), np.concatenate(labels_list))