birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +18 -0
  3. birder/common/training_utils.py +123 -10
  4. birder/data/collators/detection.py +10 -3
  5. birder/data/datasets/coco.py +8 -10
  6. birder/data/transforms/detection.py +30 -13
  7. birder/inference/detection.py +108 -4
  8. birder/inference/wbf.py +226 -0
  9. birder/net/__init__.py +8 -0
  10. birder/net/detection/efficientdet.py +65 -86
  11. birder/net/detection/rt_detr_v1.py +1 -0
  12. birder/net/detection/yolo_anchors.py +205 -0
  13. birder/net/detection/yolo_v2.py +25 -24
  14. birder/net/detection/yolo_v3.py +39 -40
  15. birder/net/detection/yolo_v4.py +28 -26
  16. birder/net/detection/yolo_v4_tiny.py +24 -20
  17. birder/net/fasternet.py +1 -1
  18. birder/net/gc_vit.py +671 -0
  19. birder/net/lit_v1.py +472 -0
  20. birder/net/lit_v1_tiny.py +342 -0
  21. birder/net/lit_v2.py +436 -0
  22. birder/net/mobilenet_v4_hybrid.py +1 -1
  23. birder/net/resnet_v1.py +1 -1
  24. birder/net/resnext.py +67 -25
  25. birder/net/se_resnet_v1.py +46 -0
  26. birder/net/se_resnext.py +3 -0
  27. birder/net/simple_vit.py +2 -2
  28. birder/net/vit.py +0 -15
  29. birder/net/vovnet_v2.py +31 -1
  30. birder/scripts/benchmark.py +90 -21
  31. birder/scripts/predict.py +1 -0
  32. birder/scripts/predict_detection.py +18 -11
  33. birder/scripts/train.py +10 -34
  34. birder/scripts/train_barlow_twins.py +10 -34
  35. birder/scripts/train_byol.py +10 -34
  36. birder/scripts/train_capi.py +10 -35
  37. birder/scripts/train_data2vec.py +9 -34
  38. birder/scripts/train_data2vec2.py +9 -34
  39. birder/scripts/train_detection.py +48 -40
  40. birder/scripts/train_dino_v1.py +10 -34
  41. birder/scripts/train_dino_v2.py +9 -34
  42. birder/scripts/train_dino_v2_dist.py +9 -34
  43. birder/scripts/train_franca.py +9 -34
  44. birder/scripts/train_i_jepa.py +9 -34
  45. birder/scripts/train_ibot.py +9 -34
  46. birder/scripts/train_kd.py +156 -64
  47. birder/scripts/train_mim.py +10 -34
  48. birder/scripts/train_mmcr.py +10 -34
  49. birder/scripts/train_rotnet.py +10 -34
  50. birder/scripts/train_simclr.py +10 -34
  51. birder/scripts/train_vicreg.py +10 -34
  52. birder/tools/auto_anchors.py +20 -1
  53. birder/tools/pack.py +172 -103
  54. birder/tools/show_det_iterator.py +10 -1
  55. birder/version.py +1 -1
  56. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
  57. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
  58. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
  59. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
  60. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
  61. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,6 @@ from birder.common import training_utils
33
33
  from birder.common.lib import format_duration
34
34
  from birder.common.lib import get_mim_network_name
35
35
  from birder.common.lib import get_network_name
36
- from birder.common.lib import set_random_seeds
37
36
  from birder.conf import settings
38
37
  from birder.data.dataloader.webdataset import make_wds_loader
39
38
  from birder.data.datasets.directory import make_image_dataset
@@ -67,41 +66,13 @@ def train(args: argparse.Namespace) -> None:
67
66
  #
68
67
  # Initialize
69
68
  #
70
- training_utils.init_distributed_mode(args)
71
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
72
- training_utils.log_git_info()
69
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
73
70
 
74
71
  if args.size is None:
75
72
  args.size = registry.get_default_size(args.network)
76
73
 
77
74
  logger.info(f"Using size={args.size}")
78
75
 
