megadetector 10.0.2__py3-none-any.whl → 10.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (30) hide show
  1. megadetector/data_management/animl_to_md.py +158 -0
  2. megadetector/data_management/zamba_to_md.py +188 -0
  3. megadetector/detection/process_video.py +165 -946
  4. megadetector/detection/pytorch_detector.py +575 -276
  5. megadetector/detection/run_detector_batch.py +629 -202
  6. megadetector/detection/run_md_and_speciesnet.py +1319 -0
  7. megadetector/detection/video_utils.py +243 -107
  8. megadetector/postprocessing/classification_postprocessing.py +12 -1
  9. megadetector/postprocessing/combine_batch_outputs.py +2 -0
  10. megadetector/postprocessing/compare_batch_results.py +21 -2
  11. megadetector/postprocessing/merge_detections.py +16 -12
  12. megadetector/postprocessing/separate_detections_into_folders.py +1 -1
  13. megadetector/postprocessing/subset_json_detector_output.py +1 -3
  14. megadetector/postprocessing/validate_batch_results.py +25 -2
  15. megadetector/tests/__init__.py +0 -0
  16. megadetector/tests/test_nms_synthetic.py +335 -0
  17. megadetector/utils/ct_utils.py +69 -5
  18. megadetector/utils/extract_frames_from_video.py +303 -0
  19. megadetector/utils/md_tests.py +583 -524
  20. megadetector/utils/path_utils.py +4 -15
  21. megadetector/utils/wi_utils.py +20 -4
  22. megadetector/visualization/visualization_utils.py +1 -1
  23. megadetector/visualization/visualize_db.py +8 -22
  24. megadetector/visualization/visualize_detector_output.py +7 -5
  25. megadetector/visualization/visualize_video_output.py +607 -0
  26. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/METADATA +134 -135
  27. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/RECORD +30 -23
  28. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/top_level.txt +0 -0
  30. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/WHEEL +0 -0
@@ -5,7 +5,7 @@ run_detector_batch.py
5
5
  Module to run MegaDetector on lots of images, writing the results
6
6
  to a file in the MegaDetector results format.
7
7
 
8
- https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#megadetector-batch-output-format
8
+ https://lila.science/megadetector-output-format
9
9
 
10
10
  This enables the results to be used in our post-processing pipeline; see postprocess_batch_results.py.
11
11
 
@@ -23,7 +23,7 @@ is not supported when using a GPU.
23
23
 
24
24
  The lack of GPU multiprocessing support might sound annoying, but in practice we
25
25
  run a gazillion MegaDetector images on multiple GPUs using this script, we just only use
26
- one GPU *per invocation of this script*. Dividing a big batch of images into one chunk
26
+ one GPU *per invocation of this script*. Dividing a list of images into one chunk
27
27
  per GPU happens outside of this script.
28
28
 
29
29
  Does not have a command-line option to bind the process to a particular GPU, but you can
@@ -44,6 +44,7 @@ import sys
44
44
  import time
45
45
  import copy
46
46
  import shutil
47
+ import random
47
48
  import warnings
48
49
  import itertools
49
50
  import humanfriendly
@@ -106,6 +107,8 @@ exif_options = read_exif.ReadExifOptions()
106
107
  exif_options.processing_library = 'pil'
107
108
  exif_options.byte_handling = 'convert_to_string'
108
109
 
110
+ randomize_batch_order_during_testing = True
111
+
109
112
 
110
113
  #%% Support functions for multiprocessing
111
114
 
