birder-clip 0.0.2.dev3__tar.gz → 0.0.2.dev5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/PKG-INFO +3 -3
  2. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/fs_ops.py +100 -3
  3. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/lib.py +11 -3
  4. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/training_cli.py +167 -100
  5. birder_clip-0.0.2.dev5/birder_clip/common/training_utils.py +99 -0
  6. birder_clip-0.0.2.dev5/birder_clip/data/datasets/webdataset.py +106 -0
  7. birder_clip-0.0.2.dev5/birder_clip/inference/zero_shot.py +128 -0
  8. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/loss/contrastive.py +11 -0
  9. birder_clip-0.0.2.dev5/birder_clip/net/clip.py +610 -0
  10. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/text/base.py +5 -0
  11. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/text/transformer.py +2 -0
  12. birder_clip-0.0.2.dev5/birder_clip/scripts/train.py +991 -0
  13. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/scripts/zero_shot.py +249 -180
  14. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/__init__.py +2 -0
  15. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/base.py +3 -0
  16. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/hf.py +13 -0
  17. birder_clip-0.0.2.dev5/birder_clip/tokenizers/openvision.py +64 -0
  18. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/simple_tokenizer.py +1 -0
  19. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/__main__.py +13 -0
  20. birder_clip-0.0.2.dev5/birder_clip/tools/convert_model.py +268 -0
  21. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/download_tokenizer.py +5 -4
  22. birder_clip-0.0.2.dev5/birder_clip/tools/list_models.py +102 -0
  23. birder_clip-0.0.2.dev5/birder_clip/tools/model_info.py +145 -0
  24. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/show_iterator.py +77 -11
  25. birder_clip-0.0.2.dev5/birder_clip/tools/stats.py +210 -0
  26. birder_clip-0.0.2.dev5/birder_clip/version.py +1 -0
  27. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/PKG-INFO +3 -3
  28. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/SOURCES.txt +8 -0
  29. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/requires.txt +2 -2
  30. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/requirements/_requirements-dev.txt +1 -1
  31. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/requirements/requirements.txt +1 -1
  32. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_common.py +1 -1
  33. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_datasets.py +6 -0
  34. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_net.py +1 -1
  35. birder_clip-0.0.2.dev3/birder_clip/inference/zero_shot.py +0 -54
  36. birder_clip-0.0.2.dev3/birder_clip/net/clip.py +0 -263
  37. birder_clip-0.0.2.dev3/birder_clip/version.py +0 -1
  38. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/LICENSE +0 -0
  39. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/README.md +0 -0
  40. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/__init__.py +0 -0
  41. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/common/__init__.py +0 -0
  42. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/conf/__init__.py +0 -0
  43. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/conf/settings.py +0 -0
  44. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/__init__.py +0 -0
  45. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/__init__.py +0 -0
  46. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/csv.py +0 -0
  47. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/data/datasets/fake.py +0 -0
  48. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/inference/__init__.py +0 -0
  49. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/inference/zero_shot_templates.py +0 -0
  50. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/loss/__init__.py +0 -0
  51. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/__init__.py +0 -0
  52. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/manifest.py +0 -0
  53. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/model_registry/model_registry.py +0 -0
  54. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/__init__.py +0 -0
  55. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/base.py +0 -0
  56. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/net/text/__init__.py +0 -0
  57. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/py.typed +0 -0
  58. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/scripts/__init__.py +0 -0
  59. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/bpe_simple_vocab_16e6.txt.gz +0 -0
  60. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tokenizers/registry.py +0 -0
  61. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip/tools/__init__.py +0 -0
  62. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/dependency_links.txt +0 -0
  63. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/birder_clip.egg-info/top_level.txt +0 -0
  64. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/pyproject.toml +0 -0
  65. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/setup.cfg +0 -0
  66. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_loss.py +0 -0
  67. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_model_registry.py +0 -0
  68. {birder_clip-0.0.2.dev3 → birder_clip-0.0.2.dev5}/tests/test_tokenizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder_clip
3
- Version: 0.0.2.dev3
3
+ Version: 0.0.2.dev5
4
4
  Summary: A Birder extension for CLIP-style image-text modeling and multimodal computer vision workflows.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -24,7 +24,7 @@ Classifier: Typing :: Typed
24
24
  Requires-Python: >=3.11
25
25
  Description-Content-Type: text/markdown