79
- if args.cpu is True:
80
- device = torch.device("cpu")
81
- device_id = 0
82
- else:
83
- device = torch.device("cuda")
84
- device_id = torch.cuda.current_device()
85
-
86
- if args.use_deterministic_algorithms is True:
87
- torch.backends.cudnn.benchmark = False
88
- torch.use_deterministic_algorithms(True)
89
- else:
90
- torch.backends.cudnn.benchmark = True
91
-
92
- if args.seed is not None:
93
- set_random_seeds(args.seed)
94
-
95
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
96
- disable_tqdm = True
97
- elif sys.stderr.isatty() is False:
98
- disable_tqdm = True
99
- else:
100
- disable_tqdm = False
101
-
102
- # Enable or disable the autograd anomaly detection
103
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
104
-
105
76
  #
106
77
  # Data
107
78
  #
@@ -148,7 +119,7 @@ def train(args: argparse.Namespace) -> None:
148
119
 
149
120
  batch_size: int = args.batch_size
150
121
  grad_accum_steps: int = args.grad_accum_steps
151
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
122
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
152
123
 
153
124
  # Data loaders and samplers
154
125
  if args.distributed is True:
@@ -189,6 +160,8 @@ def train(args: argparse.Namespace) -> None:
189
160
  else:
190
161
  args.stop_epoch += 1
191
162
 
163
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
164
+
192
165
  #
193
166
  # Initialize network
194
167
  #
@@ -236,22 +209,25 @@ def train(args: argparse.Namespace) -> None:
236
209
  # Optimizer, learning rate scheduler and training parameter groups
237
210
  #
238
211
 
212
+ # Learning rate scaling
213
+ lr = training_utils.scale_lr(args)
214
+
239
215
  # Training parameter groups
240
216
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
241
217
  parameters = training_utils.optimizer_parameter_groups(
242
218
  net,
243
219
  args.wd,
220
+ base_lr=lr,
244
221
  norm_weight_decay=args.norm_wd,
245
222
  custom_keys_weight_decay=custom_keys_weight_decay,
223
+ custom_layer_weight_decay=args.custom_layer_wd,
246
224
  layer_decay=args.layer_decay,
247
225
  layer_decay_min_scale=args.layer_decay_min_scale,
248
226
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
249
227
  bias_lr=args.bias_lr,
228
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
250
229
  )
251
230
 
252
- # Learning rate scaling
253
- lr = training_utils.scale_lr(args)
254
-
255
231
  if args.lr_scheduler_update == "epoch":
256
232
  step_update = False
257
233
  scheduler_steps_per_epoch = 1
@@ -36,7 +36,6 @@ from birder.common import training_utils
36
36
  from birder.common.lib import format_duration
37
37
  from birder.common.lib import get_mim_network_name
38
38
  from birder.common.lib import get_network_name
39
- from birder.common.lib import set_random_seeds
40
39
  from birder.conf import settings
41
40
  from birder.data.dataloader.webdataset import make_wds_loader
42
41
  from birder.data.datasets.directory import make_image_dataset
@@ -70,41 +69,13 @@ def train(args: argparse.Namespace) -> None:
70
69
  #
71
70
  # Initialize
72
71
  #
