birder-clip 0.0.2.dev7__tar.gz → 0.0.2.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/PKG-INFO +4 -4
  2. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/README.md +1 -1
  3. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/fs_ops.py +31 -11
  4. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/lib.py +25 -5
  5. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/training_cli.py +16 -2
  6. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/training_utils.py +14 -0
  7. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/loss/__init__.py +2 -0
  8. birder_clip-0.0.2.dev8/birder_clip/loss/caption.py +72 -0
  9. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/loss/coca.py +1 -3
  10. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/manifest.py +22 -5
  11. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/model_registry.py +5 -0
  12. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/__init__.py +2 -0
  13. birder_clip-0.0.2.dev8/birder_clip/net/base.py +201 -0
  14. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/clip.py +27 -1
  15. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/coca.py +67 -3
  16. birder_clip-0.0.2.dev8/birder_clip/net/openvision_v2.py +219 -0
  17. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/__init__.py +3 -1
  18. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/base.py +40 -19
  19. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/conditioned_decoder.py +4 -3
  20. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/encoder.py +22 -3
  21. birder_clip-0.0.2.dev8/birder_clip/net/text/visual_causal_decoder.py +195 -0
  22. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/embed_images.py +16 -1
  23. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/train.py +58 -9
  24. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/zero_shot.py +14 -0
  25. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/openvision.py +23 -0
  26. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/convert_model.py +21 -4
  27. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/model_info.py +22 -2
  28. birder_clip-0.0.2.dev8/birder_clip/version.py +1 -0
  29. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/PKG-INFO +4 -4
  30. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/SOURCES.txt +3 -1
  31. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/requires.txt +2 -2
  32. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/requirements/_requirements-dev.txt +1 -1
  33. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/requirements/requirements.txt +1 -1
  34. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_common.py +25 -0
  35. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_loss.py +60 -0
  36. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_model_registry.py +25 -0
  37. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_net.py +342 -0
  38. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_net_text.py +44 -0
  39. birder_clip-0.0.2.dev7/birder_clip/net/base.py +0 -77
  40. birder_clip-0.0.2.dev7/birder_clip/net/text/prefix_decoder.py +0 -1
  41. birder_clip-0.0.2.dev7/birder_clip/version.py +0 -1
  42. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/LICENSE +0 -0
  43. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/__init__.py +0 -0
  44. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/__init__.py +0 -0
  45. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/conf/__init__.py +0 -0
  46. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/conf/settings.py +0 -0
  47. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/__init__.py +0 -0
  48. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/__init__.py +0 -0
  49. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/csv.py +0 -0
  50. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/fake.py +0 -0
  51. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/webdataset.py +0 -0
  52. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/__init__.py +0 -0
  53. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/data_parallel.py +0 -0
  54. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/image_embeddings.py +0 -0
  55. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot.py +0 -0
  56. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot_templates.py +0 -0
  57. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/loss/contrastive.py +0 -0
  58. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/__init__.py +0 -0
  59. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/hf.py +0 -0
  60. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/py.typed +0 -0
  61. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/__init__.py +0 -0
  62. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/__main__.py +0 -0
  63. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/__init__.py +0 -0
  64. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/base.py +0 -0
  65. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  66. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/hf.py +0 -0
  67. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/registry.py +0 -0
  68. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
  69. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/__init__.py +0 -0
  70. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/__main__.py +0 -0
  71. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/download_tokenizer.py +0 -0
  72. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/list_models.py +0 -0
  73. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/show_iterator.py +0 -0
  74. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/stats.py +0 -0
  75. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/dependency_links.txt +0 -0
  76. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/top_level.txt +0 -0
  77. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/pyproject.toml +0 -0
  78. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/setup.cfg +0 -0
  79. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_datasets.py +0 -0
  80. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_inference.py +0 -0
  81. {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_tokenizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder_clip
3
- Version: 0.0.2.dev7
3
+ Version: 0.0.2.dev8
4
4
  Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
24
24
  Requires-Python: >=3.11
25
25
  Description-Content-Type: text/markdown
26
26
  License-File: LICENSE
27
- Requires-Dist: birder>=0.6.0
27
+ Requires-Dist: birder>=0.6.2
28
28
  Requires-Dist: ftfy>=6.3.1
29
29
  Requires-Dist: regex>=2025.7.29
30
30
  Requires-Dist: tqdm>=4.67.0
@@ -38,7 +38,7 @@ Requires-Dist: bandit~=1.9.4; extra == "dev"
38
38
  Requires-Dist: black~=26.5.0; extra == "dev"
39
39
  Requires-Dist: build~=1.5.0; extra == "dev"
40
40
  Requires-Dist: bumpver~=2026.1132; extra == "dev"
41
- Requires-Dist: coverage~=7.14.2; extra == "dev"
41
+ Requires-Dist: coverage~=7.14.3; extra == "dev"
42
42
  Requires-Dist: debugpy; extra == "dev"
43
43
  Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
44
44
  Requires-Dist: flake8~=7.3.0; extra == "dev"
@@ -85,7 +85,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
85
85
  1. Ensure your environment meets the minimum requirements:
86
86
  - Python 3.11 or newer
87
87
  - PyTorch 2.10 or newer (installed for your hardware/driver stack)
88
- - Birder 0.6.0 or newer
88
+ - Birder 0.6.2 or newer
89
89
 
90
90
  1. Install the latest Birder CLIP version:
91
91
 
@@ -23,7 +23,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
23
23
  1. Ensure your environment meets the minimum requirements:
24
24
  - Python 3.11 or newer
25
25
  - PyTorch 2.10 or newer (installed for your hardware/driver stack)
26
- - Birder 0.6.0 or newer
26
+ - Birder 0.6.2 or newer
27
27
 
28
28
  1. Install the latest Birder CLIP version:
29
29
 
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from typing import Any
7
7
  from typing import NamedTuple
8
8
  from typing import Optional
9
+ from typing import TypeAlias
9
10
 
10
11
  import torch
11
12
  from birder.common import cli
@@ -16,8 +17,10 @@ from birder.data.transforms.classification import inference_preset
16
17
  from birder_clip.common import lib
17
18
  from birder_clip.model_registry import Task
18
19
  from birder_clip.model_registry import registry
19
- from birder_clip.model_registry.manifest import EncoderMetadataType
20
20
  from birder_clip.model_registry.manifest import FileFormatType
21
+ from birder_clip.model_registry.manifest import ImageEncoderMetadataType
22
+ from birder_clip.model_registry.manifest import TextDecoderMetadataType
23
+ from birder_clip.model_registry.manifest import TextEncoderMetadataType
21
24
  from birder_clip.net.base import BaseNet
22
25
  from birder_clip.net.base import SignatureType
23
26
  from birder_clip.tokenizers import Tokenizer
@@ -36,6 +39,8 @@ except ImportError:
36
39
 
37
40
  logger = logging.getLogger(__name__)
38
41
 
42
+ ComponentMetadataType: TypeAlias = ImageEncoderMetadataType | TextEncoderMetadataType | TextDecoderMetadataType
43
+
39
44
 
40
45
  class ModelInfo(NamedTuple):
41
46
  signature: SignatureType
@@ -51,16 +56,18 @@ def write_config(network_name: str, net: BaseNet, signature: SignatureType, rgb_
51
56
  json.dump(model_config, handle, indent=2)
52
57
 
53
58
 
54
- def _split_encoder_metadata(encoder: Optional[EncoderMetadataType]) -> tuple[Optional[str], Optional[dict[str, Any]]]:
55
- if encoder is None:
59
+ def _split_component_metadata(
60
+ component: Optional[ComponentMetadataType],
61
+ ) -> tuple[Optional[str], Optional[dict[str, Any]]]:
62
+ if component is None:
56
63
  return (None, None)
57
- if isinstance(encoder, str):
58
- return (encoder, None)
64
+ if isinstance(component, str):
65
+ return (component, None)
59
66
 
60
- if "network" not in encoder:
61
- raise ValueError("Encoder metadata must include a 'network' field")
67
+ if "network" not in component:
68
+ raise ValueError("Component metadata must include a 'network' field")
62
69
 
63
- return (None, encoder) # type: ignore[return-value]
70
+ return (None, component) # type: ignore[return-value]
64
71
 
65
72
 
66
73
  def model_path(
@@ -286,10 +293,12 @@ def load_checkpoint(
286
293
  tag: Optional[str] = None,
287
294
  image_encoder: Optional[str] = None,
288
295
  text_encoder: Optional[str] = None,
296
+ text_decoder: Optional[str] = None,
289
297
  embed_dim: Optional[int] = None,
290
298
  tokenizer: Optional[str] = None,
291
299
  image_encoder_config: Optional[dict[str, Any]] = None,
292
300
  text_encoder_config: Optional[dict[str, Any]] = None,
301
+ text_decoder_config: Optional[dict[str, Any]] = None,
293
302
  epoch: Optional[int] = None,
294
303
  new_size: Optional[tuple[int, int]] = None,
295
304
  new_context_length: Optional[int] = None,
@@ -300,6 +309,7 @@ def load_checkpoint(
300
309
  tag=tag,
301
310
  image_encoder=image_encoder,
302
311
  text_encoder=text_encoder,
312
+ text_decoder=text_decoder,
303
313
  embed_dim=embed_dim,
304
314
  tokenizer=tokenizer,
305
315
  )
@@ -329,10 +339,12 @@ def load_checkpoint(
329
339
  checkpoint_config,
330
340
  image_encoder=image_encoder,
331
341
  text_encoder=text_encoder,
342
+ text_decoder=text_decoder,
332
343
  embed_dim=embed_dim,
333
344
  tokenizer=tokenizer,
334
345
  image_encoder_config=image_encoder_config,
335
346
  text_encoder_config=text_encoder_config,
347
+ text_decoder_config=text_decoder_config,
336
348
  input_channels=input_channels,
337
349
  image_size=size,
338
350
  context_length=context_length,
@@ -364,10 +376,12 @@ def load_model(
364
376
  tag: Optional[str] = None,
365
377
  image_encoder: Optional[str] = None,
366
378
  text_encoder: Optional[str] = None,
379
+ text_decoder: Optional[str] = None,
367
380
  embed_dim: Optional[int] = None,
368
381
  tokenizer: Optional[str] = None,
369
382
  image_encoder_config: Optional[dict[str, Any]] = None,
370
383
  text_encoder_config: Optional[dict[str, Any]] = None,
384
+ text_decoder_config: Optional[dict[str, Any]] = None,
371
385
  epoch: Optional[int] = None,
372
386
  new_size: Optional[tuple[int, int]] = None,
373
387
  new_context_length: Optional[int] = None,
@@ -381,6 +395,7 @@ def load_model(
381
395
  tag=tag,
382
396
  image_encoder=image_encoder,
383
397
  text_encoder=text_encoder,
398
+ text_decoder=text_decoder,
384
399
  embed_dim=embed_dim,
385
400
  tokenizer=tokenizer,
386
401
  )
@@ -420,10 +435,12 @@ def load_model(
420
435
  checkpoint_config,
421
436
  image_encoder=image_encoder,
422
437
  text_encoder=text_encoder,
438
+ text_decoder=text_decoder,
423
439
  embed_dim=embed_dim,
424
440
  tokenizer=tokenizer,
425
441
  image_encoder_config=image_encoder_config,
426
442
  text_encoder_config=text_encoder_config,
443
+ text_decoder_config=text_decoder_config,
427
444
  input_channels=input_channels,
428
445
  image_size=size,
429
446
  context_length=context_length,
@@ -525,15 +542,17 @@ def load_pretrained_model(
525
542
  if model_metadata["task"] != Task.IMAGE_TEXT:
526
543
  raise ValueError(f"Unknown model type: {model_metadata['task']}")
527
544
 
528
- image_encoder, image_config = _split_encoder_metadata(model_metadata["net"].get("image_encoder", None))
529
- text_encoder, text_config = _split_encoder_metadata(model_metadata["net"].get("text_encoder", None))
545
+ image_encoder, image_config = _split_component_metadata(model_metadata["net"].get("image_encoder", None))
546
+ text_encoder, text_config = _split_component_metadata(model_metadata["net"].get("text_encoder", None))
547
+ text_decoder, decoder_config = _split_component_metadata(model_metadata["net"].get("text_decoder", None))
530
548
 
531
549
  pretrained_config: dict[str, Any] = {}
532
550
  if image_config is not None:
533
551
  pretrained_config["image"] = image_config
534
-
535
552
  if text_config is not None:
536
553
  pretrained_config["text"] = text_config
554
+ if decoder_config is not None:
555
+ pretrained_config["decoder"] = decoder_config
537
556
 
538
557
  if custom_config is not None:
539
558
  pretrained_config.update(custom_config)
@@ -550,6 +569,7 @@ def load_pretrained_model(
550
569
  tag=model_metadata["net"].get("tag", None),
551
570
  image_encoder=image_encoder,
552
571
  text_encoder=text_encoder,
572
+ text_decoder=text_decoder,
553
573
  embed_dim=model_metadata["net"].get("embed_dim", None),
554
574
  tokenizer=model_metadata["net"].get("tokenizer", None),
555
575
  inference=inference,
@@ -3,14 +3,17 @@ from typing import Any
3
3
  from typing import Optional
4
4
 
5
5
  from birder.data.transforms.classification import RGBType
6
+ from birder.version import __version__ as birder_version
6
7
 
7
8
  from birder_clip.conf import settings
8
9
  from birder_clip.model_registry import registry
9
10
  from birder_clip.net.base import BaseNet
10
11
  from birder_clip.net.base import SignatureType
11
- from birder_clip.version import __version__
12
+ from birder_clip.version import __version__ as birder_clip_version
12
13
 
13
- MODEL_CONFIG_RESERVED_KEYS = frozenset({"image", "text", "tokenizer", "embed_dim", "embed-dim"})
14
+ MODEL_CONFIG_RESERVED_KEYS = frozenset(
15
+ {"image", "text", "decoder", "tokenizer", "embed_dim", "embed-dim", "keep_ratio"}
16
+ )
14
17
 
15
18
 
16
19
  def get_size_from_signature(signature: SignatureType) -> tuple[int, int]:
@@ -37,6 +40,7 @@ def get_image_text_network_name(
37
40
  tag: Optional[str] = None,
38
41
  image_encoder: Optional[str] = None,
39
42
  text_encoder: Optional[str] = None,
43
+ text_decoder: Optional[str] = None,
40
44
  embed_dim: Optional[int] = None,
41
45
  tokenizer: Optional[str] = None,
42
46
  ) -> str:
@@ -45,6 +49,8 @@ def get_image_text_network_name(
45
49
  parts.append(image_encoder)
46
50
  if text_encoder is not None and text_encoder != "transformer_encoder":
47
51
  parts.append(text_encoder)
52
+ if text_decoder is not None and text_decoder != "conditioned_decoder":
53
+ parts.append(text_decoder)
48
54
 
49
55
  if registry.exists(network) is True:
50
56
  default_tokenizer = registry.get_default_tokenizer(network)
@@ -71,10 +77,12 @@ def get_image_text_model_config(
71
77
  *,
72
78
  image_encoder: Optional[str] = None,
73
79
  text_encoder: Optional[str] = None,
80
+ text_decoder: Optional[str] = None,
74
81
  embed_dim: Optional[int] = None,
75
82
  tokenizer: Optional[str] = None,
76
83
  image_encoder_config: Optional[dict[str, Any]] = None,
77
84
  text_encoder_config: Optional[dict[str, Any]] = None,
85
+ text_decoder_config: Optional[dict[str, Any]] = None,
78
86
  input_channels: Optional[int] = None,
79
87
  image_size: Optional[tuple[int, int]] = None,
80
88
  context_length: Optional[int] = None,
@@ -86,7 +94,7 @@ def get_image_text_model_config(
86
94
 
87
95
  if config is not None:
88
96
  for key, value in config.items():
89
- if key in {"image", "text"} and isinstance(value, dict):
97
+ if key in {"image", "text", "decoder"} and isinstance(value, dict):
90
98
  model_config[key] = {**model_config.get(key, {}), **value}
91
99
  else:
92
100
  model_config[key] = value
@@ -111,7 +119,7 @@ def get_image_text_model_config(
111
119
 
112
120
  model_config["image"] = image_config
113
121
 
114
- if text_encoder is not None or text_encoder_config is not None or context_length is not None:
122
+ if text_encoder is not None or text_encoder_config is not None or "text" in model_config:
115
123
  text_config = model_config.get("text", {}).copy()
116
124
  if text_encoder is not None:
117
125
  # String encoder metadata replaces only the encoder name.
@@ -124,6 +132,17 @@ def get_image_text_model_config(
124
132
 
125
133
  model_config["text"] = text_config
126
134
 
135
+ if text_decoder is not None or text_decoder_config is not None or "decoder" in model_config:
136
+ decoder_config = model_config.get("decoder", {}).copy()
137
+ if text_decoder is not None:
138
+ decoder_config["network"] = text_decoder
139
+ if text_decoder_config is not None:
140
+ decoder_config["config"] = {**decoder_config.get("config", {}), **text_decoder_config}
141
+ if context_length is not None:
142
+ decoder_config["context_length"] = context_length
143
+
144
+ model_config["decoder"] = decoder_config
145
+
127
146
  if embed_dim is not None:
128
147
  model_config["embed_dim"] = embed_dim
129
148
  if tokenizer is not None:
@@ -143,7 +162,8 @@ def get_image_text_network_config(net: BaseNet, signature: SignatureType, rgb_st
143
162
  model_config = net.config
144
163
 
145
164
  return {
146
- "birder_clip_version": __version__,
165
+ "birder_clip_version": birder_clip_version,
166
+ "birder_version": birder_version,
147
167
  "name": model_name,
148
168
  "registered_name": registered_name,
149
169
  "task": net.task,
@@ -27,6 +27,7 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
27
27
  help="pretrained Birder image model weights path to load into the image encoder",
28
28
  )
29
29
  parser.add_argument("--text-encoder", type=str, help="the text encoder to use")
30
+ parser.add_argument("--text-decoder", type=str, help="the text decoder to use")
30
31
  parser.add_argument("--embed-dim", type=int, metavar="N", help="shared image-text embedding dimension")
31
32
  parser.add_argument("--tokenizer", type=str, help="the tokenizer to use")
32
33
  parser.add_argument(
@@ -44,11 +45,21 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
44
45
  action=cli.FlexibleDictAction,
45
46
  help="override the text encoder configuration, accepts key-value pairs or JSON",
46
47
  )
48
+ parser.add_argument(
49
+ "--text-decoder-config",
50
+ action=cli.FlexibleDictAction,
51
+ help="override the text decoder configuration, accepts key-value pairs or JSON",
52
+ )
53
+ parser.add_argument(
54
+ "--openvision-v2-keep-ratio", type=float, help="OpenVision v2 image token keep ratio for caption decoding"
55
+ )
47
56
 
48
57
 
49
58
  def add_loss_args(parser: argparse.ArgumentParser) -> None:
50
59
  group = parser.add_argument_group("Loss parameters")
51
- group.add_argument("--loss", type=str, choices=["clip", "coca"], default="clip", help="loss function to use")
60
+ group.add_argument(
61
+ "--loss", type=str, choices=["clip", "coca", "caption"], default="clip", help="loss function to use"
62
+ )
52
63
  group.add_argument(
53
64
  "--coca-caption-loss-weight", type=float, default=1.0, help="weight assigned to CoCa caption loss"
54
65
  )
@@ -662,13 +673,16 @@ def common_args_validation(args: argparse.Namespace) -> None:
662
673
  raise cli.ValidationError("--grad-accum-steps must be >= 1")
663
674
  if args.grad_accum_cache_negatives is True and args.grad_accum_steps == 1:
664
675
  raise cli.ValidationError("--grad-accum-cache-negatives requires --grad-accum-steps greater than 1")
665
- if args.grad_accum_cache_negatives is True and args.loss == "coca":
676
+ if args.grad_accum_cache_negatives is True and args.loss != "clip":
666
677
  raise cli.ValidationError("--grad-accum-cache-negatives is only supported with --loss clip")
667
678
 
668
679
  if args.coca_caption_loss_weight < 0.0:
669
680
  raise cli.ValidationError("--coca-caption-loss-weight must be non-negative")
670
681
  if args.coca_contrastive_loss_weight < 0.0:
671
682
  raise cli.ValidationError("--coca-contrastive-loss-weight must be non-negative")
683
+ if args.openvision_v2_keep_ratio is not None:
684
+ if args.openvision_v2_keep_ratio <= 0.0 or args.openvision_v2_keep_ratio > 1.0:
685
+ raise cli.ValidationError("--openvision-v2-keep-ratio must be in range of (0, 1]")
672
686
 
673
687
  # EMA
674
688
  if args.model_ema_steps < 1:
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import json
2
3
  import logging
3
4
  from pathlib import Path
4
5
  from typing import Any
@@ -11,6 +12,7 @@ from birder.common import training_utils as birder_training_utils
11
12
 
12
13
  from birder_clip.common import fs_ops
13
14
  from birder_clip.conf import settings
15
+ from birder_clip.version import __version__ as birder_clip_version
14
16
 
15
17
 
16
18
  def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
@@ -28,6 +30,18 @@ def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
28
30
  return file_handler
29
31
 
30
32
 
33
+ def make_training_args_payload(args: argparse.Namespace) -> dict[str, Any]:
34
+ return {
35
+ "birder_clip_version": birder_clip_version,
36
+ **birder_training_utils.make_training_args_payload(args),
37
+ }
38
+
39
+
40
+ def write_training_args_json(path: Path, args: argparse.Namespace) -> None:
41
+ with open(path.joinpath("training_args.json"), "w", encoding="utf-8") as handle:
42
+ json.dump(make_training_args_payload(args), handle, indent=2)
43
+
44
+
31
45
  def save_training_checkpoint(
32
46
  args: argparse.Namespace,
33
47
  network_name: str,
@@ -1,7 +1,9 @@
1
+ from birder_clip.loss.caption import CaptionLoss
1
2
  from birder_clip.loss.coca import CoCaLoss
2
3
  from birder_clip.loss.contrastive import CLIPLoss
3
4
 
4
5
  __all__ = [
6
+ "CaptionLoss",
5
7
  "CoCaLoss",
6
8
  "CLIPLoss",
7
9
  ]
@@ -0,0 +1,72 @@
1
+ from collections.abc import Sequence
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class CaptionLoss(torch.nn.Module):
9
+ """
10
+ Autoregressive captioning cross entropy over decoder logits
11
+
12
+ The loss consumes unshifted tokenized captions and supports two decoder output conventions:
13
+ - logits length equals text length: the final logit is ignored.
14
+ - logits length equals text length - 1: logits are used as-is.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ caption_loss_weight: float = 1.0,
20
+ pad_token_id: int = 0,
21
+ ignore_token_ids: Optional[Sequence[int]] = None,
22
+ label_smoothing: float = 0.0,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.caption_loss_weight = caption_loss_weight
26
+ self.ignore_token_ids = tuple(ignore_token_ids) if ignore_token_ids is not None else (pad_token_id,)
27
+ self.label_smoothing = label_smoothing
28
+
29
+ def _align_logits_and_targets(self, logits: torch.Tensor, texts: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
30
+ if logits.size(1) == texts.size(1):
31
+ return logits[:, :-1], texts[:, 1:]
32
+ if logits.size(1) == texts.size(1) - 1:
33
+ return logits, texts[:, 1:]
34
+
35
+ raise ValueError(
36
+ "Expected logits sequence length to equal text sequence length or text sequence length - 1, "
37
+ f"got logits={logits.size(1)}, texts={texts.size(1)}"
38
+ )
39
+
40
+ def _align_target_mask(self, target_mask: torch.Tensor, texts: torch.Tensor) -> torch.Tensor:
41
+ if target_mask.shape == texts.shape:
42
+ target_mask = target_mask[:, 1:]
43
+
44
+ return target_mask
45
+
46
+ def forward(
47
+ self, logits: torch.Tensor, texts: torch.Tensor, target_mask: Optional[torch.Tensor] = None
48
+ ) -> dict[str, torch.Tensor]:
49
+ if self.caption_loss_weight == 0.0:
50
+ return {"caption_loss": logits.new_zeros(())}
51
+
52
+ logits, targets = self._align_logits_and_targets(logits, texts)
53
+ token_loss = F.cross_entropy(
54
+ logits.permute(0, 2, 1),
55
+ targets,
56
+ reduction="none",
57
+ label_smoothing=self.label_smoothing,
58
+ )
59
+
60
+ loss_mask = torch.ones_like(targets, dtype=torch.bool)
61
+ for token_id in self.ignore_token_ids:
62
+ loss_mask = loss_mask & (targets != token_id)
63
+
64
+ if target_mask is not None:
65
+ target_mask = self._align_target_mask(target_mask, texts)
66
+ loss_mask = loss_mask & target_mask.to(dtype=torch.bool)
67
+
68
+ loss_mask_float = loss_mask.to(dtype=token_loss.dtype)
69
+ caption_loss = (token_loss * loss_mask_float).sum() / loss_mask_float.sum().clamp_min(1)
70
+ caption_loss = caption_loss * self.caption_loss_weight
71
+
72
+ return {"caption_loss": caption_loss}
@@ -26,9 +26,7 @@ class CoCaLoss(torch.nn.Module):
26
26
  captioning cross entropy over decoder logits.
27
27
  """
28
28
 
29
- def __init__(
30
- self, *, caption_loss_weight: float = 1.0, clip_loss_weight: float = 1.0, pad_token_id: int = 0
31
- ) -> None:
29
+ def __init__(self, caption_loss_weight: float = 1.0, clip_loss_weight: float = 1.0, pad_token_id: int = 0) -> None:
32
30
  super().__init__()
33
31
  self.caption_loss_weight = caption_loss_weight
34
32
  self.clip_loss_weight = clip_loss_weight
@@ -11,8 +11,8 @@ FormatInfoType = TypedDict(
11
11
  {"file_size": float, "sha256": str},
12
12
  )
13
13
 
14
- EncoderInfoType = TypedDict(
15
- "EncoderInfoType",
14
+ ImageEncoderInfoType = TypedDict(
15
+ "ImageEncoderInfoType",
16
16
  {
17
17
  "network": str,
18
18
  "config": NotRequired[dict[str, Any]],
@@ -21,16 +21,33 @@ EncoderInfoType = TypedDict(
21
21
  "size": NotRequired[tuple[int, int]],
22
22
  },
23
23
  )
24
+ TextEncoderInfoType = TypedDict(
25
+ "TextEncoderInfoType",
26
+ {
27
+ "network": str,
28
+ "config": NotRequired[dict[str, Any]],
29
+ },
30
+ )
31
+ TextDecoderInfoType = TypedDict(
32
+ "TextDecoderInfoType",
33
+ {
34
+ "network": str,
35
+ "config": NotRequired[dict[str, Any]],
36
+ },
37
+ )
24
38
 
25
- EncoderMetadataType: TypeAlias = str | EncoderInfoType
39
+ ImageEncoderMetadataType: TypeAlias = str | ImageEncoderInfoType
40
+ TextEncoderMetadataType: TypeAlias = str | TextEncoderInfoType
41
+ TextDecoderMetadataType: TypeAlias = str | TextDecoderInfoType
26
42
 
27
43
  NetworkInfoType = TypedDict(
28
44
  "NetworkInfoType",
29
45
  {
30
46
  "network": str,
31
47
  "tag": NotRequired[str],
32
- "image_encoder": NotRequired[EncoderMetadataType],
33
- "text_encoder": NotRequired[EncoderMetadataType],
48
+ "image_encoder": NotRequired[ImageEncoderMetadataType],
49
+ "text_encoder": NotRequired[TextEncoderMetadataType],
50
+ "text_decoder": NotRequired[TextDecoderMetadataType],
34
51
  "embed_dim": NotRequired[int],
35
52
  "tokenizer": NotRequired[str],
36
53
  },
@@ -170,6 +170,11 @@ class ModelRegistry:
170
170
  if context_length is not None:
171
171
  return context_length # type: ignore[no-any-return]
172
172
 
173
+ decoder_config = config.get("decoder", {})
174
+ context_length = decoder_config.get("context_length")
175
+ if context_length is not None:
176
+ return context_length # type: ignore[no-any-return]
177
+
173
178
  if tokenizer is None:
174
179
  tokenizer = config.get("tokenizer")
175
180
 
@@ -1,7 +1,9 @@
1
1
  from birder_clip.net.clip import CLIP
2
2
  from birder_clip.net.coca import CoCa
3
+ from birder_clip.net.openvision_v2 import OpenVision_v2
3
4
 
4
5
  __all__ = [
5
6
  "CLIP",
6
7
  "CoCa",
8
+ "OpenVision_v2",
7
9
  ]