birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +18 -0
- birder/common/training_utils.py +123 -10
- birder/data/collators/detection.py +10 -3
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/net/__init__.py +8 -0
- birder/net/detection/efficientdet.py +65 -86
- birder/net/detection/rt_detr_v1.py +1 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/fasternet.py +1 -1
- birder/net/gc_vit.py +671 -0
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +342 -0
- birder/net/lit_v2.py +436 -0
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +2 -2
- birder/net/vit.py +0 -15
- birder/net/vovnet_v2.py +31 -1
- birder/scripts/benchmark.py +90 -21
- birder/scripts/predict.py +1 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +10 -34
- birder/scripts/train_barlow_twins.py +10 -34
- birder/scripts/train_byol.py +10 -34
- birder/scripts/train_capi.py +10 -35
- birder/scripts/train_data2vec.py +9 -34
- birder/scripts/train_data2vec2.py +9 -34
- birder/scripts/train_detection.py +48 -40
- birder/scripts/train_dino_v1.py +10 -34
- birder/scripts/train_dino_v2.py +9 -34
- birder/scripts/train_dino_v2_dist.py +9 -34
- birder/scripts/train_franca.py +9 -34
- birder/scripts/train_i_jepa.py +9 -34
- birder/scripts/train_ibot.py +9 -34
- birder/scripts/train_kd.py +156 -64
- birder/scripts/train_mim.py +10 -34
- birder/scripts/train_mmcr.py +10 -34
- birder/scripts/train_rotnet.py +10 -34
- birder/scripts/train_simclr.py +10 -34
- birder/scripts/train_vicreg.py +10 -34
- birder/tools/auto_anchors.py +20 -1
- birder/tools/pack.py +172 -103
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
birder/scripts/train_simclr.py
CHANGED
|
@@ -33,7 +33,6 @@ from birder.common import training_utils
|
|
|
33
33
|
from birder.common.lib import format_duration
|
|
34
34
|
from birder.common.lib import get_mim_network_name
|
|
35
35
|
from birder.common.lib import get_network_name
|
|
36
|
-
from birder.common.lib import set_random_seeds
|
|
37
36
|
from birder.conf import settings
|
|
38
37
|
from birder.data.dataloader.webdataset import make_wds_loader
|
|
39
38
|
from birder.data.datasets.directory import make_image_dataset
|
|
@@ -67,41 +66,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
67
66
|
#
|
|
68
67
|
# Initialize
|
|
69
68
|
#
|
|
70
|
-
training_utils.
|
|
71
|
-
logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
|
|
72
|
-
training_utils.log_git_info()
|
|
69
|
+
(device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
|
|
73
70
|
|
|
74
71
|
if args.size is None:
|
|
75
72
|
args.size = registry.get_default_size(args.network)
|
|
76
73
|
|
|
77
74
|
logger.info(f"Using size={args.size}")
|
|
78
75
|
|
|
79
|
-
if args.cpu is True:
|
|
80
|
-
device = torch.device("cpu")
|
|
81
|
-
device_id = 0
|
|
82
|
-
else:
|
|
83
|
-
device = torch.device("cuda")
|
|
84
|
-
device_id = torch.cuda.current_device()
|
|
85
|
-
|
|
86
|
-
if args.use_deterministic_algorithms is True:
|
|
87
|
-
torch.backends.cudnn.benchmark = False
|
|
88
|
-
torch.use_deterministic_algorithms(True)
|
|
89
|
-
else:
|
|
90
|
-
torch.backends.cudnn.benchmark = True
|
|
91
|
-
|
|
92
|
-
if args.seed is not None:
|
|
93
|
-
set_random_seeds(args.seed)
|
|
94
|
-
|
|
95
|
-
if args.non_interactive is True or training_utils.is_local_primary(args) is False:
|
|
96
|
-
disable_tqdm = True
|
|
97
|
-
elif sys.stderr.isatty() is False:
|
|
98
|
-
disable_tqdm = True
|
|
99
|
-
else:
|
|
100
|
-
disable_tqdm = False
|
|
101
|
-
|
|
102
|
-
# Enable or disable the autograd anomaly detection
|
|
103
|
-
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
104
|
-
|
|
105
76
|
#
|
|
106
77
|
# Data
|
|
107
78
|
#
|
|
@@ -148,7 +119,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
148
119
|
|
|
149
120
|
batch_size: int = args.batch_size
|
|
150
121
|
grad_accum_steps: int = args.grad_accum_steps
|
|
151
|
-
logger.debug(f"Effective batch size = {
|
|
122
|
+
logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
|
|
152
123
|
|
|
153
124
|
# Data loaders and samplers
|
|
154
125
|
if args.distributed is True:
|
|
@@ -189,6 +160,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
189
160
|
else:
|
|
190
161
|
args.stop_epoch += 1
|
|
191
162
|
|
|
163
|
+
logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
|
|
164
|
+
|
|
192
165
|
#
|
|
193
166
|
# Initialize network
|
|
194
167
|
#
|
|
@@ -236,22 +209,25 @@ def train(args: argparse.Namespace) -> None:
|
|
|
236
209
|
# Optimizer, learning rate scheduler and training parameter groups
|
|
237
210
|
#
|
|
238
211
|
|
|
212
|
+
# Learning rate scaling
|
|
213
|
+
lr = training_utils.scale_lr(args)
|
|
214
|
+
|
|
239
215
|
# Training parameter groups
|
|
240
216
|
custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
|
|
241
217
|
parameters = training_utils.optimizer_parameter_groups(
|
|
242
218
|
net,
|
|
243
219
|
args.wd,
|
|
220
|
+
base_lr=lr,
|
|
244
221
|
norm_weight_decay=args.norm_wd,
|
|
245
222
|
custom_keys_weight_decay=custom_keys_weight_decay,
|
|
223
|
+
custom_layer_weight_decay=args.custom_layer_wd,
|
|
246
224
|
layer_decay=args.layer_decay,
|
|
247
225
|
layer_decay_min_scale=args.layer_decay_min_scale,
|
|
248
226
|
layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
|
|
249
227
|
bias_lr=args.bias_lr,
|
|
228
|
+
custom_layer_lr_scale=args.custom_layer_lr_scale,
|
|
250
229
|
)
|
|
251
230
|
|
|
252
|
-
# Learning rate scaling
|
|
253
|
-
lr = training_utils.scale_lr(args)
|
|
254
|
-
|
|
255
231
|
if args.lr_scheduler_update == "epoch":
|
|
256
232
|
step_update = False
|
|
257
233
|
scheduler_steps_per_epoch = 1
|
birder/scripts/train_vicreg.py
CHANGED
|
@@ -36,7 +36,6 @@ from birder.common import training_utils
|
|
|
36
36
|
from birder.common.lib import format_duration
|
|
37
37
|
from birder.common.lib import get_mim_network_name
|
|
38
38
|
from birder.common.lib import get_network_name
|
|
39
|
-
from birder.common.lib import set_random_seeds
|
|
40
39
|
from birder.conf import settings
|
|
41
40
|
from birder.data.dataloader.webdataset import make_wds_loader
|
|
42
41
|
from birder.data.datasets.directory import make_image_dataset
|
|
@@ -70,41 +69,13 @@ def train(args: argparse.Namespace) -> None:
|
|
|
70
69
|
#
|
|
71
70
|
# Initialize
|
|
72
71
|
#
|
|
73
|
-
training_utils.
|
|
74
|
-
logger.info(f"Starting training, birder version: {birder.__version__}, pytorch version: {torch.__version__}")
|
|
75
|
-
training_utils.log_git_info()
|
|
72
|
+
(device, device_id, disable_tqdm) = training_utils.init_training(args, logger)
|
|
76
73
|
|
|
77
74
|
if args.size is None:
|
|
78
75
|
args.size = registry.get_default_size(args.network)
|
|
79
76
|
|
|
80
77
|
logger.info(f"Using size={args.size}")
|
|
81
78
|
|
|
82
|
-
if args.cpu is True:
|
|
83
|
-
device = torch.device("cpu")
|
|
84
|
-
device_id = 0
|
|
85
|
-
else:
|
|
86
|
-
device = torch.device("cuda")
|
|
87
|
-
device_id = torch.cuda.current_device()
|
|
88
|
-
|
|
89
|
-
if args.use_deterministic_algorithms is True:
|
|
90
|
-
torch.backends.cudnn.benchmark = False
|
|
91
|
-
torch.use_deterministic_algorithms(True)
|
|
92
|
-
else:
|
|
93
|
-
torch.backends.cudnn.benchmark = True
|
|
94
|
-
|
|
95
|
-
if args.seed is not None:
|
|
96
|
-
set_random_seeds(args.seed)
|
|
97
|
-
|
|
98
|
-
if args.non_interactive is True or training_utils.is_local_primary(args) is False:
|
|
99
|
-
disable_tqdm = True
|
|
100
|
-
elif sys.stderr.isatty() is False:
|
|
101
|
-
disable_tqdm = True
|
|
102
|
-
else:
|
|
103
|
-
disable_tqdm = False
|
|
104
|
-
|
|
105
|
-
# Enable or disable the autograd anomaly detection
|
|
106
|
-
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
107
|
-
|
|
108
79
|
#
|
|
109
80
|
# Data
|
|
110
81
|
#
|
|
@@ -151,7 +122,7 @@ def train(args: argparse.Namespace) -> None:
|
|
|
151
122
|
|
|
152
123
|
batch_size: int = args.batch_size
|
|
153
124
|
grad_accum_steps: int = args.grad_accum_steps
|
|
154
|
-
logger.debug(f"Effective batch size = {
|
|
125
|
+
logger.debug(f"Effective batch size = {batch_size * grad_accum_steps * args.world_size}")
|
|
155
126
|
|
|
156
127
|
# Data loaders and samplers
|
|
157
128
|
if args.distributed is True:
|
|
@@ -192,6 +163,8 @@ def train(args: argparse.Namespace) -> None:
|
|
|
192
163
|
else:
|
|
193
164
|
args.stop_epoch += 1
|
|
194
165
|
|
|
166
|
+
logging.debug(f"Epoch has {last_batch_idx+1} iterations ({optimizer_steps_per_epoch} steps)")
|
|
167
|
+
|
|
195
168
|
#
|
|
196
169
|
# Initialize network
|
|
197
170
|
#
|
|
@@ -242,22 +215,25 @@ def train(args: argparse.Namespace) -> None:
|
|
|
242
215
|
# Loss criteria, optimizer, learning rate scheduler and training parameter groups
|
|
243
216
|
#
|
|
244
217
|
|
|
218
|
+
# Learning rate scaling
|
|
219
|
+
lr = training_utils.scale_lr(args)
|
|
220
|
+
|
|
245
221
|
# Training parameter groups
|
|
246
222
|
custom_keys_weight_decay = training_utils.get_wd_custom_keys(args)
|
|
247
223
|
parameters = training_utils.optimizer_parameter_groups(
|
|
248
224
|
net,
|
|
249
225
|
args.wd,
|
|
226
|
+
base_lr=lr,
|
|
250
227
|
norm_weight_decay=args.norm_wd,
|
|
251
228
|
custom_keys_weight_decay=custom_keys_weight_decay,
|
|
229
|
+
custom_layer_weight_decay=args.custom_layer_wd,
|
|
252
230
|
layer_decay=args.layer_decay,
|
|
253
231
|
layer_decay_min_scale=args.layer_decay_min_scale,
|
|
254
232
|
layer_decay_no_opt_scale=args.layer_decay_no_opt_scale,
|
|
255
233
|
bias_lr=args.bias_lr,
|
|
234
|
+
custom_layer_lr_scale=args.custom_layer_lr_scale,
|
|
256
235
|
)
|
|
257
236
|
|
|
258
|
-
# Learning rate scaling
|
|
259
|
-
lr = training_utils.scale_lr(args)
|
|
260
|
-
|
|
261
237
|
if args.lr_scheduler_update == "epoch":
|
|
262
238
|
step_update = False
|
|
263
239
|
scheduler_steps_per_epoch = 1
|
birder/tools/auto_anchors.py
CHANGED
|
@@ -242,6 +242,7 @@ def _validate_args(
|
|
|
242
242
|
return (size, num_scales, num_anchors, output_format, strides)
|
|
243
243
|
|
|
244
244
|
|
|
245
|
+
# pylint: disable=too-many-locals
|
|
245
246
|
def auto_anchors(args: argparse.Namespace) -> None:
|
|
246
247
|
(size, num_scales, num_anchors, output_format, strides) = _validate_args(args)
|
|
247
248
|
|
|
@@ -272,6 +273,7 @@ def auto_anchors(args: argparse.Namespace) -> None:
|
|
|
272
273
|
logger.info(f"Mean IoU: {best_iou.mean().item():.4f}")
|
|
273
274
|
|
|
274
275
|
formatted_groups = _format_anchor_groups(anchor_groups, args.precision)
|
|
276
|
+
anchors_output = None
|
|
275
277
|
if output_format == "pixels":
|
|
276
278
|
if num_scales == 1:
|
|
277
279
|
formatted_anchors: Any = formatted_groups[0]
|
|
@@ -280,6 +282,7 @@ def auto_anchors(args: argparse.Namespace) -> None:
|
|
|
280
282
|
|
|
281
283
|
print("Anchors (pixels):")
|
|
282
284
|
print(pformat(formatted_anchors))
|
|
285
|
+
anchors_output = formatted_anchors
|
|
283
286
|
|
|
284
287
|
if output_format == "grid":
|
|
285
288
|
grid_groups: list[torch.Tensor] = []
|
|
@@ -297,6 +300,21 @@ def auto_anchors(args: argparse.Namespace) -> None:
|
|
|
297
300
|
|
|
298
301
|
print("Anchors (grid units):")
|
|
299
302
|
print(pformat(formatted_grid_output))
|
|
303
|
+
anchors_output = formatted_grid_output
|
|
304
|
+
|
|
305
|
+
if args.output is not None:
|
|
306
|
+
payload = {
|
|
307
|
+
"anchors": anchors_output,
|
|
308
|
+
"format": output_format,
|
|
309
|
+
"size": [size[0], size[1]],
|
|
310
|
+
}
|
|
311
|
+
if output_format == "grid":
|
|
312
|
+
payload["strides"] = strides
|
|
313
|
+
|
|
314
|
+
with open(args.output, "w", encoding="utf-8") as handle:
|
|
315
|
+
json.dump(payload, handle, indent=2)
|
|
316
|
+
|
|
317
|
+
logger.info(f"Wrote anchors to {args.output}")
|
|
300
318
|
|
|
301
319
|
|
|
302
320
|
def set_parser(subparsers: Any) -> None:
|
|
@@ -312,7 +330,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
312
330
|
"python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 --format pixels "
|
|
313
331
|
"--coco-json-path data/detection_data/training_annotations_coco.json\n"
|
|
314
332
|
"python -m birder.tools auto-anchors --preset yolo_v4_tiny --size 416 416 "
|
|
315
|
-
"--coco-json-path ~/Datasets/cocodataset/annotations/instances_train2017.json\n"
|
|
333
|
+
"--coco-json-path ~/Datasets/cocodataset/annotations/instances_train2017.json --output anchors.json\n"
|
|
316
334
|
"python -m birder.tools auto-anchors --preset yolo_v2 --stride 32 "
|
|
317
335
|
"--coco-json-path data/detection_data/training_annotations_coco.json\n"
|
|
318
336
|
"python -m birder.tools auto-anchors --size 640 --num-anchors 9 --num-scales 3 "
|
|
@@ -354,6 +372,7 @@ def set_parser(subparsers: Any) -> None:
|
|
|
354
372
|
default=f"{settings.TRAINING_DETECTION_ANNOTATIONS_PATH}_coco.json",
|
|
355
373
|
help="training COCO json path",
|
|
356
374
|
)
|
|
375
|
+
subparser.add_argument("--output", type=str, help="write anchors as JSON to this path")
|
|
357
376
|
subparser.set_defaults(func=main)
|
|
358
377
|
|
|
359
378
|
|
birder/tools/pack.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
|
3
3
|
import logging
|
|
4
4
|
import multiprocessing
|
|
5
5
|
import os
|
|
6
|
+
import queue
|
|
6
7
|
import signal
|
|
7
8
|
import time
|
|
8
9
|
from collections.abc import Callable
|
|
@@ -67,39 +68,43 @@ def _save_classes(pack_path: Path, class_to_idx: dict[str, int]) -> None:
|
|
|
67
68
|
|
|
68
69
|
|
|
69
70
|
def _encode_image(path: str, file_format: str, size: Optional[int] = None) -> bytes:
|
|
70
|
-
image: Image.Image
|
|
71
|
-
|
|
72
|
-
if
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
height = round(image.size[1] / ratio)
|
|
79
|
-
if max(width, height) > MAX_SIZE:
|
|
80
|
-
if width > height:
|
|
81
|
-
ratio = width / MAX_SIZE
|
|
71
|
+
image: Image.Image
|
|
72
|
+
with Image.open(path) as image:
|
|
73
|
+
if file_format.lower() in ("jpeg", "jpg") and image.mode in ("RGBA", "P"):
|
|
74
|
+
image = image.convert("RGB")
|
|
75
|
+
|
|
76
|
+
if size is not None and size < min(image.size):
|
|
77
|
+
if image.size[0] > image.size[1]:
|
|
78
|
+
ratio = image.size[1] / size
|
|
82
79
|
else:
|
|
83
|
-
ratio =
|
|
80
|
+
ratio = image.size[0] / size
|
|
81
|
+
|
|
82
|
+
width = round(image.size[0] / ratio)
|
|
83
|
+
height = round(image.size[1] / ratio)
|
|
84
|
+
if max(width, height) > MAX_SIZE:
|
|
85
|
+
if width > height:
|
|
86
|
+
ratio = width / MAX_SIZE
|
|
87
|
+
else:
|
|
88
|
+
ratio = height / MAX_SIZE
|
|
84
89
|
|
|
85
|
-
|
|
86
|
-
|
|
90
|
+
width = round(width / ratio)
|
|
91
|
+
height = round(height / ratio)
|
|
87
92
|
|
|
88
|
-
|
|
93
|
+
image = image.resize((width, height), Image.Resampling.BICUBIC)
|
|
89
94
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
+
elif max(image.size) > MAX_SIZE:
|
|
96
|
+
if image.size[0] > image.size[1]:
|
|
97
|
+
ratio = image.size[0] / MAX_SIZE
|
|
98
|
+
else:
|
|
99
|
+
ratio = image.size[1] / MAX_SIZE
|
|
95
100
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
width = round(image.size[0] / ratio)
|
|
102
|
+
height = round(image.size[1] / ratio)
|
|
103
|
+
image = image.resize((width, height), Image.Resampling.BICUBIC)
|
|
99
104
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
105
|
+
sample_buffer = BytesIO()
|
|
106
|
+
image.save(sample_buffer, format=file_format, quality=85)
|
|
107
|
+
return sample_buffer.getvalue()
|
|
103
108
|
|
|
104
109
|
|
|
105
110
|
def read_worker(q_in: Any, q_out: Any, error_event: Any, size: Optional[int], file_format: str) -> None:
|
|
@@ -162,38 +167,43 @@ def wds_write_worker(
|
|
|
162
167
|
count = 0
|
|
163
168
|
buf = {}
|
|
164
169
|
more = True
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
170
|
+
try:
|
|
171
|
+
with tqdm(total=total, initial=0, unit="images", unit_scale=True, leave=False) as progress:
|
|
172
|
+
while more:
|
|
173
|
+
deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
|
|
174
|
+
if deq is not None:
|
|
175
|
+
(idx, sample, suffix, target) = deq
|
|
176
|
+
buf[idx] = (sample, suffix, target)
|
|
171
177
|
|
|
172
|
-
|
|
173
|
-
|
|
178
|
+
else:
|
|
179
|
+
more = False
|
|
174
180
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
181
|
+
# Ensures ordered write
|
|
182
|
+
while count in buf:
|
|
183
|
+
(sample, suffix, target) = buf[count]
|
|
184
|
+
del buf[count]
|
|
179
185
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
186
|
+
if args.no_cls is True:
|
|
187
|
+
cls = {}
|
|
188
|
+
else:
|
|
189
|
+
cls = {"cls": target}
|
|
184
190
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
191
|
+
sink.write(
|
|
192
|
+
{
|
|
193
|
+
"__key__": f"sample{count:09d}",
|
|
194
|
+
suffix: sample,
|
|
195
|
+
**cls,
|
|
196
|
+
}
|
|
197
|
+
)
|
|
192
198
|
|
|
193
|
-
|
|
199
|
+
count += 1
|
|
194
200
|
|
|
195
|
-
|
|
196
|
-
|
|
201
|
+
# Update progress bar
|
|
202
|
+
progress.update(n=1)
|
|
203
|
+
|
|
204
|
+
except Exception:
|
|
205
|
+
error_event.set()
|
|
206
|
+
raise
|
|
197
207
|
|
|
198
208
|
sink.close()
|
|
199
209
|
|
|
@@ -218,35 +228,42 @@ def wds_write_worker(
|
|
|
218
228
|
|
|
219
229
|
|
|
220
230
|
def directory_write_worker(
|
|
221
|
-
q_out: Any,
|
|
231
|
+
q_out: Any, error_event: Any, pack_path: Path, total: int, _: argparse.Namespace, idx_to_class: dict[int, str]
|
|
222
232
|
) -> None:
|
|
223
233
|
count = 0
|
|
224
234
|
buf = {}
|
|
225
235
|
more = True
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
236
|
+
try:
|
|
237
|
+
with tqdm(total=total, initial=0, unit="images", unit_scale=True, leave=False) as progress:
|
|
238
|
+
while more:
|
|
239
|
+
deq: Optional[tuple[int, bytes, str, int]] = q_out.get()
|
|
240
|
+
if deq is not None:
|
|
241
|
+
(idx, sample, suffix, target) = deq
|
|
242
|
+
buf[idx] = (sample, suffix, target)
|
|
232
243
|
|
|
233
|
-
|
|
234
|
-
|
|
244
|
+
else:
|
|
245
|
+
more = False
|
|
246
|
+
|
|
247
|
+
# Ensures ordered write
|
|
248
|
+
while count in buf:
|
|
249
|
+
(sample, suffix, target) = buf[count]
|
|
250
|
+
del buf[count]
|
|
251
|
+
with open(
|
|
252
|
+
pack_path.joinpath(idx_to_class[target]).joinpath(f"{count:06d}.{suffix}"), "wb"
|
|
253
|
+
) as handle:
|
|
254
|
+
handle.write(sample)
|
|
235
255
|
|
|
236
|
-
|
|
237
|
-
while count in buf:
|
|
238
|
-
(sample, suffix, target) = buf[count]
|
|
239
|
-
del buf[count]
|
|
240
|
-
with open(pack_path.joinpath(idx_to_class[target]).joinpath(f"{count:06d}.{suffix}"), "wb") as handle:
|
|
241
|
-
handle.write(sample)
|
|
256
|
+
count += 1
|
|
242
257
|
|
|
243
|
-
|
|
258
|
+
# Update progress bar
|
|
259
|
+
progress.update(n=1)
|
|
244
260
|
|
|
245
|
-
|
|
246
|
-
|
|
261
|
+
except Exception:
|
|
262
|
+
error_event.set()
|
|
263
|
+
raise
|
|
247
264
|
|
|
248
265
|
|
|
249
|
-
# pylint: disable=too-many-locals,too-many-branches
|
|
266
|
+
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
|
250
267
|
def pack(args: argparse.Namespace, pack_path: Path) -> None:
|
|
251
268
|
if args.sampling_file is not None:
|
|
252
269
|
with open(args.sampling_file, "r", encoding="utf-8") as handle:
|
|
@@ -308,9 +325,7 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
|
|
|
308
325
|
read_processes: list[multiprocessing.Process] = []
|
|
309
326
|
for idx in range(args.jobs):
|
|
310
327
|
read_processes.append(
|
|
311
|
-
multiprocessing.Process(
|
|
312
|
-
target=read_worker, args=(q_in[idx], q_out, error_event, args.size, args.format), daemon=True
|
|
313
|
-
)
|
|
328
|
+
multiprocessing.Process(target=read_worker, args=(q_in[idx], q_out, error_event, args.size, args.format))
|
|
314
329
|
)
|
|
315
330
|
|
|
316
331
|
for p in read_processes:
|
|
@@ -326,53 +341,107 @@ def pack(args: argparse.Namespace, pack_path: Path) -> None:
|
|
|
326
341
|
raise ValueError("Unknown pack type")
|
|
327
342
|
|
|
328
343
|
write_process = multiprocessing.Process(
|
|
329
|
-
target=target_writer, args=(q_out, error_event, pack_path, len(dataset), args, idx_to_class)
|
|
344
|
+
target=target_writer, args=(q_out, error_event, pack_path, len(dataset), args, idx_to_class)
|
|
330
345
|
)
|
|
331
346
|
write_process.start()
|
|
332
347
|
|
|
348
|
+
# Flag to prevent signal handler re-entry
|
|
349
|
+
cleanup_in_progress = False
|
|
350
|
+
|
|
351
|
+
def cleanup_processes() -> None:
|
|
352
|
+
nonlocal cleanup_in_progress
|
|
353
|
+
if cleanup_in_progress is True:
|
|
354
|
+
return
|
|
355
|
+
|
|
356
|
+
cleanup_in_progress = True
|
|
357
|
+
|
|
358
|
+
# Cancel queue join threads to prevent blocking during cleanup
|
|
359
|
+
for q in q_in:
|
|
360
|
+
q.cancel_join_thread()
|
|
361
|
+
|
|
362
|
+
q_out.cancel_join_thread()
|
|
363
|
+
|
|
364
|
+
# Terminate child processes
|
|
365
|
+
for p in read_processes:
|
|
366
|
+
if p.is_alive():
|
|
367
|
+
p.terminate()
|
|
368
|
+
|
|
369
|
+
if write_process.is_alive():
|
|
370
|
+
write_process.terminate()
|
|
371
|
+
|
|
372
|
+
# Wait briefly for termination
|
|
373
|
+
for p in read_processes:
|
|
374
|
+
p.join(timeout=1)
|
|
375
|
+
|
|
376
|
+
write_process.join(timeout=1)
|
|
377
|
+
|
|
333
378
|
def signal_handler(signum, _frame) -> None: # type: ignore
|
|
334
379
|
logger.info(f"Received signal: {signum} at {multiprocessing.current_process().name}, aborting...")
|
|
335
380
|
error_event.set()
|
|
381
|
+
cleanup_processes()
|
|
336
382
|
raise SystemExit(1)
|
|
337
383
|
|
|
338
384
|
signal.signal(signal.SIGINT, signal_handler)
|
|
339
385
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
if
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
386
|
+
try:
|
|
387
|
+
tic = time.time()
|
|
388
|
+
for idx, sample_idx in enumerate(indices):
|
|
389
|
+
if idx % 1000 == 0:
|
|
390
|
+
if error_event.is_set() is True:
|
|
391
|
+
cleanup_processes()
|
|
392
|
+
raise RuntimeError()
|
|
393
|
+
|
|
394
|
+
(path, target) = dataset[sample_idx]
|
|
395
|
+
|
|
396
|
+
while True:
|
|
397
|
+
try:
|
|
398
|
+
q_in[idx % len(q_in)].put((idx, path, target), block=True, timeout=1)
|
|
399
|
+
break
|
|
400
|
+
except queue.Full:
|
|
401
|
+
if error_event.is_set() is True:
|
|
402
|
+
cleanup_processes()
|
|
403
|
+
raise RuntimeError() # pylint: disable=raise-missing-from
|
|
404
|
+
|
|
405
|
+
for q in q_in:
|
|
406
|
+
q.put(None, block=True, timeout=None)
|
|
407
|
+
|
|
408
|
+
for p in read_processes:
|
|
409
|
+
while True:
|
|
410
|
+
p.join(timeout=2)
|
|
411
|
+
if p.is_alive() is False:
|
|
412
|
+
break
|
|
413
|
+
|
|
414
|
+
if error_event.is_set() is True:
|
|
415
|
+
cleanup_processes()
|
|
416
|
+
raise RuntimeError()
|
|
417
|
+
|
|
418
|
+
q_out.put(None, block=True, timeout=None)
|
|
353
419
|
while True:
|
|
354
|
-
|
|
355
|
-
if
|
|
420
|
+
write_process.join(timeout=2)
|
|
421
|
+
if write_process.is_alive() is False:
|
|
356
422
|
break
|
|
357
423
|
|
|
358
424
|
if error_event.is_set() is True:
|
|
425
|
+
cleanup_processes()
|
|
359
426
|
raise RuntimeError()
|
|
360
427
|
|
|
361
|
-
|
|
362
|
-
|
|
428
|
+
if error_event.is_set() is True:
|
|
429
|
+
cleanup_processes()
|
|
430
|
+
raise RuntimeError()
|
|
363
431
|
|
|
364
|
-
|
|
365
|
-
|
|
432
|
+
if args.type == "wds":
|
|
433
|
+
(wds_path, num_shards) = fs_ops.wds_braces_from_path(pack_path, prefix=f"{args.suffix}-{args.split}")
|
|
434
|
+
logger.info(f"Packed {len(dataset):,} samples into {num_shards} shards at {wds_path}")
|
|
435
|
+
elif args.type == "directory":
|
|
436
|
+
logger.info(f"Packed {len(dataset):,} samples")
|
|
366
437
|
|
|
367
|
-
|
|
368
|
-
(
|
|
369
|
-
logger.info(f"
|
|
370
|
-
elif args.type == "directory":
|
|
371
|
-
logger.info(f"Packed {len(dataset):,} samples")
|
|
438
|
+
toc = time.time()
|
|
439
|
+
rate = len(dataset) / (toc - tic)
|
|
440
|
+
logger.info(f"{format_duration(toc-tic)} to pack {len(dataset):,} samples ({rate:.2f} samples/sec)")
|
|
372
441
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
442
|
+
except Exception:
|
|
443
|
+
cleanup_processes()
|
|
444
|
+
raise
|
|
376
445
|
|
|
377
446
|
|
|
378
447
|
def set_parser(subparsers: Any) -> None:
|
|
@@ -43,6 +43,7 @@ def show_det_iterator(args: argparse.Namespace) -> None:
|
|
|
43
43
|
args.dynamic_size,
|
|
44
44
|
args.multiscale,
|
|
45
45
|
args.max_size,
|
|
46
|
+
args.multiscale_min_size,
|
|
46
47
|
)
|
|
47
48
|
mosaic_transforms = training_preset(
|
|
48
49
|
args.size,
|
|
@@ -52,6 +53,7 @@ def show_det_iterator(args: argparse.Namespace) -> None:
|
|
|
52
53
|
args.dynamic_size,
|
|
53
54
|
args.multiscale,
|
|
54
55
|
args.max_size,
|
|
56
|
+
args.multiscale_min_size,
|
|
55
57
|
post_mosaic=True,
|
|
56
58
|
)
|
|
57
59
|
if args.mosaic_prob > 0.0:
|
|
@@ -160,7 +162,9 @@ def show_det_iterator(args: argparse.Namespace) -> None:
|
|
|
160
162
|
|
|
161
163
|
else:
|
|
162
164
|
if args.batch_multiscale is True:
|
|
163
|
-
data_collate_fn: Any = BatchRandomResizeCollator(
|
|
165
|
+
data_collate_fn: Any = BatchRandomResizeCollator(
|
|
166
|
+
offset, args.size, multiscale_min_size=args.multiscale_min_size
|
|
167
|
+
)
|
|
164
168
|
else:
|
|
165
169
|
data_collate_fn = collate_fn
|
|
166
170
|
|
|
@@ -259,6 +263,11 @@ def set_parser(subparsers: Any) -> None:
|
|
|
259
263
|
help="allow variable image sizes while preserving aspect ratios",
|
|
260
264
|
)
|
|
261
265
|
subparser.add_argument("--multiscale", default=False, action="store_true", help="enable random scale per image")
|
|
266
|
+
subparser.add_argument(
|
|
267
|
+
"--multiscale-min-size",
|
|
268
|
+
type=int,
|
|
269
|
+
help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
|
|
270
|
+
)
|
|
262
271
|
subparser.add_argument(
|
|
263
272
|
"--batch-multiscale",
|
|
264
273
|
default=False,
|
birder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "v0.2.
|
|
1
|
+
__version__ = "v0.2.3"
|