26
26
  License-File: LICENSE
27
- Requires-Dist: birder>=0.5.2
27
+ Requires-Dist: birder>=0.5.6
28
28
  Requires-Dist: ftfy>=6.3.1
29
29
  Requires-Dist: regex>=2025.7.29
30
30
  Requires-Dist: tqdm>=4.67.0
@@ -37,7 +37,7 @@ Provides-Extra: dev
37
37
  Requires-Dist: bandit~=1.9.4; extra == "dev"
38
38
  Requires-Dist: black~=26.5.0; extra == "dev"
39
39
  Requires-Dist: build~=1.5.0; extra == "dev"
40
- Requires-Dist: bumpver~=2025.1131; extra == "dev"
40
+ Requires-Dist: bumpver~=2026.1132; extra == "dev"
41
41
  Requires-Dist: coverage~=7.14.1; extra == "dev"
42
42
  Requires-Dist: debugpy; extra == "dev"
43
43
  Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
@@ -67,7 +67,7 @@ def model_path(
67
67
  network_name: str,
68
68
  *,
69
69
  epoch: Optional[int | str] = None,
70
- file_format: FileFormatType = "pt",
70
+ st: bool = False,
71
71
  states: bool = False,
72
72
  ) -> Path:
73
73
  if epoch is not None:
