megadetector 5.0.6__py3-none-any.whl → 5.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- api/batch_processing/data_preparation/manage_local_batch.py +278 -197
- api/batch_processing/data_preparation/manage_video_batch.py +7 -2
- api/batch_processing/postprocessing/add_max_conf.py +1 -0
- api/batch_processing/postprocessing/compare_batch_results.py +110 -60
- api/batch_processing/postprocessing/load_api_results.py +55 -69
- api/batch_processing/postprocessing/md_to_labelme.py +1 -0
- api/batch_processing/postprocessing/postprocess_batch_results.py +158 -50
- api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
- api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
- api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
- api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +222 -74
- api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
- api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
- classification/prepare_classification_script.py +191 -191
- data_management/coco_to_yolo.py +65 -44
- data_management/databases/integrity_check_json_db.py +7 -5
- data_management/generate_crops_from_cct.py +1 -1
- data_management/importers/animl_results_to_md_results.py +2 -2
- data_management/importers/noaa_seals_2019.py +1 -1
- data_management/importers/zamba_results_to_md_results.py +2 -2
- data_management/labelme_to_coco.py +34 -6
- data_management/labelme_to_yolo.py +1 -1
- data_management/lila/create_lila_blank_set.py +474 -0
- data_management/lila/create_lila_test_set.py +2 -1
- data_management/lila/create_links_to_md_results_files.py +1 -1
- data_management/lila/download_lila_subset.py +46 -21
- data_management/lila/generate_lila_per_image_labels.py +23 -14
- data_management/lila/get_lila_annotation_counts.py +16 -10
- data_management/lila/lila_common.py +14 -11
- data_management/lila/test_lila_metadata_urls.py +116 -0
- data_management/resize_coco_dataset.py +12 -10
- data_management/yolo_output_to_md_output.py +40 -13
- data_management/yolo_to_coco.py +34 -21
- detection/process_video.py +36 -14
- detection/pytorch_detector.py +1 -1
- detection/run_detector.py +73 -18
- detection/run_detector_batch.py +104 -24
- detection/run_inference_with_yolov5_val.py +127 -26
- detection/run_tiled_inference.py +153 -43
- detection/video_utils.py +3 -1
- md_utils/ct_utils.py +79 -3
- md_utils/md_tests.py +253 -15
- md_utils/path_utils.py +129 -24
- md_utils/process_utils.py +26 -7
- md_utils/split_locations_into_train_val.py +215 -0
- md_utils/string_utils.py +10 -0
- md_utils/url_utils.py +0 -2
- md_utils/write_html_image_list.py +1 -0
- md_visualization/visualization_utils.py +17 -2
- md_visualization/visualize_db.py +8 -0
- md_visualization/visualize_detector_output.py +185 -104
- {megadetector-5.0.6.dist-info → megadetector-5.0.7.dist-info}/METADATA +2 -2
- {megadetector-5.0.6.dist-info → megadetector-5.0.7.dist-info}/RECORD +62 -58
- {megadetector-5.0.6.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
- taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
- taxonomy_mapping/map_new_lila_datasets.py +43 -39
- taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
- taxonomy_mapping/preview_lila_taxonomy.py +27 -27
- taxonomy_mapping/species_lookup.py +33 -13
- taxonomy_mapping/taxonomy_csv_checker.py +7 -5
- {megadetector-5.0.6.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
- {megadetector-5.0.6.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
detection/process_video.py
CHANGED
|
@@ -26,12 +26,17 @@ from detection.video_utils import frame_results_to_video_results
|
|
|
26
26
|
from detection.video_utils import video_folder_to_frames
|
|
27
27
|
from uuid import uuid1
|
|
28
28
|
|
|
29
|
+
from detection.video_utils import default_fourcc
|
|
30
|
+
|
|
29
31
|
|
|
30
32
|
#%% Options classes
|
|
31
33
|
|
|
32
34
|
class ProcessVideoOptions:
|
|
33
35
|
|
|
34
|
-
|
|
36
|
+
# Can be a model filename (.pt or .pb) or a model name (e.g. "MDV5A")
|
|
37
|
+
model_file = 'MDV5A'
|
|
38
|
+
|
|
39
|
+
# Can be a file or a folder
|
|
35
40
|
input_video_file = ''
|
|
36
41
|
|
|
37
42
|
output_json_file = None
|
|
@@ -72,9 +77,10 @@ class ProcessVideoOptions:
|
|
|
72
77
|
|
|
73
78
|
recursive = False
|
|
74
79
|
verbose = False
|
|
80
|
+
|
|
75
81
|
fourcc = None
|
|
76
82
|
|
|
77
|
-
rendering_confidence_threshold =
|
|
83
|
+
rendering_confidence_threshold = None
|
|
78
84
|
json_confidence_threshold = 0.005
|
|
79
85
|
frame_sample = None
|
|
80
86
|
|
|
@@ -175,8 +181,14 @@ def process_video(options):
|
|
|
175
181
|
confidence_threshold=options.rendering_confidence_threshold)
|
|
176
182
|
|
|
177
183
|
# Combine into a video
|
|
178
|
-
|
|
179
|
-
|
|
184
|
+
if options.frame_sample is None:
|
|
185
|
+
rendering_fs = Fs
|
|
186
|
+
else:
|
|
187
|
+
rendering_fs = Fs / options.frame_sample
|
|
188
|
+
|
|
189
|
+
print('Rendering video to {} at {} fps (original video {} fps)'.format(
|
|
190
|
+
options.output_video_file,rendering_fs,Fs))
|
|
191
|
+
frames_to_video(detected_frame_files, rendering_fs, options.output_video_file, codec_spec=options.fourcc)
|
|
180
192
|
|
|
181
193
|
# Delete the temporary directory we used for detection images
|
|
182
194
|
if not options.keep_rendered_frames:
|
|
@@ -344,11 +356,19 @@ def process_video_folder(options):
|
|
|
344
356
|
output_video_folder = options.input_video_file
|
|
345
357
|
|
|
346
358
|
# For each video
|
|
359
|
+
#
|
|
360
|
+
# TODO: parallelize this loop
|
|
361
|
+
#
|
|
347
362
|
# i_video=0; input_video_file_abs = video_filenames[i_video]
|
|
348
363
|
for i_video,input_video_file_abs in enumerate(video_filenames):
|
|
349
364
|
|
|
350
365
|
video_fs = Fs[i_video]
|
|
351
366
|
|
|
367
|
+
if options.frame_sample is None:
|
|
368
|
+
rendering_fs = video_fs
|
|
369
|
+
else:
|
|
370
|
+
rendering_fs = video_fs / options.frame_sample
|
|
371
|
+
|
|
352
372
|
input_video_file_relative = os.path.relpath(input_video_file_abs,options.input_video_file)
|
|
353
373
|
video_frame_output_folder = os.path.join(frame_rendering_output_dir,input_video_file_relative)
|
|
354
374
|
assert os.path.isdir(video_frame_output_folder), \
|
|
@@ -371,11 +391,10 @@ def process_video_folder(options):
|
|
|
371
391
|
os.makedirs(os.path.dirname(video_output_file),exist_ok=True)
|
|
372
392
|
|
|
373
393
|
# Create the output video
|
|
374
|
-
print('Rendering detections for video {} to {} at {} fps'.format(
|
|
375
|
-
|
|
376
|
-
frames_to_video(video_frame_files,
|
|
377
|
-
|
|
378
|
-
|
|
394
|
+
print('Rendering detections for video {} to {} at {} fps (original video {} fps)'.format(
|
|
395
|
+
input_video_file_relative,video_output_file,rendering_fs,video_fs))
|
|
396
|
+
frames_to_video(video_frame_files, rendering_fs, video_output_file, codec_spec=options.fourcc)
|
|
397
|
+
|
|
379
398
|
# ...for each video
|
|
380
399
|
|
|
381
400
|
# Possibly clean up rendered frames
|
|
@@ -525,12 +544,14 @@ if False:
|
|
|
525
544
|
|
|
526
545
|
def main():
|
|
527
546
|
|
|
547
|
+
default_options = ProcessVideoOptions()
|
|
548
|
+
|
|
528
549
|
parser = argparse.ArgumentParser(description=(
|
|
529
550
|
'Run MegaDetector on each frame in a video (or every Nth frame), optionally '\
|
|
530
551
|
'producing a new video with detections annotated'))
|
|
531
552
|
|
|
532
553
|
parser.add_argument('model_file', type=str,
|
|
533
|
-
help='MegaDetector model file')
|
|
554
|
+
help='MegaDetector model file (.pt or .pb) or model name (e.g. "MDV5A")')
|
|
534
555
|
|
|
535
556
|
parser.add_argument('input_video_file', type=str,
|
|
536
557
|
help='video file (or folder) to process')
|
|
@@ -567,8 +588,8 @@ def main():
|
|
|
567
588
|
parser.add_argument('--render_output_video', action='store_true',
|
|
568
589
|
help='enable video output rendering (not rendered by default)')
|
|
569
590
|
|
|
570
|
-
parser.add_argument('--fourcc', default=
|
|
571
|
-
help='fourcc code to use for video encoding, only used if render_output_video is True')
|
|
591
|
+
parser.add_argument('--fourcc', default=default_fourcc,
|
|
592
|
+
help='fourcc code to use for video encoding (default {}), only used if render_output_video is True'.format(default_fourcc))
|
|
572
593
|
|
|
573
594
|
parser.add_argument('--keep_rendered_frames',
|
|
574
595
|
action='store_true', help='Disable the deletion of rendered (w/boxes) frames')
|
|
@@ -586,11 +607,12 @@ def main():
|
|
|
586
607
|
'whether other files were present in the folder.')
|
|
587
608
|
|
|
588
609
|
parser.add_argument('--rendering_confidence_threshold', type=float,
|
|
589
|
-
default=
|
|
610
|
+
default=None, help="don't render boxes with confidence below this threshold (defaults to choosing based on the MD version)")
|
|
590
611
|
|
|
591
612
|
parser.add_argument('--json_confidence_threshold', type=float,
|
|
592
613
|
default=0.0, help="don't include boxes in the .json file with confidence "\
|
|
593
|
-
'below this threshold'
|
|
614
|
+
'below this threshold (default {})'.format(
|
|
615
|
+
default_options.json_confidence_threshold))
|
|
594
616
|
|
|
595
617
|
parser.add_argument('--n_cores', type=int,
|
|
596
618
|
default=1, help='number of cores to use for frame separation and detection. '\
|
detection/pytorch_detector.py
CHANGED
|
@@ -234,7 +234,7 @@ class PTDetector:
|
|
|
234
234
|
if self.device == 'mps':
|
|
235
235
|
# As of v1.13.0.dev20220824, nms is not implemented for MPS.
|
|
236
236
|
#
|
|
237
|
-
# Send
|
|
237
|
+
# Send prediction back to the CPU to fix.
|
|
238
238
|
pred = non_max_suppression(prediction=pred.cpu(), conf_thres=detection_threshold)
|
|
239
239
|
else:
|
|
240
240
|
pred = non_max_suppression(prediction=pred, conf_thres=detection_threshold)
|
detection/run_detector.py
CHANGED
|
@@ -10,12 +10,7 @@
|
|
|
10
10
|
# This script is not a good way to process lots of images (tens of thousands,
|
|
11
11
|
# say). It does not facilitate checkpointing the results so if it crashes you
|
|
12
12
|
# would have to start from scratch. If you want to run a detector (e.g., ours)
|
|
13
|
-
# on lots of images, you should check out
|
|
14
|
-
#
|
|
15
|
-
# 1) run_detector_batch.py (for local execution)
|
|
16
|
-
#
|
|
17
|
-
# 2) https://github.com/agentmorris/MegaDetector/tree/master/api/batch_processing
|
|
18
|
-
# (for running large jobs on Azure ML)
|
|
13
|
+
# on lots of images, you should check out run_detector_batch.py.
|
|
19
14
|
#
|
|
20
15
|
# To run this script, we recommend you set up a conda virtual environment
|
|
21
16
|
# following instructions in the Installation section on the main README, using
|
|
@@ -136,6 +131,33 @@ downloadable_models = {
|
|
|
136
131
|
'MDV5B':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.0.pt'
|
|
137
132
|
}
|
|
138
133
|
|
|
134
|
+
model_string_to_model_version = {
|
|
135
|
+
'v2':'v2.0.0',
|
|
136
|
+
'v3':'v3.0.0',
|
|
137
|
+
'v4.1':'v4.1.0',
|
|
138
|
+
'v5a.0.0':'v5a.0.0',
|
|
139
|
+
'v5b.0.0':'v5b.0.0',
|
|
140
|
+
'mdv5a':'v5a.0.0',
|
|
141
|
+
'mdv5b':'v5b.0.0',
|
|
142
|
+
'mdv4':'v4.1.0',
|
|
143
|
+
'mdv3':'v3.0.0'
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Approximate inference speeds (in images per second) for MDv5 based on
|
|
147
|
+
# benchmarks, only used for reporting very coarse expectations about inference time.
|
|
148
|
+
device_token_to_mdv5_inference_speed = {
|
|
149
|
+
'4090':17.6,
|
|
150
|
+
'3090':11.4,
|
|
151
|
+
'3080':9.5,
|
|
152
|
+
'3050':4.2,
|
|
153
|
+
'P2000':2.1,
|
|
154
|
+
# These are written this way because they're MDv4 benchmarks, and MDv5
|
|
155
|
+
# is around 3.5x faster than MDv4.
|
|
156
|
+
'V100':2.79*3.5,
|
|
157
|
+
'2080':2.3*3.5,
|
|
158
|
+
'2060':1.6*3.5
|
|
159
|
+
}
|
|
160
|
+
|
|
139
161
|
|
|
140
162
|
#%% Utility functions
|
|
141
163
|
|
|
@@ -190,18 +212,9 @@ def get_detector_version_from_filename(detector_filename):
|
|
|
190
212
|
"v4.1.0", "v5a.0.0", and "v5b.0.0", respectively.
|
|
191
213
|
"""
|
|
192
214
|
|
|
193
|
-
fn = os.path.basename(detector_filename)
|
|
194
|
-
known_model_versions = {'v2':'v2.0.0',
|
|
195
|
-
'v3':'v3.0.0',
|
|
196
|
-
'v4.1':'v4.1.0',
|
|
197
|
-
'v5a.0.0':'v5a.0.0',
|
|
198
|
-
'v5b.0.0':'v5b.0.0',
|
|
199
|
-
'MDV5A':'v5a.0.0',
|
|
200
|
-
'MDV5B':'v5b.0.0',
|
|
201
|
-
'MDV4':'v4.1.0',
|
|
202
|
-
'MDV3':'v3.0.0'}
|
|
215
|
+
fn = os.path.basename(detector_filename).lower()
|
|
203
216
|
matches = []
|
|
204
|
-
for s in
|
|
217
|
+
for s in model_string_to_model_version.keys():
|
|
205
218
|
if s in fn:
|
|
206
219
|
matches.append(s)
|
|
207
220
|
if len(matches) == 0:
|
|
@@ -211,9 +224,51 @@ def get_detector_version_from_filename(detector_filename):
|
|
|
211
224
|
print('Warning: multiple MegaDetector versions for model file {}'.format(detector_filename))
|
|
212
225
|
return 'multiple'
|
|
213
226
|
else:
|
|
214
|
-
return
|
|
227
|
+
return model_string_to_model_version[matches[0]]
|
|
215
228
|
|
|
216
229
|
|
|
230
|
+
def estimate_md_images_per_second(model_file, device_name=None):
|
|
231
|
+
"""
|
|
232
|
+
Estimate how fast MegaDetector will run based on benchmarks. Defaults to querying
|
|
233
|
+
the current device. Returns None if no data is available for the current card/model.
|
|
234
|
+
Estimates only available for a small handful of GPUs.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
if device_name is None:
|
|
238
|
+
try:
|
|
239
|
+
import torch
|
|
240
|
+
device_name = torch.cuda.get_device_name()
|
|
241
|
+
except Exception as e:
|
|
242
|
+
print('Error querying device name: {}'.format(e))
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
model_file = model_file.lower().strip()
|
|
246
|
+
if model_file in model_string_to_model_version.values():
|
|
247
|
+
model_version = model_file
|
|
248
|
+
else:
|
|
249
|
+
model_version = get_detector_version_from_filename(model_file)
|
|
250
|
+
if model_version not in model_string_to_model_version.values():
|
|
251
|
+
print('Error determining model version for model file {}'.format(model_file))
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
mdv5_inference_speed = None
|
|
255
|
+
for device_token in device_token_to_mdv5_inference_speed.keys():
|
|
256
|
+
if device_token in device_name:
|
|
257
|
+
mdv5_inference_speed = device_token_to_mdv5_inference_speed[device_token]
|
|
258
|
+
break
|
|
259
|
+
|
|
260
|
+
if mdv5_inference_speed is None:
|
|
261
|
+
print('No speed estimate available for {}'.format(device_name))
|
|
262
|
+
|
|
263
|
+
if 'v5' in model_version:
|
|
264
|
+
return mdv5_inference_speed
|
|
265
|
+
elif 'v2' in model_version or 'v3' in model_version or 'v4' in model_version:
|
|
266
|
+
return mdv5_inference_speed / 3.5
|
|
267
|
+
else:
|
|
268
|
+
print('Could not estimate inference speed for model file {}'.format(model_file))
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
|
|
217
272
|
def get_typical_confidence_threshold_from_results(results):
|
|
218
273
|
"""
|
|
219
274
|
Given the .json data loaded from a MD results file, determine a typical confidence
|
detection/run_detector_batch.py
CHANGED
|
@@ -751,17 +751,75 @@ if False:
|
|
|
751
751
|
|
|
752
752
|
#%%
|
|
753
753
|
|
|
754
|
+
model_file = 'MDV5A'
|
|
755
|
+
image_dir = r'g:\camera_traps\camera_trap_images'
|
|
756
|
+
output_file = r'g:\temp\md-test.json'
|
|
757
|
+
|
|
758
|
+
recursive = True
|
|
759
|
+
output_relative_filenames = True
|
|
760
|
+
include_max_conf = False
|
|
761
|
+
quiet = True
|
|
762
|
+
image_size = None
|
|
763
|
+
use_image_queue = False
|
|
764
|
+
confidence_threshold = 0.0001
|
|
765
|
+
checkpoint_frequency = 5
|
|
754
766
|
checkpoint_path = None
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
checkpoint_frequency = -1
|
|
758
|
-
results = None
|
|
767
|
+
resume_from_checkpoint = 'auto'
|
|
768
|
+
allow_checkpoint_overwrite = False
|
|
759
769
|
ncores = 1
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
770
|
+
class_mapping_filename = None
|
|
771
|
+
include_image_size = True
|
|
772
|
+
include_image_timestamp = True
|
|
773
|
+
include_exif_data = True
|
|
774
|
+
overwrite_handling = None
|
|
775
|
+
|
|
776
|
+
# Generate a command line
|
|
777
|
+
cmd = 'python run_detector_batch.py "{}" "{}" "{}"'.format(
|
|
778
|
+
model_file,image_dir,output_file)
|
|
779
|
+
|
|
780
|
+
if recursive:
|
|
781
|
+
cmd += ' --recursive'
|
|
782
|
+
if output_relative_filenames:
|
|
783
|
+
cmd += ' --output_relative_filenames'
|
|
784
|
+
if include_max_conf:
|
|
785
|
+
cmd += ' --include_max_conf'
|
|
786
|
+
if quiet:
|
|
787
|
+
cmd += ' --quiet'
|
|
788
|
+
if image_size is not None:
|
|
789
|
+
cmd += ' --image_size {}'.format(image_size)
|
|
790
|
+
if use_image_queue:
|
|
791
|
+
cmd += ' --use_image_queue'
|
|
792
|
+
if confidence_threshold is not None:
|
|
793
|
+
cmd += ' --threshold {}'.format(confidence_threshold)
|
|
794
|
+
if checkpoint_frequency is not None:
|
|
795
|
+
cmd += ' --checkpoint_frequency {}'.format(checkpoint_frequency)
|
|
796
|
+
if checkpoint_path is not None:
|
|
797
|
+
cmd += ' --checkpoint_path "{}"'.format(checkpoint_path)
|
|
798
|
+
if resume_from_checkpoint is not None:
|
|
799
|
+
cmd += ' --resume_from_checkpoint "{}"'.format(resume_from_checkpoint)
|
|
800
|
+
if allow_checkpoint_overwrite:
|
|
801
|
+
cmd += ' --allow_checkpoint_overwrite'
|
|
802
|
+
if ncores is not None:
|
|
803
|
+
cmd += ' --ncores {}'.format(ncores)
|
|
804
|
+
if class_mapping_filename is not None:
|
|
805
|
+
cmd += ' --class_mapping_filename "{}"'.format(class_mapping_filename)
|
|
806
|
+
if include_image_size:
|
|
807
|
+
cmd += ' --include_image_size'
|
|
808
|
+
if include_image_timestamp:
|
|
809
|
+
cmd += ' --include_image_timestamp'
|
|
810
|
+
if include_exif_data:
|
|
811
|
+
cmd += ' --include_exif_data'
|
|
812
|
+
if overwrite_handling is not None:
|
|
813
|
+
cmd += ' --overwrite_handling {}'.format(overwrite_handling)
|
|
814
|
+
|
|
815
|
+
print(cmd)
|
|
816
|
+
import clipboard; clipboard.copy(cmd)
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
#%% Run inference interactively
|
|
820
|
+
|
|
764
821
|
image_file_names = path_utils.find_images(image_dir, recursive=False)
|
|
822
|
+
results = None
|
|
765
823
|
|
|
766
824
|
start_time = time.time()
|
|
767
825
|
|
|
@@ -840,12 +898,15 @@ def main():
|
|
|
840
898
|
'--checkpoint_path',
|
|
841
899
|
type=str,
|
|
842
900
|
default=None,
|
|
843
|
-
help='File name to which checkpoints will be written if checkpoint_frequency is > 0'
|
|
901
|
+
help='File name to which checkpoints will be written if checkpoint_frequency is > 0, ' + \
|
|
902
|
+
'defaults to md_checkpoint_[date].json in the same folder as the output file')
|
|
844
903
|
parser.add_argument(
|
|
845
904
|
'--resume_from_checkpoint',
|
|
846
905
|
type=str,
|
|
847
906
|
default=None,
|
|
848
|
-
help='Path to a JSON checkpoint file to resume from'
|
|
907
|
+
help='Path to a JSON checkpoint file to resume from, or "auto" to ' + \
|
|
908
|
+
'find the most recent checkpoint in the same folder as the output file. "auto" uses' + \
|
|
909
|
+
'checkpoint_path (rather than searching the output folder) if checkpoint_path is specified.')
|
|
849
910
|
parser.add_argument(
|
|
850
911
|
'--allow_checkpoint_overwrite',
|
|
851
912
|
action='store_true',
|
|
@@ -897,7 +958,7 @@ def main():
|
|
|
897
958
|
|
|
898
959
|
assert os.path.exists(args.detector_file), \
|
|
899
960
|
'detector file {} does not exist'.format(args.detector_file)
|
|
900
|
-
assert 0.0
|
|
961
|
+
assert 0.0 <= args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1'
|
|
901
962
|
assert args.output_file.endswith('.json'), 'output_file specified needs to end with .json'
|
|
902
963
|
if args.checkpoint_frequency != -1:
|
|
903
964
|
assert args.checkpoint_frequency > 0, 'Checkpoint_frequency needs to be > 0 or == -1'
|
|
@@ -919,19 +980,42 @@ def main():
|
|
|
919
980
|
else:
|
|
920
981
|
raise ValueError('Illegal overwrite handling string {}'.format(args.overwrite_handling))
|
|
921
982
|
|
|
983
|
+
output_dir = os.path.dirname(args.output_file)
|
|
984
|
+
|
|
985
|
+
if len(output_dir) > 0:
|
|
986
|
+
os.makedirs(output_dir,exist_ok=True)
|
|
987
|
+
|
|
988
|
+
assert not os.path.isdir(args.output_file), 'Specified output file is a directory'
|
|
989
|
+
|
|
922
990
|
if args.class_mapping_filename is not None:
|
|
923
991
|
load_custom_class_mapping(args.class_mapping_filename)
|
|
924
|
-
|
|
992
|
+
|
|
925
993
|
# Load the checkpoint if available
|
|
926
994
|
#
|
|
927
995
|
# Relative file names are only output at the end; all file paths in the checkpoint are
|
|
928
|
-
# still
|
|
996
|
+
# still absolute paths.
|
|
929
997
|
if args.resume_from_checkpoint is not None:
|
|
930
|
-
|
|
998
|
+
if args.resume_from_checkpoint == 'auto':
|
|
999
|
+
checkpoint_files = os.listdir(output_dir)
|
|
1000
|
+
checkpoint_files = [fn for fn in checkpoint_files if \
|
|
1001
|
+
(fn.startswith('md_checkpoint') and fn.endswith('.json'))]
|
|
1002
|
+
if len(checkpoint_files) == 0:
|
|
1003
|
+
raise ValueError('resume_from_checkpoint set to "auto", but no checkpoints found in {}'.format(
|
|
1004
|
+
output_dir))
|
|
1005
|
+
else:
|
|
1006
|
+
if len(checkpoint_files) > 1:
|
|
1007
|
+
print('Warning: found {} checkpoints in {}, using the latest'.format(
|
|
1008
|
+
len(checkpoint_files),output_dir))
|
|
1009
|
+
checkpoint_files = sorted(checkpoint_files)
|
|
1010
|
+
checkpoint_file_relative = checkpoint_files[-1]
|
|
1011
|
+
checkpoint_file = os.path.join(output_dir,checkpoint_file_relative)
|
|
1012
|
+
else:
|
|
1013
|
+
checkpoint_file = args.resume_from_checkpoint
|
|
1014
|
+
assert os.path.exists(checkpoint_file), \
|
|
931
1015
|
'File at resume_from_checkpoint specified does not exist'
|
|
932
|
-
with open(
|
|
1016
|
+
with open(checkpoint_file) as f:
|
|
933
1017
|
print('Loading previous results from checkpoint file {}'.format(
|
|
934
|
-
|
|
1018
|
+
checkpoint_file))
|
|
935
1019
|
saved = json.load(f)
|
|
936
1020
|
assert 'images' in saved, \
|
|
937
1021
|
'The checkpoint file does not have the correct fields; cannot be restored'
|
|
@@ -982,13 +1066,6 @@ def main():
|
|
|
982
1066
|
assert os.path.exists(image_file_names[0]), \
|
|
983
1067
|
'The first image to be processed does not exist at {}'.format(image_file_names[0])
|
|
984
1068
|
|
|
985
|
-
output_dir = os.path.dirname(args.output_file)
|
|
986
|
-
|
|
987
|
-
if len(output_dir) > 0:
|
|
988
|
-
os.makedirs(output_dir,exist_ok=True)
|
|
989
|
-
|
|
990
|
-
assert not os.path.isdir(args.output_file), 'Specified output file is a directory'
|
|
991
|
-
|
|
992
1069
|
# Test that we can write to the output_file's dir if checkpointing requested
|
|
993
1070
|
if args.checkpoint_frequency != -1:
|
|
994
1071
|
|
|
@@ -996,7 +1073,7 @@ def main():
|
|
|
996
1073
|
checkpoint_path = args.checkpoint_path
|
|
997
1074
|
else:
|
|
998
1075
|
checkpoint_path = os.path.join(output_dir,
|
|
999
|
-
'
|
|
1076
|
+
'md_checkpoint_{}.json'.format(
|
|
1000
1077
|
datetime.utcnow().strftime("%Y%m%d%H%M%S")))
|
|
1001
1078
|
|
|
1002
1079
|
# Don't overwrite existing checkpoint files, this is a sure-fire way to eventually
|
|
@@ -1023,6 +1100,9 @@ def main():
|
|
|
1023
1100
|
|
|
1024
1101
|
else:
|
|
1025
1102
|
|
|
1103
|
+
if args.checkpoint_path is not None:
|
|
1104
|
+
print('Warning: checkpointing disabled because checkpoint_frequency is -1, ' + \
|
|
1105
|
+
'but a checkpoint path was specified')
|
|
1026
1106
|
checkpoint_path = None
|
|
1027
1107
|
|
|
1028
1108
|
start_time = time.time()
|