birder-clip 0.0.2.dev7__tar.gz → 0.0.2.dev8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/PKG-INFO +4 -4
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/README.md +1 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/fs_ops.py +31 -11
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/lib.py +25 -5
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/training_cli.py +16 -2
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/training_utils.py +14 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/loss/__init__.py +2 -0
- birder_clip-0.0.2.dev8/birder_clip/loss/caption.py +72 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/loss/coca.py +1 -3
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/manifest.py +22 -5
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/model_registry.py +5 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/__init__.py +2 -0
- birder_clip-0.0.2.dev8/birder_clip/net/base.py +201 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/clip.py +27 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/coca.py +67 -3
- birder_clip-0.0.2.dev8/birder_clip/net/openvision_v2.py +219 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/__init__.py +3 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/base.py +40 -19
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/conditioned_decoder.py +4 -3
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/encoder.py +22 -3
- birder_clip-0.0.2.dev8/birder_clip/net/text/visual_causal_decoder.py +195 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/embed_images.py +16 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/train.py +58 -9
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/zero_shot.py +14 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/openvision.py +23 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/convert_model.py +21 -4
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/model_info.py +22 -2
- birder_clip-0.0.2.dev8/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/PKG-INFO +4 -4
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/SOURCES.txt +3 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/requires.txt +2 -2
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/requirements/_requirements-dev.txt +1 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/requirements/requirements.txt +1 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_common.py +25 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_loss.py +60 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_model_registry.py +25 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_net.py +342 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_net_text.py +44 -0
- birder_clip-0.0.2.dev7/birder_clip/net/base.py +0 -77
- birder_clip-0.0.2.dev7/birder_clip/net/text/prefix_decoder.py +0 -1
- birder_clip-0.0.2.dev7/birder_clip/version.py +0 -1
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/LICENSE +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/conf/settings.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/csv.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/fake.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/webdataset.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/data_parallel.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/image_embeddings.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/loss/contrastive.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/net/text/hf.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/py.typed +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/scripts/__main__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/base.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/hf.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/registry.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/__init__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/__main__.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/download_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/list_models.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/show_iterator.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/tools/stats.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_datasets.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_inference.py +0 -0
- {birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/tests/test_tokenizers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder_clip
|
|
3
|
-
Version: 0.0.2.
|
|
3
|
+
Version: 0.0.2.dev8
|
|
4
4
|
Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
|
|
|
24
24
|
Requires-Python: >=3.11
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
License-File: LICENSE
|
|
27
|
-
Requires-Dist: birder>=0.6.
|
|
27
|
+
Requires-Dist: birder>=0.6.2
|
|
28
28
|
Requires-Dist: ftfy>=6.3.1
|
|
29
29
|
Requires-Dist: regex>=2025.7.29
|
|
30
30
|
Requires-Dist: tqdm>=4.67.0
|
|
@@ -38,7 +38,7 @@ Requires-Dist: bandit~=1.9.4; extra == "dev"
|
|
|
38
38
|
Requires-Dist: black~=26.5.0; extra == "dev"
|
|
39
39
|
Requires-Dist: build~=1.5.0; extra == "dev"
|
|
40
40
|
Requires-Dist: bumpver~=2026.1132; extra == "dev"
|
|
41
|
-
Requires-Dist: coverage~=7.14.
|
|
41
|
+
Requires-Dist: coverage~=7.14.3; extra == "dev"
|
|
42
42
|
Requires-Dist: debugpy; extra == "dev"
|
|
43
43
|
Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
|
|
44
44
|
Requires-Dist: flake8~=7.3.0; extra == "dev"
|
|
@@ -85,7 +85,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
|
|
|
85
85
|
1. Ensure your environment meets the minimum requirements:
|
|
86
86
|
- Python 3.11 or newer
|
|
87
87
|
- PyTorch 2.10 or newer (installed for your hardware/driver stack)
|
|
88
|
-
- Birder 0.6.
|
|
88
|
+
- Birder 0.6.2 or newer
|
|
89
89
|
|
|
90
90
|
1. Install the latest Birder CLIP version:
|
|
91
91
|
|
|
@@ -23,7 +23,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
|
|
|
23
23
|
1. Ensure your environment meets the minimum requirements:
|
|
24
24
|
- Python 3.11 or newer
|
|
25
25
|
- PyTorch 2.10 or newer (installed for your hardware/driver stack)
|
|
26
|
-
- Birder 0.6.
|
|
26
|
+
- Birder 0.6.2 or newer
|
|
27
27
|
|
|
28
28
|
1. Install the latest Birder CLIP version:
|
|
29
29
|
|
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
from typing import NamedTuple
|
|
8
8
|
from typing import Optional
|
|
9
|
+
from typing import TypeAlias
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
from birder.common import cli
|
|
@@ -16,8 +17,10 @@ from birder.data.transforms.classification import inference_preset
|
|
|
16
17
|
from birder_clip.common import lib
|
|
17
18
|
from birder_clip.model_registry import Task
|
|
18
19
|
from birder_clip.model_registry import registry
|
|
19
|
-
from birder_clip.model_registry.manifest import EncoderMetadataType
|
|
20
20
|
from birder_clip.model_registry.manifest import FileFormatType
|
|
21
|
+
from birder_clip.model_registry.manifest import ImageEncoderMetadataType
|
|
22
|
+
from birder_clip.model_registry.manifest import TextDecoderMetadataType
|
|
23
|
+
from birder_clip.model_registry.manifest import TextEncoderMetadataType
|
|
21
24
|
from birder_clip.net.base import BaseNet
|
|
22
25
|
from birder_clip.net.base import SignatureType
|
|
23
26
|
from birder_clip.tokenizers import Tokenizer
|
|
@@ -36,6 +39,8 @@ except ImportError:
|
|
|
36
39
|
|
|
37
40
|
logger = logging.getLogger(__name__)
|
|
38
41
|
|
|
42
|
+
ComponentMetadataType: TypeAlias = ImageEncoderMetadataType | TextEncoderMetadataType | TextDecoderMetadataType
|
|
43
|
+
|
|
39
44
|
|
|
40
45
|
class ModelInfo(NamedTuple):
|
|
41
46
|
signature: SignatureType
|
|
@@ -51,16 +56,18 @@ def write_config(network_name: str, net: BaseNet, signature: SignatureType, rgb_
|
|
|
51
56
|
json.dump(model_config, handle, indent=2)
|
|
52
57
|
|
|
53
58
|
|
|
54
|
-
def
|
|
55
|
-
|
|
59
|
+
def _split_component_metadata(
|
|
60
|
+
component: Optional[ComponentMetadataType],
|
|
61
|
+
) -> tuple[Optional[str], Optional[dict[str, Any]]]:
|
|
62
|
+
if component is None:
|
|
56
63
|
return (None, None)
|
|
57
|
-
if isinstance(
|
|
58
|
-
return (
|
|
64
|
+
if isinstance(component, str):
|
|
65
|
+
return (component, None)
|
|
59
66
|
|
|
60
|
-
if "network" not in
|
|
61
|
-
raise ValueError("
|
|
67
|
+
if "network" not in component:
|
|
68
|
+
raise ValueError("Component metadata must include a 'network' field")
|
|
62
69
|
|
|
63
|
-
return (None,
|
|
70
|
+
return (None, component) # type: ignore[return-value]
|
|
64
71
|
|
|
65
72
|
|
|
66
73
|
def model_path(
|
|
@@ -286,10 +293,12 @@ def load_checkpoint(
|
|
|
286
293
|
tag: Optional[str] = None,
|
|
287
294
|
image_encoder: Optional[str] = None,
|
|
288
295
|
text_encoder: Optional[str] = None,
|
|
296
|
+
text_decoder: Optional[str] = None,
|
|
289
297
|
embed_dim: Optional[int] = None,
|
|
290
298
|
tokenizer: Optional[str] = None,
|
|
291
299
|
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
292
300
|
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
301
|
+
text_decoder_config: Optional[dict[str, Any]] = None,
|
|
293
302
|
epoch: Optional[int] = None,
|
|
294
303
|
new_size: Optional[tuple[int, int]] = None,
|
|
295
304
|
new_context_length: Optional[int] = None,
|
|
@@ -300,6 +309,7 @@ def load_checkpoint(
|
|
|
300
309
|
tag=tag,
|
|
301
310
|
image_encoder=image_encoder,
|
|
302
311
|
text_encoder=text_encoder,
|
|
312
|
+
text_decoder=text_decoder,
|
|
303
313
|
embed_dim=embed_dim,
|
|
304
314
|
tokenizer=tokenizer,
|
|
305
315
|
)
|
|
@@ -329,10 +339,12 @@ def load_checkpoint(
|
|
|
329
339
|
checkpoint_config,
|
|
330
340
|
image_encoder=image_encoder,
|
|
331
341
|
text_encoder=text_encoder,
|
|
342
|
+
text_decoder=text_decoder,
|
|
332
343
|
embed_dim=embed_dim,
|
|
333
344
|
tokenizer=tokenizer,
|
|
334
345
|
image_encoder_config=image_encoder_config,
|
|
335
346
|
text_encoder_config=text_encoder_config,
|
|
347
|
+
text_decoder_config=text_decoder_config,
|
|
336
348
|
input_channels=input_channels,
|
|
337
349
|
image_size=size,
|
|
338
350
|
context_length=context_length,
|
|
@@ -364,10 +376,12 @@ def load_model(
|
|
|
364
376
|
tag: Optional[str] = None,
|
|
365
377
|
image_encoder: Optional[str] = None,
|
|
366
378
|
text_encoder: Optional[str] = None,
|
|
379
|
+
text_decoder: Optional[str] = None,
|
|
367
380
|
embed_dim: Optional[int] = None,
|
|
368
381
|
tokenizer: Optional[str] = None,
|
|
369
382
|
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
370
383
|
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
384
|
+
text_decoder_config: Optional[dict[str, Any]] = None,
|
|
371
385
|
epoch: Optional[int] = None,
|
|
372
386
|
new_size: Optional[tuple[int, int]] = None,
|
|
373
387
|
new_context_length: Optional[int] = None,
|
|
@@ -381,6 +395,7 @@ def load_model(
|
|
|
381
395
|
tag=tag,
|
|
382
396
|
image_encoder=image_encoder,
|
|
383
397
|
text_encoder=text_encoder,
|
|
398
|
+
text_decoder=text_decoder,
|
|
384
399
|
embed_dim=embed_dim,
|
|
385
400
|
tokenizer=tokenizer,
|
|
386
401
|
)
|
|
@@ -420,10 +435,12 @@ def load_model(
|
|
|
420
435
|
checkpoint_config,
|
|
421
436
|
image_encoder=image_encoder,
|
|
422
437
|
text_encoder=text_encoder,
|
|
438
|
+
text_decoder=text_decoder,
|
|
423
439
|
embed_dim=embed_dim,
|
|
424
440
|
tokenizer=tokenizer,
|
|
425
441
|
image_encoder_config=image_encoder_config,
|
|
426
442
|
text_encoder_config=text_encoder_config,
|
|
443
|
+
text_decoder_config=text_decoder_config,
|
|
427
444
|
input_channels=input_channels,
|
|
428
445
|
image_size=size,
|
|
429
446
|
context_length=context_length,
|
|
@@ -525,15 +542,17 @@ def load_pretrained_model(
|
|
|
525
542
|
if model_metadata["task"] != Task.IMAGE_TEXT:
|
|
526
543
|
raise ValueError(f"Unknown model type: {model_metadata['task']}")
|
|
527
544
|
|
|
528
|
-
image_encoder, image_config =
|
|
529
|
-
text_encoder, text_config =
|
|
545
|
+
image_encoder, image_config = _split_component_metadata(model_metadata["net"].get("image_encoder", None))
|
|
546
|
+
text_encoder, text_config = _split_component_metadata(model_metadata["net"].get("text_encoder", None))
|
|
547
|
+
text_decoder, decoder_config = _split_component_metadata(model_metadata["net"].get("text_decoder", None))
|
|
530
548
|
|
|
531
549
|
pretrained_config: dict[str, Any] = {}
|
|
532
550
|
if image_config is not None:
|
|
533
551
|
pretrained_config["image"] = image_config
|
|
534
|
-
|
|
535
552
|
if text_config is not None:
|
|
536
553
|
pretrained_config["text"] = text_config
|
|
554
|
+
if decoder_config is not None:
|
|
555
|
+
pretrained_config["decoder"] = decoder_config
|
|
537
556
|
|
|
538
557
|
if custom_config is not None:
|
|
539
558
|
pretrained_config.update(custom_config)
|
|
@@ -550,6 +569,7 @@ def load_pretrained_model(
|
|
|
550
569
|
tag=model_metadata["net"].get("tag", None),
|
|
551
570
|
image_encoder=image_encoder,
|
|
552
571
|
text_encoder=text_encoder,
|
|
572
|
+
text_decoder=text_decoder,
|
|
553
573
|
embed_dim=model_metadata["net"].get("embed_dim", None),
|
|
554
574
|
tokenizer=model_metadata["net"].get("tokenizer", None),
|
|
555
575
|
inference=inference,
|
|
@@ -3,14 +3,17 @@ from typing import Any
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
5
|
from birder.data.transforms.classification import RGBType
|
|
6
|
+
from birder.version import __version__ as birder_version
|
|
6
7
|
|
|
7
8
|
from birder_clip.conf import settings
|
|
8
9
|
from birder_clip.model_registry import registry
|
|
9
10
|
from birder_clip.net.base import BaseNet
|
|
10
11
|
from birder_clip.net.base import SignatureType
|
|
11
|
-
from birder_clip.version import __version__
|
|
12
|
+
from birder_clip.version import __version__ as birder_clip_version
|
|
12
13
|
|
|
13
|
-
MODEL_CONFIG_RESERVED_KEYS = frozenset(
|
|
14
|
+
MODEL_CONFIG_RESERVED_KEYS = frozenset(
|
|
15
|
+
{"image", "text", "decoder", "tokenizer", "embed_dim", "embed-dim", "keep_ratio"}
|
|
16
|
+
)
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
def get_size_from_signature(signature: SignatureType) -> tuple[int, int]:
|
|
@@ -37,6 +40,7 @@ def get_image_text_network_name(
|
|
|
37
40
|
tag: Optional[str] = None,
|
|
38
41
|
image_encoder: Optional[str] = None,
|
|
39
42
|
text_encoder: Optional[str] = None,
|
|
43
|
+
text_decoder: Optional[str] = None,
|
|
40
44
|
embed_dim: Optional[int] = None,
|
|
41
45
|
tokenizer: Optional[str] = None,
|
|
42
46
|
) -> str:
|
|
@@ -45,6 +49,8 @@ def get_image_text_network_name(
|
|
|
45
49
|
parts.append(image_encoder)
|
|
46
50
|
if text_encoder is not None and text_encoder != "transformer_encoder":
|
|
47
51
|
parts.append(text_encoder)
|
|
52
|
+
if text_decoder is not None and text_decoder != "conditioned_decoder":
|
|
53
|
+
parts.append(text_decoder)
|
|
48
54
|
|
|
49
55
|
if registry.exists(network) is True:
|
|
50
56
|
default_tokenizer = registry.get_default_tokenizer(network)
|
|
@@ -71,10 +77,12 @@ def get_image_text_model_config(
|
|
|
71
77
|
*,
|
|
72
78
|
image_encoder: Optional[str] = None,
|
|
73
79
|
text_encoder: Optional[str] = None,
|
|
80
|
+
text_decoder: Optional[str] = None,
|
|
74
81
|
embed_dim: Optional[int] = None,
|
|
75
82
|
tokenizer: Optional[str] = None,
|
|
76
83
|
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
77
84
|
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
85
|
+
text_decoder_config: Optional[dict[str, Any]] = None,
|
|
78
86
|
input_channels: Optional[int] = None,
|
|
79
87
|
image_size: Optional[tuple[int, int]] = None,
|
|
80
88
|
context_length: Optional[int] = None,
|
|
@@ -86,7 +94,7 @@ def get_image_text_model_config(
|
|
|
86
94
|
|
|
87
95
|
if config is not None:
|
|
88
96
|
for key, value in config.items():
|
|
89
|
-
if key in {"image", "text"} and isinstance(value, dict):
|
|
97
|
+
if key in {"image", "text", "decoder"} and isinstance(value, dict):
|
|
90
98
|
model_config[key] = {**model_config.get(key, {}), **value}
|
|
91
99
|
else:
|
|
92
100
|
model_config[key] = value
|
|
@@ -111,7 +119,7 @@ def get_image_text_model_config(
|
|
|
111
119
|
|
|
112
120
|
model_config["image"] = image_config
|
|
113
121
|
|
|
114
|
-
if text_encoder is not None or text_encoder_config is not None or
|
|
122
|
+
if text_encoder is not None or text_encoder_config is not None or "text" in model_config:
|
|
115
123
|
text_config = model_config.get("text", {}).copy()
|
|
116
124
|
if text_encoder is not None:
|
|
117
125
|
# String encoder metadata replaces only the encoder name.
|
|
@@ -124,6 +132,17 @@ def get_image_text_model_config(
|
|
|
124
132
|
|
|
125
133
|
model_config["text"] = text_config
|
|
126
134
|
|
|
135
|
+
if text_decoder is not None or text_decoder_config is not None or "decoder" in model_config:
|
|
136
|
+
decoder_config = model_config.get("decoder", {}).copy()
|
|
137
|
+
if text_decoder is not None:
|
|
138
|
+
decoder_config["network"] = text_decoder
|
|
139
|
+
if text_decoder_config is not None:
|
|
140
|
+
decoder_config["config"] = {**decoder_config.get("config", {}), **text_decoder_config}
|
|
141
|
+
if context_length is not None:
|
|
142
|
+
decoder_config["context_length"] = context_length
|
|
143
|
+
|
|
144
|
+
model_config["decoder"] = decoder_config
|
|
145
|
+
|
|
127
146
|
if embed_dim is not None:
|
|
128
147
|
model_config["embed_dim"] = embed_dim
|
|
129
148
|
if tokenizer is not None:
|
|
@@ -143,7 +162,8 @@ def get_image_text_network_config(net: BaseNet, signature: SignatureType, rgb_st
|
|
|
143
162
|
model_config = net.config
|
|
144
163
|
|
|
145
164
|
return {
|
|
146
|
-
"birder_clip_version":
|
|
165
|
+
"birder_clip_version": birder_clip_version,
|
|
166
|
+
"birder_version": birder_version,
|
|
147
167
|
"name": model_name,
|
|
148
168
|
"registered_name": registered_name,
|
|
149
169
|
"task": net.task,
|
|
@@ -27,6 +27,7 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
|
27
27
|
help="pretrained Birder image model weights path to load into the image encoder",
|
|
28
28
|
)
|
|
29
29
|
parser.add_argument("--text-encoder", type=str, help="the text encoder to use")
|
|
30
|
+
parser.add_argument("--text-decoder", type=str, help="the text decoder to use")
|
|
30
31
|
parser.add_argument("--embed-dim", type=int, metavar="N", help="shared image-text embedding dimension")
|
|
31
32
|
parser.add_argument("--tokenizer", type=str, help="the tokenizer to use")
|
|
32
33
|
parser.add_argument(
|
|
@@ -44,11 +45,21 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
|
44
45
|
action=cli.FlexibleDictAction,
|
|
45
46
|
help="override the text encoder configuration, accepts key-value pairs or JSON",
|
|
46
47
|
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"--text-decoder-config",
|
|
50
|
+
action=cli.FlexibleDictAction,
|
|
51
|
+
help="override the text decoder configuration, accepts key-value pairs or JSON",
|
|
52
|
+
)
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--openvision-v2-keep-ratio", type=float, help="OpenVision v2 image token keep ratio for caption decoding"
|
|
55
|
+
)
|
|
47
56
|
|
|
48
57
|
|
|
49
58
|
def add_loss_args(parser: argparse.ArgumentParser) -> None:
|
|
50
59
|
group = parser.add_argument_group("Loss parameters")
|
|
51
|
-
group.add_argument(
|
|
60
|
+
group.add_argument(
|
|
61
|
+
"--loss", type=str, choices=["clip", "coca", "caption"], default="clip", help="loss function to use"
|
|
62
|
+
)
|
|
52
63
|
group.add_argument(
|
|
53
64
|
"--coca-caption-loss-weight", type=float, default=1.0, help="weight assigned to CoCa caption loss"
|
|
54
65
|
)
|
|
@@ -662,13 +673,16 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
662
673
|
raise cli.ValidationError("--grad-accum-steps must be >= 1")
|
|
663
674
|
if args.grad_accum_cache_negatives is True and args.grad_accum_steps == 1:
|
|
664
675
|
raise cli.ValidationError("--grad-accum-cache-negatives requires --grad-accum-steps greater than 1")
|
|
665
|
-
if args.grad_accum_cache_negatives is True and args.loss
|
|
676
|
+
if args.grad_accum_cache_negatives is True and args.loss != "clip":
|
|
666
677
|
raise cli.ValidationError("--grad-accum-cache-negatives is only supported with --loss clip")
|
|
667
678
|
|
|
668
679
|
if args.coca_caption_loss_weight < 0.0:
|
|
669
680
|
raise cli.ValidationError("--coca-caption-loss-weight must be non-negative")
|
|
670
681
|
if args.coca_contrastive_loss_weight < 0.0:
|
|
671
682
|
raise cli.ValidationError("--coca-contrastive-loss-weight must be non-negative")
|
|
683
|
+
if args.openvision_v2_keep_ratio is not None:
|
|
684
|
+
if args.openvision_v2_keep_ratio <= 0.0 or args.openvision_v2_keep_ratio > 1.0:
|
|
685
|
+
raise cli.ValidationError("--openvision-v2-keep-ratio must be in range of (0, 1]")
|
|
672
686
|
|
|
673
687
|
# EMA
|
|
674
688
|
if args.model_ema_steps < 1:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
import json
|
|
2
3
|
import logging
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any
|
|
@@ -11,6 +12,7 @@ from birder.common import training_utils as birder_training_utils
|
|
|
11
12
|
|
|
12
13
|
from birder_clip.common import fs_ops
|
|
13
14
|
from birder_clip.conf import settings
|
|
15
|
+
from birder_clip.version import __version__ as birder_clip_version
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
|
|
@@ -28,6 +30,18 @@ def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
|
|
|
28
30
|
return file_handler
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def make_training_args_payload(args: argparse.Namespace) -> dict[str, Any]:
|
|
34
|
+
return {
|
|
35
|
+
"birder_clip_version": birder_clip_version,
|
|
36
|
+
**birder_training_utils.make_training_args_payload(args),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def write_training_args_json(path: Path, args: argparse.Namespace) -> None:
|
|
41
|
+
with open(path.joinpath("training_args.json"), "w", encoding="utf-8") as handle:
|
|
42
|
+
json.dump(make_training_args_payload(args), handle, indent=2)
|
|
43
|
+
|
|
44
|
+
|
|
31
45
|
def save_training_checkpoint(
|
|
32
46
|
args: argparse.Namespace,
|
|
33
47
|
network_name: str,
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CaptionLoss(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Autoregressive captioning cross entropy over decoder logits
|
|
11
|
+
|
|
12
|
+
The loss consumes unshifted tokenized captions and supports two decoder output conventions:
|
|
13
|
+
- logits length equals text length: the final logit is ignored.
|
|
14
|
+
- logits length equals text length - 1: logits are used as-is.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
caption_loss_weight: float = 1.0,
|
|
20
|
+
pad_token_id: int = 0,
|
|
21
|
+
ignore_token_ids: Optional[Sequence[int]] = None,
|
|
22
|
+
label_smoothing: float = 0.0,
|
|
23
|
+
) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.caption_loss_weight = caption_loss_weight
|
|
26
|
+
self.ignore_token_ids = tuple(ignore_token_ids) if ignore_token_ids is not None else (pad_token_id,)
|
|
27
|
+
self.label_smoothing = label_smoothing
|
|
28
|
+
|
|
29
|
+
def _align_logits_and_targets(self, logits: torch.Tensor, texts: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
30
|
+
if logits.size(1) == texts.size(1):
|
|
31
|
+
return logits[:, :-1], texts[:, 1:]
|
|
32
|
+
if logits.size(1) == texts.size(1) - 1:
|
|
33
|
+
return logits, texts[:, 1:]
|
|
34
|
+
|
|
35
|
+
raise ValueError(
|
|
36
|
+
"Expected logits sequence length to equal text sequence length or text sequence length - 1, "
|
|
37
|
+
f"got logits={logits.size(1)}, texts={texts.size(1)}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def _align_target_mask(self, target_mask: torch.Tensor, texts: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
if target_mask.shape == texts.shape:
|
|
42
|
+
target_mask = target_mask[:, 1:]
|
|
43
|
+
|
|
44
|
+
return target_mask
|
|
45
|
+
|
|
46
|
+
def forward(
|
|
47
|
+
self, logits: torch.Tensor, texts: torch.Tensor, target_mask: Optional[torch.Tensor] = None
|
|
48
|
+
) -> dict[str, torch.Tensor]:
|
|
49
|
+
if self.caption_loss_weight == 0.0:
|
|
50
|
+
return {"caption_loss": logits.new_zeros(())}
|
|
51
|
+
|
|
52
|
+
logits, targets = self._align_logits_and_targets(logits, texts)
|
|
53
|
+
token_loss = F.cross_entropy(
|
|
54
|
+
logits.permute(0, 2, 1),
|
|
55
|
+
targets,
|
|
56
|
+
reduction="none",
|
|
57
|
+
label_smoothing=self.label_smoothing,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
loss_mask = torch.ones_like(targets, dtype=torch.bool)
|
|
61
|
+
for token_id in self.ignore_token_ids:
|
|
62
|
+
loss_mask = loss_mask & (targets != token_id)
|
|
63
|
+
|
|
64
|
+
if target_mask is not None:
|
|
65
|
+
target_mask = self._align_target_mask(target_mask, texts)
|
|
66
|
+
loss_mask = loss_mask & target_mask.to(dtype=torch.bool)
|
|
67
|
+
|
|
68
|
+
loss_mask_float = loss_mask.to(dtype=token_loss.dtype)
|
|
69
|
+
caption_loss = (token_loss * loss_mask_float).sum() / loss_mask_float.sum().clamp_min(1)
|
|
70
|
+
caption_loss = caption_loss * self.caption_loss_weight
|
|
71
|
+
|
|
72
|
+
return {"caption_loss": caption_loss}
|
|
@@ -26,9 +26,7 @@ class CoCaLoss(torch.nn.Module):
|
|
|
26
26
|
captioning cross entropy over decoder logits.
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
|
-
def __init__(
|
|
30
|
-
self, *, caption_loss_weight: float = 1.0, clip_loss_weight: float = 1.0, pad_token_id: int = 0
|
|
31
|
-
) -> None:
|
|
29
|
+
def __init__(self, caption_loss_weight: float = 1.0, clip_loss_weight: float = 1.0, pad_token_id: int = 0) -> None:
|
|
32
30
|
super().__init__()
|
|
33
31
|
self.caption_loss_weight = caption_loss_weight
|
|
34
32
|
self.clip_loss_weight = clip_loss_weight
|
|
@@ -11,8 +11,8 @@ FormatInfoType = TypedDict(
|
|
|
11
11
|
{"file_size": float, "sha256": str},
|
|
12
12
|
)
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
"
|
|
14
|
+
ImageEncoderInfoType = TypedDict(
|
|
15
|
+
"ImageEncoderInfoType",
|
|
16
16
|
{
|
|
17
17
|
"network": str,
|
|
18
18
|
"config": NotRequired[dict[str, Any]],
|
|
@@ -21,16 +21,33 @@ EncoderInfoType = TypedDict(
|
|
|
21
21
|
"size": NotRequired[tuple[int, int]],
|
|
22
22
|
},
|
|
23
23
|
)
|
|
24
|
+
TextEncoderInfoType = TypedDict(
|
|
25
|
+
"TextEncoderInfoType",
|
|
26
|
+
{
|
|
27
|
+
"network": str,
|
|
28
|
+
"config": NotRequired[dict[str, Any]],
|
|
29
|
+
},
|
|
30
|
+
)
|
|
31
|
+
TextDecoderInfoType = TypedDict(
|
|
32
|
+
"TextDecoderInfoType",
|
|
33
|
+
{
|
|
34
|
+
"network": str,
|
|
35
|
+
"config": NotRequired[dict[str, Any]],
|
|
36
|
+
},
|
|
37
|
+
)
|
|
24
38
|
|
|
25
|
-
|
|
39
|
+
ImageEncoderMetadataType: TypeAlias = str | ImageEncoderInfoType
|
|
40
|
+
TextEncoderMetadataType: TypeAlias = str | TextEncoderInfoType
|
|
41
|
+
TextDecoderMetadataType: TypeAlias = str | TextDecoderInfoType
|
|
26
42
|
|
|
27
43
|
NetworkInfoType = TypedDict(
|
|
28
44
|
"NetworkInfoType",
|
|
29
45
|
{
|
|
30
46
|
"network": str,
|
|
31
47
|
"tag": NotRequired[str],
|
|
32
|
-
"image_encoder": NotRequired[
|
|
33
|
-
"text_encoder": NotRequired[
|
|
48
|
+
"image_encoder": NotRequired[ImageEncoderMetadataType],
|
|
49
|
+
"text_encoder": NotRequired[TextEncoderMetadataType],
|
|
50
|
+
"text_decoder": NotRequired[TextDecoderMetadataType],
|
|
34
51
|
"embed_dim": NotRequired[int],
|
|
35
52
|
"tokenizer": NotRequired[str],
|
|
36
53
|
},
|
{birder_clip-0.0.2.dev7 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/model_registry.py
RENAMED
|
@@ -170,6 +170,11 @@ class ModelRegistry:
|
|
|
170
170
|
if context_length is not None:
|
|
171
171
|
return context_length # type: ignore[no-any-return]
|
|
172
172
|
|
|
173
|
+
decoder_config = config.get("decoder", {})
|
|
174
|
+
context_length = decoder_config.get("context_length")
|
|
175
|
+
if context_length is not None:
|
|
176
|
+
return context_length # type: ignore[no-any-return]
|
|
177
|
+
|
|
173
178
|
if tokenizer is None:
|
|
174
179
|
tokenizer = config.get("tokenizer")
|
|
175
180
|
|