birder-clip 0.0.2.dev3__tar.gz → 0.0.2.dev5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/PKG-INFO +3 -3
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/fs_ops.py +100 -3
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/lib.py +11 -3
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/training_cli.py +167 -100
- birder_clip-0.0.2.dev5/birder_clip/common/training_utils.py +99 -0
- birder_clip-0.0.2.dev5/birder_clip/data/datasets/webdataset.py +106 -0
- birder_clip-0.0.2.dev5/birder_clip/inference/zero_shot.py +128 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/loss/contrastive.py +11 -0
- birder_clip-0.0.2.dev5/birder_clip/net/clip.py +610 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/text/base.py +5 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/text/transformer.py +2 -0
- birder_clip-0.0.2.dev5/birder_clip/scripts/train.py +991 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/scripts/zero_shot.py +249 -180
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/__init__.py +2 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/base.py +3 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/hf.py +13 -0
- birder_clip-0.0.2.dev5/birder_clip/tokenizers/openvision.py +64 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/simple_tokenizer.py +1 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/__main__.py +13 -0
- birder_clip-0.0.2.dev5/birder_clip/tools/convert_model.py +268 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/download_tokenizer.py +5 -4
- birder_clip-0.0.2.dev5/birder_clip/tools/list_models.py +102 -0
- birder_clip-0.0.2.dev5/birder_clip/tools/model_info.py +145 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/show_iterator.py +77 -11
- birder_clip-0.0.2.dev5/birder_clip/tools/stats.py +210 -0
- birder_clip-0.0.2.dev5/birder_clip/version.py +1 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/PKG-INFO +3 -3
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/SOURCES.txt +8 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/requires.txt +2 -2
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/requirements/_requirements-dev.txt +1 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/requirements/requirements.txt +1 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_common.py +1 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_datasets.py +6 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_net.py +1 -1
- birder_clip-0.0.2.dev3/birder_clip/inference/zero_shot.py +0 -54
- birder_clip-0.0.2.dev3/birder_clip/net/clip.py +0 -263
- birder_clip-0.0.2.dev3/birder_clip/version.py +0 -1
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/LICENSE +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/README.md +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/conf/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/conf/settings.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/csv.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/fake.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/inference/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/inference/zero_shot_templates.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/loss/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/manifest.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/model_registry.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/base.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/text/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/py.typed +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/scripts/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/registry.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/__init__.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/dependency_links.txt +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/top_level.txt +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/pyproject.toml +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/setup.cfg +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_loss.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_model_registry.py +0 -0
- {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_tokenizers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder_clip
|
|
3
|
-
Version: 0.0.2.
|
|
3
|
+
Version: 0.0.2.dev5
|
|
4
4
|
Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
|
|
|
24
24
|
Requires-Python: >=3.11
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
License-File: LICENSE
|
|
27
|
-
Requires-Dist: birder>=0.5.
|
|
27
|
+
Requires-Dist: birder>=0.5.6
|
|
28
28
|
Requires-Dist: ftfy>=6.3.1
|
|
29
29
|
Requires-Dist: regex>=2025.7.29
|
|
30
30
|
Requires-Dist: tqdm>=4.67.0
|
|
@@ -37,7 +37,7 @@ Provides-Extra: dev
|
|
|
37
37
|
Requires-Dist: bandit~=1.9.4; extra == "dev"
|
|
38
38
|
Requires-Dist: black~=26.5.0; extra == "dev"
|
|
39
39
|
Requires-Dist: build~=1.5.0; extra == "dev"
|
|
40
|
-
Requires-Dist: bumpver~=
|
|
40
|
+
Requires-Dist: bumpver~=2026.1132; extra == "dev"
|
|
41
41
|
Requires-Dist: coverage~=7.14.1; extra == "dev"
|
|
42
42
|
Requires-Dist: debugpy; extra == "dev"
|
|
43
43
|
Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
|
|
@@ -67,7 +67,7 @@ def model_path(
|
|
|
67
67
|
network_name: str,
|
|
68
68
|
*,
|
|
69
69
|
epoch: Optional[int | str] = None,
|
|
70
|
-
|
|
70
|
+
st: bool = False,
|
|
71
71
|
states: bool = False,
|
|
72
72
|
) -> Path:
|
|
73
73
|
if epoch is not None:
|
|
@@ -77,8 +77,10 @@ def model_path(
|
|
|
77
77
|
|
|
78
78
|
if states is True:
|
|
79
79
|
file_name = f"{file_name}_states.pt"
|
|
80
|
+
elif st is True:
|
|
81
|
+
file_name = f"{file_name}.safetensors"
|
|
80
82
|
else:
|
|
81
|
-
file_name = f"{file_name}.
|
|
83
|
+
file_name = f"{file_name}.pt"
|
|
82
84
|
|
|
83
85
|
return settings.MODELS_DIR.joinpath(file_name)
|
|
84
86
|
|
|
@@ -109,6 +111,30 @@ def _checkpoint_states(
|
|
|
109
111
|
torch.save(kwargs, states_path)
|
|
110
112
|
|
|
111
113
|
|
|
114
|
+
def _checkpoint_states_from_state_dicts(
|
|
115
|
+
states_path: Path,
|
|
116
|
+
optimizer_state: Optional[dict[str, Any]],
|
|
117
|
+
scheduler_state: Optional[dict[str, Any]],
|
|
118
|
+
scaler_state: Optional[dict[str, Any]],
|
|
119
|
+
model_base_state: Optional[dict[str, Any]],
|
|
120
|
+
**extra_states: Optional[dict[str, Any]],
|
|
121
|
+
) -> None:
|
|
122
|
+
if optimizer_state is None or scheduler_state is None:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
logger.info(f"Saving checkpoint states {states_path}...")
|
|
126
|
+
torch.save(
|
|
127
|
+
{
|
|
128
|
+
"optimizer_state": optimizer_state,
|
|
129
|
+
"scheduler_state": scheduler_state,
|
|
130
|
+
"scaler_state": scaler_state,
|
|
131
|
+
"model_base_state": model_base_state,
|
|
132
|
+
**extra_states,
|
|
133
|
+
},
|
|
134
|
+
states_path,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
112
138
|
class TrainingStates(NamedTuple):
|
|
113
139
|
optimizer_state: Optional[dict[str, Any]]
|
|
114
140
|
scheduler_state: Optional[dict[str, Any]]
|
|
@@ -182,6 +208,50 @@ def checkpoint_model(
|
|
|
182
208
|
_checkpoint_states(states_path, optimizer, scheduler, scaler, model_base, **extra_states)
|
|
183
209
|
|
|
184
210
|
|
|
211
|
+
def checkpoint_model_from_state_dicts(
|
|
212
|
+
network_name: str,
|
|
213
|
+
epoch: int,
|
|
214
|
+
model_state: dict[str, Any],
|
|
215
|
+
task: Any,
|
|
216
|
+
signature: SignatureType,
|
|
217
|
+
rgb_stats: RGBType,
|
|
218
|
+
optimizer_state: Optional[dict[str, Any]],
|
|
219
|
+
scheduler_state: Optional[dict[str, Any]],
|
|
220
|
+
scaler_state: Optional[dict[str, Any]],
|
|
221
|
+
model_base_state: Optional[dict[str, Any]],
|
|
222
|
+
*,
|
|
223
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
224
|
+
**extra_states: Optional[dict[str, Any]],
|
|
225
|
+
) -> None:
|
|
226
|
+
kwargs = {}
|
|
227
|
+
if external_config is not None:
|
|
228
|
+
kwargs["config"] = external_config
|
|
229
|
+
|
|
230
|
+
path = model_path(network_name, epoch=epoch)
|
|
231
|
+
states_path = model_path(network_name, epoch=epoch, states=True)
|
|
232
|
+
logger.info(f"Saving model checkpoint {path}...")
|
|
233
|
+
torch.save(
|
|
234
|
+
{
|
|
235
|
+
"state": model_state,
|
|
236
|
+
"birder_clip_version": __version__,
|
|
237
|
+
"task": task,
|
|
238
|
+
"signature": signature,
|
|
239
|
+
"rgb_stats": rgb_stats,
|
|
240
|
+
**kwargs,
|
|
241
|
+
},
|
|
242
|
+
path,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
_checkpoint_states_from_state_dicts(
|
|
246
|
+
states_path,
|
|
247
|
+
optimizer_state,
|
|
248
|
+
scheduler_state,
|
|
249
|
+
scaler_state,
|
|
250
|
+
model_base_state,
|
|
251
|
+
**extra_states,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
185
255
|
def clean_checkpoints(network_name: str, keep_last: int) -> None:
|
|
186
256
|
epoch = "*[0-9]"
|
|
187
257
|
models_glob = str(model_path(network_name, epoch=epoch))
|
|
@@ -314,7 +384,7 @@ def load_model(
|
|
|
314
384
|
embed_dim=embed_dim,
|
|
315
385
|
tokenizer=tokenizer,
|
|
316
386
|
)
|
|
317
|
-
path = model_path(_network_name, epoch=epoch,
|
|
387
|
+
path = model_path(_network_name, epoch=epoch, st=st)
|
|
318
388
|
|
|
319
389
|
logger.info(f"Loading model from {path} on device {device}...")
|
|
320
390
|
|
|
@@ -589,6 +659,33 @@ def load_pretrained_tokenizer(weights: str, *, download: bool = True, **kwargs:
|
|
|
589
659
|
return get_tokenizer(tokenizer_name, **tokenizer_kwargs)
|
|
590
660
|
|
|
591
661
|
|
|
662
|
+
def save_st(
|
|
663
|
+
net: torch.nn.Module,
|
|
664
|
+
dst: str,
|
|
665
|
+
task: str,
|
|
666
|
+
signature: SignatureType,
|
|
667
|
+
rgb_stats: RGBType,
|
|
668
|
+
*,
|
|
669
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
670
|
+
) -> None:
|
|
671
|
+
assert _HAS_SAFETENSORS, "'pip install safetensors' to use .safetensors"
|
|
672
|
+
kwargs = {}
|
|
673
|
+
if external_config is not None:
|
|
674
|
+
kwargs["config"] = json.dumps(external_config)
|
|
675
|
+
|
|
676
|
+
safetensors.torch.save_model(
|
|
677
|
+
net,
|
|
678
|
+
str(dst),
|
|
679
|
+
{
|
|
680
|
+
"birder_clip_version": __version__,
|
|
681
|
+
"task": task,
|
|
682
|
+
"signature": json.dumps(signature),
|
|
683
|
+
"rgb_stats": json.dumps(rgb_stats),
|
|
684
|
+
**kwargs,
|
|
685
|
+
},
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
592
689
|
def download_model_by_weights(
|
|
593
690
|
weights: str, *, dst: Optional[str | Path] = None, file_format: FileFormatType = "pt", progress_bar: bool = True
|
|
594
691
|
) -> None:
|
|
@@ -43,9 +43,17 @@ def get_image_text_network_name(
|
|
|
43
43
|
parts = [network]
|
|
44
44
|
if image_encoder is not None:
|
|
45
45
|
parts.append(image_encoder)
|
|
46
|
-
if text_encoder is not None:
|
|
46
|
+
if text_encoder is not None and text_encoder != "text_transformer":
|
|
47
47
|
parts.append(text_encoder)
|
|
48
|
-
|
|
48
|
+
|
|
49
|
+
if registry.exists(network) is True:
|
|
50
|
+
default_tokenizer = registry.get_default_tokenizer(network)
|
|
51
|
+
else:
|
|
52
|
+
default_tokenizer = "simple_tokenizer"
|
|
53
|
+
if default_tokenizer is None:
|
|
54
|
+
default_tokenizer = "simple_tokenizer"
|
|
55
|
+
|
|
56
|
+
if tokenizer is not None and tokenizer != default_tokenizer:
|
|
49
57
|
parts.append(tokenizer)
|
|
50
58
|
if embed_dim is not None:
|
|
51
59
|
parts.append(f"d{embed_dim}")
|
|
@@ -78,7 +86,7 @@ def get_image_text_model_config(
|
|
|
78
86
|
|
|
79
87
|
if config is not None:
|
|
80
88
|
for key, value in config.items():
|
|
81
|
-
if key in {"image", "text"} and isinstance(value, dict)
|
|
89
|
+
if key in {"image", "text"} and isinstance(value, dict):
|
|
82
90
|
model_config[key] = {**model_config.get(key, {}), **value}
|
|
83
91
|
else:
|
|
84
92
|
model_config[key] = value
|
|
@@ -17,32 +17,6 @@ from birder_clip.model_registry import Task
|
|
|
17
17
|
from birder_clip.model_registry import registry
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
def add_compile_args(parser: argparse.ArgumentParser) -> None:
|
|
21
|
-
group = parser.add_argument_group("Compilation parameters")
|
|
22
|
-
group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
23
|
-
group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
|
|
24
|
-
group.add_argument(
|
|
25
|
-
"--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
|
|
26
|
-
)
|
|
27
|
-
group.add_argument(
|
|
28
|
-
"--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
|
|
29
|
-
)
|
|
30
|
-
group.add_argument(
|
|
31
|
-
"--compile-recompile-limit",
|
|
32
|
-
type=int,
|
|
33
|
-
default=torch.compiler.config.recompile_limit,
|
|
34
|
-
metavar="N",
|
|
35
|
-
help="maximum recompilations per compiled function before eager fallback",
|
|
36
|
-
)
|
|
37
|
-
group.add_argument(
|
|
38
|
-
"--compile-accumulated-recompile-limit",
|
|
39
|
-
type=int,
|
|
40
|
-
default=torch.compiler.config.accumulated_recompile_limit,
|
|
41
|
-
metavar="N",
|
|
42
|
-
help="maximum total recompilations across compiled functions",
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
20
|
def add_model_args(parser: argparse.ArgumentParser) -> None:
|
|
47
21
|
parser.add_argument("-n", "--network", type=str, help="the image-text network to train")
|
|
48
22
|
parser.add_argument("-t", "--tag", type=str, help="add model tag")
|
|
@@ -92,6 +66,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
|
|
|
92
66
|
metavar="N",
|
|
93
67
|
help="number of iterations to accumulate gradients per optimizer step",
|
|
94
68
|
)
|
|
69
|
+
# NOTE: Add flag for negative sample caching in grad accum mode
|
|
95
70
|
|
|
96
71
|
|
|
97
72
|
def add_lr_wd_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -129,14 +104,14 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
|
129
104
|
"--lr-scheduler-update",
|
|
130
105
|
type=str,
|
|
131
106
|
choices=["epoch", "step"],
|
|
132
|
-
default="
|
|
107
|
+
default="epoch",
|
|
133
108
|
help="when to apply learning rate scheduler update: epoch (once per epoch), step (each optimizer step)",
|
|
134
109
|
)
|
|
135
110
|
group.add_argument(
|
|
136
111
|
"--lr-scheduler",
|
|
137
112
|
type=str,
|
|
138
113
|
choices=list(get_args(SchedulerType)),
|
|
139
|
-
default="
|
|
114
|
+
default="constant",
|
|
140
115
|
help="learning rate scheduler",
|
|
141
116
|
)
|
|
142
117
|
group.add_argument(
|
|
@@ -175,15 +150,6 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
|
175
150
|
)
|
|
176
151
|
|
|
177
152
|
|
|
178
|
-
def add_input_args(parser: argparse.ArgumentParser) -> None:
|
|
179
|
-
group = parser.add_argument_group("Input parameters")
|
|
180
|
-
group.add_argument(
|
|
181
|
-
"--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
|
|
182
|
-
)
|
|
183
|
-
group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
|
|
184
|
-
group.add_argument("--context-length", type=int, metavar="N", help="text context length")
|
|
185
|
-
|
|
186
|
-
|
|
187
153
|
def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
|
|
188
154
|
group = parser.add_argument_group("Training schedule parameters")
|
|
189
155
|
group.add_argument("--epochs", type=int, default=default_epochs, metavar="N", help="number of training epochs")
|
|
@@ -204,6 +170,37 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
|
|
|
204
170
|
)
|
|
205
171
|
|
|
206
172
|
|
|
173
|
+
def add_ema_args(
|
|
174
|
+
parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
|
|
175
|
+
) -> None:
|
|
176
|
+
group = parser.add_argument_group("Exponential moving average parameters")
|
|
177
|
+
group.add_argument(
|
|
178
|
+
"--model-ema",
|
|
179
|
+
default=False,
|
|
180
|
+
action="store_true",
|
|
181
|
+
help="enable tracking exponential moving average of model parameters",
|
|
182
|
+
)
|
|
183
|
+
group.add_argument(
|
|
184
|
+
"--model-ema-steps",
|
|
185
|
+
type=int,
|
|
186
|
+
default=default_ema_steps,
|
|
187
|
+
metavar="N",
|
|
188
|
+
help="number of optimizer steps between EMA updates",
|
|
189
|
+
)
|
|
190
|
+
group.add_argument(
|
|
191
|
+
"--model-ema-decay",
|
|
192
|
+
type=float,
|
|
193
|
+
default=default_ema_decay,
|
|
194
|
+
help="decay factor for exponential moving average of model parameters",
|
|
195
|
+
)
|
|
196
|
+
group.add_argument(
|
|
197
|
+
"--model-ema-warmup",
|
|
198
|
+
type=int,
|
|
199
|
+
metavar="N",
|
|
200
|
+
help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
207
204
|
def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
|
|
208
205
|
group = parser.add_argument_group("Batch normalization parameters")
|
|
209
206
|
group.add_argument(
|
|
@@ -215,6 +212,15 @@ def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
|
|
|
215
212
|
group.add_argument("--sync-bn", default=False, action="store_true", help="use synchronized BatchNorm")
|
|
216
213
|
|
|
217
214
|
|
|
215
|
+
def add_input_args(parser: argparse.ArgumentParser) -> None:
|
|
216
|
+
group = parser.add_argument_group("Input parameters")
|
|
217
|
+
group.add_argument(
|
|
218
|
+
"--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
|
|
219
|
+
)
|
|
220
|
+
group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
|
|
221
|
+
group.add_argument("--context-length", type=int, metavar="N", help="text context length")
|
|
222
|
+
|
|
223
|
+
|
|
218
224
|
def add_data_aug_args(
|
|
219
225
|
parser: argparse.ArgumentParser,
|
|
220
226
|
default_level: int = 4,
|
|
@@ -260,7 +266,7 @@ def add_data_aug_args(
|
|
|
260
266
|
"--rgb-mode",
|
|
261
267
|
type=str,
|
|
262
268
|
choices=list(typing.get_args(RGBMode)),
|
|
263
|
-
default="
|
|
269
|
+
default="birder",
|
|
264
270
|
help="RGB mean and std to use for normalization",
|
|
265
271
|
)
|
|
266
272
|
group.add_argument(
|
|
@@ -279,67 +285,6 @@ def add_data_aug_args(
|
|
|
279
285
|
)
|
|
280
286
|
|
|
281
287
|
|
|
282
|
-
def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
|
|
283
|
-
group = parser.add_argument_group("Checkpoint parameters")
|
|
284
|
-
group.add_argument(
|
|
285
|
-
"--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
|
|
286
|
-
)
|
|
287
|
-
group.add_argument(
|
|
288
|
-
"--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
|
|
289
|
-
)
|
|
290
|
-
group.add_argument(
|
|
291
|
-
"--pretrained",
|
|
292
|
-
default=False,
|
|
293
|
-
action="store_true",
|
|
294
|
-
help="start with pretrained version of specified network (will download if not found locally)",
|
|
295
|
-
)
|
|
296
|
-
group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
|
|
297
|
-
group.add_argument(
|
|
298
|
-
"--non-strict-weights",
|
|
299
|
-
default=False,
|
|
300
|
-
action="store_true",
|
|
301
|
-
help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
|
|
302
|
-
)
|
|
303
|
-
group.add_argument(
|
|
304
|
-
"--load-states",
|
|
305
|
-
default=False,
|
|
306
|
-
action="store_true",
|
|
307
|
-
help="load optimizer, scheduler and scaler states when resuming",
|
|
308
|
-
)
|
|
309
|
-
group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
def add_ema_args(
|
|
313
|
-
parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
|
|
314
|
-
) -> None:
|
|
315
|
-
group = parser.add_argument_group("Exponential moving average parameters")
|
|
316
|
-
group.add_argument(
|
|
317
|
-
"--model-ema",
|
|
318
|
-
default=False,
|
|
319
|
-
action="store_true",
|
|
320
|
-
help="enable tracking exponential moving average of model parameters",
|
|
321
|
-
)
|
|
322
|
-
group.add_argument(
|
|
323
|
-
"--model-ema-steps",
|
|
324
|
-
type=int,
|
|
325
|
-
default=default_ema_steps,
|
|
326
|
-
metavar="N",
|
|
327
|
-
help="number of optimizer steps between EMA updates",
|
|
328
|
-
)
|
|
329
|
-
group.add_argument(
|
|
330
|
-
"--model-ema-decay",
|
|
331
|
-
type=float,
|
|
332
|
-
default=default_ema_decay,
|
|
333
|
-
help="decay factor for exponential moving average of model parameters",
|
|
334
|
-
)
|
|
335
|
-
group.add_argument(
|
|
336
|
-
"--model-ema-warmup",
|
|
337
|
-
type=int,
|
|
338
|
-
metavar="N",
|
|
339
|
-
help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
|
|
343
288
|
def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
|
|
344
289
|
group = parser.add_argument_group("Dataloader parameters")
|
|
345
290
|
group.add_argument(
|
|
@@ -375,7 +320,13 @@ def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
|
|
|
375
320
|
action="store_true",
|
|
376
321
|
help="keep dataloader worker processes alive between epochs",
|
|
377
322
|
)
|
|
378
|
-
group.add_argument(
|
|
323
|
+
group.add_argument(
|
|
324
|
+
"--no-drop-last",
|
|
325
|
+
dest="drop_last",
|
|
326
|
+
default=True,
|
|
327
|
+
action="store_false",
|
|
328
|
+
help="do not drop the last incomplete batch",
|
|
329
|
+
)
|
|
379
330
|
|
|
380
331
|
|
|
381
332
|
def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -405,12 +356,106 @@ def add_precision_args(parser: argparse.ArgumentParser) -> None:
|
|
|
405
356
|
)
|
|
406
357
|
|
|
407
358
|
|
|
359
|
+
def add_compile_args(parser: argparse.ArgumentParser) -> None:
|
|
360
|
+
group = parser.add_argument_group("Compilation parameters")
|
|
361
|
+
group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
|
|
362
|
+
group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
|
|
363
|
+
group.add_argument(
|
|
364
|
+
"--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
|
|
365
|
+
)
|
|
366
|
+
group.add_argument(
|
|
367
|
+
"--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
|
|
368
|
+
)
|
|
369
|
+
group.add_argument(
|
|
370
|
+
"--compile-recompile-limit",
|
|
371
|
+
type=int,
|
|
372
|
+
default=torch.compiler.config.recompile_limit,
|
|
373
|
+
metavar="N",
|
|
374
|
+
help="maximum recompilations per compiled function before eager fallback",
|
|
375
|
+
)
|
|
376
|
+
group.add_argument(
|
|
377
|
+
"--compile-accumulated-recompile-limit",
|
|
378
|
+
type=int,
|
|
379
|
+
default=torch.compiler.config.accumulated_recompile_limit,
|
|
380
|
+
metavar="N",
|
|
381
|
+
help="maximum total recompilations across compiled functions",
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
|
|
386
|
+
group = parser.add_argument_group("Checkpoint parameters")
|
|
387
|
+
group.add_argument(
|
|
388
|
+
"--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
|
|
389
|
+
)
|
|
390
|
+
group.add_argument(
|
|
391
|
+
"--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
|
|
392
|
+
)
|
|
393
|
+
group.add_argument(
|
|
394
|
+
"--pretrained",
|
|
395
|
+
default=False,
|
|
396
|
+
action="store_true",
|
|
397
|
+
help="start with pretrained version of specified network (will download if not found locally)",
|
|
398
|
+
)
|
|
399
|
+
group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
|
|
400
|
+
group.add_argument(
|
|
401
|
+
"--non-strict-weights",
|
|
402
|
+
default=False,
|
|
403
|
+
action="store_true",
|
|
404
|
+
help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
|
|
405
|
+
)
|
|
406
|
+
group.add_argument(
|
|
407
|
+
"--load-states",
|
|
408
|
+
default=False,
|
|
409
|
+
action="store_true",
|
|
410
|
+
help="load optimizer, scheduler and scaler states when resuming",
|
|
411
|
+
)
|
|
412
|
+
group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
|
|
413
|
+
|
|
414
|
+
|
|
408
415
|
def add_distributed_args(parser: argparse.ArgumentParser) -> None:
|
|
409
416
|
group = parser.add_argument_group("Distributed training parameters")
|
|
410
417
|
group.add_argument("--world-size", type=int, default=1, metavar="N", help="number of distributed processes")
|
|
411
418
|
group.add_argument("--local-rank", type=int, metavar="N", help="local rank")
|
|
412
419
|
group.add_argument("--dist-url", type=str, default="env://", help="URL used to initialize distributed training")
|
|
413
420
|
group.add_argument("--dist-backend", type=str, default="nccl", help="distributed backend")
|
|
421
|
+
group.add_argument(
|
|
422
|
+
"--distributed-mode", type=str, choices=["ddp", "fsdp"], default="ddp", help="distributed training mode"
|
|
423
|
+
)
|
|
424
|
+
group.add_argument(
|
|
425
|
+
"--fsdp-sharding-strategy",
|
|
426
|
+
type=str,
|
|
427
|
+
choices=["shard-grad-op", "full-shard"],
|
|
428
|
+
default="shard-grad-op",
|
|
429
|
+
help="FSDP sharding strategy",
|
|
430
|
+
)
|
|
431
|
+
group.add_argument(
|
|
432
|
+
"--fsdp-param-dtype",
|
|
433
|
+
type=str,
|
|
434
|
+
choices=["float32", "float16", "bfloat16"],
|
|
435
|
+
help="FSDP mixed precision parameter dtype",
|
|
436
|
+
)
|
|
437
|
+
group.add_argument(
|
|
438
|
+
"--fsdp-reduce-dtype",
|
|
439
|
+
type=str,
|
|
440
|
+
choices=["float32", "float16", "bfloat16"],
|
|
441
|
+
help="FSDP mixed precision gradient reduction dtype",
|
|
442
|
+
)
|
|
443
|
+
group.add_argument(
|
|
444
|
+
"--fsdp-wrap-policy",
|
|
445
|
+
type=str,
|
|
446
|
+
choices=["block-group-regex", "min-num-params"],
|
|
447
|
+
default="block-group-regex",
|
|
448
|
+
help="FSDP module wrapping policy",
|
|
449
|
+
)
|
|
450
|
+
group.add_argument(
|
|
451
|
+
"--fsdp-wrap-min-num-params",
|
|
452
|
+
type=float,
|
|
453
|
+
metavar="M",
|
|
454
|
+
help="minimum module parameter count in millions for wrapping when using --fsdp-wrap-policy min-num-params",
|
|
455
|
+
)
|
|
456
|
+
group.add_argument(
|
|
457
|
+
"--fsdp-offload-policy", type=str, choices=["none", "cpu"], default="none", help="FSDP parameter offload policy"
|
|
458
|
+
)
|
|
414
459
|
group.add_argument(
|
|
415
460
|
"--find-unused-parameters",
|
|
416
461
|
default=False,
|
|
@@ -558,5 +603,27 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
558
603
|
raise cli.ValidationError("--embed-dim must be positive")
|
|
559
604
|
if args.context_length is not None and args.context_length <= 0:
|
|
560
605
|
raise cli.ValidationError("--context-length must be positive")
|
|
606
|
+
if args.grad_accum_steps < 1:
|
|
607
|
+
raise cli.ValidationError("--grad-accum-steps must be >= 1")
|
|
561
608
|
if args.model_ema_steps < 1:
|
|
562
609
|
raise cli.ValidationError("--model-ema-steps must be >= 1")
|
|
610
|
+
|
|
611
|
+
if args.distributed_mode == "fsdp":
|
|
612
|
+
if args.sync_bn is True:
|
|
613
|
+
raise cli.ValidationError("--sync-bn cannot be used with --distributed-mode fsdp")
|
|
614
|
+
if args.find_unused_parameters is True:
|
|
615
|
+
raise cli.ValidationError("--find-unused-parameters cannot be used with --distributed-mode fsdp")
|
|
616
|
+
if args.compile_opt is True:
|
|
617
|
+
raise cli.ValidationError("--compile-opt cannot be used with --distributed-mode fsdp")
|
|
618
|
+
if args.compile_fullgraph is True:
|
|
619
|
+
raise cli.ValidationError("--compile-fullgraph cannot be used with --distributed-mode fsdp")
|
|
620
|
+
if args.cpu is True:
|
|
621
|
+
raise cli.ValidationError("--cpu cannot be used with --distributed-mode fsdp")
|
|
622
|
+
if args.model_ema is True:
|
|
623
|
+
raise cli.ValidationError("--model-ema cannot be used with --distributed-mode fsdp")
|
|
624
|
+
if args.fsdp_wrap_policy == "min-num-params" and args.fsdp_wrap_min_num_params is None:
|
|
625
|
+
raise cli.ValidationError(
|
|
626
|
+
"--fsdp-wrap-min-num-params is required when --fsdp-wrap-policy is min-num-params"
|
|
627
|
+
)
|
|
628
|
+
if args.fsdp_wrap_min_num_params is not None and args.fsdp_wrap_min_num_params <= 0:
|
|
629
|
+
raise cli.ValidationError("--fsdp-wrap-min-num-params must be > 0")
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from birder.common import fsdp_utils
|
|
10
|
+
from birder.common import training_utils as birder_training_utils
|
|
11
|
+
|
|
12
|
+
from birder_clip.common import fs_ops
|
|
13
|
+
from birder_clip.conf import settings
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
|
|
17
|
+
file_handler = logging.FileHandler(log_file_path)
|
|
18
|
+
formatter = logging.Formatter(
|
|
19
|
+
fmt="{message}",
|
|
20
|
+
style="{",
|
|
21
|
+
)
|
|
22
|
+
file_handler.setFormatter(formatter)
|
|
23
|
+
file_handler.setLevel(settings.LOG_LEVEL)
|
|
24
|
+
|
|
25
|
+
logging.getLogger("birder").addHandler(file_handler)
|
|
26
|
+
logging.getLogger("birder_clip").addHandler(file_handler)
|
|
27
|
+
|
|
28
|
+
return file_handler
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def save_training_checkpoint(
|
|
32
|
+
args: argparse.Namespace,
|
|
33
|
+
network_name: str,
|
|
34
|
+
epoch: int,
|
|
35
|
+
net: torch.nn.Module,
|
|
36
|
+
signature: Any,
|
|
37
|
+
rgb_stats: Any,
|
|
38
|
+
optimizer: Optional[torch.optim.Optimizer],
|
|
39
|
+
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
|
|
40
|
+
scaler: Optional[torch.amp.grad_scaler.GradScaler],
|
|
41
|
+
model_base: Optional[torch.nn.Module],
|
|
42
|
+
*,
|
|
43
|
+
fsdp_mode: bool = False,
|
|
44
|
+
fsdp_model_state: Optional[dict[str, Any]] = None,
|
|
45
|
+
external_config: Optional[dict[str, Any]] = None,
|
|
46
|
+
**extra_states: Optional[dict[str, Any]],
|
|
47
|
+
) -> None:
|
|
48
|
+
if fsdp_mode is True:
|
|
49
|
+
if fsdp_model_state is not None:
|
|
50
|
+
model_state = fsdp_model_state
|
|
51
|
+
else:
|
|
52
|
+
model_state = fsdp_utils.gather_full_model_state_dict(net)
|
|
53
|
+
|
|
54
|
+
optimizer_state = None
|
|
55
|
+
scheduler_state = None
|
|
56
|
+
scaler_state = None
|
|
57
|
+
model_base_state = None
|
|
58
|
+
if optimizer is not None and scheduler is not None:
|
|
59
|
+
optimizer_state = fsdp_utils.gather_full_optimizer_state_dict(net, optimizer)
|
|
60
|
+
scheduler_state = scheduler.state_dict()
|
|
61
|
+
if scaler is not None:
|
|
62
|
+
scaler_state = scaler.state_dict()
|
|
63
|
+
if model_base is not None:
|
|
64
|
+
model_base_state = model_base.state_dict()
|
|
65
|
+
|
|
66
|
+
if birder_training_utils.is_global_primary(args) is True:
|
|
67
|
+
fs_ops.checkpoint_model_from_state_dicts(
|
|
68
|
+
network_name,
|
|
69
|
+
epoch,
|
|
70
|
+
model_state,
|
|
71
|
+
net.task,
|
|
72
|
+
signature,
|
|
73
|
+
rgb_stats,
|
|
74
|
+
optimizer_state,
|
|
75
|
+
scheduler_state,
|
|
76
|
+
scaler_state,
|
|
77
|
+
model_base_state,
|
|
78
|
+
external_config=external_config,
|
|
79
|
+
**extra_states,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if birder_training_utils.is_dist_available_and_initialized() is True:
|
|
83
|
+
dist.barrier()
|
|
84
|
+
|
|
85
|
+
else:
|
|
86
|
+
if birder_training_utils.is_global_primary(args) is True:
|
|
87
|
+
fs_ops.checkpoint_model(
|
|
88
|
+
network_name,
|
|
89
|
+
epoch,
|
|
90
|
+
net,
|
|
91
|
+
signature,
|
|
92
|
+
rgb_stats,
|
|
93
|
+
optimizer,
|
|
94
|
+
scheduler,
|
|
95
|
+
scaler,
|
|
96
|
+
model_base,
|
|
97
|
+
external_config=external_config,
|
|
98
|
+
**extra_states,
|
|
99
|
+
)
|