73
- training_utils.init_distributed_mode(args)
74
- logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
75
- training_utils.log_git_info()
72
+ (device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
76
73
 
77
74
  if args.size is None:
78
75
  args.size = registry.get_default_size(args.network)
79
76
 
80
77
  logger.info(f"Using size={args.size}")
81
78
 
82
- if args.cpu is True:
83
- device = torch.device("cpu")
84
- device_id = 0
85
- else:
86
- device = torch.device("cuda")
87
- device_id = torch.cuda.current_device()
88
-
89
- if args.use_deterministic_algorithms is True:
90
- torch.backends.cudnn.benchmark = False
91
- torch.use_deterministic_algorithms(True)
92
- else:
93
- torch.backends.cudnn.benchmark = True
94
-
95
- if args.seed is not None:
96
- set_random_seeds(args.seed)
97
-
98
- if args.non_interactive is True or training_utils.is_local_primary(args) is False:
99
- disable_tqdm = True
100
- elif sys.stderr.isatty() is False:
101
- disable_tqdm = True
102
- else:
103
- disable_tqdm = False
104
-
105
- # Enable or disable the autograd anomaly detection
106
- torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
107
-
108
79
  #
109
80
  # Data
110
81
  #
@@ -151,7 +122,7 @@ def train(args: argparse.Namespace) -> None:
151
122
 
152
123
  batch_size: int = args.batch_size
153
124
  grad_accum_steps: int = args.grad_accum_steps
154
- logger.debug(f"Effective batch size = {args.batch_size * grad_accum_steps * args.world_size}")
125
+ logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
155
126
 
156
127
  # Data loaders and samplers
157
128
  if args.distributed is True:
@@ -192,6 +163,8 @@ def train(args: argparse.Namespace) -> None:
192
163
  else:
193
164
  args.stop_epoch += 1
194
165
 
166
+ logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
167
+
195
168
  #
196
169
  # Initialize network
197
170
  #
@@ -242,22 +215,25 @@ def train(args: argparse.Namespace) -> None:
242
215
  # Loss criteria, optimizer, learning rate scheduler and training parameter groups
243
216
  #
244
217
 
218
+ # Learning rate scaling
219
+ lr = training_utils.scale_lr(args)
220
+
245
221
  # Training parameter groups
246
222
  custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
247
223
  parameters = training_utils.optimizer_parameter_groups(
248
224
  net,
249
225
  args.wd,
226
+ base_lr=lr,
250
227
  norm_weight_decay=args.norm_wd,
251
228
  custom_keys_weight_decay=custom_keys_weight_decay,
229
+ custom_layer_weight_decay=args.custom_layer_wd,
252
230
  layer_decay=args.layer_decay,
253
231
  layer_decay_min_scale=args.layer_decay_min_scale,
254
232
  layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
255
233
  bias_lr=args.bias_lr,
234
+ custom_layer_lr_scale=args.custom_layer_lr_scale,
256
235
  )
257
236
 
258
- # Learning rate scaling
259
- lr = training_utils.scale_lr(args)
260
-
261
237
  if args.lr_scheduler_update == "epoch":
262
238
  step_update = False
263
239
  scheduler_steps_per_epoch = 1
@@ -242,6 +242,7 @@ def _validate_args(
242
242
  return (size, num_scales, num_anchors, output_format, strides)
243
243
 
244
244
 
245
+ # pylint: disable=too-many-locals
245
246
  def auto_anchors(args: argparse.Namespace) -> None:
246
247
  (size, num_scales, num_anchors, output_format, strides) = _validate_args(args)
247
248
 
@@ -272,6 +273,7 @@ def auto_anchors(args: argparse.Namespace) -> None:
272
273
  logger.info(f"Mean IoU: {best_iou.mean().item():.4f}")
273
274
 
274
275
  formatted_groups = _format_anchor_groups(anchor_groups, args.precision)
276
+ anchors_output = None
275
277
  if output_format == "pixels":
276
278
  if num_scales == 1:
277
279
  formatted_anchors: Any = formatted_groups[0]
@@ -280,6 +282,7 @@ def auto_anchors(args: argparse.Namespace) -> None:
280
282
 
281
283
  print("Anchors (pixels):")
282
284
  print(pformat(formatted_anchors))
285
+ anchors_output = formatted_anchors
283
286
 
284
287
  if output_format == "grid":
285
288
  grid_groups: list[torch.Tensor] = []
@@ -297,6 +300,21 @@ def auto_anchors(args: argparse.Namespace) -> None:
297
300
 
298
301
  print("Anchors (grid units):")
299
302
  print(pformat(formatted_grid_output))
303
+ anchors_output = formatted_grid_output
304
+
305
+ if args.output is not None:
306
+ payload = {
307
+ "anchors": anchors_output,
308
+ "format": output_format,
309
+ "size": [size[0], size[1]],
310
+ }
311
+ if output_format == "grid":
312
+ payload["strides"] = strides
313
+
314
+ with open(args.output, "w", encoding="utf-8") as handle:
315
+ json.dump(payload, handle, indent=2)
316
+
317
+ logger.info(f"Wrote anchors to {args.output}")
300
318
 
301
319
 
302
320
  def set_parser(subparsers: Any) -> None:
@@ -312,7 +330,7 @@ def set_parser(subparsers: Any) -> None:
312
330
  "python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 --format pixels "
313
331
  "--coco-json-path data/detection_data/training_annotations_coco.json\n"
314
332
  "python -m birder.tools auto-anchors --preset yolo_v4_tiny --size 416 416 "
315
- "--coco-json-path ~/Datasets/cocodataset/annotations/instances_train2017.json\n"
333
+ "--coco-json-path ~/Datasets/cocodataset/annotations/instances_train2017.json --output anchors.json\n"
316
334
  "python -m birder.tools auto-anchors --preset yolo_v2 --stride 32 "
317
335
  "--coco-json-path data/detection_data/training_annotations_coco.json\n"
318
336
  "python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 "
@@ -354,6 +372,7 @@ def set_parser(subparsers: Any) -> None:
354
372
  default=f"{settings.TRAINING_DETECTION_ANNOTATIONS_PATH}_coco.json",
355
373
  help="training COCO json path",
356
374
  )
375
+ subparser.add_argument("--output", type=str, help="write anchors as JSON to this path")
357
376
  subparser.set_defaults(func=main)
358
377
 
359
378
 
birder/tools/pack.py CHANGED
@@ -3,6 +3,7 @@ import json
3
3
  import logging
4
4
  import multiprocessing
5
5
  import os
6
+ import queue
6
7
  import signal
7
8
  import time
8
9
  from collections.abc import Callable
@@ -67,39 +68,43 @@ def _save_classes(pack_path: Path, class_to_idx: dict[str, int]) -> None:
67
68
 
68
69
 
69
70
  def _encode_image(path: str, file_format: str, size: Optional[int] = None) -> bytes:
70
- image: Image.Image = Image.open(path)
71
- if size is not None and size < min(image.size):
72
- if image.size[0] > image.size[1]:
73
- ratio = image.size[1] / size
74
- else:
75
- ratio = image.size[0] / size
76
-
77
- width = round(image.size[0] / ratio)
78
- height = round(image.size[1] / ratio)
79
- if max(width, height) > MAX_SIZE:
80
- if width > height:
81
- ratio = width / MAX_SIZE
71
+ image: Image.Image
72
+ with Image.open(path) as image:
73
+ if file_format.lower() in ("jpeg", "jpg") and image.mode in ("RGBA", "P"):
74
+ image = image.convert("RGB")
75
+
76
+ if size is not None and size < min(image.size):
77
+ if image.size[0] > image.size[1]:
78
+ ratio = image.size[1] / size
82
79
  else:
83
- ratio = height / MAX_SIZE
80
+ ratio = image.size[0] / size
81
+
82
+ width = round(image.size[0] / ratio)
83
+ height = round(image.size[1] / ratio)
84
+ if max(width, height) > MAX_SIZE:
85
+ if width > height:
86
+ ratio = width / MAX_SIZE
87
+ else:
88
+ ratio = height / MAX_SIZE
84
89
 
85
- width = round(width / ratio)
86
- height = round(height / ratio)
90
+ width = round(width / ratio)
91
+ height = round(height / ratio)
87
92
 
88
- image = image.resize((width, height), Image.Resampling.BICUBIC)
93
+ image = image.resize((width, height), Image.Resampling.BICUBIC)
89
94
 
90
- elif max(image.size) > MAX_SIZE:
91
- if image.size[0] > image.size[1]:
92
- ratio = image.size[0] / MAX_SIZE
93
- else:
94
- ratio = image.size[1] / MAX_SIZE
95
+ elif max(image.size) > MAX_SIZE:
96
+ if image.size[0] > image.size[1]:
97
+ ratio = image.size[0] / MAX_SIZE
98
+ else:
99
+ ratio = image.size[1] / MAX_SIZE
95
100
 
96
- width = round(image.size[0] / ratio)
97
- height = round(image.size[1] / ratio)
98
- image = image.resize((width, height), Image.Resampling.BICUBIC)
101
+ width = round(image.size[0] / ratio)
102
+ height = round(image.size[1] / ratio)
103
+ image = image.resize((width, height), Image.Resampling.BICUBIC)
99
104
 
100
- sample_buffer = BytesIO()
101
- image.save(sample_buffer, format=file_format, quality=85)
102
- return sample_buffer.getvalue()
105
+ sample_buffer = BytesIO()
106
+ image.save(sample_buffer, format=file_format, quality=85)
107
+ return sample_buffer.getvalue()
103
108
 
104
109
 
105
110
  def read_worker(q_in: Any, q_out: Any, error_event: Any, size: Optional[int], file_format: str) -> None:
@@ -162,38 +167,43 @@ def wds_write_worker(
162
167
  count = 0
163
168
  buf = {}
164
169
  more = True
165
- with tqdm(total=total, initial=0, unit="images", unit_scale=True, leave=False) as progress:
166
- while more:
167
- deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
168
- if deq is not None:
169
- (idx, sample, suffix, target) = deq
170
- buf[idx] = (sample, suffix, target)
170
+ try:
171
+ with tqdm(total=total, initial=0, unit="images", unit_scale=True, leave=False) as progress:
172
+ while more:
173
+ deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
174
+ if deq is not None:
175
+ (idx, sample, suffix, target) = deq
176
+ buf[idx] = (sample, suffix, target)
171
177
 
172
- else:
173
- more = False
178
+ else:
179
+ more = False
174
180
 
175
- # Ensures ordered write
176
- while count in buf:
177
- (sample, suffix, target) = buf[count]
178
- del buf[count]
181
+ # Ensures ordered write
182
+ while count in buf:
183
+ (sample, suffix, target) = buf[count]
184
+ del buf[count]
179
185
 
180
- if args.no_cls is True:
181
- cls = {}
182
- else:
183
- cls = {"cls": target}
186
+ if args.no_cls is True:
187
+ cls = {}
188
+ else:
189
+ cls = {"cls": target}
184
190
 
185
- sink.write(
186
- {
187
- "__key__": f"sample{count:09d}",
188
- suffix: sample,
189
- **cls,
190
- }
191
- )
191
+ sink.write(
192
+ {
193
+ "__key__": f"sample{count:09d}",
194
+ suffix: sample,
195
+ **cls,
196
+ }
197
+ )
192
198
 
193
- count += 1
199
+ count += 1
194
200
 
195
- # Update progress bar
196
- progress.update(n=1)
201
+ # Update progress bar
202
+ progress.update(n=1)
203
+
204
+ except Exception:
205
+ error_event.set()
206
+ raise
197
207
 
198
208
  sink.close()
199
209
 
@@ -218,35 +228,42 @@ def wds_write_worker(
218
228
 
219
229
 
220
230
  def directory_write_worker(
221
- q_out: Any, _error_event: Any, pack_path: Path, total: int, _: argparse.Namespace, idx_to_class: dict[int, str]
231
+ q_out: Any, error_event: Any, pack_path: Path, total: int, _: argparse.Namespace, idx_to_class: dict[int, str]
222
232
  ) -> None:
223
233
  count = 0
224
234
  buf = {}
225
235
  more = True
226
- with tqdm(total=total, initial=0, unit="images", unit_scale=True, leave=False) as progress:
227
- while more:
228
- deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
229
- if deq is not None:
230
- (idx, sample, suffix, target) = deq
231
- buf[idx] = (sample, suffix, target)
236
+ try:
237
+ with tqdm(total=total, initial=0, unit="images", unit_scale=True, leave=False) as progress:
238
+ while more:
239
+ deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
240
+ if deq is not None:
241
+ (idx, sample, suffix, target) = deq
242
+ buf[idx] = (sample, suffix, target)
232
243
 
233
- else:
234
- more = False
244
+ else:
245
+ more = False
246
+
247
+ # Ensures ordered write
248
+ while count in buf:
249
+ (sample, suffix, target) = buf[count]
250
+ del buf[count]
251
+ with open(
252
+ pack_path.joinpath(idx_to_class[target]).joinpath(f"{count:06d}.{suffix}"), "wb"
253
+ ) as handle:
254
+ handle.write(sample)
235
255
 
236
- # Ensures ordered write
237
- while count in buf:
238
- (sample, suffix, target) = buf[count]
239
- del buf[count]
240
- with open(pack_path.joinpath(idx_to_class[target]).joinpath(f"{count:06d}.{suffix}"), "wb") as handle:
241
- handle.write(sample)
256
+ count += 1
242
257
 
243
- count += 1
258
+ # Update progress bar
259
+ progress.update(n=1)
244
260
 
245
- # Update progress bar
246
- progress.update(n=1)
261
+ except Exception:
262
+ error_event.set()
263
+ raise
247
264
 
248
265
 
249
- # pylint: disable=too-many-locals,too-many-branches
266
+ # pylint: disable=too-many-locals,too-many-branches,too-many-statements
250
267
  def pack(args: argparse.Namespace, pack_path: Path) -> None:
251
268
  if args.sampling_file is not None:
252
269
  with open(args.sampling_file, "r", encoding="utf-8") as handle:
@@ -308,9 +325,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
308
325
  read_processes: list[multiprocessing.Process] = []
309
326
  for idx in range(args.jobs):
310
327
  read_processes.append(
311
- multiprocessing.Process(
312
- target=read_worker, args=(q_in[idx], q_out, error_event, args.size, args.format), daemon=True
313
- )
328
+ multiprocessing.Process(target=read_worker, args=(q_in[idx], q_out, error_event, args.size, args.format))
314
329
  )
315
330
 
316
331
  for p in read_processes:
@@ -326,53 +341,107 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
326
341
  raise ValueError("Unknown pack type")
327
342
 
328
343
  write_process = multiprocessing.Process(
329
- target=target_writer, args=(q_out, error_event, pack_path, len(dataset), args, idx_to_class), daemon=True
344
+ target=target_writer, args=(q_out, error_event, pack_path, len(dataset), args, idx_to_class)
330
345
  )
331
346
  write_process.start()
332
347
 
348
+ # Flag to prevent signal handler re-entry
349
+ cleanup_in_progress = False
350
+
351
+ def cleanup_processes() -> None:
352
+ nonlocal cleanup_in_progress
353
+ if cleanup_in_progress is True:
354
+ return
355
+
356
+ cleanup_in_progress = True
357
+
358
+ # Cancel queue join threads to prevent blocking during cleanup
359
+ for q in q_in:
360
+ q.cancel_join_thread()
361
+
362
+ q_out.cancel_join_thread()
363
+
364
+ # Terminate child processes
365
+ for p in read_processes:
366
+ if p.is_alive():
367
+ p.terminate()
368
+
369
+ if write_process.is_alive():
370
+ write_process.terminate()
371
+
372
+ # Wait briefly for termination
373
+ for p in read_processes:
374
+ p.join(timeout=1)
375
+
376
+ write_process.join(timeout=1)
377
+
333
378
  def signal_handler(signum, _frame) -> None: # type: ignore
334
379
  logger.info(f"Received signal: {signum} at {multiprocessing.current_process().name}, aborting...")
335
380
  error_event.set()
381
+ cleanup_processes()
336
382
  raise SystemExit(1)
337
383
 
338
384
  signal.signal(signal.SIGINT, signal_handler)
339
385
 
340
- tic = time.time()
341
- for idx, sample_idx in enumerate(indices):
342
- if idx % 1000 == 0:
343
- if error_event.is_set() is True:
344
- raise RuntimeError()
345
-
346
- (path, target) = dataset[sample_idx]
347
- q_in[idx % len(q_in)].put((idx, path, target), block=True, timeout=None)
348
-
349
- for q in q_in:
350
- q.put(None, block=True, timeout=None)
351
-
352
- for p in read_processes:
386
+ try:
387
+ tic = time.time()
388
+ for idx, sample_idx in enumerate(indices):
389
+ if idx % 1000 == 0:
390
+ if error_event.is_set() is True:
391
+ cleanup_processes()
392
+ raise RuntimeError()
393
+
394
+ (path, target) = dataset[sample_idx]
395
+
396
+ while True:
397
+ try:
398
+ q_in[idx % len(q_in)].put((idx, path, target), block=True, timeout=1)
399
+ break
400
+ except queue.Full:
401
+ if error_event.is_set() is True:
402
+ cleanup_processes()
403
+ raise RuntimeError() # pylint: disable=raise-missing-from
404
+
405
+ for q in q_in:
406
+ q.put(None, block=True, timeout=None)
407
+
408
+ for p in read_processes:
409
+ while True:
410
+ p.join(timeout=2)
411
+ if p.is_alive() is False:
412
+ break
413
+
414
+ if error_event.is_set() is True:
415
+ cleanup_processes()
416
+ raise RuntimeError()
417
+
418
+ q_out.put(None, block=True, timeout=None)
353
419
  while True:
354
- p.join(timeout=2)
355
- if p.is_alive() is False:
420
+ write_process.join(timeout=2)
421
+ if write_process.is_alive() is False:
356
422
  break
357
423
 
358
424
  if error_event.is_set() is True:
425
+ cleanup_processes()
359
426
  raise RuntimeError()
360
427
 
361
- q_out.put(None, block=True, timeout=None)
362
- write_process.join()
428
+ if error_event.is_set() is True:
429
+ cleanup_processes()
430
+ raise RuntimeError()
363
431
 
364
- if error_event.is_set() is True:
365
- raise RuntimeError()
432
+ if args.type == "wds":
433
+ (wds_path, num_shards) = fs_ops.wds_braces_from_path(pack_path, prefix=f"{args.suffix}-{args.split}")
434
+ logger.info(f"Packed {len(dataset):,} samples into {num_shards} shards at {wds_path}")
435
+ elif args.type == "directory":
436
+ logger.info(f"Packed {len(dataset):,} samples")
366
437
 
367
- if args.type == "wds":
368
- (wds_path, num_shards) = fs_ops.wds_braces_from_path(pack_path, prefix=f"{args.suffix}-{args.split}")
369
- logger.info(f"Packed {len(dataset):,} samples into {num_shards} shards at {wds_path}")
370
- elif args.type == "directory":
371
- logger.info(f"Packed {len(dataset):,} samples")
438
+ toc = time.time()
439
+ rate = len(dataset) / (toc - tic)
440
+ logger.info(f"{format_duration(toc-tic)} to pack {len(dataset):,} samples ({rate:.2f} samples/sec)")
372
441
 
373
- toc = time.time()
374
- rate = len(dataset) / (toc - tic)
375
- logger.info(f"{format_duration(toc-tic)} to pack {len(dataset):,} samples ({rate:.2f} samples/sec)")
442
+ except Exception:
443
+ cleanup_processes()
444
+ raise
376
445
 
377
446
 
378
447
  def set_parser(subparsers: Any) -> None:
@@ -43,6 +43,7 @@ def show_det_iterator(args: argparse.Namespace) -> None:
43
43
  args.dynamic_size,
44
44
  args.multiscale,
45
45
  args.max_size,
46
+ args.multiscale_min_size,
46
47
  )
47
48
  mosaic_transforms = training_preset(
48
49
  args.size,
@@ -52,6 +53,7 @@ def show_det_iterator(args: argparse.Namespace) -> None:
52
53
  args.dynamic_size,
53
54
  args.multiscale,
54
55
  args.max_size,
56
+ args.multiscale_min_size,
55
57
  post_mosaic=True,
56
58
  )
57
59
  if args.mosaic_prob > 0.0:
@@ -160,7 +162,9 @@ def show_det_iterator(args: argparse.Namespace) -> None:
160
162
 
161
163
  else:
162
164
  if args.batch_multiscale is True:
163
- data_collate_fn: Any = BatchRandomResizeCollator(offset, args.size)
165
+ data_collate_fn: Any = BatchRandomResizeCollator(
166
+ offset, args.size, multiscale_min_size=args.multiscale_min_size
167
+ )
164
168
  else:
165
169
  data_collate_fn = collate_fn
166
170
 
@@ -259,6 +263,11 @@ def set_parser(subparsers: Any) -> None:
259
263
  help="allow variable image sizes while preserving aspect ratios",
260
264
  )
261
265
  subparser.add_argument("--multiscale", default=False, action="store_true", help="enable random scale per image")
266
+ subparser.add_argument(
267
+ "--multiscale-min-size",
268
+ type=int,
269
+ help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
270
+ )
262
271
  subparser.add_argument(
263
272
  "--batch-multiscale",
264
273
  default=False,
birder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "v0.2.2"
1
+ __version__ = "v0.2.3"