@@ -77,8 +77,10 @@ def model_path(
77
77
 
78
78
  if states is True:
79
79
  file_name = f"{file_name}_states.pt"
80
+ elif st is True:
81
+ file_name = f"{file_name}.safetensors"
80
82
  else:
81
- file_name = f"{file_name}.{file_format}"
83
+ file_name = f"{file_name}.pt"
82
84
 
83
85
  return settings.MODELS_DIR.joinpath(file_name)
84
86
 
@@ -109,6 +111,30 @@ def _checkpoint_states(
109
111
  torch.save(kwargs, states_path)
110
112
 
111
113
 
114
+ def _checkpoint_states_from_state_dicts(
115
+ states_path: Path,
116
+ optimizer_state: Optional[dict[str, Any]],
117
+ scheduler_state: Optional[dict[str, Any]],
118
+ scaler_state: Optional[dict[str, Any]],
119
+ model_base_state: Optional[dict[str, Any]],
120
+ **extra_states: Optional[dict[str, Any]],
121
+ ) -> None:
122
+ if optimizer_state is None or scheduler_state is None:
123
+ return
124
+
125
+ logger.info(f"Saving checkpoint states {states_path}...")
126
+ torch.save(
127
+ {
128
+ "optimizer_state": optimizer_state,
129
+ "scheduler_state": scheduler_state,
130
+ "scaler_state": scaler_state,
131
+ "model_base_state": model_base_state,
132
+ **extra_states,
133
+ },
134
+ states_path,
135
+ )
136
+
137
+
112
138
  class TrainingStates(NamedTuple):
113
139
  optimizer_state: Optional[dict[str, Any]]
114
140
  scheduler_state: Optional[dict[str, Any]]
@@ -182,6 +208,50 @@ def checkpoint_model(
182
208
  _checkpoint_states(states_path, optimizer, scheduler, scaler, model_base, **extra_states)
183
209
 
184
210
 
211
+ def checkpoint_model_from_state_dicts(
212
+ network_name: str,
213
+ epoch: int,
214
+ model_state: dict[str, Any],
215
+ task: Any,
216
+ signature: SignatureType,
217
+ rgb_stats: RGBType,
218
+ optimizer_state: Optional[dict[str, Any]],
219
+ scheduler_state: Optional[dict[str, Any]],
220
+ scaler_state: Optional[dict[str, Any]],
221
+ model_base_state: Optional[dict[str, Any]],
222
+ *,
223
+ external_config: Optional[dict[str, Any]] = None,
224
+ **extra_states: Optional[dict[str, Any]],
225
+ ) -> None:
226
+ kwargs = {}
227
+ if external_config is not None:
228
+ kwargs["config"] = external_config
229
+
230
+ path = model_path(network_name, epoch=epoch)
231
+ states_path = model_path(network_name, epoch=epoch, states=True)
232
+ logger.info(f"Saving model checkpoint {path}...")
233
+ torch.save(
234
+ {
235
+ "state": model_state,
236
+ "birder_clip_version": __version__,
237
+ "task": task,
238
+ "signature": signature,
239
+ "rgb_stats": rgb_stats,
240
+ **kwargs,
241
+ },
242
+ path,
243
+ )
244
+
245
+ _checkpoint_states_from_state_dicts(
246
+ states_path,
247
+ optimizer_state,
248
+ scheduler_state,
249
+ scaler_state,
250
+ model_base_state,
251
+ **extra_states,
252
+ )
253
+
254
+
185
255
  def clean_checkpoints(network_name: str, keep_last: int) -> None:
186
256
  epoch = "*[0-9]"
187
257
  models_glob = str(model_path(network_name, epoch=epoch))
@@ -314,7 +384,7 @@ def load_model(
314
384
  embed_dim=embed_dim,
315
385
  tokenizer=tokenizer,
316
386
  )
317
- path = model_path(_network_name, epoch=epoch, file_format="safetensors" if st is True else "pt")
387
+ path = model_path(_network_name, epoch=epoch, st=st)
318
388
 
319
389
  logger.info(f"Loading model from {path} on device {device}...")
320
390
 
@@ -589,6 +659,33 @@ def load_pretrained_tokenizer(weights: str, *, download: bool = True, **kwargs:
589
659
  return get_tokenizer(tokenizer_name, **tokenizer_kwargs)
590
660
 
591
661
 
662
+ def save_st(
663
+ net: torch.nn.Module,
664
+ dst: str,
665
+ task: str,
666
+ signature: SignatureType,
667
+ rgb_stats: RGBType,
668
+ *,
669
+ external_config: Optional[dict[str, Any]] = None,
670
+ ) -> None:
671
+ assert _HAS_SAFETENSORS, "'pip install safetensors' to use .safetensors"
672
+ kwargs = {}
673
+ if external_config is not None:
674
+ kwargs["config"] = json.dumps(external_config)
675
+
676
+ safetensors.torch.save_model(
677
+ net,
678
+ str(dst),
679
+ {
680
+ "birder_clip_version": __version__,
681
+ "task": task,
682
+ "signature": json.dumps(signature),
683
+ "rgb_stats": json.dumps(rgb_stats),
684
+ **kwargs,
685
+ },
686
+ )
687
+
688
+
592
689
  def download_model_by_weights(
593
690
  weights: str, *, dst: Optional[str | Path] = None, file_format: FileFormatType = "pt", progress_bar: bool = True
594
691
  ) -> None:
@@ -43,9 +43,17 @@ def get_image_text_network_name(
43
43
  parts = [network]
44
44
  if image_encoder is not None:
45
45
  parts.append(image_encoder)
46
- if text_encoder is not None:
46
+ if text_encoder is not None and text_encoder != "text_transformer":
47
47
  parts.append(text_encoder)
48
- if tokenizer is not None:
48
+
49
+ if registry.exists(network) is True:
50
+ default_tokenizer = registry.get_default_tokenizer(network)
51
+ else:
52
+ default_tokenizer = "simple_tokenizer"
53
+ if default_tokenizer is None:
54
+ default_tokenizer = "simple_tokenizer"
55
+
56
+ if tokenizer is not None and tokenizer != default_tokenizer:
49
57
  parts.append(tokenizer)
50
58
  if embed_dim is not None:
51
59
  parts.append(f"d{embed_dim}")
@@ -78,7 +86,7 @@ def get_image_text_model_config(
78
86
 
79
87
  if config is not None:
80
88
  for key, value in config.items():
81
- if key in {"image", "text"} and isinstance(value, dict) is True:
89
+ if key in {"image", "text"} and isinstance(value, dict):
82
90
  model_config[key] = {**model_config.get(key, {}), **value}
83
91
  else:
84
92
  model_config[key] = value
@@ -17,32 +17,6 @@ from birder_clip.model_registry import Task
17
17
  from birder_clip.model_registry import registry
18
18
 
19
19
 
20
- def add_compile_args(parser: argparse.ArgumentParser) -> None:
21
- group = parser.add_argument_group("Compilation parameters")
22
- group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
23
- group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
24
- group.add_argument(
25
- "--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
26
- )
27
- group.add_argument(
28
- "--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
29
- )
30
- group.add_argument(
31
- "--compile-recompile-limit",
32
- type=int,
33
- default=torch.compiler.config.recompile_limit,
34
- metavar="N",
35
- help="maximum recompilations per compiled function before eager fallback",
36
- )
37
- group.add_argument(
38
- "--compile-accumulated-recompile-limit",
39
- type=int,
40
- default=torch.compiler.config.accumulated_recompile_limit,
41
- metavar="N",
42
- help="maximum total recompilations across compiled functions",
43
- )
44
-
45
-
46
20
  def add_model_args(parser: argparse.ArgumentParser) -> None:
47
21
  parser.add_argument("-n", "--network", type=str, help="the image-text network to train")
48
22
  parser.add_argument("-t", "--tag", type=str, help="add model tag")
@@ -92,6 +66,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
92
66
  metavar="N",
93
67
  help="number of iterations to accumulate gradients per optimizer step",
94
68
  )
69
+ # NOTE: Add flag for negative sample caching in grad accum mode
95
70
 
96
71
 
97
72
  def add_lr_wd_args(parser: argparse.ArgumentParser) -> None:
@@ -129,14 +104,14 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
129
104
  "--lr-scheduler-update",
130
105
  type=str,
131
106
  choices=["epoch", "step"],
132
- default="step",
107
+ default="epoch",
133
108
  help="when to apply learning rate scheduler update: epoch (once per epoch), step (each optimizer step)",
134
109
  )
135
110
  group.add_argument(
136
111
  "--lr-scheduler",
137
112
  type=str,
138
113
  choices=list(get_args(SchedulerType)),
139
- default="cosine",
114
+ default="constant",
140
115
  help="learning rate scheduler",
141
116
  )
142
117
  group.add_argument(
@@ -175,15 +150,6 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
175
150
  )
176
151
 
177
152
 
178
- def add_input_args(parser: argparse.ArgumentParser) -> None:
179
- group = parser.add_argument_group("Input parameters")
180
- group.add_argument(
181
- "--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
182
- )
183
- group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
184
- group.add_argument("--context-length", type=int, metavar="N", help="text context length")
185
-
186
-
187
153
  def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
188
154
  group = parser.add_argument_group("Training schedule parameters")
189
155
  group.add_argument("--epochs", type=int, default=default_epochs, metavar="N", help="number of training epochs")
@@ -204,6 +170,37 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
204
170
  )
205
171
 
206
172
 
173
+ def add_ema_args(
174
+ parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
175
+ ) -> None:
176
+ group = parser.add_argument_group("Exponential moving average parameters")
177
+ group.add_argument(
178
+ "--model-ema",
179
+ default=False,
180
+ action="store_true",
181
+ help="enable tracking exponential moving average of model parameters",
182
+ )
183
+ group.add_argument(
184
+ "--model-ema-steps",
185
+ type=int,
186
+ default=default_ema_steps,
187
+ metavar="N",
188
+ help="number of optimizer steps between EMA updates",
189
+ )
190
+ group.add_argument(
191
+ "--model-ema-decay",
192
+ type=float,
193
+ default=default_ema_decay,
194
+ help="decay factor for exponential moving average of model parameters",
195
+ )
196
+ group.add_argument(
197
+ "--model-ema-warmup",
198
+ type=int,
199
+ metavar="N",
200
+ help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
201
+ )
202
+
203
+
207
204
  def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
208
205
  group = parser.add_argument_group("Batch normalization parameters")
209
206
  group.add_argument(
@@ -215,6 +212,15 @@ def add_batch_norm_args(parser: argparse.ArgumentParser) -> None:
215
212
  group.add_argument("--sync-bn", default=False, action="store_true", help="use synchronized BatchNorm")
216
213
 
217
214
 
215
+ def add_input_args(parser: argparse.ArgumentParser) -> None:
216
+ group = parser.add_argument_group("Input parameters")
217
+ group.add_argument(
218
+ "--channels", type=int, default=settings.DEFAULT_NUM_CHANNELS, metavar="N", help="no. of image channels"
219
+ )
220
+ group.add_argument("--size", type=int, nargs="+", metavar=("H", "W"), help="image size")
221
+ group.add_argument("--context-length", type=int, metavar="N", help="text context length")
222
+
223
+
218
224
  def add_data_aug_args(
219
225
  parser: argparse.ArgumentParser,
220
226
  default_level: int = 4,
@@ -260,7 +266,7 @@ def add_data_aug_args(
260
266
  "--rgb-mode",
261
267
  type=str,
262
268
  choices=list(typing.get_args(RGBMode)),
263
- default="clip",
269
+ default="birder",
264
270
  help="RGB mean and std to use for normalization",
265
271
  )
266
272
  group.add_argument(
@@ -279,67 +285,6 @@ def add_data_aug_args(
279
285
  )
280
286
 
281
287
 
282
- def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
283
- group = parser.add_argument_group("Checkpoint parameters")
284
- group.add_argument(
285
- "--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
286
- )
287
- group.add_argument(
288
- "--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
289
- )
290
- group.add_argument(
291
- "--pretrained",
292
- default=False,
293
- action="store_true",
294
- help="start with pretrained version of specified network (will download if not found locally)",
295
- )
296
- group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
297
- group.add_argument(
298
- "--non-strict-weights",
299
- default=False,
300
- action="store_true",
301
- help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
302
- )
303
- group.add_argument(
304
- "--load-states",
305
- default=False,
306
- action="store_true",
307
- help="load optimizer, scheduler and scaler states when resuming",
308
- )
309
- group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
310
-
311
-
312
- def add_ema_args(
313
- parser: argparse.ArgumentParser, default_ema_steps: int = 1, default_ema_decay: float = 0.9999
314
- ) -> None:
315
- group = parser.add_argument_group("Exponential moving average parameters")
316
- group.add_argument(
317
- "--model-ema",
318
- default=False,
319
- action="store_true",
320
- help="enable tracking exponential moving average of model parameters",
321
- )
322
- group.add_argument(
323
- "--model-ema-steps",
324
- type=int,
325
- default=default_ema_steps,
326
- metavar="N",
327
- help="number of optimizer steps between EMA updates",
328
- )
329
- group.add_argument(
330
- "--model-ema-decay",
331
- type=float,
332
- default=default_ema_decay,
333
- help="decay factor for exponential moving average of model parameters",
334
- )
335
- group.add_argument(
336
- "--model-ema-warmup",
337
- type=int,
338
- metavar="N",
339
- help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
340
- )
341
-
342
-
343
288
  def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
344
289
  group = parser.add_argument_group("Dataloader parameters")
345
290
  group.add_argument(
@@ -375,7 +320,13 @@ def add_dataloader_args(parser: argparse.ArgumentParser) -> None:
375
320
  action="store_true",
376
321
  help="keep dataloader worker processes alive between epochs",
377
322
  )
378
- group.add_argument("--drop-last", default=False, action="store_true", help="drop the last incomplete batch")
323
+ group.add_argument(
324
+ "--no-drop-last",
325
+ dest="drop_last",
326
+ default=True,
327
+ action="store_false",
328
+ help="do not drop the last incomplete batch",
329
+ )
379
330
 
380
331
 
381
332
  def add_precision_args(parser: argparse.ArgumentParser) -> None:
@@ -405,12 +356,106 @@ def add_precision_args(parser: argparse.ArgumentParser) -> None:
405
356
  )
406
357
 
407
358
 
359
+ def add_compile_args(parser: argparse.ArgumentParser) -> None:
360
+ group = parser.add_argument_group("Compilation parameters")
361
+ group.add_argument("--compile", default=False, action="store_true", help="enable compilation")
362
+ group.add_argument("--compile-fullgraph", default=False, action="store_true", help="compile using fullgraph=True")
363
+ group.add_argument(
364
+ "--compile-mode", type=str, choices=list(torch._inductor.list_mode_options().keys()), help="torch.compile mode"
365
+ )
366
+ group.add_argument(
367
+ "--compile-opt", default=False, action="store_true", help="enable compilation for optimizer step"
368
+ )
369
+ group.add_argument(
370
+ "--compile-recompile-limit",
371
+ type=int,
372
+ default=torch.compiler.config.recompile_limit,
373
+ metavar="N",
374
+ help="maximum recompilations per compiled function before eager fallback",
375
+ )
376
+ group.add_argument(
377
+ "--compile-accumulated-recompile-limit",
378
+ type=int,
379
+ default=torch.compiler.config.accumulated_recompile_limit,
380
+ metavar="N",
381
+ help="maximum total recompilations across compiled functions",
382
+ )
383
+
384
+
385
+ def add_checkpoint_args(parser: argparse.ArgumentParser, default_save_frequency: int = 1) -> None:
386
+ group = parser.add_argument_group("Checkpoint parameters")
387
+ group.add_argument(
388
+ "--save-frequency", type=int, default=default_save_frequency, metavar="N", help="frequency of model saving"
389
+ )
390
+ group.add_argument(
391
+ "--keep-last", type=int, metavar="N", help="number of recent checkpoints to keep (older ones are deleted)"
392
+ )
393
+ group.add_argument(
394
+ "--pretrained",
395
+ default=False,
396
+ action="store_true",
397
+ help="start with pretrained version of specified network (will download if not found locally)",
398
+ )
399
+ group.add_argument("--resume-epoch", type=int, metavar="N", help="epoch number to resume training from")
400
+ group.add_argument(
401
+ "--non-strict-weights",
402
+ default=False,
403
+ action="store_true",
404
+ help="allow non-strict loading of model weights (missing or unexpected keys in state_dict)",
405
+ )
406
+ group.add_argument(
407
+ "--load-states",
408
+ default=False,
409
+ action="store_true",
410
+ help="load optimizer, scheduler and scaler states when resuming",
411
+ )
412
+ group.add_argument("--load-scheduler", default=False, action="store_true", help="load only scheduler when resuming")
413
+
414
+
408
415
  def add_distributed_args(parser: argparse.ArgumentParser) -> None:
409
416
  group = parser.add_argument_group("Distributed training parameters")
410
417
  group.add_argument("--world-size", type=int, default=1, metavar="N", help="number of distributed processes")
411
418
  group.add_argument("--local-rank", type=int, metavar="N", help="local rank")
412
419
  group.add_argument("--dist-url", type=str, default="env://", help="URL used to initialize distributed training")
413
420
  group.add_argument("--dist-backend", type=str, default="nccl", help="distributed backend")
421
+ group.add_argument(
422
+ "--distributed-mode", type=str, choices=["ddp", "fsdp"], default="ddp", help="distributed training mode"
423
+ )
424
+ group.add_argument(
425
+ "--fsdp-sharding-strategy",
426
+ type=str,
427
+ choices=["shard-grad-op", "full-shard"],
428
+ default="shard-grad-op",
429
+ help="FSDP sharding strategy",
430
+ )
431
+ group.add_argument(
432
+ "--fsdp-param-dtype",
433
+ type=str,
434
+ choices=["float32", "float16", "bfloat16"],
435
+ help="FSDP mixed precision parameter dtype",
436
+ )
437
+ group.add_argument(
438
+ "--fsdp-reduce-dtype",
439
+ type=str,
440
+ choices=["float32", "float16", "bfloat16"],
441
+ help="FSDP mixed precision gradient reduction dtype",
442
+ )
443
+ group.add_argument(
444
+ "--fsdp-wrap-policy",
445
+ type=str,
446
+ choices=["block-group-regex", "min-num-params"],
447
+ default="block-group-regex",
448
+ help="FSDP module wrapping policy",
449
+ )
450
+ group.add_argument(
451
+ "--fsdp-wrap-min-num-params",
452
+ type=float,
453
+ metavar="M",
454
+ help="minimum module parameter count in millions for wrapping when using --fsdp-wrap-policy min-num-params",
455
+ )
456
+ group.add_argument(
457
+ "--fsdp-offload-policy", type=str, choices=["none", "cpu"], default="none", help="FSDP parameter offload policy"
458
+ )
414
459
  group.add_argument(
415
460
  "--find-unused-parameters",
416
461
  default=False,
@@ -558,5 +603,27 @@ def common_args_validation(args: argparse.Namespace) -> None:
558
603
  raise cli.ValidationError("--embed-dim must be positive")
559
604
  if args.context_length is not None and args.context_length <= 0:
560
605
  raise cli.ValidationError("--context-length must be positive")
606
+ if args.grad_accum_steps < 1:
607
+ raise cli.ValidationError("--grad-accum-steps must be >= 1")
561
608
  if args.model_ema_steps < 1:
562
609
  raise cli.ValidationError("--model-ema-steps must be >= 1")
610
+
611
+ if args.distributed_mode == "fsdp":
612
+ if args.sync_bn is True:
613
+ raise cli.ValidationError("--sync-bn cannot be used with --distributed-mode fsdp")
614
+ if args.find_unused_parameters is True:
615
+ raise cli.ValidationError("--find-unused-parameters cannot be used with --distributed-mode fsdp")
616
+ if args.compile_opt is True:
617
+ raise cli.ValidationError("--compile-opt cannot be used with --distributed-mode fsdp")
618
+ if args.compile_fullgraph is True:
619
+ raise cli.ValidationError("--compile-fullgraph cannot be used with --distributed-mode fsdp")
620
+ if args.cpu is True:
621
+ raise cli.ValidationError("--cpu cannot be used with --distributed-mode fsdp")
622
+ if args.model_ema is True:
623
+ raise cli.ValidationError("--model-ema cannot be used with --distributed-mode fsdp")
624
+ if args.fsdp_wrap_policy == "min-num-params" and args.fsdp_wrap_min_num_params is None:
625
+ raise cli.ValidationError(
626
+ "--fsdp-wrap-min-num-params is required when --fsdp-wrap-policy is min-num-params"
627
+ )
628
+ if args.fsdp_wrap_min_num_params is not None and args.fsdp_wrap_min_num_params <= 0:
629
+ raise cli.ValidationError("--fsdp-wrap-min-num-params must be > 0")
@@ -0,0 +1,99 @@
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from birder.common import fsdp_utils
10
+ from birder.common import training_utils as birder_training_utils
11
+
12
+ from birder_clip.common import fs_ops
13
+ from birder_clip.conf import settings
14
+
15
+
16
+ def setup_file_logging(log_file_path: str | Path) -> logging.Handler:
17
+ file_handler = logging.FileHandler(log_file_path)
18
+ formatter = logging.Formatter(
19
+ fmt="{message}",
20
+ style="{",
21
+ )
22
+ file_handler.setFormatter(formatter)
23
+ file_handler.setLevel(settings.LOG_LEVEL)
24
+
25
+ logging.getLogger("birder").addHandler(file_handler)
26
+ logging.getLogger("birder_clip").addHandler(file_handler)
27
+
28
+ return file_handler
29
+
30
+
31
+ def save_training_checkpoint(
32
+ args: argparse.Namespace,
33
+ network_name: str,
34
+ epoch: int,
35
+ net: torch.nn.Module,
36
+ signature: Any,
37
+ rgb_stats: Any,
38
+ optimizer: Optional[torch.optim.Optimizer],
39
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
40
+ scaler: Optional[torch.amp.grad_scaler.GradScaler],
41
+ model_base: Optional[torch.nn.Module],
42
+ *,
43
+ fsdp_mode: bool = False,
44
+ fsdp_model_state: Optional[dict[str, Any]] = None,
45
+ external_config: Optional[dict[str, Any]] = None,
46
+ **extra_states: Optional[dict[str, Any]],
47
+ ) -> None:
48
+ if fsdp_mode is True:
49
+ if fsdp_model_state is not None:
50
+ model_state = fsdp_model_state
51
+ else:
52
+ model_state = fsdp_utils.gather_full_model_state_dict(net)
53
+
54
+ optimizer_state = None
55
+ scheduler_state = None
56
+ scaler_state = None
57
+ model_base_state = None
58
+ if optimizer is not None and scheduler is not None:
59
+ optimizer_state = fsdp_utils.gather_full_optimizer_state_dict(net, optimizer)
60
+ scheduler_state = scheduler.state_dict()
61
+ if scaler is not None:
62
+ scaler_state = scaler.state_dict()
63
+ if model_base is not None:
64
+ model_base_state = model_base.state_dict()
65
+
66
+ if birder_training_utils.is_global_primary(args) is True:
67
+ fs_ops.checkpoint_model_from_state_dicts(
68
+ network_name,
69
+ epoch,
70
+ model_state,
71
+ net.task,
72
+ signature,
73
+ rgb_stats,
74
+ optimizer_state,
75
+ scheduler_state,
76
+ scaler_state,
77
+ model_base_state,
78
+ external_config=external_config,
79
+ **extra_states,
80
+ )
81
+
82
+ if birder_training_utils.is_dist_available_and_initialized() is True:
83
+ dist.barrier()
84
+
85
+ else:
86
+ if birder_training_utils.is_global_primary(args) is True:
87
+ fs_ops.checkpoint_model(
88
+ network_name,
89
+ epoch,
90
+ net,
91
+ signature,
92
+ rgb_stats,
93
+ optimizer,
94
+ scheduler,
95
+ scaler,
96
+ model_base,
97
+ external_config=external_config,
98
+ **extra_states,
99
+ )