birder-clip 0.0.2.dev6__tar.gz → 0.0.2.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/PKG-INFO +4 -4
  2. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/README.md +1 -1
  3. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/fs_ops.py +31 -11
  4. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/lib.py +25 -5
  5. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/training_cli.py +16 -2
  6. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/training_utils.py +14 -0
  7. birder_clip-0.0.2.dev8/birder_clip/inference/data_parallel.py +118 -0
  8. birder_clip-0.0.2.dev8/birder_clip/inference/image_embeddings.py +63 -0
  9. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot.py +40 -33
  10. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/loss/__init__.py +2 -0
  11. birder_clip-0.0.2.dev8/birder_clip/loss/caption.py +72 -0
  12. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/loss/coca.py +1 -3
  13. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/manifest.py +22 -5
  14. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/model_registry.py +5 -0
  15. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/__init__.py +2 -0
  16. birder_clip-0.0.2.dev8/birder_clip/net/base.py +201 -0
  17. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/clip.py +50 -1
  18. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/coca.py +133 -3
  19. birder_clip-0.0.2.dev8/birder_clip/net/openvision_v2.py +219 -0
  20. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/__init__.py +3 -1
  21. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/base.py +40 -19
  22. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/conditioned_decoder.py +4 -3
  23. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/encoder.py +22 -3
  24. birder_clip-0.0.2.dev8/birder_clip/net/text/visual_causal_decoder.py +195 -0
  25. birder_clip-0.0.2.dev8/birder_clip/scripts/__main__.py +25 -0
  26. birder_clip-0.0.2.dev8/birder_clip/scripts/embed_images.py +447 -0
  27. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/scripts/train.py +58 -9
  28. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/scripts/zero_shot.py +42 -7
  29. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/openvision.py +23 -0
  30. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/convert_model.py +186 -9
  31. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/model_info.py +22 -2
  32. birder_clip-0.0.2.dev8/birder_clip/version.py +1 -0
  33. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/PKG-INFO +4 -4
  34. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/SOURCES.txt +8 -1
  35. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/requires.txt +2 -2
  36. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/requirements/_requirements-dev.txt +1 -1
  37. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/requirements/requirements.txt +1 -1
  38. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_common.py +25 -0
  39. birder_clip-0.0.2.dev8/tests/test_inference.py +143 -0
  40. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_loss.py +60 -0
  41. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_model_registry.py +25 -0
  42. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_net.py +342 -0
  43. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_net_text.py +44 -0
  44. birder_clip-0.0.2.dev6/birder_clip/net/base.py +0 -77
  45. birder_clip-0.0.2.dev6/birder_clip/net/text/prefix_decoder.py +0 -1
  46. birder_clip-0.0.2.dev6/birder_clip/version.py +0 -1
  47. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/LICENSE +0 -0
  48. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/__init__.py +0 -0
  49. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/__init__.py +0 -0
  50. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/conf/__init__.py +0 -0
  51. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/conf/settings.py +0 -0
  52. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/__init__.py +0 -0
  53. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/__init__.py +0 -0
  54. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/csv.py +0 -0
  55. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/fake.py +0 -0
  56. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/webdataset.py +0 -0
  57. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/inference/__init__.py +0 -0
  58. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot_templates.py +0 -0
  59. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/loss/contrastive.py +0 -0
  60. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/__init__.py +0 -0
  61. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/hf.py +0 -0
  62. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/py.typed +0 -0
  63. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/scripts/__init__.py +0 -0
  64. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/__init__.py +0 -0
  65. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/base.py +0 -0
  66. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  67. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/hf.py +0 -0
  68. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/registry.py +0 -0
  69. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
  70. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/__init__.py +0 -0
  71. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/__main__.py +0 -0
  72. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/download_tokenizer.py +0 -0
  73. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/list_models.py +0 -0
  74. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/show_iterator.py +0 -0
  75. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/stats.py +0 -0
  76. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/dependency_links.txt +0 -0
  77. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/top_level.txt +0 -0
  78. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/pyproject.toml +0 -0
  79. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/setup.cfg +0 -0
  80. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_datasets.py +0 -0
  81. {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_tokenizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder_clip
3
- Version: 0.0.2.dev6
3
+ Version: 0.0.2.dev8
4
4
  Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
24
24
  Requires-Python: >=3.11
25
25
  Description-Content-Type: text/markdown
26
26
  License-File: LICENSE
27
- Requires-Dist: birder>=0.6.0
27
+ Requires-Dist: birder>=0.6.2
28
28
  Requires-Dist: ftfy>=6.3.1
29
29
  Requires-Dist: regex>=2025.7.29
30
30
  Requires-Dist: tqdm>=4.67.0
@@ -38,7 +38,7 @@ Requires-Dist: bandit~=1.9.4; extra == "dev"
38
38
  Requires-Dist: black~=26.5.0; extra == "dev"
39
39
  Requires-Dist: build~=1.5.0; extra == "dev"
40
40
  Requires-Dist: bumpver~=2026.1132; extra == "dev"
41
- Requires-Dist: coverage~=7.14.1; extra == "dev"
41
+ Requires-Dist: coverage~=7.14.3; extra == "dev"
42
42
  Requires-Dist: debugpy; extra == "dev"
43
43
  Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
44
44
  Requires-Dist: flake8~=7.3.0; extra == "dev"
@@ -85,7 +85,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
85
85
  1. Ensure your environment meets the minimum requirements:
86
86
  - Python 3.11 or newer
87
87
  - PyTorch 2.10 or newer (installed for your hardware/driver stack)
88
- - Birder 0.6.0 or newer
88
+ - Birder 0.6.2 or newer
89
89
 
90
90
  1. Install the latest Birder CLIP version:
91
91
 
@@ -23,7 +23,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
23
23
  1. Ensure your environment meets the minimum requirements:
24
24
  - Python 3.11 or newer
25
25
  - PyTorch 2.10 or newer (installed for your hardware/driver stack)
26
- - Birder 0.6.0 or newer
26
+ - Birder 0.6.2 or newer
27
27
 
28
28
  1. Install the latest Birder CLIP version:
29
29
 
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from typing import Any
7
7
  from typing import NamedTuple
8
8
  from typing import Optional
9
+ from typing import TypeAlias
9
10
 
10
11
  import torch
11
12
  from birder.common import cli
@@ -16,8 +17,10 @@ from birder.data.transforms.classification import inference_preset
16
17
  from birder_clip.common import lib
17
18
  from birder_clip.model_registry import Task
18
19
  from birder_clip.model_registry import registry
19
- from birder_clip.model_registry.manifest import EncoderMetadataType
20
20
  from birder_clip.model_registry.manifest import FileFormatType
21
+ from birder_clip.model_registry.manifest import ImageEncoderMetadataType
22
+ from birder_clip.model_registry.manifest import TextDecoderMetadataType
23
+ from birder_clip.model_registry.manifest import TextEncoderMetadataType
21
24
  from birder_clip.net.base import BaseNet
22
25
  from birder_clip.net.base import SignatureType
23
26
  from birder_clip.tokenizers import Tokenizer
@@ -36,6 +39,8 @@ except ImportError:
36
39
 
37
40
  logger = logging.getLogger(__name__)
38
41
 
42
+ ComponentMetadataType: TypeAlias = ImageEncoderMetadataType | TextEncoderMetadataType | TextDecoderMetadataType
43
+
39
44
 
40
45
  class ModelInfo(NamedTuple):
41
46
  signature: SignatureType
@@ -51,16 +56,18 @@ def write_config(network_name: str, net: BaseNet, signature: SignatureType, rgb_
51
56
  json.dump(model_config, handle, indent=2)
52
57
 
53
58
 
54
- def _split_encoder_metadata(encoder: Optional[EncoderMetadataType]) -> tuple[Optional[str], Optional[dict[str, Any]]]:
55
- if encoder is None:
59
+ def _split_component_metadata(
60
+ component: Optional[ComponentMetadataType],
61
+ ) -> tuple[Optional[str], Optional[dict[str, Any]]]:
62
+ if component is None:
56
63
  return (None, None)
57
- if isinstance(encoder, str):
58
- return (encoder, None)
64
+ if isinstance(component, str):
65
+ return (component, None)
59
66
 
60
- if "network" not in encoder:
61
- raise ValueError("Encoder metadata must include a 'network' field")
67
+ if "network" not in component:
68
+ raise ValueError("Component metadata must include a 'network' field")
62
69
 
63
- return (None, encoder) # type: ignore[return-value]
70
+ return (None, component) # type: ignore[return-value]
64
71
 
65
72
 
66
73
  def model_path(
@@ -286,10 +293,12 @@ def load_checkpoint(
286
293
  tag: Optional[str] = None,
287
294
  image_encoder: Optional[str] = None,
288
295
  text_encoder: Optional[str] = None,
296
+ text_decoder: Optional[str] = None,
289
297
  embed_dim: Optional[int] = None,
290
298
  tokenizer: Optional[str] = None,
291
299
  image_encoder_config: Optional[dict[str, Any]] = None,
292
300
  text_encoder_config: Optional[dict[str, Any]] = None,
301
+ text_decoder_config: Optional[dict[str, Any]] = None,
293
302
  epoch: Optional[int] = None,
294
303
  new_size: Optional[tuple[int, int]] = None,
295
304
  new_context_length: Optional[int] = None,
@@ -300,6 +309,7 @@ def load_checkpoint(
300
309
  tag=tag,
301
310
  image_encoder=image_encoder,
302
311
  text_encoder=text_encoder,
312
+ text_decoder=text_decoder,
303
313
  embed_dim=embed_dim,
304
314
  tokenizer=tokenizer,
305
315
  )
@@ -329,10 +339,12 @@ def load_checkpoint(
329
339
  checkpoint_config,
330
340
  image_encoder=image_encoder,
331
341
  text_encoder=text_encoder,
342
+ text_decoder=text_decoder,
332
343
  embed_dim=embed_dim,
333
344
  tokenizer=tokenizer,
334
345
  image_encoder_config=image_encoder_config,
335
346
  text_encoder_config=text_encoder_config,
347
+ text_decoder_config=text_decoder_config,
336
348
  input_channels=input_channels,
337
349
  image_size=size,
338
350
  context_length=context_length,
@@ -364,10 +376,12 @@ def load_model(
364
376
  tag: Optional[str] = None,
365
377
  image_encoder: Optional[str] = None,
366
378
  text_encoder: Optional[str] = None,
379
+ text_decoder: Optional[str] = None,
367
380
  embed_dim: Optional[int] = None,
368
381
  tokenizer: Optional[str] = None,
369
382
  image_encoder_config: Optional[dict[str, Any]] = None,
370
383
  text_encoder_config: Optional[dict[str, Any]] = None,
384
+ text_decoder_config: Optional[dict[str, Any]] = None,
371
385
  epoch: Optional[int] = None,
372
386
  new_size: Optional[tuple[int, int]] = None,
373
387
  new_context_length: Optional[int] = None,
@@ -381,6 +395,7 @@ def load_model(
381
395
  tag=tag,
382
396
  image_encoder=image_encoder,
383
397
  text_encoder=text_encoder,
398
+ text_decoder=text_decoder,
384
399
  embed_dim=embed_dim,
385
400
  tokenizer=tokenizer,
386
401
  )
@@ -420,10 +435,12 @@ def load_model(
420
435
  checkpoint_config,
421
436
  image_encoder=image_encoder,
422
437
  text_encoder=text_encoder,
438
+ text_decoder=text_decoder,
423
439
  embed_dim=embed_dim,
424
440
  tokenizer=tokenizer,
425
441
  image_encoder_config=image_encoder_config,
426
442
  text_encoder_config=text_encoder_config,
443
+ text_decoder_config=text_decoder_config,
427
444
  input_channels=input_channels,
428
445
  image_size=size,
429
446
  context_length=context_length,
@@ -525,15 +542,17 @@ def load_pretrained_model(
525
542
  if model_metadata["task"] != Task.IMAGE_TEXT:
526
543
  raise ValueError(f"Unknown model type: {model_metadata['task']}")
527
544
 
528
- image_encoder, image_config = _split_encoder_metadata(model_metadata["net"].get("image_encoder", None))
529
- text_encoder, text_config = _split_encoder_metadata(model_metadata["net"].get("text_encoder", None))
545
+ image_encoder, image_config = _split_component_metadata(model_metadata["net"].get("image_encoder", None))
546
+ text_encoder, text_config = _split_component_metadata(model_metadata["net"].get("text_encoder", None))
547
+ text_decoder, decoder_config = _split_component_metadata(model_metadata["net"].get("text_decoder", None))
530
548
 
531
549
  pretrained_config: dict[str, Any] = {}
532
550
  if image_config is not None:
533
551
  pretrained_config["image"] = image_config
534
-
535
552
  if text_config is not None:
536
553
  pretrained_config["text"] = text_config
554
+ if decoder_config is not None:
555
+ pretrained_config["decoder"] = decoder_config
537
556
 
538
557
  if custom_config is not None:
539
558
  pretrained_config.update(custom_config)
@@ -550,6 +569,7 @@ def load_pretrained_model(
550
569
  tag=model_metadata["net"].get("tag", None),
551
570
  image_encoder=image_encoder,
552
571
  text_encoder=text_encoder,
572
+ text_decoder=text_decoder,
553
573
  embed_dim=model_metadata["net"].get("embed_dim", None),
554
574
  tokenizer=model_metadata["net"].get("tokenizer", None),
555
575
  inference=inference,
@@ -3,14 +3,17 @@ from typing import Any
3
3
  from typing import Optional
4
4
 
5
5
  from birder.data.transforms.classification import RGBType
6
+ from birder.version import __version__ as birder_version
6
7
 
7
8
  from birder_clip.conf import settings
8
9
  from birder_clip.model_registry import registry
9
10
  from birder_clip.net.base import BaseNet
10
11
  from birder_clip.net.base import SignatureType
11
- from birder_clip.version import __version__
12
+ from birder_clip.version import __version__ as birder_clip_version
12
13
 
13
- MODEL_CONFIG_RESERVED_KEYS = frozenset({"image", "text", "tokenizer", "embed_dim", "embed-dim"})
14
+ MODEL_CONFIG_RESERVED_KEYS = frozenset(
15
+ {"image", "text", "decoder", "tokenizer", "embed_dim", "embed-dim", "keep_ratio"}
16
+ )
14
17
 
15
18
 
16
19
  def get_size_from_signature(signature: SignatureType) -> tuple[int, int]:
@@ -37,6 +40,7 @@ def get_image_text_network_name(
37
40
  tag: Optional[str] = None,
38
41
  image_encoder: Optional[str] = None,
39
42
  text_encoder: Optional[str] = None,
43
+ text_decoder: Optional[str] = None,
40
44
  embed_dim: Optional[int] = None,
41
45
  tokenizer: Optional[str] = None,
42
46
  ) -> str:
@@ -45,6 +49,8 @@ def get_image_text_network_name(
45
49
  parts.append(image_encoder)
46
50
  if text_encoder is not None and text_encoder != "transformer_encoder":
47
51
  parts.append(text_encoder)
52
+ if text_decoder is not None and text_decoder != "conditioned_decoder":
53
+ parts.append(text_decoder)
48
54
 
49
55
  if registry.exists(network) is True:
50
56
  default_tokenizer = registry.get_default_tokenizer(network)
@@ -71,10 +77,12 @@ def get_image_text_model_config(
71
77
  *,
72
78
  image_encoder: Optional[str] = None,
73
79
  text_encoder: Optional[str] = None,
80
+ text_decoder: Optional[str] = None,
74
81
  embed_dim: Optional[int] = None,
75
82
  tokenizer: Optional[str] = None,
76
83
  image_encoder_config: Optional[dict[str, Any]] = None,
77
84
  text_encoder_config: Optional[dict[str, Any]] = None,
85
+ text_decoder_config: Optional[dict[str, Any]] = None,
78
86
  input_channels: Optional[int] = None,
79
87
  image_size: Optional[tuple[int, int]] = None,
80
88
  context_length: Optional[int] = None,
@@ -86,7 +94,7 @@ def get_image_text_model_config(
86
94
 
87
95
  if config is not None:
88
96
  for key, value in config.items():
89
- if key in {"image", "text"} and isinstance(value, dict):
97
+ if key in {"image", "text", "decoder"} and isinstance(value, dict):
90
98
  model_config[key] = {**model_config.get(key, {}), **value}
91
99
  else:
92
100
  model_config[key] = value
@@ -111,7 +119,7 @@ def get_image_text_model_config(
111
119
 
112
120
  model_config["image"] = image_config
113
121
 
114
- if text_encoder is not None or text_encoder_config is not None or context_length is not None:
122
+ if text_encoder is not None or text_encoder_config is not None or "text" in model_config:
115
123
  text_config = model_config.get("text", {}).copy()
116
124
  if text_encoder is not None:
117
125
  # String encoder metadata replaces only the encoder name.
@@ -124,6 +132,17 @@ def get_image_text_model_config(
124
132
 
125
133
  model_config["text"] = text_config
126
134
 
135
+ if text_decoder is not None or text_decoder_config is not None or "decoder" in model_config:
136
+ decoder_config = model_config.get("decoder", {}).copy()
137
+ if text_decoder is not None:
138
+ decoder_config["network"] = text_decoder
139
+ if text_decoder_config is not None:
140
+ decoder_config["config"] = {**decoder_config.get("config", {}), **text_decoder_config}
141
+ if context_length is not None:
142
+ decoder_config["context_length"] = context_length
143
+
144
+ model_config["decoder"] = decoder_config
145
+
127
146
  if embed_dim is not None:
128
147
  model_config["embed_dim"] = embed_dim
129
148
  if tokenizer is not None:
@@ -143,7 +162,8 @@ def get_image_text_network_config(net: BaseNet, signature: SignatureType, rgb_st
143
162
  model_config = net.config
144
163
 
145
164
  return {
146
- "birder_clip_version": __version__,
165
+ "birder_clip_version": birder_clip_version,
166
+ "birder_version": birder_version,
147
167
  "name": model_name,
148
168
  "registered_name": registered_name,
149
169
  "task": net.task,
@@ -27,6 +27,7 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
27
27
  help="pretrained Birder image model weights path to load into the image encoder",
28
28
  )
29
29
  parser.add_argument("--text-encoder", type=str, help="the text encoder to use")
30
+ parser.add_argument("--text-decoder", type=str, help="the text decoder to use")
30
31
  parser.add_argument("--embed-dim", type=int, metavar="N", help="shared image-text embedding dimension")
31
32
  parser.add_argument("--tokenizer", type=str, help="the tokenizer to use")
32
33
  parser.add_argument(
@@ -44,11 +45,21 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
44
45
  action=cli.FlexibleDictAction,
45
46
  help="override the text encoder configuration, accepts key-value pairs or JSON",
46
47
  )
48
+ parser.add_argument(
49
+ "--text-decoder-config",
50
+ action=cli.FlexibleDictAction,
51
+ help="override the text decoder configuration, accepts key-value pairs or JSON",
52
+ )
53
+ parser.add_argument(
54
+ "--openvision-v2-keep-ratio", type=float, help="OpenVision v2 image token keep ratio for caption decoding"
55
+ )
47
56
 
48
57
 
49
58
  def add_loss_args(parser: argparse.ArgumentParser) -> None:
50
59
  group = parser.add_argument_group("Loss parameters")
51
- group.add_argument("--loss", type=str, choices=["clip", "coca"], default="clip", help="loss function to use")
60
+ group.add_argument(
61
+ "--loss", type=str, choices=["clip", "coca", "caption"], default="clip", help="loss function to use"
62
+ )
52
63
  group.add_argument(
53
64
  "--coca-caption-loss-weight", type=float, default=1.0, help="weight assigned to CoCa caption loss"
54
65
  )
@@ -662,13 +673,16 @@ def common_args_validation(args: argparse.Namespace) -> None:
662
673
  raise cli.ValidationError("--grad-accum-steps must be >= 1")
663
674
  if args.grad_accum_cache_negatives is True and args.grad_accum_steps == 1:
664
675
  raise cli.ValidationError("--grad-accum-cache-negatives requires --grad-accum-steps greater than 1")
665
- if args.grad_accum_cache_negatives is True and args.loss == "coca":
676
+ if args.grad_accum_cache_negatives is True and args.loss != "clip":
666
677
  raise cli.ValidationError("--grad-accum-cache-negatives is only supported with --loss clip")
667
678
 
668
679
  if args.coca_caption_loss_weight < 0.0:
669
680
  raise cli.ValidationError("--coca-caption-loss-weight must be non-negative")
670
681
  if args.coca_contrastive_loss_weight < 0.0:
671
682
  raise cli.ValidationError("--coca-contrastive-loss-weight must be non-negative")
683
+ if args.openvision_v2_keep_ratio is not None:
684
+ if args.openvision_v2_keep_ratio <= 0.0 or args.openvision_v2_keep_ratio > 1.0:
685
+ raise cli.ValidationError("--openvision-v2-keep-ratio must be in range of (0, 1]")
672
686
 
673
687
  # EMA
674
688
  if args.model_ema_steps < 1:
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import json
2
3
  import logging
3
4
  from pathlib import Path
4
5
  from typing import Any
@@ -11,6 +12,7 @@ from birder.common import training_utils as birder_training_utils
11
12
 
12
13
  from birder_clip.common import fs_ops
13
14
  from birder_clip.conf import settings
15
+ from birder_clip.version import __version__ as birder_clip_version
14
16
 
15
17
 
16
18
  def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
@@ -28,6 +30,18 @@ def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
28
30
  return file_handler
29
31
 
30
32
 
33
+ def make_training_args_payload(args: argparse.Namespace) -> dict[str, Any]:
34
+ return {
35
+ "birder_clip_version": birder_clip_version,
36
+ **birder_training_utils.make_training_args_payload(args),
37
+ }
38
+
39
+
40
+ def write_training_args_json(path: Path, args: argparse.Namespace) -> None:
41
+ with open(path.joinpath("training_args.json"), "w", encoding="utf-8") as handle:
42
+ json.dump(make_training_args_payload(args), handle, indent=2)
43
+
44
+
31
45
  def save_training_checkpoint(
32
46
  args: argparse.Namespace,
33
47
  network_name: str,
@@ -0,0 +1,118 @@
1
+ """
2
+ Inference-optimized multi-GPU parallelization for image-text models
3
+
4
+ This module provides ZeroShotInferenceDataParallel, a CLIP-style zero-shot
5
+ specialization of Birder's InferenceDataParallel.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from birder.inference.data_parallel import InferenceDataParallel
12
+
13
+ from birder_clip.inference.zero_shot import ZeroShotInference
14
+ from birder_clip.net.base import BaseNet
15
+
16
+
17
+ class ZeroShotInferenceDataParallel(InferenceDataParallel):
18
+ """
19
+ Distributes zero-shot image inference batches across multiple GPUs
20
+
21
+ This wrapper scatters the image batch across devices and keeps a replicated
22
+ copy of the zero-shot text embeddings on each device. Each replica computes
23
+ image embeddings and zero-shot logits locally before outputs are gathered.
24
+
25
+ Important
26
+ ---------
27
+ This class assumes the model is already configured for inference mode
28
+ (i.e., loaded with inference=True in load_model or manually set to eval mode
29
+ with requires_grad=False on all parameters).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ module: BaseNet,
35
+ text_embeddings: torch.Tensor,
36
+ device_ids: Optional[list[int]] = None,
37
+ output_device: Optional[int | str | torch.device] = None,
38
+ compile_replicas: bool = False,
39
+ compile_methods: Optional[list[str]] = None,
40
+ compile_mode: Optional[str] = None,
41
+ ) -> None:
42
+ if compile_methods is None:
43
+ compile_methods = ["encode_image", "forward_logits"]
44
+
45
+ super().__init__(
46
+ module,
47
+ device_ids=device_ids,
48
+ output_device=output_device,
49
+ compile_replicas=compile_replicas,
50
+ compile_methods=compile_methods,
51
+ compile_mode=compile_mode,
52
+ )
53
+ self.set_text_embeddings(text_embeddings)
54
+
55
+ def set_text_embeddings(self, text_embeddings: torch.Tensor) -> None:
56
+ if text_embeddings.ndim != 2:
57
+ raise ValueError(
58
+ f"text_embeddings must be a 2D tensor of shape (num_classes, embedding_size), "
59
+ f"got shape {text_embeddings.size()}"
60
+ )
61
+
62
+ self.text_embeddings = [
63
+ text_embeddings.to(f"cuda:{device_id}", non_blocking=True) for device_id in self.device_ids
64
+ ]
65
+ self.inference_modules = [
66
+ ZeroShotInference(replica, embeddings) for replica, embeddings in zip(self.replicas, self.text_embeddings)
67
+ ]
68
+
69
+ def forward( # type: ignore[override] # pylint: disable=arguments-differ
70
+ self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False
71
+ ) -> torch.Tensor:
72
+ """
73
+ Run zero-shot inference distributed across GPUs
74
+
75
+ Parameters
76
+ ----------
77
+ inputs
78
+ Input image batch to process.
79
+ tta
80
+ Run inference with oversampling.
81
+ return_logits
82
+ If True, return raw logits instead of probabilities after softmax.
83
+ """
84
+
85
+ if len(self.device_ids) == 1:
86
+ output = self.inference_modules[0](
87
+ inputs,
88
+ tta=tta,
89
+ return_logits=return_logits,
90
+ )
91
+ return self._gather([output])
92
+
93
+ scattered = self._scatter(inputs, {})
94
+
95
+ outputs = []
96
+ for inference, (input_chunk, _), device_id in zip(self.inference_modules, scattered, self.device_ids):
97
+ if input_chunk is not None and input_chunk.size(0) > 0:
98
+ with torch.cuda.device(device_id):
99
+ output = inference(
100
+ input_chunk,
101
+ tta=tta,
102
+ return_logits=return_logits,
103
+ )
104
+ outputs.append(output)
105
+ else:
106
+ outputs.append(None)
107
+
108
+ return self._gather(outputs)
109
+
110
+ def __repr__(self) -> str:
111
+ return (
112
+ f"ZeroShotInferenceDataParallel(\n"
113
+ f" devices={self.device_ids},\n"
114
+ f" output_device={self.output_device},\n"
115
+ f" src_device={self.src_device},\n"
116
+ f" text_embeddings_shape={tuple(self.text_embeddings[0].shape)}\n"
117
+ f")"
118
+ )
@@ -0,0 +1,63 @@
1
+ import sys
2
+ from collections.abc import Iterator
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from birder_clip.net.base import BaseNet
12
+
13
+ DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32]]
14
+
15
+
16
+ def infer_dataloader_iter(
17
+ device: torch.device,
18
+ net: BaseNet,
19
+ dataloader: DataLoader,
20
+ model_dtype: torch.dtype = torch.float32,
21
+ amp: bool = False,
22
+ amp_dtype: Optional[torch.dtype] = None,
23
+ num_samples: Optional[int] = None,
24
+ chunk_size: Optional[float] = None,
25
+ ) -> Iterator[DataloaderInferenceResult]:
26
+ if chunk_size is None:
27
+ chunk_size = float("inf")
28
+
29
+ net.to(device, dtype=model_dtype)
30
+ embeddings_list: list[npt.NDArray[np.float32]] = []
31
+ sample_paths: list[str] = []
32
+ sample_count = 0
33
+ with tqdm(total=num_samples, initial=0, unit="images", unit_scale=True, leave=False) as progress:
34
+ for file_paths, inputs, _targets in dataloader:
35
+ batch_size = inputs.size(0)
36
+
37
+ # Inference
38
+ inputs = inputs.to(device, dtype=model_dtype)
39
+ with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
40
+ embeddings = net.encode_image(inputs, normalize=True)
41
+ embeddings = embeddings.cpu().float().numpy()
42
+
43
+ embeddings_list.append(embeddings)
44
+
45
+ # Set sample list
46
+ sample_paths.extend(file_paths)
47
+
48
+ # Update progress bar
49
+ progress.update(n=batch_size)
50
+
51
+ # Yield results when we reach chunk_size
52
+ sample_count += batch_size
53
+ if sample_count >= chunk_size:
54
+ with tqdm.external_write_mode(file=sys.stderr):
55
+ yield (sample_paths, np.concatenate(embeddings_list, axis=0))
56
+
57
+ # Reset for next chunk
58
+ embeddings_list = []
59
+ sample_paths = []
60
+ sample_count = 0
61
+
62
+ if len(embeddings_list) > 0:
63
+ yield (sample_paths, np.concatenate(embeddings_list, axis=0))
@@ -13,6 +13,7 @@ from collections.abc import Callable
13
13
  from collections.abc import Iterator
14
14
  from collections.abc import Sequence
15
15
  from typing import Optional
16
+ from typing import Protocol
16
17
 
17
18
  import numpy as np
18
19
  import numpy.typing as npt
@@ -27,6 +28,42 @@ from birder_clip.net.base import BaseNet
27
28
  from birder_clip.tokenizers.base import Tokenizer
28
29
 
29
30
 
31
+ class ZeroShotInferenceModule(Protocol):
32
+ def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor: ...
33
+
34
+
35
+ class ZeroShotInference:
36
+ def __init__(self, net: BaseNet, text_embeddings: torch.Tensor) -> None:
37
+ self.net = net
38
+ self.text_embeddings = text_embeddings
39
+
40
+ def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor:
41
+ inputs = inputs.to(self.text_embeddings.device, non_blocking=True)
42
+ if tta is True:
43
+ _, _, H, W = inputs.size()
44
+ crop_h = int(H * 0.8)
45
+ crop_w = int(W * 0.8)
46
+ tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
47
+ t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
48
+ outs = []
49
+ for tta_input in tta_inputs:
50
+ image_embeddings = self.net.encode_image(t(tta_input), normalize=True)
51
+ logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
52
+ if return_logits is True:
53
+ outs.append(logits)
54
+ else:
55
+ outs.append(F.softmax(logits, dim=-1))
56
+
57
+ return torch.stack(outs).mean(dim=0)
58
+
59
+ image_embeddings = self.net.encode_image(inputs, normalize=True)
60
+ logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
61
+ if return_logits is True:
62
+ return logits
63
+
64
+ return F.softmax(logits, dim=-1)
65
+
66
+
30
67
  def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
31
68
  return [template.format(class_name) for class_name in class_names for template in templates]
32
69
 
@@ -66,39 +103,10 @@ def build_class_text_embeddings(
66
103
  DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]]
67
104
 
68
105
 
69
- def infer_batch(
70
- net: BaseNet, inputs: torch.Tensor, text_embeddings: torch.Tensor, tta: bool = False, return_logits: bool = False
71
- ) -> torch.Tensor:
72
- if tta is True:
73
- _, _, H, W = inputs.size()
74
- crop_h = int(H * 0.8)
75
- crop_w = int(W * 0.8)
76
- tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
77
- t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
78
- outs = []
79
- for tta_input in tta_inputs:
80
- image_embeddings = net.encode_image(t(tta_input), normalize=True)
81
- logits = net.forward_logits(image_embeddings, text_embeddings)
82
- if return_logits is True:
83
- outs.append(logits)
84
- else:
85
- outs.append(F.softmax(logits, dim=-1))
86
-
87
- return torch.stack(outs).mean(dim=0)
88
-
89
- image_embeddings = net.encode_image(inputs, normalize=True)
90
- logits = net.forward_logits(image_embeddings, text_embeddings)
91
- if return_logits is True:
92
- return logits
93
-
94
- return F.softmax(logits, dim=-1)
95
-
96
-
97
106
  def infer_dataloader_iter(
98
107
  device: torch.device,
99
- net: BaseNet,
108
+ net: ZeroShotInferenceModule,
100
109
  dataloader: DataLoader,
101
- text_embeddings: torch.Tensor,
102
110
  tta: bool = False,
103
111
  return_logits: bool = False,
104
112
  model_dtype: torch.dtype = torch.float32,
@@ -111,7 +119,6 @@ def infer_dataloader_iter(
111
119
  if chunk_size is None:
112
120
  chunk_size = float("inf")
113
121
 
114
- net.to(device, dtype=model_dtype)
115
122
  out_list: list[npt.NDArray[np.float32]] = []
116
123
  labels_list: list[npt.NDArray[np.int64]] = []
117
124
  sample_paths: list[str] = []
@@ -121,9 +128,9 @@ def infer_dataloader_iter(
121
128
  batch_size = inputs.size(0)
122
129
 
123
130
  # Inference
124
- inputs = inputs.to(device, dtype=model_dtype)
131
+ inputs = inputs.to(dtype=model_dtype)
125
132
  with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
126
- out = infer_batch(net, inputs, text_embeddings, return_logits=return_logits, tta=tta)
133
+ out = net(inputs, return_logits=return_logits, tta=tta)
127
134
  out = out.cpu().float().numpy()
128
135
 
129
136
  out_list.append(out)
@@ -1,7 +1,9 @@
1
+ from birder_clip.loss.caption import CaptionLoss
1
2
  from birder_clip.loss.coca import CoCaLoss
2
3
  from birder_clip.loss.contrastive import CLIPLoss
3
4
 
4
5
  __all__ = [
6
+ "CaptionLoss",
5
7
  "CoCaLoss",
6
8
  "CLIPLoss",
7
9
  ]