@@ -120,11 +123,22 @@ def _producer_func(q,
120
123
  """
121
124
  Producer function; only used when using the (optional) image queue.
122
125
 
123
- Reads up to images from disk and puts them on the blocking queue for
124
- processing. Each image is queued as a tuple of [filename,Image]. Sends
125
- "None" to the queue when finished.
126
+ Reads images from disk and puts, optionally preprocesses them (depending on whether "preprocessor"
127
+ is None, then puts them on the blocking queue for processing. Each image is queued as a tuple of
128
+ [filename,Image]. Sends "None" to the queue when finished.
126
129
 
127
130
  The "detector" argument is only used for preprocessing.
131
+
132
+ Args:
133
+ q (Queue): multiprocessing queue to put loaded/preprocessed images into
134
+ image_files (list): list of image file paths to process
135
+ producer_id (int, optional): identifier for this producer worker (for logging)
136
+ preprocessor (str, optional): model file path/identifier for preprocessing, or None to skip preprocessing
137
+ detector_options (dict, optional): key/value pairs that are interpreted differently
138
+ by different detectors
139
+ verbose (bool, optional): enable additional debug output
140
+ image_size (int, optional): image size to use for preprocessing
141
+ augment (bool, optional): enable image augmentation during preprocessing
128
142
  """
129
143
 
130
144
  if verbose:
@@ -134,6 +148,8 @@ def _producer_func(q,
134
148
  if preprocessor is not None:
135
149
  assert isinstance(preprocessor,str)
136
150
  detector_options = deepcopy(detector_options)
151
+ # Tell the detector object it's being loaded as a preprocessor, so it
152
+ # shouldn't actually load model weights.
137
153
  detector_options['preprocess_only'] = True
138
154
  preprocessor = load_detector(preprocessor,
139
155
  detector_options=detector_options,
@@ -142,22 +158,15 @@ def _producer_func(q,
142
158
  for im_file in image_files:
143
159
 
144
160
  try:
145
- if verbose:
146
- print('Loading image {} on producer {}'.format(im_file,producer_id))
147
- sys.stdout.flush()
161
+
148
162
  image = vis_utils.load_image(im_file)
149
163
 
150
164
  if preprocessor is not None:
151
165
 
152
- image_info = preprocessor.generate_detections_one_image(
153
- image,
154
- im_file,
155
- detection_threshold=None,
156
- image_size=image_size,
157
- skip_image_resizing=False,
158
- augment=augment,
159
- preprocess_only=True,
160
- verbose=verbose)
166
+ image_info = preprocessor.preprocess_image(image,
167
+ image_id=im_file,
168
+ image_size=image_size,
169
+ verbose=verbose)
161
170
  if 'failure' in image_info:
162
171
  assert image_info['failure'] == run_detector.FAILURE_INFER
163
172
  raise
@@ -168,12 +177,10 @@ def _producer_func(q,
168
177
  print('Producer process: image {} cannot be loaded:\n{}'.format(im_file,str(e)))
169
178
  image = run_detector.FAILURE_IMAGE_OPEN
170
179
 
171
- if verbose:
172
- print('Queueing image {} from producer {}'.format(im_file,producer_id))
173
- sys.stdout.flush()
174
-
175
180
  q.put([im_file,image,producer_id])
176
181
 
182
+ # ...for each image
183
+
177
184
  # This is a signal to the consumer function that a worker has finished
178
185
  q.put(None)
179
186
 
@@ -196,13 +203,38 @@ def _consumer_func(q,
196
203
  augment=False,
197
204
  detector_options=None,
198
205
  preprocess_on_image_queue=default_preprocess_on_image_queue,
199
- n_total_images=None
206
+ n_total_images=None,
207
+ batch_size=1,
208
+ checkpoint_path=None,
209
+ checkpoint_frequency=-1
200
210
  ):
201
211
  """
202
212
  Consumer function; only used when using the (optional) image queue.
203
213
 
204
214
  Pulls images from a blocking queue and processes them. Returns when "None" has
205
215
  been read from each loader's queue.
216
+
217
+ Args:
218
+ q (Queue): multiprocessing queue to pull images from
219
+ return_queue (Queue): queue to put final results into
220
+ model_file (str or detector object): model file path/identifier or pre-loaded detector
221
+ confidence_threshold (float): only detections above this threshold are returned
222
+ loader_workers (int): number of producer workers (used to know when all are finished)
223
+ image_size (int, optional): image size to use for inference
224
+ include_image_size (bool, optional): include image dimensions in output
225
+ include_image_timestamp (bool, optional): include image timestamps in output
226
+ include_exif_data (bool, optional): include EXIF data in output
227
+ augment (bool, optional): enable image augmentation
228
+ detector_options (dict, optional): key/value pairs that are interpreted differently
229
+ by different detectors
230
+ preprocess_on_image_queue (bool, optional): whether images are already preprocessed on
231
+ the queue
232
+ n_total_images (int, optional): total number of images expected (for progress bar)
233
+ batch_size (int, optional): batch size for GPU inference
234
+ checkpoint_path (str, optional): path to write checkpoint files, None disables
235
+ checkpointing
236
+ checkpoint_frequency (int, optional): write checkpoint every N images, -1 disables
237
+ checkpointing
206
238
  """
207
239
 
208
240
  if verbose:
@@ -226,76 +258,189 @@ def _consumer_func(q,
226
258
 
227
259
  n_images_processed = 0
228
260
  n_queues_finished = 0
261
+ last_checkpoint_count = 0
262
+
263
+ def _should_write_checkpoint():
264
+ """
265
+ Check whether we should write a checkpoint. Returns True if we've crossed a
266
+ checkpoint boundary.
267
+ """
268
+
269
+ if (checkpoint_frequency <= 0) or (checkpoint_path is None):
270
+ return False
271
+
272
+ # Calculate the checkpoint threshold we should have crossed
273
+ current_checkpoint_threshold = \
274
+ (n_images_processed // checkpoint_frequency) * checkpoint_frequency
275
+ last_checkpoint_threshold = \
276
+ (last_checkpoint_count // checkpoint_frequency) * checkpoint_frequency
277
+
278
+ # We should write a checkpoint if we've crossed into a new checkpoint interval
279
+ return (current_checkpoint_threshold > last_checkpoint_threshold)
229
280
 
230
281
  pbar = None
231
282
  if n_total_images is not None:
232
283
  # TODO: in principle I should close this pbar
233
284
  pbar = tqdm(total=n_total_images)
234
285
 
286
+ # Batch processing state
287
+ if batch_size > 1:
288
+ current_batch_items = []
289
+
235
290
  while True:
236
291
 
237
292
  r = q.get()
238
293
 
239
294
  # Is this the last image in one of the producer queues?
240
295
  if r is None:
296
+
241
297
  n_queues_finished += 1
242
298
  q.task_done()
299
+
243
300
  if verbose:
244
301
  print('Consumer thread: {} of {} queues finished'.format(
245
302
  n_queues_finished,loader_workers))
303
+
304
+ # Was this the last worker to finish?
246
305
  if n_queues_finished == loader_workers:
306
+
307
+ # Do we have any leftover images?
308
+ if (batch_size > 1) and (len(current_batch_items) > 0):
309
+
310
+ # We should never have more than one batch of work left to do, so this loop
311
+ # not strictly necessary; it's a bit of future-proofing.
312
+ leftover_batches = _group_into_batches(current_batch_items, batch_size)
313
+
314
+ if len(leftover_batches) > 1:
315
+ print('Warning: after all producer queues finished, '
316
+ '{} images were left for processing, which is more than'
317
+ 'the batch size of {}'.format(len(current_batch_items),batch_size))
318
+
319
+ for leftover_batch in leftover_batches:
320
+
321
+ batch_results = _process_batch(leftover_batch,
322
+ detector,
323
+ confidence_threshold,
324
+ quiet=True,
325
+ image_size=image_size,
326
+ include_image_size=include_image_size,
327
+ include_image_timestamp=include_image_timestamp,
328
+ include_exif_data=include_exif_data,
329
+ augment=augment)
330
+ results.extend(batch_results)
331
+
332
+ if pbar is not None:
333
+ pbar.update(len(leftover_batch))
334
+
335
+ n_images_processed += len(leftover_batch)
336
+
337
+ # In theory we could write a checkpoint here, but because we're basically
338
+ # done at this point, there's not much upside to writing another checkpoint,
339
+ # so for simplicity, I'm skipping it.
340
+
341
+ # ...for each batch we have left to process
342
+
247
343
  return_queue.put(results)
248
344
  return
345
+
249
346
  else:
250
- continue
251
- n_images_processed += 1
252
- im_file = r[0]
253
- image = r[1]
254
347
 
255
- """
256
- result['img_processed'] = img
257
- result['img_original'] = img_original
258
- result['target_shape'] = target_shape
259
- result['scaling_shape'] = scaling_shape
260
- result['letterbox_ratio'] = letterbox_ratio
261
- result['letterbox_pad'] = letterbox_pad
262
- """
348
+ continue
263
349
 
264
- if pbar is not None:
265
- pbar.update(1)
350
+ # ...if we pulled the sentinel signal (None) telling us that a worker finished
266
351
 
267
- if False:
268
- if verbose or ((n_images_processed % n_queue_print) == 1):
269
- elapsed = time.time() - start_time
270
- images_per_second = n_images_processed / elapsed
271
- print('De-queued image {} ({:.2f}/s) ({})'.format(n_images_processed,
272
- images_per_second,
273
- im_file))
274
- sys.stdout.flush()
352
+ # At this point, we have a real image (i.e., not a sentinel indicating that a worker finished)
353
+ #
354
+ # "r" is always a tuple of (filename,image,producer_id)
355
+ #
356
+ # Image can be a PIL image (if the loader wasn't doing preprocessing) or a dict with
357
+ # a preprocessed image and associated metadata.
358
+ im_file = r[0]
359
+ image = r[1]
275
360
 
361
+ # Handle failed images immediately (don't batch them)
362
+ #
363
+ # Loader workers communicate failures by passing a string to
364
+ # the consumer, rather than an image.
276
365
  if isinstance(image,str):
277
- # This is how the producer function communicates read errors
366
+
278
367
  results.append({'file': im_file,
279
368
  'failure': image})
369
+ n_images_processed += 1
370
+
371
+ if pbar is not None:
372
+ pbar.update(1)
373
+
374
+ # This is a catastrophic internal failure; preprocessing workers should
375
+ # be passing the consumer dicts that represent processed images
280
376
  elif preprocess_on_image_queue and (not isinstance(image,dict)):
281
- print('Expected a dict, received an image of type {}'.format(type(image)))
282
- results.append({'file': im_file,
283
- 'failure': 'illegal image type'})
377
+
378
+ print('Expected a dict, received an image of type {}'.format(type(image)))
379
+ results.append({'file': im_file,
380
+ 'failure': 'illegal image type'})
381
+ n_images_processed += 1
382
+
383
+ if pbar is not None:
384
+ pbar.update(1)
284
385
 
285
386
  else:
286
- results.append(process_image(im_file=im_file,
287
- detector=detector,
288
- confidence_threshold=confidence_threshold,
289
- image=image,
290
- quiet=True,
291
- image_size=image_size,
292
- include_image_size=include_image_size,
293
- include_image_timestamp=include_image_timestamp,
294
- include_exif_data=include_exif_data,
295
- augment=augment,
296
- skip_image_resizing=preprocess_on_image_queue))
297
- if verbose:
298
- print('Processed image {}'.format(im_file)); sys.stdout.flush()
387
+
388
+ # At this point, "image" is either an image (if the producer workers are only
389
+ # doing loading) or a dict (if the producer workers are doing preprocessing)
390
+
391
+ if batch_size > 1:
392
+
393
+ # Add to current batch
394
+ current_batch_items.append([im_file, image, r[2]])
395
+
396
+ # Process batch when full
397
+ if len(current_batch_items) >= batch_size:
398
+ batch_results = _process_batch(current_batch_items,
399
+ detector,
400
+ confidence_threshold,
401
+ quiet=True,
402
+ image_size=image_size,
403
+ include_image_size=include_image_size,
404
+ include_image_timestamp=include_image_timestamp,
405
+ include_exif_data=include_exif_data,
406
+ augment=augment)
407
+ results.extend(batch_results)
408
+
409
+ if pbar is not None:
410
+ pbar.update(len(current_batch_items))
411
+
412
+ n_images_processed += len(current_batch_items)
413
+ current_batch_items = []
414
+ else:
415
+
416
+ # Process single image
417
+ result = _process_image(im_file=im_file,
418
+ detector=detector,
419
+ confidence_threshold=confidence_threshold,
420
+ image=image,
421
+ quiet=True,
422
+ image_size=image_size,
423
+ include_image_size=include_image_size,
424
+ include_image_timestamp=include_image_timestamp,
425
+ include_exif_data=include_exif_data,
426
+ augment=augment)
427
+ results.append(result)
428
+ n_images_processed += 1
429
+
430
+ if pbar is not None:
431
+ pbar.update(1)
432
+
433
+ # ...if we are/aren't doing batch processing
434
+
435
+ # Write checkpoint if necessary
436
+ if _should_write_checkpoint():
437
+ print('Consumer: writing checkpoint after {} images'.format(
438
+ n_images_processed))
439
+ write_checkpoint(checkpoint_path, results)
440
+ last_checkpoint_count = n_images_processed
441
+
442
+ # ...whether we received a string (indicating failure) or an image from the loader worker
443
+
299
444
  q.task_done()
300
445
 
301
446
  # ...while True (consumer loop)
@@ -303,23 +448,24 @@ def _consumer_func(q,
303
448
  # ...def _consumer_func(...)
304
449
 
305
450
 
306
- def run_detector_with_image_queue(image_files,
307
- model_file,
308
- confidence_threshold,
309
- quiet=False,
310
- image_size=None,
311
- include_image_size=False,
312
- include_image_timestamp=False,
313
- include_exif_data=False,
314
- augment=False,
315
- detector_options=None,
316
- loader_workers=default_loaders,
317
- preprocess_on_image_queue=default_preprocess_on_image_queue):
451
+ def _run_detector_with_image_queue(image_files,
452
+ model_file,
453
+ confidence_threshold,
454
+ quiet=False,
455
+ image_size=None,
456
+ include_image_size=False,
457
+ include_image_timestamp=False,
458
+ include_exif_data=False,
459
+ augment=False,
460
+ detector_options=None,
461
+ loader_workers=default_loaders,
462
+ preprocess_on_image_queue=default_preprocess_on_image_queue,
463
+ batch_size=1,
464
+ checkpoint_path=None,
465
+ checkpoint_frequency=-1):
318
466
  """
319
- Driver function for the (optional) multiprocessing-based image queue; only used
320
- when --use_image_queue is specified. Starts a reader process to read images from disk, but
321
- processes images in the process from which this function is called (i.e., does not currently
322
- spawn a separate consumer process).
467
+ Driver function for the (optional) multiprocessing-based image queue. Spawns workers to read and
468
+ preprocess images, runs the consumer function in the calling process.
323
469
 
324
470
  Args:
325
471
  image_files (str): list of absolute paths to images
@@ -339,6 +485,9 @@ def run_detector_with_image_queue(image_files,
339
485
  loader_workers (int, optional): number of loaders to use
340
486
  preprocess_on_image_queue (bool, optional): if the image queue is enabled, should it handle
341
487
  image loading and preprocessing (True), or just image loading (False)?
488
+ batch_size (int, optional): batch size for GPU processing
489
+ checkpoint_path (str, optional): path to write checkpoint files, None disables checkpointing
490
+ checkpoint_frequency (int, optional): write checkpoint every N images, -1 disables checkpointing
342
491
 
343
492
  Returns:
344
493
  list: list of dicts in the format returned by process_image()
@@ -408,7 +557,10 @@ def run_detector_with_image_queue(image_files,
408
557
  augment,
409
558
  detector_options,
410
559
  preprocess_on_image_queue,
411
- n_total_images))
560
+ n_total_images,
561
+ batch_size,
562
+ checkpoint_path,
563
+ checkpoint_frequency))
412
564
  else:
413
565
  consumer = Process(target=_consumer_func,args=(q,
414
566
  return_queue,
@@ -422,7 +574,10 @@ def run_detector_with_image_queue(image_files,
422
574
  augment,
423
575
  detector_options,
424
576
  preprocess_on_image_queue,
425
- n_total_images))
577
+ n_total_images,
578
+ batch_size,
579
+ checkpoint_path,
580
+ checkpoint_frequency))
426
581
  consumer.daemon = True
427
582
  consumer.start()
428
583
  else:
@@ -438,7 +593,10 @@ def run_detector_with_image_queue(image_files,
438
593
  augment,
439
594
  detector_options,
440
595
  preprocess_on_image_queue,
441
- n_total_images)
596
+ n_total_images,
597
+ batch_size,
598
+ checkpoint_path,
599
+ checkpoint_frequency)
442
600
 
443
601
  for i_producer,producer in enumerate(producers):
444
602
  producer.join()
@@ -461,7 +619,7 @@ def run_detector_with_image_queue(image_files,
461
619
 
462
620
  return results
463
621
 
464
- # ...def run_detector_with_image_queue(...)
622
+ # ...def _run_detector_with_image_queue(...)
465
623
 
466
624
 
467
625
  #%% Other support functions
@@ -481,9 +639,182 @@ def _chunks_by_number_of_chunks(ls, n):
481
639
  yield ls[i::n]
482
640
 
483
641
 
642
+ #%% Batch processing helper functions
643
+
644
+ def _group_into_batches(items, batch_size):
645
+ """
646
+ Group items into batches.
647
+
648
+ Args:
649
+ items (list): items to group into batches
650
+ batch_size (int): size of each batch
651
+
652
+ Returns:
653
+ list: list of batches, where each batch is a list of items
654
+ """
655
+
656
+ if batch_size <= 0:
657
+ raise ValueError('Batch size must be positive')
658
+
659
+ batches = []
660
+ for i_item in range(0, len(items), batch_size):
661
+ batch = items[i_item:i_item + batch_size]
662
+ batches.append(batch)
663
+
664
+ return batches
665
+
666
+
667
+ def _process_batch(image_items_batch,
668
+ detector,
669
+ confidence_threshold,
670
+ quiet=False,
671
+ image_size=None,
672
+ include_image_size=False,
673
+ include_image_timestamp=False,
674
+ include_exif_data=False,
675
+ augment=False):
676
+ """
677
+ Process a batch of images using generate_detections_one_batch(). Does not necessarily return
678
+ results in the same order in which they were supplied; in particular, images that fail preprocessing
679
+ will be returned out of order.
680
+
681
+ Args:
682
+ image_items_batch (list): list of image file paths (strings) or list of tuples [filename, image, producer_id]
683
+ detector: loaded detector object
684
+ confidence_threshold (float): confidence threshold for detections
685
+ quiet (bool, optional): suppress per-image output
686
+ image_size (int, optional): image size override
687
+ include_image_size (bool, optional): include image dimensions in results
688
+ include_image_timestamp (bool, optional): include image timestamps in results
689
+ include_exif_data (bool, optional): include EXIF data in results
690
+ augment (bool, optional): whether to use image augmentation
691
+
692
+ Returns:
693
+ list of dict: list of results for each image in the batch
694
+ """
695
+
696
+ # This will be the set of items we send for inference; it may be
697
+ # smaller than the input list (image_items_batch) if some images
698
+ # fail to load. [valid_images] will be either a list of PIL Image
699
+ # objects or a list of dicts containing preprocessed images.
700
+ valid_images = []
701
+ valid_image_filenames = []
702
+
703
+ batch_results = []
704
+
705
+ for i_image, item in enumerate(image_items_batch):
706
+
707
+ # Handle both filename strings and tuples
708
+ if isinstance(item, str):
709
+ im_file = item
710
+ try:
711
+ image = vis_utils.load_image(im_file)
712
+ except Exception as e:
713
+ print('Image {} cannot be loaded: {}'.format(im_file,str(e)))
714
+ failed_result = {
715
+ 'file': im_file,
716
+ 'failure': run_detector.FAILURE_IMAGE_OPEN
717
+ }
718
+ batch_results.append(failed_result)
719
+ continue
720
+ else:
721
+ assert len(item) == 3
722
+ im_file, image, producer_id = item
723
+
724
+ valid_images.append(image)
725
+ valid_image_filenames.append(im_file)
726
+
727
+ # ...for each image in the batch
728
+
729
+ assert len(valid_images) == len(valid_image_filenames)
730
+
731
+ valid_batch_results = []
732
+
733
+ # Process the batch if we have any valid images
734
+ if len(valid_images) > 0:
735
+
736
+ try:
737
+
738
+ batch_detections = \
739
+ detector.generate_detections_one_batch(valid_images, valid_image_filenames, verbose=verbose)
740
+
741
+ assert len(batch_detections) == len(valid_images)
742
+
743
+ # Apply confidence threshold and add metadata
744
+ for i_valid_image,image_result in enumerate(batch_detections):
745
+
746
+ assert valid_image_filenames[i_valid_image] == image_result['file']
747
+
748
+ if 'failure' not in image_result:
749
+
750
+ # Apply confidence threshold
751
+ image_result['detections'] = \
752
+ [det for det in image_result['detections'] if det['conf'] >= confidence_threshold]
753
+
754
+ if include_image_size or include_image_timestamp or include_exif_data:
755
+
756
+ image = valid_images[i_valid_image]
757
+
758
+ # If this was preprocessed by the producer thread, pull out the PIL version
759
+ if isinstance(image,dict):
760
+ image = image['img_original_pil']
761
+
762
+ if include_image_size:
763
+
764
+ image_result['width'] = image.width
765
+ image_result['height'] = image.height
766
+
767
+ if include_image_timestamp:
768
+
769
+ image_result['datetime'] = get_image_datetime(image)
770
+
771
+ if include_exif_data:
772
+
773
+ image_result['exif_metadata'] = read_exif.read_pil_exif(image,exif_options)
774
+
775
+ # ...if we need to store metadata
776
+
777
+ # ...if this image succeeded
778
+
779
+ # Failures here should be very rare; there's almost no reason an image would fail
780
+ # within a batch once it's been loaded
781
+ else:
782
+
783
+ print('Warning: within-batch processing failure for image {}'.format(
784
+ image_result['file']))
785
+
786
+ # Add to the list of results for the batch whether or not it succeeded
787
+ valid_batch_results.append(image_result)
788
+
789
+ # ...for each image in this batch
790
+
791
+ except Exception as e:
792
+
793
+ print('Batch processing failure for {} images: {}'.format(len(valid_images),str(e)))
794
+
795
+ # Throw out any successful results for this batch, this should almost never happen
796
+ valid_batch_results = []
797
+
798
+ for image_id in valid_image_filenames:
799
+ r = {'file':image_id,'failure': run_detector.FAILURE_INFER}
800
+ valid_batch_results.append(r)
801
+
802
+ # ...try/except
803
+
804
+ assert len(valid_batch_results) == len(valid_images)
805
+
806
+ # ...if we have valid images in this batch
807
+
808
+ batch_results.extend(valid_batch_results)
809
+
810
+ return batch_results
811
+
812
+ # ...def _process_batch(...)
813
+
814
+
484
815
  #%% Image processing functions
485
816
 
486
- def process_images(im_files,
817
+ def _process_images(im_files,
487
818
  detector,
488
819
  confidence_threshold,
489
820
  use_image_queue=False,
@@ -498,7 +829,8 @@ def process_images(im_files,
498
829
  loader_workers=default_loaders,
499
830
  preprocess_on_image_queue=default_preprocess_on_image_queue):
500
831
  """
501
- Runs a detector (typically MegaDetector) over a list of image files on a single thread.
832
+ Runs a detector (typically MegaDetector) over a list of image files, possibly using multiple
833
+ image loading workers, but not using multiple inference workers.
502
834
 
503
835
  Args:
504
836
  im_files (list): paths to image files
@@ -523,7 +855,7 @@ def process_images(im_files,
523
855
 
524
856
  Returns:
525
857
  list: list of dicts, in which each dict represents detections on one image,
526
- see the 'images' key in https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#batch-processing-api-output-format
858
+ see the 'images' key in https://lila.science/megadetector-output-format
527
859
  """
528
860
 
529
861
  if isinstance(detector, str):
@@ -533,14 +865,14 @@ def process_images(im_files,
533
865
  detector_options=detector_options,
534
866
  verbose=verbose)
535
867
  elapsed = time.time() - start_time
536
- print('Loaded model (batch level) in {}'.format(humanfriendly.format_timespan(elapsed)))
868
+ print('Loaded model (process_images) in {}'.format(humanfriendly.format_timespan(elapsed)))
537
869
 
538
870
  if detector_options is None:
539
871
  detector_options = {}
540
872
 
541
873
  if use_image_queue:
542
874
 
543
- run_detector_with_image_queue(im_files,
875
+ _run_detector_with_image_queue(im_files,
544
876
  detector,
545
877
  confidence_threshold,
546
878
  quiet=quiet,
@@ -557,7 +889,7 @@ def process_images(im_files,
557
889
 
558
890
  results = []
559
891
  for im_file in im_files:
560
- result = process_image(im_file,
892
+ result = _process_image(im_file,
561
893
  detector,
562
894
  confidence_threshold,
563
895
  quiet=quiet,
@@ -573,20 +905,21 @@ def process_images(im_files,
573
905
 
574
906
  return results
575
907
 
576
- # ...def process_images(...)
908
+ # ...if we are/aren't using the image queue
909
+
910
+ # ...def _process_images(...)
577
911
 
578
912
 
579
- def process_image(im_file,
580
- detector,
581
- confidence_threshold,
582
- image=None,
583
- quiet=False,
584
- image_size=None,
585
- include_image_size=False,
586
- include_image_timestamp=False,
587
- include_exif_data=False,
588
- skip_image_resizing=False,
589
- augment=False):
913
+ def _process_image(im_file,
914
+ detector,
915
+ confidence_threshold,
916
+ image=None,
917
+ quiet=False,
918
+ image_size=None,
919
+ include_image_size=False,
920
+ include_image_timestamp=False,
921
+ include_exif_data=False,
922
+ augment=False):
590
923
  """
591
924
  Runs a detector (typically MegaDetector) on a single image file.
592
925
 
@@ -595,8 +928,8 @@ def process_image(im_file,
595
928
  detector (detector object): loaded model, this can no longer be a string by the time
596
929
  you get this far down the pipeline
597
930
  confidence_threshold (float): only detections above this threshold are returned
598
- image (Image, optional): previously-loaded image, if available, used when a worker
599
- thread is handling image loads
931
+ image (Image or dict, optional): previously-loaded image, if available, used when a worker
932
+ thread is handling image loading (and possibly preprocessing)
600
933
  quiet (bool, optional): suppress per-image printouts
601
934
  image_size (int, optional): image size to use for inference, only mess with this
602
935
  if (a) you're using a model other than MegaDetector or (b) you know what you're
@@ -604,13 +937,11 @@ def process_image(im_file,
604
937
  include_image_size (bool, optional): should we include image size in the output for each image?
605
938
  include_image_timestamp (bool, optional): should we include image timestamps in the output for each image?
606
939
  include_exif_data (bool, optional): should we include EXIF data in the output for each image?
607
- skip_image_resizing (bool, optional): whether to skip internal image resizing and rely on external resizing
608
940
  augment (bool, optional): enable image augmentation
609
941
 
610
942
  Returns:
611
943
  dict: dict representing detections on one image,
612
- see the 'images' key in
613
- https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#batch-processing-api-output-format
944
+ see the 'images' key in https://lila.science/megadetector-output-format
614
945
  """
615
946
 
616
947
  if not quiet:
@@ -635,8 +966,9 @@ def process_image(im_file,
635
966
  im_file,
636
967
  detection_threshold=confidence_threshold,
637
968
  image_size=image_size,
638
- skip_image_resizing=skip_image_resizing,
639
- augment=augment)
969
+ augment=augment,
970
+ verbose=verbose)
971
+
640
972
  except Exception as e:
641
973
  if not quiet:
642
974
  print('Image {} cannot be processed. Exception: {}'.format(im_file, e))
@@ -646,6 +978,7 @@ def process_image(im_file,
646
978
  }
647
979
  return result
648
980
 
981
+ # If this image has already been preprocessed
649
982
  if isinstance(image,dict):
650
983
  image = image['img_original_pil']
651
984
 
@@ -661,15 +994,19 @@ def process_image(im_file,
661
994
 
662
995
  return result
663
996
 
664
- # ...def process_image(...)
997
+ # ...def _process_image(...)
665
998
 
666
999
 
667
1000
  def _load_custom_class_mapping(class_mapping_filename):
668
1001
  """
669
- This is an experimental hack to allow the use of non-MD YOLOv5 models through
670
- the same infrastructure; it disables the code that enforces MDv5-like class lists.
1002
+ Allows the use of non-MD models, disables the code that enforces MD-like class lists.
1003
+
1004
+ Args:
1005
+ class_mapping_filename (str): .json file that maps int-strings to strings, or a YOLOv5
1006
+ dataset.yaml file.
671
1007
 
672
- Should be a .json file that maps int-strings to strings, or a YOLOv5 dataset.yaml file.
1008
+ Returns:
1009
+ dict: maps class IDs (int-strings) to class names
673
1010
  """
674
1011
 
675
1012
  if class_mapping_filename is None:
@@ -712,7 +1049,8 @@ def load_and_run_detector_batch(model_file,
712
1049
  force_model_download=False,
713
1050
  detector_options=None,
714
1051
  loader_workers=default_loaders,
715
- preprocess_on_image_queue=default_preprocess_on_image_queue):
1052
+ preprocess_on_image_queue=default_preprocess_on_image_queue,
1053
+ batch_size=1):
716
1054
  """
717
1055
  Load a model file and run it on a list of images.
718
1056
 
@@ -748,6 +1086,7 @@ def load_and_run_detector_batch(model_file,
748
1086
  loader_workers (int, optional): number of loaders to use, only relevant when use_image_queue is True
749
1087
  preprocess_on_image_queue (bool, optional): if the image queue is enabled, should it handle
750
1088
  image loading and preprocessing (True), or just image loading (False)?
1089
+ batch_size (int, optional): batch size for GPU processing, automatically set to 1 for CPU processing
751
1090
 
752
1091
  Returns:
753
1092
  results: list of dicts; each dict represents detections on one image
@@ -815,9 +1154,11 @@ def load_and_run_detector_batch(model_file,
815
1154
  force_download=force_model_download,
816
1155
  verbose=verbose)
817
1156
 
818
- print('GPU available: {}'.format(is_gpu_available(model_file)))
1157
+ gpu_available = is_gpu_available(model_file)
1158
+
1159
+ print('GPU available: {}'.format(gpu_available))
819
1160
 
820
- if (n_cores > 1) and is_gpu_available(model_file):
1161
+ if (n_cores > 1) and gpu_available:
821
1162
 
822
1163
  print('Warning: multiple cores requested, but a GPU is available; parallelization across ' + \
823
1164
  'GPUs is not currently supported, defaulting to one GPU')
@@ -831,26 +1172,39 @@ def load_and_run_detector_batch(model_file,
831
1172
 
832
1173
  if use_image_queue:
833
1174
 
834
- assert checkpoint_frequency < 0, \
835
- 'Using an image queue is not currently supported when checkpointing is enabled'
836
- assert len(results) == 0, \
837
- 'Using an image queue with results loaded from a checkpoint is not currently supported'
838
1175
  assert n_cores <= 1
839
- results = run_detector_with_image_queue(image_file_names,
840
- model_file,
841
- confidence_threshold,
842
- quiet,
843
- image_size=image_size,
844
- include_image_size=include_image_size,
845
- include_image_timestamp=include_image_timestamp,
846
- include_exif_data=include_exif_data,
847
- augment=augment,
848
- detector_options=detector_options,
849
- loader_workers=loader_workers,
850
- preprocess_on_image_queue=preprocess_on_image_queue)
1176
+
1177
+ # Filter out already processed images
1178
+ images_to_process = [im_file for im_file in image_file_names
1179
+ if im_file not in already_processed]
1180
+
1181
+ if len(images_to_process) != len(image_file_names):
1182
+ print('Bypassing {} images that have already been processed'.format(
1183
+ len(image_file_names) - len(images_to_process)))
1184
+
1185
+ new_results = _run_detector_with_image_queue(images_to_process,
1186
+ model_file,
1187
+ confidence_threshold,
1188
+ quiet,
1189
+ image_size=image_size,
1190
+ include_image_size=include_image_size,
1191
+ include_image_timestamp=include_image_timestamp,
1192
+ include_exif_data=include_exif_data,
1193
+ augment=augment,
1194
+ detector_options=detector_options,
1195
+ loader_workers=loader_workers,
1196
+ preprocess_on_image_queue=preprocess_on_image_queue,
1197
+ batch_size=batch_size,
1198
+ checkpoint_path=checkpoint_path,
1199
+ checkpoint_frequency=checkpoint_frequency)
1200
+
1201
+ # Merge new results with existing results from checkpoint
1202
+ results.extend(new_results)
851
1203
 
852
1204
  elif n_cores <= 1:
853
1205
 
1206
+ # Single-threaded processing, no image queue
1207
+
854
1208
  # Load the detector
855
1209
  start_time = time.time()
856
1210
  detector = load_detector(model_file,
@@ -859,38 +1213,81 @@ def load_and_run_detector_batch(model_file,
859
1213
  elapsed = time.time() - start_time
860
1214
  print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
861
1215
 
862
- # This is only used for console reporting, so it's OK that it doesn't
863
- # include images we might have loaded from a previous checkpoint
864
- count = 0
1216
+ if (batch_size > 1) and (not gpu_available):
1217
+ print('Batch size of {} requested, but no GPU is available, using batch size 1'.format(
1218
+ batch_size))
1219
+ batch_size = 1
865
1220
 
866
- for im_file in tqdm(image_file_names):
1221
+ # Filter out already processed images
1222
+ images_to_process = [im_file for im_file in image_file_names
1223
+ if im_file not in already_processed]
867
1224
 
868
- # Will not add additional entries not in the starter checkpoint
869
- if im_file in already_processed:
870
- if not quiet:
871
- print('Bypassing image {}'.format(im_file))
872
- continue
1225
+ if len(images_to_process) != len(image_file_names):
1226
+ print('Bypassing {} images that have already been processed'.format(
1227
+ len(image_file_names) - len(images_to_process)))
873
1228
 
874
- count += 1
1229
+ image_count = 0
875
1230
 
876
- result = process_image(im_file,
877
- detector,
878
- confidence_threshold,
879
- quiet=quiet,
880
- image_size=image_size,
881
- include_image_size=include_image_size,
882
- include_image_timestamp=include_image_timestamp,
883
- include_exif_data=include_exif_data,
884
- augment=augment)
885
- results.append(result)
1231
+ if (batch_size > 1):
1232
+
1233
+ # During testing, randomize the order of images_to_process to help detect
1234
+ # non-deterministic batching issues
1235
+ if randomize_batch_order_during_testing and ('PYTEST_CURRENT_TEST' in os.environ):
1236
+ print('PyTest detected: randomizing batch order')
1237
+ random.seed(int(time.time()))
1238
+ debug_seed = random.randint(0, 2**31 - 1)
1239
+ print('Debug seed: {}'.format(debug_seed))
1240
+ random.seed(debug_seed)
1241
+ random.shuffle(images_to_process)
1242
+
1243
+ # Use batch processing
1244
+ image_batches = _group_into_batches(images_to_process, batch_size)
1245
+
1246
+ for batch in tqdm(image_batches):
1247
+ batch_results = _process_batch(batch,
1248
+ detector,
1249
+ confidence_threshold,
1250
+ quiet=quiet,
1251
+ image_size=image_size,
1252
+ include_image_size=include_image_size,
1253
+ include_image_timestamp=include_image_timestamp,
1254
+ include_exif_data=include_exif_data,
1255
+ augment=augment)
886
1256
 
887
- # Write a checkpoint if necessary
888
- if (checkpoint_frequency != -1) and ((count % checkpoint_frequency) == 0):
1257
+ results.extend(batch_results)
1258
+ image_count += len(batch)
889
1259
 
890
- print('Writing a new checkpoint after having processed {} images since '
891
- 'last restart'.format(count))
1260
+ # Write a checkpoint if necessary
1261
+ if (checkpoint_frequency != -1) and ((image_count % checkpoint_frequency) == 0):
1262
+ print('Writing a new checkpoint after having processed {} images since '
1263
+ 'last restart'.format(image_count))
1264
+ write_checkpoint(checkpoint_path, results)
892
1265
 
893
- _write_checkpoint(checkpoint_path, results)
1266
+ else:
1267
+
1268
+ # Use non-batch processing
1269
+ for im_file in tqdm(images_to_process):
1270
+
1271
+ image_count += 1
1272
+
1273
+ result = _process_image(im_file,
1274
+ detector,
1275
+ confidence_threshold,
1276
+ quiet=quiet,
1277
+ image_size=image_size,
1278
+ include_image_size=include_image_size,
1279
+ include_image_timestamp=include_image_timestamp,
1280
+ include_exif_data=include_exif_data,
1281
+ augment=augment)
1282
+ results.append(result)
1283
+
1284
+ # Write a checkpoint if necessary
1285
+ if (checkpoint_frequency != -1) and ((image_count % checkpoint_frequency) == 0):
1286
+ print('Writing a new checkpoint after having processed {} images since '
1287
+ 'last restart'.format(image_count))
1288
+ write_checkpoint(checkpoint_path, results)
1289
+
1290
+ # ...if the batch size is > 1
894
1291
 
895
1292
  else:
896
1293
 
@@ -910,7 +1307,7 @@ def load_and_run_detector_batch(model_file,
910
1307
  len(already_processed),n_images_all))
911
1308
 
912
1309
  # Divide images into chunks; we'll send one chunk to each worker process
913
- image_batches = list(_chunks_by_number_of_chunks(image_file_names, n_cores))
1310
+ image_chunks = list(_chunks_by_number_of_chunks(image_file_names, n_cores))
914
1311
 
915
1312
  pool = None
916
1313
  try:
@@ -922,15 +1319,15 @@ def load_and_run_detector_batch(model_file,
922
1319
 
923
1320
  checkpoint_queue = Manager().Queue()
924
1321
 
925
- # Pass the "results" array (which may already contain images loaded from an existing
926
- # checkpoint) to the checkpoint queue handler function, which will append results to
927
- # the list as they become available.
1322
+ # Pass the "results" array (which may already contain images loaded from an
1323
+ # existing checkpoint) to the checkpoint queue handler function, which will
1324
+ # append results to the list as they become available.
928
1325
  checkpoint_thread = Thread(target=_checkpoint_queue_handler,
929
1326
  args=(checkpoint_path, checkpoint_frequency,
930
1327
  checkpoint_queue, results), daemon=True)
931
1328
  checkpoint_thread.start()
932
1329
 
933
- pool.map(partial(process_images,
1330
+ pool.map(partial(_process_images,
934
1331
  detector=detector,
935
1332
  confidence_threshold=confidence_threshold,
936
1333
  use_image_queue=False,
@@ -942,7 +1339,7 @@ def load_and_run_detector_batch(model_file,
942
1339
  include_exif_data=include_exif_data,
943
1340
  augment=augment,
944
1341
  detector_options=detector_options),
945
- image_batches)
1342
+ image_chunks)
946
1343
 
947
1344
  checkpoint_queue.put(None)
948
1345
 
@@ -950,7 +1347,7 @@ def load_and_run_detector_batch(model_file,
950
1347
 
951
1348
  # Multprocessing is enabled, but checkpointing is not
952
1349
 
953
- new_results = pool.map(partial(process_images,
1350
+ new_results = pool.map(partial(_process_images,
954
1351
  detector=detector,
955
1352
  confidence_threshold=confidence_threshold,
956
1353
  use_image_queue=False,
@@ -962,13 +1359,13 @@ def load_and_run_detector_batch(model_file,
962
1359
  include_exif_data=include_exif_data,
963
1360
  augment=augment,
964
1361
  detector_options=detector_options),
965
- image_batches)
1362
+ image_chunks)
966
1363
 
967
1364
  new_results = list(itertools.chain.from_iterable(new_results))
968
1365
 
969
1366
  # Append the results we just computed to "results", which is *usually* empty, but will
970
1367
  # be non-empty if we resumed from a checkpoint
971
- results += new_results
1368
+ results.extend(new_results)
972
1369
 
973
1370
  # ...if checkpointing is/isn't enabled
974
1371
 
@@ -1007,12 +1404,18 @@ def _checkpoint_queue_handler(checkpoint_path, checkpoint_frequency, checkpoint_
1007
1404
  print('Writing a new checkpoint after having processed {} images since '
1008
1405
  'last restart'.format(result_count))
1009
1406
 
1010
- _write_checkpoint(checkpoint_path, results)
1407
+ write_checkpoint(checkpoint_path, results)
1011
1408
 
1012
1409
 
1013
- def _write_checkpoint(checkpoint_path, results):
1410
+ def write_checkpoint(checkpoint_path, results):
1014
1411
  """
1015
- Writes the 'images' field in the dict 'results' to a json checkpoint file.
1412
+ Writes the object in [results] to a json checkpoint file, as a dict with the
1413
+ key "checkpoint". First backs up the checkpoint file if it exists, in case we
1414
+ crash while writing the file.
1415
+
1416
+ Args:
1417
+ checkpoint_path (str): the file to write the checkpoint to
1418
+ results (object): the object we should write
1016
1419
  """
1017
1420
 
1018
1421
  assert checkpoint_path is not None
@@ -1025,11 +1428,41 @@ def _write_checkpoint(checkpoint_path, results):
1025
1428
  shutil.copyfile(checkpoint_path,checkpoint_tmp_path)
1026
1429
 
1027
1430
  # Write the new checkpoint
1028
- ct_utils.write_json(checkpoint_path, {'images': results}, force_str=True)
1431
+ ct_utils.write_json(checkpoint_path, {'checkpoint': results}, force_str=True)
1029
1432
 
1030
1433
  # Remove the backup checkpoint if it exists
1031
1434
  if checkpoint_tmp_path is not None:
1032
- os.remove(checkpoint_tmp_path)
1435
+ try:
1436
+ os.remove(checkpoint_tmp_path)
1437
+ except Exception as e:
1438
+ print('Warning: error removing backup checkpoint file {}:\n{}'.format(
1439
+ checkpoint_tmp_path,str(e)))
1440
+
1441
+
1442
+ def load_checkpoint(checkpoint_path):
1443
+ """
1444
+ Loads results from a checkpoint file. A checkpoint file is always a dict
1445
+ with the key "checkpoint".
1446
+
1447
+ Args:
1448
+ checkpoint_path (str): the .json file to load
1449
+
1450
+ Returns:
1451
+ object: object retrieved from the checkpoint, typically a list of results
1452
+ """
1453
+
1454
+ print('Loading previous results from checkpoint file {}'.format(checkpoint_path))
1455
+
1456
+ with open(checkpoint_path, 'r') as f:
1457
+ checkpoint_data = json.load(f)
1458
+
1459
+ if 'checkpoint' not in checkpoint_data:
1460
+ raise ValueError('Checkpoint file {} is missing "checkpoint" field'.format(checkpoint_path))
1461
+
1462
+ results = checkpoint_data['checkpoint']
1463
+ print('Restored {} entries from the checkpoint {}'.format(len(results),checkpoint_path))
1464
+
1465
+ return results
1033
1466
 
1034
1467
 
1035
1468
  def get_image_datetime(image):
@@ -1066,7 +1499,7 @@ def write_results_to_file(results,
1066
1499
  """
1067
1500
  Writes list of detection results to JSON output file. Format matches:
1068
1501
 
1069
- https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#batch-processing-api-output-format
1502
+ https://lila.science/megadetector-output-format
1070
1503
 
1071
1504
  Args:
1072
1505
  results (list): list of dict, each dict represents detections on one image
@@ -1109,7 +1542,7 @@ def write_results_to_file(results,
1109
1542
 
1110
1543
  info = {
1111
1544
  'detection_completion_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
1112
- 'format_version': '1.4'
1545
+ 'format_version': '1.5'
1113
1546
  }
1114
1547
 
1115
1548
  if detector_file is not None:
@@ -1144,9 +1577,16 @@ def write_results_to_file(results,
1144
1577
 
1145
1578
  # Sort detections in descending order by confidence; not required by the format, but
1146
1579
  # convenient for consistency
1147
- for r in results:
1148
- if ('detections' in r) and (r['detections'] is not None):
1149
- r['detections'] = sort_list_of_dicts_by_key(r['detections'], 'conf', reverse=True)
1580
+ for im in results:
1581
+ if ('detections' in im) and (im['detections'] is not None):
1582
+ im['detections'] = sort_list_of_dicts_by_key(im['detections'], 'conf', reverse=True)
1583
+
1584
+ for im in results:
1585
+ if 'failure' in im:
1586
+ if 'detections' in im:
1587
+ assert im['detections'] is None, 'Illegal failure/detection combination'
1588
+ else:
1589
+ im['detections'] = None
1150
1590
 
1151
1591
  final_output = {
1152
1592
  'images': results,
@@ -1209,8 +1649,6 @@ if False:
1209
1649
  cmd += ' --output_relative_filenames'
1210
1650
  if include_max_conf:
1211
1651
  cmd += ' --include_max_conf'
1212
- if quiet:
1213
- cmd += ' --quiet'
1214
1652
  if image_size is not None:
1215
1653
  cmd += ' --image_size {}'.format(image_size)
1216
1654
  if use_image_queue:
@@ -1294,10 +1732,6 @@ def main(): # noqa
1294
1732
  '--include_max_conf',
1295
1733
  action='store_true',
1296
1734
  help='Include the "max_detection_conf" field in the output')
1297
- parser.add_argument(
1298
- '--quiet',
1299
- action='store_true',
1300
- help='Suppress per-image console output')
1301
1735
  parser.add_argument(
1302
1736
  '--verbose',
1303
1737
  action='store_true',
@@ -1414,6 +1848,17 @@ def main(): # noqa
1414
1848
  metavar='KEY=VALUE',
1415
1849
  default='',
1416
1850
  help='Detector-specific options, as a space-separated list of key-value pairs')
1851
+ parser.add_argument(
1852
+ '--batch_size',
1853
+ type=int,
1854
+ default=1,
1855
+ help='Batch size for GPU inference (default 1). CPU inference will ignore this and use batch_size=1.')
1856
+
1857
+ # This argument is deprecated, we always use what was formerly "quiet mode"
1858
+ parser.add_argument(
1859
+ '--quiet',
1860
+ action='store_true',
1861
+ help=argparse.SUPPRESS)
1417
1862
 
1418
1863
  if len(sys.argv[1:]) == 0:
1419
1864
  parser.print_help()
@@ -1476,7 +1921,7 @@ def main(): # noqa
1476
1921
  # Load the checkpoint if available
1477
1922
  #
1478
1923
  # File paths in the checkpoint are always absolute paths; conversion to relative paths
1479
- # happens below (if necessary).
1924
+ # (if requested) happens at the time results are exported at the end of a job.
1480
1925
  if args.resume_from_checkpoint is not None:
1481
1926
  if args.resume_from_checkpoint == 'auto':
1482
1927
  checkpoint_files = os.listdir(output_dir)
@@ -1494,16 +1939,7 @@ def main(): # noqa
1494
1939
  checkpoint_file = os.path.join(output_dir,checkpoint_file_relative)
1495
1940
  else:
1496
1941
  checkpoint_file = args.resume_from_checkpoint
1497
- assert os.path.exists(checkpoint_file), \
1498
- 'File at resume_from_checkpoint specified does not exist'
1499
- with open(checkpoint_file) as f:
1500
- print('Loading previous results from checkpoint file {}'.format(
1501
- checkpoint_file))
1502
- saved = json.load(f)
1503
- assert 'images' in saved, \
1504
- 'The checkpoint file does not have the correct fields; cannot be restored'
1505
- results = saved['images']
1506
- print('Restored {} entries from the checkpoint'.format(len(results)))
1942
+ results = load_checkpoint(checkpoint_file)
1507
1943
  else:
1508
1944
  results = []
1509
1945
 
@@ -1620,16 +2056,6 @@ def main(): # noqa
1620
2056
  f'Checkpoint path {checkpoint_path} already exists, delete or move it before ' + \
1621
2057
  're-using the same checkpoint path, or specify --allow_checkpoint_overwrite'
1622
2058
 
1623
-
1624
- # Confirm that we can write to the checkpoint path; this avoids issues where
1625
- # we crash after several thousand images.
1626
- #
1627
- # But actually, commenting this out for now... the scenario where we are resuming from a
1628
- # checkpoint, then immediately overwrite that checkpoint with empty data is higher-risk
1629
- # than the annoyance of crashing a few minutes after starting a job.
1630
- if False:
1631
- ct_utils.write_json(checkpoint_path, {'images': []}, indent=None)
1632
-
1633
2059
  print('The checkpoint file will be written to {}'.format(checkpoint_path))
1634
2060
 
1635
2061
  else:
@@ -1649,7 +2075,7 @@ def main(): # noqa
1649
2075
  results=results,
1650
2076
  n_cores=args.ncores,
1651
2077
  use_image_queue=args.use_image_queue,
1652
- quiet=args.quiet,
2078
+ quiet=True,
1653
2079
  image_size=args.image_size,
1654
2080
  class_mapping_filename=args.class_mapping_filename,
1655
2081
  include_image_size=args.include_image_size,
@@ -1660,7 +2086,8 @@ def main(): # noqa
1660
2086
  force_model_download=False,
1661
2087
  detector_options=detector_options,
1662
2088
  loader_workers=args.loader_workers,
1663
- preprocess_on_image_queue=args.preprocess_on_image_queue)
2089
+ preprocess_on_image_queue=args.preprocess_on_image_queue,
2090
+ batch_size=args.batch_size)
1664
2091
 
1665
2092
  elapsed = time.time() - start_time
1666
2093
  images_per_second = len(results) / elapsed