birder-clip 0.0.2.dev6__tar.gz → 0.0.2.dev8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/PKG-INFO +4 -4
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/README.md +1 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/fs_ops.py +31 -11
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/lib.py +25 -5
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/training_cli.py +16 -2
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/training_utils.py +14 -0
- birder_clip-0.0.2.dev8/birder_clip/inference/data_parallel.py +118 -0
- birder_clip-0.0.2.dev8/birder_clip/inference/image_embeddings.py +63 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot.py +40 -33
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/loss/__init__.py +2 -0
- birder_clip-0.0.2.dev8/birder_clip/loss/caption.py +72 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/loss/coca.py +1 -3
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/manifest.py +22 -5
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/model_registry.py +5 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/__init__.py +2 -0
- birder_clip-0.0.2.dev8/birder_clip/net/base.py +201 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/clip.py +50 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/coca.py +133 -3
- birder_clip-0.0.2.dev8/birder_clip/net/openvision_v2.py +219 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/__init__.py +3 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/base.py +40 -19
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/conditioned_decoder.py +4 -3
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/encoder.py +22 -3
- birder_clip-0.0.2.dev8/birder_clip/net/text/visual_causal_decoder.py +195 -0
- birder_clip-0.0.2.dev8/birder_clip/scripts/__main__.py +25 -0
- birder_clip-0.0.2.dev8/birder_clip/scripts/embed_images.py +447 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/scripts/train.py +58 -9
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/scripts/zero_shot.py +42 -7
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/openvision.py +23 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/convert_model.py +186 -9
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/model_info.py +22 -2
- birder_clip-0.0.2.dev8/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/PKG-INFO +4 -4
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/SOURCES.txt +8 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/requires.txt +2 -2
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/requirements/_requirements-dev.txt +1 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/requirements/requirements.txt +1 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_common.py +25 -0
- birder_clip-0.0.2.dev8/tests/test_inference.py +143 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_loss.py +60 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_model_registry.py +25 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_net.py +342 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_net_text.py +44 -0
- birder_clip-0.0.2.dev6/birder_clip/net/base.py +0 -77
- birder_clip-0.0.2.dev6/birder_clip/net/text/prefix_decoder.py +0 -1
- birder_clip-0.0.2.dev6/birder_clip/version.py +0 -1
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/LICENSE +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/conf/settings.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/csv.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/fake.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/data/datasets/webdataset.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/loss/contrastive.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/net/text/hf.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/py.typed +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/base.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/hf.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/registry.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tokenizers/simple_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/__init__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/__main__.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/download_tokenizer.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/list_models.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/show_iterator.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip/tools/stats.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_datasets.py +0 -0
- {birder_clip-0.0.2.dev6 → birder_clip-0.0.2.dev8}/tests/test_tokenizers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder_clip
|
|
3
|
-
Version: 0.0.2.
|
|
3
|
+
Version: 0.0.2.dev8
|
|
4
4
|
Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
|
|
|
24
24
|
Requires-Python: >=3.11
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
License-File: LICENSE
|
|
27
|
-
Requires-Dist: birder>=0.6.
|
|
27
|
+
Requires-Dist: birder>=0.6.2
|
|
28
28
|
Requires-Dist: ftfy>=6.3.1
|
|
29
29
|
Requires-Dist: regex>=2025.7.29
|
|
30
30
|
Requires-Dist: tqdm>=4.67.0
|
|
@@ -38,7 +38,7 @@ Requires-Dist: bandit~=1.9.4; extra == "dev"
|
|
|
38
38
|
Requires-Dist: black~=26.5.0; extra == "dev"
|
|
39
39
|
Requires-Dist: build~=1.5.0; extra == "dev"
|
|
40
40
|
Requires-Dist: bumpver~=2026.1132; extra == "dev"
|
|
41
|
-
Requires-Dist: coverage~=7.14.
|
|
41
|
+
Requires-Dist: coverage~=7.14.3; extra == "dev"
|
|
42
42
|
Requires-Dist: debugpy; extra == "dev"
|
|
43
43
|
Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
|
|
44
44
|
Requires-Dist: flake8~=7.3.0; extra == "dev"
|
|
@@ -85,7 +85,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
|
|
|
85
85
|
1. Ensure your environment meets the minimum requirements:
|
|
86
86
|
- Python 3.11 or newer
|
|
87
87
|
- PyTorch 2.10 or newer (installed for your hardware/driver stack)
|
|
88
|
-
- Birder 0.6.
|
|
88
|
+
- Birder 0.6.2 or newer
|
|
89
89
|
|
|
90
90
|
1. Install the latest Birder CLIP version:
|
|
91
91
|
|
|
@@ -23,7 +23,7 @@ Full training is supported, but for large-scale CLIP pretraining you are probabl
|
|
|
23
23
|
1. Ensure your environment meets the minimum requirements:
|
|
24
24
|
- Python 3.11 or newer
|
|
25
25
|
- PyTorch 2.10 or newer (installed for your hardware/driver stack)
|
|
26
|
-
- Birder 0.6.
|
|
26
|
+
- Birder 0.6.2 or newer
|
|
27
27
|
|
|
28
28
|
1. Install the latest Birder CLIP version:
|
|
29
29
|
|
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
from typing import NamedTuple
|
|
8
8
|
from typing import Optional
|
|
9
|
+
from typing import TypeAlias
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
from birder.common import cli
|
|
@@ -16,8 +17,10 @@ from birder.data.transforms.classification import inference_preset
|
|
|
16
17
|
from birder_clip.common import lib
|
|
17
18
|
from birder_clip.model_registry import Task
|
|
18
19
|
from birder_clip.model_registry import registry
|
|
19
|
-
from birder_clip.model_registry.manifest import EncoderMetadataType
|
|
20
20
|
from birder_clip.model_registry.manifest import FileFormatType
|
|
21
|
+
from birder_clip.model_registry.manifest import ImageEncoderMetadataType
|
|
22
|
+
from birder_clip.model_registry.manifest import TextDecoderMetadataType
|
|
23
|
+
from birder_clip.model_registry.manifest import TextEncoderMetadataType
|
|
21
24
|
from birder_clip.net.base import BaseNet
|
|
22
25
|
from birder_clip.net.base import SignatureType
|
|
23
26
|
from birder_clip.tokenizers import Tokenizer
|
|
@@ -36,6 +39,8 @@ except ImportError:
|
|
|
36
39
|
|
|
37
40
|
logger = logging.getLogger(__name__)
|
|
38
41
|
|
|
42
|
+
ComponentMetadataType: TypeAlias = ImageEncoderMetadataType | TextEncoderMetadataType | TextDecoderMetadataType
|
|
43
|
+
|
|
39
44
|
|
|
40
45
|
class ModelInfo(NamedTuple):
|
|
41
46
|
signature: SignatureType
|
|
@@ -51,16 +56,18 @@ def write_config(network_name: str, net: BaseNet, signature: SignatureType, rgb_
|
|
|
51
56
|
json.dump(model_config, handle, indent=2)
|
|
52
57
|
|
|
53
58
|
|
|
54
|
-
def
|
|
55
|
-
|
|
59
|
+
def _split_component_metadata(
|
|
60
|
+
component: Optional[ComponentMetadataType],
|
|
61
|
+
) -> tuple[Optional[str], Optional[dict[str, Any]]]:
|
|
62
|
+
if component is None:
|
|
56
63
|
return (None, None)
|
|
57
|
-
if isinstance(
|
|
58
|
-
return (
|
|
64
|
+
if isinstance(component, str):
|
|
65
|
+
return (component, None)
|
|
59
66
|
|
|
60
|
-
if "network" not in
|
|
61
|
-
raise ValueError("
|
|
67
|
+
if "network" not in component:
|
|
68
|
+
raise ValueError("Component metadata must include a 'network' field")
|
|
62
69
|
|
|
63
|
-
return (None,
|
|
70
|
+
return (None, component) # type: ignore[return-value]
|
|
64
71
|
|
|
65
72
|
|
|
66
73
|
def model_path(
|
|
@@ -286,10 +293,12 @@ def load_checkpoint(
|
|
|
286
293
|
tag: Optional[str] = None,
|
|
287
294
|
image_encoder: Optional[str] = None,
|
|
288
295
|
text_encoder: Optional[str] = None,
|
|
296
|
+
text_decoder: Optional[str] = None,
|
|
289
297
|
embed_dim: Optional[int] = None,
|
|
290
298
|
tokenizer: Optional[str] = None,
|
|
291
299
|
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
292
300
|
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
301
|
+
text_decoder_config: Optional[dict[str, Any]] = None,
|
|
293
302
|
epoch: Optional[int] = None,
|
|
294
303
|
new_size: Optional[tuple[int, int]] = None,
|
|
295
304
|
new_context_length: Optional[int] = None,
|
|
@@ -300,6 +309,7 @@ def load_checkpoint(
|
|
|
300
309
|
tag=tag,
|
|
301
310
|
image_encoder=image_encoder,
|
|
302
311
|
text_encoder=text_encoder,
|
|
312
|
+
text_decoder=text_decoder,
|
|
303
313
|
embed_dim=embed_dim,
|
|
304
314
|
tokenizer=tokenizer,
|
|
305
315
|
)
|
|
@@ -329,10 +339,12 @@ def load_checkpoint(
|
|
|
329
339
|
checkpoint_config,
|
|
330
340
|
image_encoder=image_encoder,
|
|
331
341
|
text_encoder=text_encoder,
|
|
342
|
+
text_decoder=text_decoder,
|
|
332
343
|
embed_dim=embed_dim,
|
|
333
344
|
tokenizer=tokenizer,
|
|
334
345
|
image_encoder_config=image_encoder_config,
|
|
335
346
|
text_encoder_config=text_encoder_config,
|
|
347
|
+
text_decoder_config=text_decoder_config,
|
|
336
348
|
input_channels=input_channels,
|
|
337
349
|
image_size=size,
|
|
338
350
|
context_length=context_length,
|
|
@@ -364,10 +376,12 @@ def load_model(
|
|
|
364
376
|
tag: Optional[str] = None,
|
|
365
377
|
image_encoder: Optional[str] = None,
|
|
366
378
|
text_encoder: Optional[str] = None,
|
|
379
|
+
text_decoder: Optional[str] = None,
|
|
367
380
|
embed_dim: Optional[int] = None,
|
|
368
381
|
tokenizer: Optional[str] = None,
|
|
369
382
|
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
370
383
|
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
384
|
+
text_decoder_config: Optional[dict[str, Any]] = None,
|
|
371
385
|
epoch: Optional[int] = None,
|
|
372
386
|
new_size: Optional[tuple[int, int]] = None,
|
|
373
387
|
new_context_length: Optional[int] = None,
|
|
@@ -381,6 +395,7 @@ def load_model(
|
|
|
381
395
|
tag=tag,
|
|
382
396
|
image_encoder=image_encoder,
|
|
383
397
|
text_encoder=text_encoder,
|
|
398
|
+
text_decoder=text_decoder,
|
|
384
399
|
embed_dim=embed_dim,
|
|
385
400
|
tokenizer=tokenizer,
|
|
386
401
|
)
|
|
@@ -420,10 +435,12 @@ def load_model(
|
|
|
420
435
|
checkpoint_config,
|
|
421
436
|
image_encoder=image_encoder,
|
|
422
437
|
text_encoder=text_encoder,
|
|
438
|
+
text_decoder=text_decoder,
|
|
423
439
|
embed_dim=embed_dim,
|
|
424
440
|
tokenizer=tokenizer,
|
|
425
441
|
image_encoder_config=image_encoder_config,
|
|
426
442
|
text_encoder_config=text_encoder_config,
|
|
443
|
+
text_decoder_config=text_decoder_config,
|
|
427
444
|
input_channels=input_channels,
|
|
428
445
|
image_size=size,
|
|
429
446
|
context_length=context_length,
|
|
@@ -525,15 +542,17 @@ def load_pretrained_model(
|
|
|
525
542
|
if model_metadata["task"] != Task.IMAGE_TEXT:
|
|
526
543
|
raise ValueError(f"Unknown model type: {model_metadata['task']}")
|
|
527
544
|
|
|
528
|
-
image_encoder, image_config =
|
|
529
|
-
text_encoder, text_config =
|
|
545
|
+
image_encoder, image_config = _split_component_metadata(model_metadata["net"].get("image_encoder", None))
|
|
546
|
+
text_encoder, text_config = _split_component_metadata(model_metadata["net"].get("text_encoder", None))
|
|
547
|
+
text_decoder, decoder_config = _split_component_metadata(model_metadata["net"].get("text_decoder", None))
|
|
530
548
|
|
|
531
549
|
pretrained_config: dict[str, Any] = {}
|
|
532
550
|
if image_config is not None:
|
|
533
551
|
pretrained_config["image"] = image_config
|
|
534
|
-
|
|
535
552
|
if text_config is not None:
|
|
536
553
|
pretrained_config["text"] = text_config
|
|
554
|
+
if decoder_config is not None:
|
|
555
|
+
pretrained_config["decoder"] = decoder_config
|
|
537
556
|
|
|
538
557
|
if custom_config is not None:
|
|
539
558
|
pretrained_config.update(custom_config)
|
|
@@ -550,6 +569,7 @@ def load_pretrained_model(
|
|
|
550
569
|
tag=model_metadata["net"].get("tag", None),
|
|
551
570
|
image_encoder=image_encoder,
|
|
552
571
|
text_encoder=text_encoder,
|
|
572
|
+
text_decoder=text_decoder,
|
|
553
573
|
embed_dim=model_metadata["net"].get("embed_dim", None),
|
|
554
574
|
tokenizer=model_metadata["net"].get("tokenizer", None),
|
|
555
575
|
inference=inference,
|
|
@@ -3,14 +3,17 @@ from typing import Any
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
5
|
from birder.data.transforms.classification import RGBType
|
|
6
|
+
from birder.version import __version__ as birder_version
|
|
6
7
|
|
|
7
8
|
from birder_clip.conf import settings
|
|
8
9
|
from birder_clip.model_registry import registry
|
|
9
10
|
from birder_clip.net.base import BaseNet
|
|
10
11
|
from birder_clip.net.base import SignatureType
|
|
11
|
-
from birder_clip.version import __version__
|
|
12
|
+
from birder_clip.version import __version__ as birder_clip_version
|
|
12
13
|
|
|
13
|
-
MODEL_CONFIG_RESERVED_KEYS = frozenset(
|
|
14
|
+
MODEL_CONFIG_RESERVED_KEYS = frozenset(
|
|
15
|
+
{"image", "text", "decoder", "tokenizer", "embed_dim", "embed-dim", "keep_ratio"}
|
|
16
|
+
)
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
def get_size_from_signature(signature: SignatureType) -> tuple[int, int]:
|
|
@@ -37,6 +40,7 @@ def get_image_text_network_name(
|
|
|
37
40
|
tag: Optional[str] = None,
|
|
38
41
|
image_encoder: Optional[str] = None,
|
|
39
42
|
text_encoder: Optional[str] = None,
|
|
43
|
+
text_decoder: Optional[str] = None,
|
|
40
44
|
embed_dim: Optional[int] = None,
|
|
41
45
|
tokenizer: Optional[str] = None,
|
|
42
46
|
) -> str:
|
|
@@ -45,6 +49,8 @@ def get_image_text_network_name(
|
|
|
45
49
|
parts.append(image_encoder)
|
|
46
50
|
if text_encoder is not None and text_encoder != "transformer_encoder":
|
|
47
51
|
parts.append(text_encoder)
|
|
52
|
+
if text_decoder is not None and text_decoder != "conditioned_decoder":
|
|
53
|
+
parts.append(text_decoder)
|
|
48
54
|
|
|
49
55
|
if registry.exists(network) is True:
|
|
50
56
|
default_tokenizer = registry.get_default_tokenizer(network)
|
|
@@ -71,10 +77,12 @@ def get_image_text_model_config(
|
|
|
71
77
|
*,
|
|
72
78
|
image_encoder: Optional[str] = None,
|
|
73
79
|
text_encoder: Optional[str] = None,
|
|
80
|
+
text_decoder: Optional[str] = None,
|
|
74
81
|
embed_dim: Optional[int] = None,
|
|
75
82
|
tokenizer: Optional[str] = None,
|
|
76
83
|
image_encoder_config: Optional[dict[str, Any]] = None,
|
|
77
84
|
text_encoder_config: Optional[dict[str, Any]] = None,
|
|
85
|
+
text_decoder_config: Optional[dict[str, Any]] = None,
|
|
78
86
|
input_channels: Optional[int] = None,
|
|
79
87
|
image_size: Optional[tuple[int, int]] = None,
|
|
80
88
|
context_length: Optional[int] = None,
|
|
@@ -86,7 +94,7 @@ def get_image_text_model_config(
|
|
|
86
94
|
|
|
87
95
|
if config is not None:
|
|
88
96
|
for key, value in config.items():
|
|
89
|
-
if key in {"image", "text"} and isinstance(value, dict):
|
|
97
|
+
if key in {"image", "text", "decoder"} and isinstance(value, dict):
|
|
90
98
|
model_config[key] = {**model_config.get(key, {}), **value}
|
|
91
99
|
else:
|
|
92
100
|
model_config[key] = value
|
|
@@ -111,7 +119,7 @@ def get_image_text_model_config(
|
|
|
111
119
|
|
|
112
120
|
model_config["image"] = image_config
|
|
113
121
|
|
|
114
|
-
if text_encoder is not None or text_encoder_config is not None or
|
|
122
|
+
if text_encoder is not None or text_encoder_config is not None or "text" in model_config:
|
|
115
123
|
text_config = model_config.get("text", {}).copy()
|
|
116
124
|
if text_encoder is not None:
|
|
117
125
|
# String encoder metadata replaces only the encoder name.
|
|
@@ -124,6 +132,17 @@ def get_image_text_model_config(
|
|
|
124
132
|
|
|
125
133
|
model_config["text"] = text_config
|
|
126
134
|
|
|
135
|
+
if text_decoder is not None or text_decoder_config is not None or "decoder" in model_config:
|
|
136
|
+
decoder_config = model_config.get("decoder", {}).copy()
|
|
137
|
+
if text_decoder is not None:
|
|
138
|
+
decoder_config["network"] = text_decoder
|
|
139
|
+
if text_decoder_config is not None:
|
|
140
|
+
decoder_config["config"] = {**decoder_config.get("config", {}), **text_decoder_config}
|
|
141
|
+
if context_length is not None:
|
|
142
|
+
decoder_config["context_length"] = context_length
|
|
143
|
+
|
|
144
|
+
model_config["decoder"] = decoder_config
|
|
145
|
+
|
|
127
146
|
if embed_dim is not None:
|
|
128
147
|
model_config["embed_dim"] = embed_dim
|
|
129
148
|
if tokenizer is not None:
|
|
@@ -143,7 +162,8 @@ def get_image_text_network_config(net: BaseNet, signature: SignatureType, rgb_st
|
|
|
143
162
|
model_config = net.config
|
|
144
163
|
|
|
145
164
|
return {
|
|
146
|
-
"birder_clip_version":
|
|
165
|
+
"birder_clip_version": birder_clip_version,
|
|
166
|
+
"birder_version": birder_version,
|
|
147
167
|
"name": model_name,
|
|
148
168
|
"registered_name": registered_name,
|
|
149
169
|
"task": net.task,
|
|
@@ -27,6 +27,7 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
|
27
27
|
help="pretrained Birder image model weights path to load into the image encoder",
|
|
28
28
|
)
|
|
29
29
|
parser.add_argument("--text-encoder", type=str, help="the text encoder to use")
|
|
30
|
+
parser.add_argument("--text-decoder", type=str, help="the text decoder to use")
|
|
30
31
|
parser.add_argument("--embed-dim", type=int, metavar="N", help="shared image-text embedding dimension")
|
|
31
32
|
parser.add_argument("--tokenizer", type=str, help="the tokenizer to use")
|
|
32
33
|
parser.add_argument(
|
|
@@ -44,11 +45,21 @@ def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
|
44
45
|
action=cli.FlexibleDictAction,
|
|
45
46
|
help="override the text encoder configuration, accepts key-value pairs or JSON",
|
|
46
47
|
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"--text-decoder-config",
|
|
50
|
+
action=cli.FlexibleDictAction,
|
|
51
|
+
help="override the text decoder configuration, accepts key-value pairs or JSON",
|
|
52
|
+
)
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--openvision-v2-keep-ratio", type=float, help="OpenVision v2 image token keep ratio for caption decoding"
|
|
55
|
+
)
|
|
47
56
|
|
|
48
57
|
|
|
49
58
|
def add_loss_args(parser: argparse.ArgumentParser) -> None:
|
|
50
59
|
group = parser.add_argument_group("Loss parameters")
|
|
51
|
-
group.add_argument(
|
|
60
|
+
group.add_argument(
|
|
61
|
+
"--loss", type=str, choices=["clip", "coca", "caption"], default="clip", help="loss function to use"
|
|
62
|
+
)
|
|
52
63
|
group.add_argument(
|
|
53
64
|
"--coca-caption-loss-weight", type=float, default=1.0, help="weight assigned to CoCa caption loss"
|
|
54
65
|
)
|
|
@@ -662,13 +673,16 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
662
673
|
raise cli.ValidationError("--grad-accum-steps must be >= 1")
|
|
663
674
|
if args.grad_accum_cache_negatives is True and args.grad_accum_steps == 1:
|
|
664
675
|
raise cli.ValidationError("--grad-accum-cache-negatives requires --grad-accum-steps greater than 1")
|
|
665
|
-
if args.grad_accum_cache_negatives is True and args.loss
|
|
676
|
+
if args.grad_accum_cache_negatives is True and args.loss != "clip":
|
|
666
677
|
raise cli.ValidationError("--grad-accum-cache-negatives is only supported with --loss clip")
|
|
667
678
|
|
|
668
679
|
if args.coca_caption_loss_weight < 0.0:
|
|
669
680
|
raise cli.ValidationError("--coca-caption-loss-weight must be non-negative")
|
|
670
681
|
if args.coca_contrastive_loss_weight < 0.0:
|
|
671
682
|
raise cli.ValidationError("--coca-contrastive-loss-weight must be non-negative")
|
|
683
|
+
if args.openvision_v2_keep_ratio is not None:
|
|
684
|
+
if args.openvision_v2_keep_ratio <= 0.0 or args.openvision_v2_keep_ratio > 1.0:
|
|
685
|
+
raise cli.ValidationError("--openvision-v2-keep-ratio must be in range of (0, 1]")
|
|
672
686
|
|
|
673
687
|
# EMA
|
|
674
688
|
if args.model_ema_steps < 1:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
import json
|
|
2
3
|
import logging
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any
|
|
@@ -11,6 +12,7 @@ from birder.common import training_utils as birder_training_utils
|
|
|
11
12
|
|
|
12
13
|
from birder_clip.common import fs_ops
|
|
13
14
|
from birder_clip.conf import settings
|
|
15
|
+
from birder_clip.version import __version__ as birder_clip_version
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
|
|
@@ -28,6 +30,18 @@ def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
|
|
|
28
30
|
return file_handler
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def make_training_args_payload(args: argparse.Namespace) -> dict[str, Any]:
|
|
34
|
+
return {
|
|
35
|
+
"birder_clip_version": birder_clip_version,
|
|
36
|
+
**birder_training_utils.make_training_args_payload(args),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def write_training_args_json(path: Path, args: argparse.Namespace) -> None:
|
|
41
|
+
with open(path.joinpath("training_args.json"), "w", encoding="utf-8") as handle:
|
|
42
|
+
json.dump(make_training_args_payload(args), handle, indent=2)
|
|
43
|
+
|
|
44
|
+
|
|
31
45
|
def save_training_checkpoint(
|
|
32
46
|
args: argparse.Namespace,
|
|
33
47
|
network_name: str,
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Inference-optimized multi-GPU parallelization for image-text models
|
|
3
|
+
|
|
4
|
+
This module provides ZeroShotInferenceDataParallel, a CLIP-style zero-shot
|
|
5
|
+
specialization of Birder's InferenceDataParallel.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from birder.inference.data_parallel import InferenceDataParallel
|
|
12
|
+
|
|
13
|
+
from birder_clip.inference.zero_shot import ZeroShotInference
|
|
14
|
+
from birder_clip.net.base import BaseNet
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ZeroShotInferenceDataParallel(InferenceDataParallel):
|
|
18
|
+
"""
|
|
19
|
+
Distributes zero-shot image inference batches across multiple GPUs
|
|
20
|
+
|
|
21
|
+
This wrapper scatters the image batch across devices and keeps a replicated
|
|
22
|
+
copy of the zero-shot text embeddings on each device. Each replica computes
|
|
23
|
+
image embeddings and zero-shot logits locally before outputs are gathered.
|
|
24
|
+
|
|
25
|
+
Important
|
|
26
|
+
---------
|
|
27
|
+
This class assumes the model is already configured for inference mode
|
|
28
|
+
(i.e., loaded with inference=True in load_model or manually set to eval mode
|
|
29
|
+
with requires_grad=False on all parameters).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
module: BaseNet,
|
|
35
|
+
text_embeddings: torch.Tensor,
|
|
36
|
+
device_ids: Optional[list[int]] = None,
|
|
37
|
+
output_device: Optional[int | str | torch.device] = None,
|
|
38
|
+
compile_replicas: bool = False,
|
|
39
|
+
compile_methods: Optional[list[str]] = None,
|
|
40
|
+
compile_mode: Optional[str] = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
if compile_methods is None:
|
|
43
|
+
compile_methods = ["encode_image", "forward_logits"]
|
|
44
|
+
|
|
45
|
+
super().__init__(
|
|
46
|
+
module,
|
|
47
|
+
device_ids=device_ids,
|
|
48
|
+
output_device=output_device,
|
|
49
|
+
compile_replicas=compile_replicas,
|
|
50
|
+
compile_methods=compile_methods,
|
|
51
|
+
compile_mode=compile_mode,
|
|
52
|
+
)
|
|
53
|
+
self.set_text_embeddings(text_embeddings)
|
|
54
|
+
|
|
55
|
+
def set_text_embeddings(self, text_embeddings: torch.Tensor) -> None:
|
|
56
|
+
if text_embeddings.ndim != 2:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"text_embeddings must be a 2D tensor of shape (num_classes, embedding_size), "
|
|
59
|
+
f"got shape {text_embeddings.size()}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.text_embeddings = [
|
|
63
|
+
text_embeddings.to(f"cuda:{device_id}", non_blocking=True) for device_id in self.device_ids
|
|
64
|
+
]
|
|
65
|
+
self.inference_modules = [
|
|
66
|
+
ZeroShotInference(replica, embeddings) for replica, embeddings in zip(self.replicas, self.text_embeddings)
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
def forward( # type: ignore[override] # pylint: disable=arguments-differ
|
|
70
|
+
self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Run zero-shot inference distributed across GPUs
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
inputs
|
|
78
|
+
Input image batch to process.
|
|
79
|
+
tta
|
|
80
|
+
Run inference with oversampling.
|
|
81
|
+
return_logits
|
|
82
|
+
If True, return raw logits instead of probabilities after softmax.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
if len(self.device_ids) == 1:
|
|
86
|
+
output = self.inference_modules[0](
|
|
87
|
+
inputs,
|
|
88
|
+
tta=tta,
|
|
89
|
+
return_logits=return_logits,
|
|
90
|
+
)
|
|
91
|
+
return self._gather([output])
|
|
92
|
+
|
|
93
|
+
scattered = self._scatter(inputs, {})
|
|
94
|
+
|
|
95
|
+
outputs = []
|
|
96
|
+
for inference, (input_chunk, _), device_id in zip(self.inference_modules, scattered, self.device_ids):
|
|
97
|
+
if input_chunk is not None and input_chunk.size(0) > 0:
|
|
98
|
+
with torch.cuda.device(device_id):
|
|
99
|
+
output = inference(
|
|
100
|
+
input_chunk,
|
|
101
|
+
tta=tta,
|
|
102
|
+
return_logits=return_logits,
|
|
103
|
+
)
|
|
104
|
+
outputs.append(output)
|
|
105
|
+
else:
|
|
106
|
+
outputs.append(None)
|
|
107
|
+
|
|
108
|
+
return self._gather(outputs)
|
|
109
|
+
|
|
110
|
+
def __repr__(self) -> str:
|
|
111
|
+
return (
|
|
112
|
+
f"ZeroShotInferenceDataParallel(\n"
|
|
113
|
+
f" devices={self.device_ids},\n"
|
|
114
|
+
f" output_device={self.output_device},\n"
|
|
115
|
+
f" src_device={self.src_device},\n"
|
|
116
|
+
f" text_embeddings_shape={tuple(self.text_embeddings[0].shape)}\n"
|
|
117
|
+
f")"
|
|
118
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import numpy.typing as npt
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from birder_clip.net.base import BaseNet
|
|
12
|
+
|
|
13
|
+
DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def infer_dataloader_iter(
|
|
17
|
+
device: torch.device,
|
|
18
|
+
net: BaseNet,
|
|
19
|
+
dataloader: DataLoader,
|
|
20
|
+
model_dtype: torch.dtype = torch.float32,
|
|
21
|
+
amp: bool = False,
|
|
22
|
+
amp_dtype: Optional[torch.dtype] = None,
|
|
23
|
+
num_samples: Optional[int] = None,
|
|
24
|
+
chunk_size: Optional[float] = None,
|
|
25
|
+
) -> Iterator[DataloaderInferenceResult]:
|
|
26
|
+
if chunk_size is None:
|
|
27
|
+
chunk_size = float("inf")
|
|
28
|
+
|
|
29
|
+
net.to(device, dtype=model_dtype)
|
|
30
|
+
embeddings_list: list[npt.NDArray[np.float32]] = []
|
|
31
|
+
sample_paths: list[str] = []
|
|
32
|
+
sample_count = 0
|
|
33
|
+
with tqdm(total=num_samples, initial=0, unit="images", unit_scale=True, leave=False) as progress:
|
|
34
|
+
for file_paths, inputs, _targets in dataloader:
|
|
35
|
+
batch_size = inputs.size(0)
|
|
36
|
+
|
|
37
|
+
# Inference
|
|
38
|
+
inputs = inputs.to(device, dtype=model_dtype)
|
|
39
|
+
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
40
|
+
embeddings = net.encode_image(inputs, normalize=True)
|
|
41
|
+
embeddings = embeddings.cpu().float().numpy()
|
|
42
|
+
|
|
43
|
+
embeddings_list.append(embeddings)
|
|
44
|
+
|
|
45
|
+
# Set sample list
|
|
46
|
+
sample_paths.extend(file_paths)
|
|
47
|
+
|
|
48
|
+
# Update progress bar
|
|
49
|
+
progress.update(n=batch_size)
|
|
50
|
+
|
|
51
|
+
# Yield results when we reach chunk_size
|
|
52
|
+
sample_count += batch_size
|
|
53
|
+
if sample_count >= chunk_size:
|
|
54
|
+
with tqdm.external_write_mode(file=sys.stderr):
|
|
55
|
+
yield (sample_paths, np.concatenate(embeddings_list, axis=0))
|
|
56
|
+
|
|
57
|
+
# Reset for next chunk
|
|
58
|
+
embeddings_list = []
|
|
59
|
+
sample_paths = []
|
|
60
|
+
sample_count = 0
|
|
61
|
+
|
|
62
|
+
if len(embeddings_list) > 0:
|
|
63
|
+
yield (sample_paths, np.concatenate(embeddings_list, axis=0))
|
|
@@ -13,6 +13,7 @@ from collections.abc import Callable
|
|
|
13
13
|
from collections.abc import Iterator
|
|
14
14
|
from collections.abc import Sequence
|
|
15
15
|
from typing import Optional
|
|
16
|
+
from typing import Protocol
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import numpy.typing as npt
|
|
@@ -27,6 +28,42 @@ from birder_clip.net.base import BaseNet
|
|
|
27
28
|
from birder_clip.tokenizers.base import Tokenizer
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
class ZeroShotInferenceModule(Protocol):
|
|
32
|
+
def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ZeroShotInference:
|
|
36
|
+
def __init__(self, net: BaseNet, text_embeddings: torch.Tensor) -> None:
|
|
37
|
+
self.net = net
|
|
38
|
+
self.text_embeddings = text_embeddings
|
|
39
|
+
|
|
40
|
+
def __call__(self, inputs: torch.Tensor, *, tta: bool = False, return_logits: bool = False) -> torch.Tensor:
|
|
41
|
+
inputs = inputs.to(self.text_embeddings.device, non_blocking=True)
|
|
42
|
+
if tta is True:
|
|
43
|
+
_, _, H, W = inputs.size()
|
|
44
|
+
crop_h = int(H * 0.8)
|
|
45
|
+
crop_w = int(W * 0.8)
|
|
46
|
+
tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
|
|
47
|
+
t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
|
|
48
|
+
outs = []
|
|
49
|
+
for tta_input in tta_inputs:
|
|
50
|
+
image_embeddings = self.net.encode_image(t(tta_input), normalize=True)
|
|
51
|
+
logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
|
|
52
|
+
if return_logits is True:
|
|
53
|
+
outs.append(logits)
|
|
54
|
+
else:
|
|
55
|
+
outs.append(F.softmax(logits, dim=-1))
|
|
56
|
+
|
|
57
|
+
return torch.stack(outs).mean(dim=0)
|
|
58
|
+
|
|
59
|
+
image_embeddings = self.net.encode_image(inputs, normalize=True)
|
|
60
|
+
logits = self.net.forward_logits(image_embeddings, self.text_embeddings)
|
|
61
|
+
if return_logits is True:
|
|
62
|
+
return logits
|
|
63
|
+
|
|
64
|
+
return F.softmax(logits, dim=-1)
|
|
65
|
+
|
|
66
|
+
|
|
30
67
|
def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
|
|
31
68
|
return [template.format(class_name) for class_name in class_names for template in templates]
|
|
32
69
|
|
|
@@ -66,39 +103,10 @@ def build_class_text_embeddings(
|
|
|
66
103
|
DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]]
|
|
67
104
|
|
|
68
105
|
|
|
69
|
-
def infer_batch(
|
|
70
|
-
net: BaseNet, inputs: torch.Tensor, text_embeddings: torch.Tensor, tta: bool = False, return_logits: bool = False
|
|
71
|
-
) -> torch.Tensor:
|
|
72
|
-
if tta is True:
|
|
73
|
-
_, _, H, W = inputs.size()
|
|
74
|
-
crop_h = int(H * 0.8)
|
|
75
|
-
crop_w = int(W * 0.8)
|
|
76
|
-
tta_inputs = five_crop(inputs, size=[crop_h, crop_w])
|
|
77
|
-
t = v2.Resize((H, W), interpolation=v2.InterpolationMode.BICUBIC, antialias=True)
|
|
78
|
-
outs = []
|
|
79
|
-
for tta_input in tta_inputs:
|
|
80
|
-
image_embeddings = net.encode_image(t(tta_input), normalize=True)
|
|
81
|
-
logits = net.forward_logits(image_embeddings, text_embeddings)
|
|
82
|
-
if return_logits is True:
|
|
83
|
-
outs.append(logits)
|
|
84
|
-
else:
|
|
85
|
-
outs.append(F.softmax(logits, dim=-1))
|
|
86
|
-
|
|
87
|
-
return torch.stack(outs).mean(dim=0)
|
|
88
|
-
|
|
89
|
-
image_embeddings = net.encode_image(inputs, normalize=True)
|
|
90
|
-
logits = net.forward_logits(image_embeddings, text_embeddings)
|
|
91
|
-
if return_logits is True:
|
|
92
|
-
return logits
|
|
93
|
-
|
|
94
|
-
return F.softmax(logits, dim=-1)
|
|
95
|
-
|
|
96
|
-
|
|
97
106
|
def infer_dataloader_iter(
|
|
98
107
|
device: torch.device,
|
|
99
|
-
net:
|
|
108
|
+
net: ZeroShotInferenceModule,
|
|
100
109
|
dataloader: DataLoader,
|
|
101
|
-
text_embeddings: torch.Tensor,
|
|
102
110
|
tta: bool = False,
|
|
103
111
|
return_logits: bool = False,
|
|
104
112
|
model_dtype: torch.dtype = torch.float32,
|
|
@@ -111,7 +119,6 @@ def infer_dataloader_iter(
|
|
|
111
119
|
if chunk_size is None:
|
|
112
120
|
chunk_size = float("inf")
|
|
113
121
|
|
|
114
|
-
net.to(device, dtype=model_dtype)
|
|
115
122
|
out_list: list[npt.NDArray[np.float32]] = []
|
|
116
123
|
labels_list: list[npt.NDArray[np.int64]] = []
|
|
117
124
|
sample_paths: list[str] = []
|
|
@@ -121,9 +128,9 @@ def infer_dataloader_iter(
|
|
|
121
128
|
batch_size = inputs.size(0)
|
|
122
129
|
|
|
123
130
|
# Inference
|
|
124
|
-
inputs = inputs.to(
|
|
131
|
+
inputs = inputs.to(dtype=model_dtype)
|
|
125
132
|
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
126
|
-
out =
|
|
133
|
+
out = net(inputs, return_logits=return_logits, tta=tta)
|
|
127
134
|
out = out.cpu().float().numpy()
|
|
128
135
|
|
|
129
136
|
out_list.append(out)
|