birder-clip 0.0.2.dev4__tar.gz → 0.0.2.dev5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/PKG-INFO +3 -3
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/fs_ops.py +100 -3
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/lib.py +1 -1
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/training_cli.py +68 -2
- birder_clip-0.0.2.dev5/birder_clip/common/training_utils.py +99 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/webdataset.py +2 -2
- birder_clip-0.0.2.dev5/birder_clip/inference/zero_shot.py +128 -0
- birder_clip-0.0.2.dev5/birder_clip/net/clip.py +610 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/text/base.py +1 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/text/transformer.py +2 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/scripts/train.py +63 -12
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/scripts/zero_shot.py +249 -180
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/__init__.py +2 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/base.py +3 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/hf.py +13 -0
- birder_clip-0.0.2.dev5/birder_clip/tokenizers/openvision.py +64 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/simple_tokenizer.py +1 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/__main__.py +13 -0
- birder_clip-0.0.2.dev5/birder_clip/tools/convert_model.py +268 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/download_tokenizer.py +5 -4
- birder_clip-0.0.2.dev5/birder_clip/tools/list_models.py +102 -0
- birder_clip-0.0.2.dev5/birder_clip/tools/model_info.py +145 -0
- birder_clip-0.0.2.dev5/birder_clip/tools/stats.py +210 -0
- birder_clip-0.0.2.dev5/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/PKG-INFO +3 -3
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/SOURCES.txt +5 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/requires.txt +2 -2
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/requirements/_requirements-dev.txt +1 -1
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/requirements/requirements.txt +1 -1
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_net.py +1 -1
- birder_clip-0.0.2.dev4/birder_clip/common/training_utils.py +0 -61
- birder_clip-0.0.2.dev4/birder_clip/inference/zero_shot.py +0 -54
- birder_clip-0.0.2.dev4/birder_clip/net/clip.py +0 -282
- birder_clip-0.0.2.dev4/birder_clip/version.py +0 -1
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/LICENSE +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/README.md +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/conf/settings.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/csv.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/fake.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/loss/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/loss/contrastive.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/manifest.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/model_registry.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/base.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/net/text/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/py.typed +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/registry.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/__init__.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip/tools/show_iterator.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_common.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_datasets.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_loss.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_model_registry.py +0 -0
- {birder_clip-0.0.2.dev4 → birder_clip-0.0.2.dev5}/tests/test_tokenizers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder_clip
|
|
3
|
-
Version: 0.0.2.
|
|
3
|
+
Version: 0.0.2.dev5
|
|
4
4
|
Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
|
|
|
24
24
|
Requires-Python: >=3.11
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
License-File: LICENSE
|
|
27
|
-
Requires-Dist: birder>=0.5.
|
|
27
|
+
Requires-Dist: birder>=0.5.6
|
|
28
28
|
Requires-Dist: ftfy>=6.3.1
|
|
29
29
|
Requires-Dist: regex>=2025.7.29
|
|
30
30
|
Requires-Dist: tqdm>=4.67.0
|
|
@@ -37,7 +37,7 @@ Provides-Extra: dev
|
|
|
37
37
|
Requires-Dist: bandit~=1.9.4; extra == "dev"
|
|
38
38
|
Requires-Dist: black~=26.5.0; extra == "dev"
|
|
39
39
|
Requires-Dist: build~=1.5.0; extra == "dev"
|
|
40
|
-
Requires-Dist: bumpver~=
|
|
40
|
+
Requires-Dist: bumpver~=2026.1132; extra == "dev"
|
|
41
41
|
Requires-Dist: coverage~=7.14.1; extra == "dev"
|
|
42
42
|
Requires-Dist: debugpy; extra == "dev"
|
|
43
43
|
Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
|
|
@@ -67,7 +67,7 @@ def model_path(
|
|
|
67
67
|
network_name: str,
|
|
68
68
|
*,
|
|
69
69
|
epoch: Optional[int | str] = None,
|
|
70
|
-
|
|
70
|
+
st: bool = False,
|
|
71
71
|
states: bool = False,
|
|
72
72
|
) -> Path:
|
|
73
73
|
if epoch is not None:
|
|
@@ -77,8 +77,10 @@ def model_path(
|
|
|
77
77
|
|
|
78
78
|
if states is True:
|
|
79
79
|
file_name = f"{file_name}_states.pt"
|
|
80
|
+
elif st is True:
|
|
81
|
+
file_name = f"{file_name}.safetensors"
|
|
80
82
|
else:
|
|
81
|
-
file_name = f"{file_name}.
|
|
83
|
+
file_name = f"{file_name}.pt"
|
|
82
84
|
|
|
83
85
|
return settings.MODELS_DIR.joinpath(file_name)
|
|
84
86
|
|
|
@@ -109,6 +111,30 @@ def _checkpoint_states(
|
|
|
109
111
|
torch.save(kwargs, states_path)
|
|
110
112
|
|
|
111
113
|
|
|
114
|
+
def _checkpoint_states_from_state_dicts(
|
|
115
|
+
states_path: Path,
|
|
116
|
+
optimizer_state: Optional[dict[str, Any]],
|
|
117
|
+
scheduler_state: Optional[dict[str, Any]],
|
|
118
|
+
scaler_state: Optional[dict[str, Any]],
|
|
119
|
+
model_base_state: Optional[dict[str, Any]],
|
|
120
|
+
**extra_states: Optional[dict[str, Any]],
|
|
121
|
+
) -> None:
|
|
122
|
+
if optimizer_state is None or scheduler_state is None:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
logger.info(f"Saving checkpoint states {states_path}...")
|
|
126
|
+
torch.save(
|
|
127
|
+
{
|
|
128
|
+
"optimizer_state": optimizer_state,
|
|
129
|
+
"scheduler_state": scheduler_state,
|
|
130
|
+
"scaler_state": scaler_state,
|
|
131
|
+
"model_base_state": model_base_state,
|
|
132
|
+
**extra_states,
|
|
133
|
+
},
|
|
134
|
+
states_path,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
112
138
|
class TrainingStates(NamedTuple):
|
|
113
139
|
optimizer_state: Optional[dict[str, Any]]
|
|
114
140
|
scheduler_state: Optional[dict[str, Any]]
|
|
@@ -182,6 +208,50 @@ def checkpoint_model(
|
|
|
182
208
|
_checkpoint_states(states_path, optimizer, scheduler, scaler, model_base, **extra_states)
|
|
183
209
|
|
|
184
210
|
|
|
211
|
+
def checkpoint_model_from_state_dicts(
|
|
212
|
+
network_name: str,
|
|
213
|
+
epoch: int,
|
|
214
|
+
model_state: dict[str, Any],
|
|
215
|
+
task: Any,
|
|
216
|
+
signature: SignatureType,
|
|
217
|
+
rgb_stats: RGBType,
|
|
218
|
+
optimizer_state: Optional[dict[str, Any]],
|
|
219
|
+
scheduler_state: Optional[dict[str, Any]],
|
|
220
|
+
scaler_state: Optional[dict[str, Any]],
|
|
221
|
+
model_base_state: Optional[dict[str, Any]],
|
|
222
|
+
*,
|
|
223
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
224
|
+
**extra_states: Optional[dict[str, Any]],
|
|
225
|
+
) -> None:
|
|
226
|
+
kwargs = {}
|
|
227
|
+
if external_config is not None:
|
|
228
|
+
kwargs["config"] = external_config
|
|
229
|
+
|
|
230
|
+
path = model_path(network_name, epoch=epoch)
|
|
231
|
+
states_path = model_path(network_name, epoch=epoch, states=True)
|
|
232
|
+
logger.info(f"Saving model checkpoint {path}...")
|
|
233
|
+
torch.save(
|
|
234
|
+
{
|
|
235
|
+
"state": model_state,
|
|
236
|
+
"birder_clip_version": __version__,
|
|
237
|
+
"task": task,
|
|
238
|
+
"signature": signature,
|
|
239
|
+
"rgb_stats": rgb_stats,
|
|
240
|
+
**kwargs,
|
|
241
|
+
},
|
|
242
|
+
path,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
_checkpoint_states_from_state_dicts(
|
|
246
|
+
states_path,
|
|
247
|
+
optimizer_state,
|
|
248
|
+
scheduler_state,
|
|
249
|
+
scaler_state,
|
|
250
|
+
model_base_state,
|
|
251
|
+
**extra_states,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
185
255
|
def clean_checkpoints(network_name: str, keep_last: int) -> None:
|
|
186
256
|
epoch = "*[0-9]"
|
|
187
257
|
models_glob = str(model_path(network_name, epoch=epoch))
|
|
@@ -314,7 +384,7 @@ def load_model(
|
|
|
314
384
|
embed_dim=embed_dim,
|
|
315
385
|
tokenizer=tokenizer,
|
|
316
386
|
)
|
|
317
|
-
path = model_path(_network_name, epoch=epoch,
|
|
387
|
+
path = model_path(_network_name, epoch=epoch, st=st)
|
|
318
388
|
|
|
319
389
|
logger.info(f"Loading model from {path} on device {device}...")
|
|
320
390
|
|
|
@@ -589,6 +659,33 @@ def load_pretrained_tokenizer(weights: str, *, download: bool = True, **kwargs:
|
|
|
589
659
|
return get_tokenizer(tokenizer_name, **tokenizer_kwargs)
|
|
590
660
|
|
|
591
661
|
|
|
662
|
+
def save_st(
|
|
663
|
+
net: torch.nn.Module,
|
|
664
|
+
dst: str,
|
|
665
|
+
task: str,
|
|
666
|
+
signature: SignatureType,
|
|
667
|
+
rgb_stats: RGBType,
|
|
668
|
+
*,
|
|
669
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
670
|
+
) -> None:
|
|
671
|
+
assert _HAS_SAFETENSORS, "'pip install safetensors' to use .safetensors"
|
|
672
|
+
kwargs = {}
|
|
673
|
+
if external_config is not None:
|
|
674
|
+
kwargs["config"] = json.dumps(external_config)
|
|
675
|
+
|
|
676
|
+
safetensors.torch.save_model(
|
|
677
|
+
net,
|
|
678
|
+
str(dst),
|
|
679
|
+
{
|
|
680
|
+
"birder_clip_version": __version__,
|
|
681
|
+
"task": task,
|
|
682
|
+
"signature": json.dumps(signature),
|
|
683
|
+
"rgb_stats": json.dumps(rgb_stats),
|
|
684
|
+
**kwargs,
|
|
685
|
+
},
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
592
689
|
def download_model_by_weights(
|
|
593
690
|
weights: str, *, dst: Optional[str | Path] = None, file_format: FileFormatType = "pt", progress_bar: bool = True
|
|
594
691
|
) -> None:
|
|
@@ -86,7 +86,7 @@ def get_image_text_model_config(
|
|
|
86
86
|
|
|
87
87
|
if config is not None:
|
|
88
88
|
for key, value in config.items():
|
|
89
|
-
if key in {"image", "text"} and isinstance(value, dict)
|
|
89
|
+
if key in {"image", "text"} and isinstance(value, dict):
|
|
90
90
|
model_config[key] = {**model_config.get(key, {}), **value}
|
|
91
91
|
else:
|
|
92
92
|
model_config[key] = value
|
|
@@ -49,7 +49,9 @@ def add_loss_args(parser: argparse.ArgumentParser) -> None:
|
|
|
49
49
|
def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: int = 32) -> None:
|
|
50
50
|
group = parser.add_argument_group("Optimization parameters")
|
|
51
51
|
group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
|
|
52
|
-
group.add_argument(
|
|
52
|
+
group.add_argument(
|
|
53
|
+
"--opt", type=str, choices=list(get_args(OptimizerType)), default="adamw", help="optimizer to use"
|
|
54
|
+
)
|
|
53
55
|
group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
|
|
54
56
|
group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
|
|
55
57
|
group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
|
|
@@ -318,7 +320,13 @@ def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
|
|
|
318
320
|
action="store_true",
|
|
319
321
|
help="keep dataloader worker processes alive between epochs",
|
|
320
322
|
)
|
|
321
|
-
group.add_argument(
|
|
323
|
+
group.add_argument(
|
|
324
|
+
"--no-drop-last",
|
|
325
|
+
dest="drop_last",
|
|
326
|
+
default=True,
|
|
327
|
+
action="store_false",
|
|
328
|
+
help="do not drop the last incomplete batch",
|
|
329
|
+
)
|
|
322
330
|
|
|
323
331
|
|
|
324
332
|
def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -410,6 +418,44 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
|
|
410
418
|
group.add_argument("--local-rank", type=int, metavar="N", help="local rank")
|
|
411
419
|
group.add_argument("--dist-url", type=str, default="env://", help="URL used to initialize distributed training")
|
|
412
420
|
group.add_argument("--dist-backend", type=str, default="nccl", help="distributed backend")
|
|
421
|
+
group.add_argument(
|
|
422
|
+
"--distributed-mode", type=str, choices=["ddp", "fsdp"], default="ddp", help="distributed training mode"
|
|
423
|
+
)
|
|
424
|
+
group.add_argument(
|
|
425
|
+
"--fsdp-sharding-strategy",
|
|
426
|
+
type=str,
|
|
427
|
+
choices=["shard-grad-op", "full-shard"],
|
|
428
|
+
default="shard-grad-op",
|
|
429
|
+
help="FSDP sharding strategy",
|
|
430
|
+
)
|
|
431
|
+
group.add_argument(
|
|
432
|
+
"--fsdp-param-dtype",
|
|
433
|
+
type=str,
|
|
434
|
+
choices=["float32", "float16", "bfloat16"],
|
|
435
|
+
help="FSDP mixed precision parameter dtype",
|
|
436
|
+
)
|
|
437
|
+
group.add_argument(
|
|
438
|
+
"--fsdp-reduce-dtype",
|
|
439
|
+
type=str,
|
|
440
|
+
choices=["float32", "float16", "bfloat16"],
|
|
441
|
+
help="FSDP mixed precision gradient reduction dtype",
|
|
442
|
+
)
|
|
443
|
+
group.add_argument(
|
|
444
|
+
"--fsdp-wrap-policy",
|
|
445
|
+
type=str,
|
|
446
|
+
choices=["block-group-regex", "min-num-params"],
|
|
447
|
+
default="block-group-regex",
|
|
448
|
+
help="FSDP module wrapping policy",
|
|
449
|
+
)
|
|
450
|
+
group.add_argument(
|
|
451
|
+
"--fsdp-wrap-min-num-params",
|
|
452
|
+
type=float,
|
|
453
|
+
metavar="M",
|
|
454
|
+
help="minimum module parameter count in millions for wrapping when using --fsdp-wrap-policy min-num-params",
|
|
455
|
+
)
|
|
456
|
+
group.add_argument(
|
|
457
|
+
"--fsdp-offload-policy", type=str, choices=["none", "cpu"], default="none", help="FSDP parameter offload policy"
|
|
458
|
+
)
|
|
413
459
|
group.add_argument(
|
|
414
460
|
"--find-unused-parameters",
|
|
415
461
|
default=False,
|
|
@@ -561,3 +607,23 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
561
607
|
raise cli.ValidationError("--grad-accum-steps must be >= 1")
|
|
562
608
|
if args.model_ema_steps < 1:
|
|
563
609
|
raise cli.ValidationError("--model-ema-steps must be >= 1")
|
|
610
|
+
|
|
611
|
+
if args.distributed_mode == "fsdp":
|
|
612
|
+
if args.sync_bn is True:
|
|
613
|
+
raise cli.ValidationError("--sync-bn cannot be used with --distributed-mode fsdp")
|
|
614
|
+
if args.find_unused_parameters is True:
|
|
615
|
+
raise cli.ValidationError("--find-unused-parameters cannot be used with --distributed-mode fsdp")
|
|
616
|
+
if args.compile_opt is True:
|
|
617
|
+
raise cli.ValidationError("--compile-opt cannot be used with --distributed-mode fsdp")
|
|
618
|
+
if args.compile_fullgraph is True:
|
|
619
|
+
raise cli.ValidationError("--compile-fullgraph cannot be used with --distributed-mode fsdp")
|
|
620
|
+
if args.cpu is True:
|
|
621
|
+
raise cli.ValidationError("--cpu cannot be used with --distributed-mode fsdp")
|
|
622
|
+
if args.model_ema is True:
|
|
623
|
+
raise cli.ValidationError("--model-ema cannot be used with --distributed-mode fsdp")
|
|
624
|
+
if args.fsdp_wrap_policy == "min-num-params" and args.fsdp_wrap_min_num_params is None:
|
|
625
|
+
raise cli.ValidationError(
|
|
626
|
+
"--fsdp-wrap-min-num-params is required when --fsdp-wrap-policy is min-num-params"
|
|
627
|
+
)
|
|
628
|
+
if args.fsdp_wrap_min_num_params is not None and args.fsdp_wrap_min_num_params <= 0:
|
|
629
|
+
raise cli.ValidationError("--fsdp-wrap-min-num-params must be > 0")
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from birder.common import fsdp_utils
|
|
10
|
+
from birder.common import training_utils as birder_training_utils
|
|
11
|
+
|
|
12
|
+
from birder_clip.common import fs_ops
|
|
13
|
+
from birder_clip.conf import settings
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
|
|
17
|
+
file_handler = logging.FileHandler(log_file_path)
|
|
18
|
+
formatter = logging.Formatter(
|
|
19
|
+
fmt="{message}",
|
|
20
|
+
style="{",
|
|
21
|
+
)
|
|
22
|
+
file_handler.setFormatter(formatter)
|
|
23
|
+
file_handler.setLevel(settings.LOG_LEVEL)
|
|
24
|
+
|
|
25
|
+
logging.getLogger("birder").addHandler(file_handler)
|
|
26
|
+
logging.getLogger("birder_clip").addHandler(file_handler)
|
|
27
|
+
|
|
28
|
+
return file_handler
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def save_training_checkpoint(
|
|
32
|
+
args: argparse.Namespace,
|
|
33
|
+
network_name: str,
|
|
34
|
+
epoch: int,
|
|
35
|
+
net: torch.nn.Module,
|
|
36
|
+
signature: Any,
|
|
37
|
+
rgb_stats: Any,
|
|
38
|
+
optimizer: Optional[torch.optim.Optimizer],
|
|
39
|
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
|
|
40
|
+
scaler: Optional[torch.amp.grad_scaler.GradScaler],
|
|
41
|
+
model_base: Optional[torch.nn.Module],
|
|
42
|
+
*,
|
|
43
|
+
fsdp_mode: bool = False,
|
|
44
|
+
fsdp_model_state: Optional[dict[str, Any]] = None,
|
|
45
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
46
|
+
**extra_states: Optional[dict[str, Any]],
|
|
47
|
+
) -> None:
|
|
48
|
+
if fsdp_mode is True:
|
|
49
|
+
if fsdp_model_state is not None:
|
|
50
|
+
model_state = fsdp_model_state
|
|
51
|
+
else:
|
|
52
|
+
model_state = fsdp_utils.gather_full_model_state_dict(net)
|
|
53
|
+
|
|
54
|
+
optimizer_state = None
|
|
55
|
+
scheduler_state = None
|
|
56
|
+
scaler_state = None
|
|
57
|
+
model_base_state = None
|
|
58
|
+
if optimizer is not None and scheduler is not None:
|
|
59
|
+
optimizer_state = fsdp_utils.gather_full_optimizer_state_dict(net, optimizer)
|
|
60
|
+
scheduler_state = scheduler.state_dict()
|
|
61
|
+
if scaler is not None:
|
|
62
|
+
scaler_state = scaler.state_dict()
|
|
63
|
+
if model_base is not None:
|
|
64
|
+
model_base_state = model_base.state_dict()
|
|
65
|
+
|
|
66
|
+
if birder_training_utils.is_global_primary(args) is True:
|
|
67
|
+
fs_ops.checkpoint_model_from_state_dicts(
|
|
68
|
+
network_name,
|
|
69
|
+
epoch,
|
|
70
|
+
model_state,
|
|
71
|
+
net.task,
|
|
72
|
+
signature,
|
|
73
|
+
rgb_stats,
|
|
74
|
+
optimizer_state,
|
|
75
|
+
scheduler_state,
|
|
76
|
+
scaler_state,
|
|
77
|
+
model_base_state,
|
|
78
|
+
external_config=external_config,
|
|
79
|
+
**extra_states,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if birder_training_utils.is_dist_available_and_initialized() is True:
|
|
83
|
+
dist.barrier()
|
|
84
|
+
|
|
85
|
+
else:
|
|
86
|
+
if birder_training_utils.is_global_primary(args) is True:
|
|
87
|
+
fs_ops.checkpoint_model(
|
|
88
|
+
network_name,
|
|
89
|
+
epoch,
|
|
90
|
+
net,
|
|
91
|
+
signature,
|
|
92
|
+
rgb_stats,
|
|
93
|
+
optimizer,
|
|
94
|
+
scheduler,
|
|
95
|
+
scaler,
|
|
96
|
+
model_base,
|
|
97
|
+
external_config=external_config,
|
|
98
|
+
**extra_states,
|
|
99
|
+
)
|
|
@@ -24,10 +24,10 @@ def decode_caption(caption: Any, caption_json_key: str = "caption") -> str:
|
|
|
24
24
|
if isinstance(caption, bytes):
|
|
25
25
|
caption = caption.decode("utf-8")
|
|
26
26
|
|
|
27
|
-
if isinstance(caption, str)
|
|
27
|
+
if not isinstance(caption, str):
|
|
28
28
|
raise TypeError(f"WebDataset caption must be a string, got {type(caption).__name__}")
|
|
29
29
|
|
|
30
|
-
return caption
|
|
30
|
+
return caption
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def tokenize_caption(caption: str, tokenizer: Tokenizer) -> torch.Tensor:
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Zero-shot text embedding helpers
|
|
3
|
+
|
|
4
|
+
Zero-shot classification compares image features against one text feature per
|
|
5
|
+
candidate class. When multiple prompt templates are used, this module follows
|
|
6
|
+
the OpenCLIP/OpenAI CLIP convention: encode every class/template prompt,
|
|
7
|
+
normalize prompt embeddings, average them per class and normalize the averaged
|
|
8
|
+
class embedding again.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import sys
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
from collections.abc import Iterator
|
|
14
|
+
from collections.abc import Sequence
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import numpy.typing as npt
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
from torch.utils.data import DataLoader
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
from birder_clip.net.base import BaseNet
|
|
25
|
+
from birder_clip.tokenizers.base import Tokenizer
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def render_prompts(class_names: Sequence[str], templates: Sequence[str]) -> list[str]:
|
|
29
|
+
return [template.format(class_name) for class_name in class_names for template in templates]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def build_class_text_embeddings(
|
|
33
|
+
net: BaseNet,
|
|
34
|
+
tokenizer: Tokenizer,
|
|
35
|
+
class_names: Sequence[str],
|
|
36
|
+
templates: Sequence[str],
|
|
37
|
+
*,
|
|
38
|
+
device: torch.device,
|
|
39
|
+
context_length: Optional[int] = None,
|
|
40
|
+
batch_size: Optional[int] = None,
|
|
41
|
+
amp: bool = False,
|
|
42
|
+
amp_dtype: Optional[torch.dtype] = None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
num_templates = len(templates)
|
|
45
|
+
if batch_size is None:
|
|
46
|
+
batch_size = len(class_names)
|
|
47
|
+
|
|
48
|
+
class_text_embeddings = []
|
|
49
|
+
with torch.inference_mode():
|
|
50
|
+
for start in range(0, len(class_names), batch_size):
|
|
51
|
+
batch_class_names = class_names[start : start + batch_size]
|
|
52
|
+
prompts = render_prompts(batch_class_names, templates)
|
|
53
|
+
tokens = tokenizer(prompts, context_length=context_length).to(device)
|
|
54
|
+
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
55
|
+
class_embeddings = net.encode_text(tokens, normalize=True)
|
|
56
|
+
|
|
57
|
+
class_embeddings = class_embeddings.reshape(len(batch_class_names), num_templates, -1).mean(dim=1)
|
|
58
|
+
class_embeddings = F.normalize(class_embeddings, dim=-1)
|
|
59
|
+
class_text_embeddings.append(class_embeddings)
|
|
60
|
+
|
|
61
|
+
return torch.concat(class_text_embeddings, dim=0)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
DataloaderInferenceResult = tuple[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def infer_dataloader_iter(
|
|
68
|
+
device: torch.device,
|
|
69
|
+
net: BaseNet | torch.ScriptModule,
|
|
70
|
+
dataloader: DataLoader,
|
|
71
|
+
text_embeddings: torch.Tensor,
|
|
72
|
+
return_logits: bool = False,
|
|
73
|
+
model_dtype: torch.dtype = torch.float32,
|
|
74
|
+
amp: bool = False,
|
|
75
|
+
amp_dtype: Optional[torch.dtype] = None,
|
|
76
|
+
num_samples: Optional[int] = None,
|
|
77
|
+
batch_callback: Optional[Callable[[list[str], npt.NDArray[np.float32], npt.NDArray[np.int64]], None]] = None,
|
|
78
|
+
chunk_size: Optional[float] = None,
|
|
79
|
+
) -> Iterator[DataloaderInferenceResult]:
|
|
80
|
+
if chunk_size is None:
|
|
81
|
+
chunk_size = float("inf")
|
|
82
|
+
|
|
83
|
+
net.to(device, dtype=model_dtype)
|
|
84
|
+
out_list: list[npt.NDArray[np.float32]] = []
|
|
85
|
+
labels_list: list[npt.NDArray[np.int64]] = []
|
|
86
|
+
sample_paths: list[str] = []
|
|
87
|
+
sample_count = 0
|
|
88
|
+
with tqdm(total=num_samples, initial=0, unit="images", unit_scale=True, leave=False) as progress:
|
|
89
|
+
for file_paths, inputs, targets in dataloader:
|
|
90
|
+
batch_size = inputs.size(0)
|
|
91
|
+
|
|
92
|
+
# Inference
|
|
93
|
+
inputs = inputs.to(device, dtype=model_dtype)
|
|
94
|
+
with torch.amp.autocast(device.type, enabled=amp, dtype=amp_dtype):
|
|
95
|
+
image_embeddings = net.encode_image(inputs, normalize=True)
|
|
96
|
+
logits = net.forward_logits(image_embeddings, text_embeddings)
|
|
97
|
+
if return_logits is True:
|
|
98
|
+
out = logits.cpu().float().numpy()
|
|
99
|
+
else:
|
|
100
|
+
out = F.softmax(logits, dim=-1).cpu().float().numpy()
|
|
101
|
+
|
|
102
|
+
out_list.append(out)
|
|
103
|
+
|
|
104
|
+
# Set labels and sample list
|
|
105
|
+
batch_labels = targets.cpu().numpy()
|
|
106
|
+
labels_list.append(batch_labels)
|
|
107
|
+
sample_paths.extend(file_paths)
|
|
108
|
+
|
|
109
|
+
if batch_callback is not None:
|
|
110
|
+
batch_callback(file_paths, out, batch_labels)
|
|
111
|
+
|
|
112
|
+
# Update progress bar
|
|
113
|
+
progress.update(n=batch_size)
|
|
114
|
+
|
|
115
|
+
# Yield results when we reach chunk_size
|
|
116
|
+
sample_count += batch_size
|
|
117
|
+
if sample_count >= chunk_size:
|
|
118
|
+
with tqdm.external_write_mode(file=sys.stderr):
|
|
119
|
+
yield (sample_paths, np.concatenate(out_list, axis=0), np.concatenate(labels_list))
|
|
120
|
+
|
|
121
|
+
# Reset for next chunk
|
|
122
|
+
out_list = []
|
|
123
|
+
labels_list = []
|
|
124
|
+
sample_paths = []
|
|
125
|
+
sample_count = 0
|
|
126
|
+
|
|
127
|
+
if len(out_list) > 0:
|
|
128
|
+
yield (sample_paths, np.concatenate(out_list, axis=0), np.concatenate(labels_list))
|