megadetector 10.0.1__py3-none-any.whl → 10.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

@@ -5,7 +5,7 @@ run_detector_batch.py
5
5
  Module to run MegaDetector on lots of images, writing the results
6
6
  to a file in the MegaDetector results format.
7
7
 
8
- https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#megadetector-batch-output-format
8
+ https://lila.science/megadetector-output-format
9
9
 
10
10
  This enables the results to be used in our post-processing pipeline; see postprocess_batch_results.py.
11
11
 
@@ -23,7 +23,7 @@ is not supported when using a GPU.
23
23
 
24
24
  The lack of GPU multiprocessing support might sound annoying, but in practice we
25
25
  run a gazillion MegaDetector images on multiple GPUs using this script, we just only use
26
- one GPU *per invocation of this script*. Dividing a big batch of images into one chunk
26
+ one GPU *per invocation of this script*. Dividing a list of images into one chunk
27
27
  per GPU happens outside of this script.
28
28
 
29
29
  Does not have a command-line option to bind the process to a particular GPU, but you can
@@ -44,6 +44,7 @@ import sys
44
44
  import time
45
45
  import copy
46
46
  import shutil
47
+ import random
47
48
  import warnings
48
49
  import itertools
49
50
  import humanfriendly
@@ -106,6 +107,8 @@ exif_options = read_exif.ReadExifOptions()
106
107
  exif_options.processing_library = 'pil'
107
108
  exif_options.byte_handling = 'convert_to_string'
108
109
 
110
+ randomize_batch_order_during_testing = True
111
+
109
112
 
110
113
  #%% Support functions for multiprocessing
111
114
 
@@ -120,11 +123,22 @@ def _producer_func(q,
120
123
  """
121
124
  Producer function; only used when using the (optional) image queue.
122
125
 
123
- Reads up to images from disk and puts them on the blocking queue for
124
- processing. Each image is queued as a tuple of [filename,Image]. Sends
125
- "None" to the queue when finished.
126
+ Reads images from disk and puts, optionally preprocesses them (depending on whether "preprocessor"
127
+ is None, then puts them on the blocking queue for processing. Each image is queued as a tuple of
128
+ [filename,Image]. Sends "None" to the queue when finished.
126
129
 
127
130
  The "detector" argument is only used for preprocessing.
131
+
132
+ Args:
133
+ q (Queue): multiprocessing queue to put loaded/preprocessed images into
134
+ image_files (list): list of image file paths to process
135
+ producer_id (int, optional): identifier for this producer worker (for logging)
136
+ preprocessor (str, optional): model file path/identifier for preprocessing, or None to skip preprocessing
137
+ detector_options (dict, optional): key/value pairs that are interpreted differently
138
+ by different detectors
139
+ verbose (bool, optional): enable additional debug output
140
+ image_size (int, optional): image size to use for preprocessing
141
+ augment (bool, optional): enable image augmentation during preprocessing
128
142
  """
129
143
 
130
144
  if verbose:
@@ -134,6 +148,8 @@ def _producer_func(q,
134
148
  if preprocessor is not None:
135
149
  assert isinstance(preprocessor,str)
136
150
  detector_options = deepcopy(detector_options)
151
+ # Tell the detector object it's being loaded as a preprocessor, so it
152
+ # shouldn't actually load model weights.
137
153
  detector_options['preprocess_only'] = True
138
154
  preprocessor = load_detector(preprocessor,
139
155
  detector_options=detector_options,
@@ -149,15 +165,10 @@ def _producer_func(q,
149
165
 
150
166
  if preprocessor is not None:
151
167
 
152
- image_info = preprocessor.generate_detections_one_image(
153
- image,
154
- im_file,
155
- detection_threshold=None,
156
- image_size=image_size,
157
- skip_image_resizing=False,
158
- augment=augment,
159
- preprocess_only=True,
160
- verbose=verbose)
168
+ image_info = preprocessor.preprocess_image(image,
169
+ image_id=im_file,
170
+ image_size=image_size,
171
+ verbose=verbose)
161
172
  if 'failure' in image_info:
162
173
  assert image_info['failure'] == run_detector.FAILURE_INFER
163
174
  raise
@@ -174,6 +185,8 @@ def _producer_func(q,
174
185
 
175
186
  q.put([im_file,image,producer_id])
176
187
 
188
+ # ...for each image
189
+
177
190
  # This is a signal to the consumer function that a worker has finished
178
191
  q.put(None)
179
192
 
@@ -196,13 +209,31 @@ def _consumer_func(q,
196
209
  augment=False,
197
210
  detector_options=None,
198
211
  preprocess_on_image_queue=default_preprocess_on_image_queue,
199
- n_total_images=None
212
+ n_total_images=None,
213
+ batch_size=1
200
214
  ):
201
215
  """
202
216
  Consumer function; only used when using the (optional) image queue.
203
217
 
204
218
  Pulls images from a blocking queue and processes them. Returns when "None" has
205
219
  been read from each loader's queue.
220
+
221
+ Args:
222
+ q (Queue): multiprocessing queue to pull images from
223
+ return_queue (Queue): queue to put final results into
224
+ model_file (str or detector object): model file path/identifier or pre-loaded detector
225
+ confidence_threshold (float): only detections above this threshold are returned
226
+ loader_workers (int): number of producer workers (used to know when all are finished)
227
+ image_size (int, optional): image size to use for inference
228
+ include_image_size (bool, optional): include image dimensions in output
229
+ include_image_timestamp (bool, optional): include image timestamps in output
230
+ include_exif_data (bool, optional): include EXIF data in output
231
+ augment (bool, optional): enable image augmentation
232
+ detector_options (dict, optional): key/value pairs that are interpreted differently
233
+ by different detectors
234
+ preprocess_on_image_queue (bool, optional): whether images are already preprocessed on the queue
235
+ n_total_images (int, optional): total number of images expected (for progress bar)
236
+ batch_size (int, optional): batch size for GPU inference
206
237
  """
207
238
 
208
239
  if verbose:
@@ -232,38 +263,78 @@ def _consumer_func(q,
232
263
  # TODO: in principle I should close this pbar
233
264
  pbar = tqdm(total=n_total_images)
234
265
 
266
+ # Batch processing state
267
+ if batch_size > 1:
268
+ current_batch_items = []
269
+
235
270
  while True:
236
271
 
237
272
  r = q.get()
238
273
 
239
274
  # Is this the last image in one of the producer queues?
240
275
  if r is None:
276
+
241
277
  n_queues_finished += 1
242
278
  q.task_done()
279
+
243
280
  if verbose:
244
281
  print('Consumer thread: {} of {} queues finished'.format(
245
282
  n_queues_finished,loader_workers))
283
+
284
+ # Was this the last worker to finish?
246
285
  if n_queues_finished == loader_workers:
286
+
287
+ # Do we have any leftover images?
288
+ if (batch_size > 1) and (len(current_batch_items) > 0):
289
+
290
+ # We should never have more than one batch of work left to do, so this loop
291
+ # not strictly necessary; it's a bit of future-proofing.
292
+ leftover_batches = _group_into_batches(current_batch_items, batch_size)
293
+
294
+ if len(leftover_batches) > 1:
295
+ print('Warning: after all producer queues finished, '
296
+ '{} images were left for processing, which is more than'
297
+ 'the batch size of {}'.format(len(current_batch_items),batch_size))
298
+
299
+ for leftover_batch in leftover_batches:
300
+
301
+ batch_results = _process_batch(leftover_batch,
302
+ detector,
303
+ confidence_threshold,
304
+ quiet=True,
305
+ image_size=image_size,
306
+ include_image_size=include_image_size,
307
+ include_image_timestamp=include_image_timestamp,
308
+ include_exif_data=include_exif_data,
309
+ augment=augment)
310
+ results.extend(batch_results)
311
+
312
+ if pbar is not None:
313
+ pbar.update(len(leftover_batch))
314
+
315
+ n_images_processed += len(leftover_batch)
316
+
317
+ # ...for each batch we have left to process
318
+
247
319
  return_queue.put(results)
248
320
  return
321
+
249
322
  else:
323
+
250
324
  continue
251
- n_images_processed += 1
252
- im_file = r[0]
253
- image = r[1]
254
325
 
255
- """
256
- result['img_processed'] = img
257
- result['img_original'] = img_original
258
- result['target_shape'] = target_shape
259
- result['scaling_shape'] = scaling_shape
260
- result['letterbox_ratio'] = letterbox_ratio
261
- result['letterbox_pad'] = letterbox_pad
262
- """
326
+ # ...if we pulled the sentinel signal (None) telling us that a worker finished
263
327
 
264
- if pbar is not None:
265
- pbar.update(1)
328
+ # At this point, we have a real image (i.e., not a sentinel indicating that a worker finished)
329
+ #
330
+ # "r" is always a tuple of (filename,image,producer_id)
331
+ #
332
+ # Image can be a PIL image (if the loader wasn't doing preprocessing) or a dict with
333
+ # a preprocessed image and associated metadata.
334
+ im_file = r[0]
335
+ image = r[1]
266
336
 
337
+ # This block is sometimes useful for debugging, so I'm leaving it here, but if'd out
267
338
  if False:
268
339
  if verbose or ((n_images_processed % n_queue_print) == 1):
269
340
  elapsed = time.time() - start_time
@@ -273,29 +344,85 @@ def _consumer_func(q,
273
344
  im_file))
274
345
  sys.stdout.flush()
275
346
 
347
+ # Handle failed images immediately (don't batch them)
348
+ #
349
+ # Loader workers communicate failures by passing a string to
350
+ # the consumer, rather than an image.
276
351
  if isinstance(image,str):
277
- # This is how the producer function communicates read errors
352
+
278
353
  results.append({'file': im_file,
279
354
  'failure': image})
355
+ n_images_processed += 1
356
+
357
+ if pbar is not None:
358
+ pbar.update(1)
359
+
360
+ # This is a catastrophic internal failure; preprocessing workers should
361
+ # be passing the consumer dicts that represent processed images
280
362
  elif preprocess_on_image_queue and (not isinstance(image,dict)):
281
- print('Expected a dict, received an image of type {}'.format(type(image)))
282
- results.append({'file': im_file,
283
- 'failure': 'illegal image type'})
363
+
364
+ print('Expected a dict, received an image of type {}'.format(type(image)))
365
+ results.append({'file': im_file,
366
+ 'failure': 'illegal image type'})
367
+ n_images_processed += 1
368
+
369
+ if pbar is not None:
370
+ pbar.update(1)
284
371
 
285
372
  else:
286
- results.append(process_image(im_file=im_file,
287
- detector=detector,
288
- confidence_threshold=confidence_threshold,
289
- image=image,
290
- quiet=True,
291
- image_size=image_size,
292
- include_image_size=include_image_size,
293
- include_image_timestamp=include_image_timestamp,
294
- include_exif_data=include_exif_data,
295
- augment=augment,
296
- skip_image_resizing=preprocess_on_image_queue))
373
+
374
+ # At this point, "image" is either an image (if the producer workers are only
375
+ # doing loading) or a dict (if the producer workers are doing preprocessing)
376
+
377
+ if batch_size > 1:
378
+
379
+ # Add to current batch
380
+ current_batch_items.append([im_file, image, r[2]])
381
+
382
+ # Process batch when full
383
+ if len(current_batch_items) >= batch_size:
384
+ batch_results = _process_batch(current_batch_items,
385
+ detector,
386
+ confidence_threshold,
387
+ quiet=True,
388
+ image_size=image_size,
389
+ include_image_size=include_image_size,
390
+ include_image_timestamp=include_image_timestamp,
391
+ include_exif_data=include_exif_data,
392
+ augment=augment)
393
+ results.extend(batch_results)
394
+
395
+ if pbar is not None:
396
+ pbar.update(len(current_batch_items))
397
+
398
+ n_images_processed += len(current_batch_items)
399
+ current_batch_items = []
400
+ else:
401
+
402
+ # Process single image
403
+ result = _process_image(im_file=im_file,
404
+ detector=detector,
405
+ confidence_threshold=confidence_threshold,
406
+ image=image,
407
+ quiet=True,
408
+ image_size=image_size,
409
+ include_image_size=include_image_size,
410
+ include_image_timestamp=include_image_timestamp,
411
+ include_exif_data=include_exif_data,
412
+ augment=augment)
413
+ results.append(result)
414
+ n_images_processed += 1
415
+
416
+ if pbar is not None:
417
+ pbar.update(1)
418
+
419
+ # ...if we are/aren't doing batch processing
420
+
421
+ # ...whether we received a string (indicating failure) or an image from the loader worker
422
+
297
423
  if verbose:
298
424
  print('Processed image {}'.format(im_file)); sys.stdout.flush()
425
+
299
426
  q.task_done()
300
427
 
301
428
  # ...while True (consumer loop)
@@ -303,23 +430,22 @@ def _consumer_func(q,
303
430
  # ...def _consumer_func(...)
304
431
 
305
432
 
306
- def run_detector_with_image_queue(image_files,
307
- model_file,
308
- confidence_threshold,
309
- quiet=False,
310
- image_size=None,
311
- include_image_size=False,
312
- include_image_timestamp=False,
313
- include_exif_data=False,
314
- augment=False,
315
- detector_options=None,
316
- loader_workers=default_loaders,
317
- preprocess_on_image_queue=default_preprocess_on_image_queue):
433
+ def _run_detector_with_image_queue(image_files,
434
+ model_file,
435
+ confidence_threshold,
436
+ quiet=False,
437
+ image_size=None,
438
+ include_image_size=False,
439
+ include_image_timestamp=False,
440
+ include_exif_data=False,
441
+ augment=False,
442
+ detector_options=None,
443
+ loader_workers=default_loaders,
444
+ preprocess_on_image_queue=default_preprocess_on_image_queue,
445
+ batch_size=1):
318
446
  """
319
- Driver function for the (optional) multiprocessing-based image queue; only used
320
- when --use_image_queue is specified. Starts a reader process to read images from disk, but
321
- processes images in the process from which this function is called (i.e., does not currently
322
- spawn a separate consumer process).
447
+ Driver function for the (optional) multiprocessing-based image queue. Spawns workers to read and
448
+ preprocess images, runs the consumer function in the calling process.
323
449
 
324
450
  Args:
325
451
  image_files (str): list of absolute paths to images
@@ -339,6 +465,7 @@ def run_detector_with_image_queue(image_files,
339
465
  loader_workers (int, optional): number of loaders to use
340
466
  preprocess_on_image_queue (bool, optional): if the image queue is enabled, should it handle
341
467
  image loading and preprocessing (True), or just image loading (False)?
468
+ batch_size (int, optional): batch size for GPU processing
342
469
 
343
470
  Returns:
344
471
  list: list of dicts in the format returned by process_image()
@@ -408,7 +535,8 @@ def run_detector_with_image_queue(image_files,
408
535
  augment,
409
536
  detector_options,
410
537
  preprocess_on_image_queue,
411
- n_total_images))
538
+ n_total_images,
539
+ batch_size))
412
540
  else:
413
541
  consumer = Process(target=_consumer_func,args=(q,
414
542
  return_queue,
@@ -422,7 +550,8 @@ def run_detector_with_image_queue(image_files,
422
550
  augment,
423
551
  detector_options,
424
552
  preprocess_on_image_queue,
425
- n_total_images))
553
+ n_total_images,
554
+ batch_size))
426
555
  consumer.daemon = True
427
556
  consumer.start()
428
557
  else:
@@ -438,7 +567,8 @@ def run_detector_with_image_queue(image_files,
438
567
  augment,
439
568
  detector_options,
440
569
  preprocess_on_image_queue,
441
- n_total_images)
570
+ n_total_images,
571
+ batch_size)
442
572
 
443
573
  for i_producer,producer in enumerate(producers):
444
574
  producer.join()
@@ -461,7 +591,7 @@ def run_detector_with_image_queue(image_files,
461
591
 
462
592
  return results
463
593
 
464
- # ...def run_detector_with_image_queue(...)
594
+ # ...def _run_detector_with_image_queue(...)
465
595
 
466
596
 
467
597
  #%% Other support functions
@@ -481,9 +611,191 @@ def _chunks_by_number_of_chunks(ls, n):
481
611
  yield ls[i::n]
482
612
 
483
613
 
614
+ #%% Batch processing helper functions
615
+
616
+ def _group_into_batches(items, batch_size):
617
+ """
618
+ Group items into batches.
619
+
620
+ Args:
621
+ items (list): items to group into batches
622
+ batch_size (int): size of each batch
623
+
624
+ Returns:
625
+ list: list of batches, where each batch is a list of items
626
+ """
627
+
628
+ if batch_size <= 0:
629
+ raise ValueError('Batch size must be positive')
630
+
631
+ batches = []
632
+ for i_item in range(0, len(items), batch_size):
633
+ batch = items[i_item:i_item + batch_size]
634
+ batches.append(batch)
635
+
636
+ return batches
637
+
638
+
639
+ def _process_batch(image_items_batch,
640
+ detector,
641
+ confidence_threshold,
642
+ quiet=False,
643
+ image_size=None,
644
+ include_image_size=False,
645
+ include_image_timestamp=False,
646
+ include_exif_data=False,
647
+ augment=False):
648
+ """
649
+ Process a batch of images using generate_detections_one_batch(). Does not necessarily return
650
+ results in the same order in which they were supplied; in particular, images that fail preprocessing
651
+ will be returned out of order.
652
+
653
+ Args:
654
+ image_items_batch (list): list of image file paths (strings) or list of tuples [filename, image, producer_id]
655
+ detector: loaded detector object
656
+ confidence_threshold (float): confidence threshold for detections
657
+ quiet (bool, optional): suppress per-image output
658
+ image_size (int, optional): image size override
659
+ include_image_size (bool, optional): include image dimensions in results
660
+ include_image_timestamp (bool, optional): include image timestamps in results
661
+ include_exif_data (bool, optional): include EXIF data in results
662
+ augment (bool, optional): whether to use image augmentation
663
+
664
+ Returns:
665
+ list of dict: list of results for each image in the batch
666
+ """
667
+
668
+ if (verbose):
669
+ print('_process_batch called with {} items'.format(len(image_items_batch)))
670
+
671
+ # This will be the set of items we send for inference; it may be
672
+ # smaller than the input list (image_items_batch) if some images
673
+ # fail to load. [valid_images] will be either a list of PIL Image
674
+ # objects or a list of dicts containing preprocessed images.
675
+ valid_images = []
676
+ valid_image_filenames = []
677
+
678
+ batch_results = []
679
+
680
+ for i_image, item in enumerate(image_items_batch):
681
+
682
+ # Handle both filename strings and tuples
683
+ if isinstance(item, str):
684
+ im_file = item
685
+ try:
686
+ image = vis_utils.load_image(im_file)
687
+ except Exception as e:
688
+ print('Image {} cannot be loaded: {}'.format(im_file,str(e)))
689
+ failed_result = {
690
+ 'file': im_file,
691
+ 'failure': run_detector.FAILURE_IMAGE_OPEN
692
+ }
693
+ batch_results.append(failed_result)
694
+ continue
695
+ else:
696
+ assert len(item) == 3
697
+ im_file, image, producer_id = item
698
+
699
+ valid_images.append(image)
700
+ valid_image_filenames.append(im_file)
701
+
702
+ # ...for each image in the batch
703
+
704
+ assert len(valid_images) == len(valid_image_filenames)
705
+
706
+ if verbose:
707
+ print('_process_batch found {} valid items in batch'.format(len(valid_images)))
708
+
709
+ valid_batch_results = []
710
+
711
+ # Process the batch if we have any valid images
712
+ if len(valid_images) > 0:
713
+
714
+ try:
715
+
716
+ batch_detections = \
717
+ detector.generate_detections_one_batch(valid_images, valid_image_filenames, verbose=verbose)
718
+
719
+ assert len(batch_detections) == len(valid_images)
720
+
721
+ # Apply confidence threshold and add metadata
722
+ for i_valid_image,image_result in enumerate(batch_detections):
723
+
724
+ assert valid_image_filenames[i_valid_image] == image_result['file']
725
+
726
+ if 'failure' not in image_result:
727
+
728
+ # Apply confidence threshold
729
+ image_result['detections'] = \
730
+ [det for det in image_result['detections'] if det['conf'] >= confidence_threshold]
731
+
732
+ if include_image_size or include_image_timestamp or include_exif_data:
733
+
734
+ image = valid_images[i_valid_image]
735
+
736
+ # If this was preprocessed by the producer thread, pull out the PIL version
737
+ if isinstance(image,dict):
738
+ image = image['img_original_pil']
739
+
740
+ if include_image_size:
741
+
742
+ image_result['width'] = image.width
743
+ image_result['height'] = image.height
744
+
745
+ if include_image_timestamp:
746
+
747
+ image_result['datetime'] = get_image_datetime(image)
748
+
749
+ if include_exif_data:
750
+
751
+ image_result['exif_metadata'] = read_exif.read_pil_exif(image,exif_options)
752
+
753
+ # ...if we need to store metadata
754
+
755
+ # ...if this image succeeded
756
+
757
+ # Failures here should be very rare; there's almost no reason an image would fail
758
+ # within a batch once it's been loaded
759
+ else:
760
+
761
+ print('Warning: within-batch processing failure for image {}'.format(
762
+ image_result['file']))
763
+
764
+ # Add to the list of results for the batch whether or not it succeeded
765
+ valid_batch_results.append(image_result)
766
+
767
+ # ...for each image in this batch
768
+
769
+ except Exception as e:
770
+
771
+ print('Batch processing failure for {} images: {}'.format(len(valid_images),str(e)))
772
+
773
+ # Throw out any successful results for this batch, this should almost never happen
774
+ valid_batch_results = []
775
+
776
+ for image_id in valid_image_filenames:
777
+ r = {'file':image_id,'failure': run_detector.FAILURE_INFER}
778
+ valid_batch_results.append(r)
779
+
780
+ # ...try/except
781
+
782
+ assert len(valid_batch_results) == len(valid_images)
783
+
784
+ # ...if we have valid images in this batch
785
+
786
+ batch_results.extend(valid_batch_results)
787
+
788
+ if verbose:
789
+ print('_process batch returning results for {} items'.format(len(batch_results)))
790
+
791
+ return batch_results
792
+
793
+ # ...def _process_batch(...)
794
+
795
+
484
796
  #%% Image processing functions
485
797
 
486
- def process_images(im_files,
798
+ def _process_images(im_files,
487
799
  detector,
488
800
  confidence_threshold,
489
801
  use_image_queue=False,
@@ -498,7 +810,8 @@ def process_images(im_files,
498
810
  loader_workers=default_loaders,
499
811
  preprocess_on_image_queue=default_preprocess_on_image_queue):
500
812
  """
501
- Runs a detector (typically MegaDetector) over a list of image files on a single thread.
813
+ Runs a detector (typically MegaDetector) over a list of image files, possibly using multiple
814
+ image loading workers, but not using multiple inference workers.
502
815
 
503
816
  Args:
504
817
  im_files (list): paths to image files
@@ -523,7 +836,7 @@ def process_images(im_files,
523
836
 
524
837
  Returns:
525
838
  list: list of dicts, in which each dict represents detections on one image,
526
- see the 'images' key in https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#batch-processing-api-output-format
839
+ see the 'images' key in https://lila.science/megadetector-output-format
527
840
  """
528
841
 
529
842
  if isinstance(detector, str):
@@ -533,14 +846,14 @@ def process_images(im_files,
533
846
  detector_options=detector_options,
534
847
  verbose=verbose)
535
848
  elapsed = time.time() - start_time
536
- print('Loaded model (batch level) in {}'.format(humanfriendly.format_timespan(elapsed)))
849
+ print('Loaded model (process_images) in {}'.format(humanfriendly.format_timespan(elapsed)))
537
850
 
538
851
  if detector_options is None:
539
852
  detector_options = {}
540
853
 
541
854
  if use_image_queue:
542
855
 
543
- run_detector_with_image_queue(im_files,
856
+ _run_detector_with_image_queue(im_files,
544
857
  detector,
545
858
  confidence_threshold,
546
859
  quiet=quiet,
@@ -557,7 +870,7 @@ def process_images(im_files,
557
870
 
558
871
  results = []
559
872
  for im_file in im_files:
560
- result = process_image(im_file,
873
+ result = _process_image(im_file,
561
874
  detector,
562
875
  confidence_threshold,
563
876
  quiet=quiet,
@@ -573,20 +886,21 @@ def process_images(im_files,
573
886
 
574
887
  return results
575
888
 
576
- # ...def process_images(...)
889
+ # ...if we are/aren't using the image queue
577
890
 
891
+ # ...def _process_images(...)
578
892
 
579
- def process_image(im_file,
580
- detector,
581
- confidence_threshold,
582
- image=None,
583
- quiet=False,
584
- image_size=None,
585
- include_image_size=False,
586
- include_image_timestamp=False,
587
- include_exif_data=False,
588
- skip_image_resizing=False,
589
- augment=False):
893
+
894
+ def _process_image(im_file,
895
+ detector,
896
+ confidence_threshold,
897
+ image=None,
898
+ quiet=False,
899
+ image_size=None,
900
+ include_image_size=False,
901
+ include_image_timestamp=False,
902
+ include_exif_data=False,
903
+ augment=False):
590
904
  """
591
905
  Runs a detector (typically MegaDetector) on a single image file.
592
906
 
@@ -595,8 +909,8 @@ def process_image(im_file,
595
909
  detector (detector object): loaded model, this can no longer be a string by the time
596
910
  you get this far down the pipeline
597
911
  confidence_threshold (float): only detections above this threshold are returned
598
- image (Image, optional): previously-loaded image, if available, used when a worker
599
- thread is handling image loads
912
+ image (Image or dict, optional): previously-loaded image, if available, used when a worker
913
+ thread is handling image loading (and possibly preprocessing)
600
914
  quiet (bool, optional): suppress per-image printouts
601
915
  image_size (int, optional): image size to use for inference, only mess with this
602
916
  if (a) you're using a model other than MegaDetector or (b) you know what you're
@@ -604,13 +918,11 @@ def process_image(im_file,
604
918
  include_image_size (bool, optional): should we include image size in the output for each image?
605
919
  include_image_timestamp (bool, optional): should we include image timestamps in the output for each image?
606
920
  include_exif_data (bool, optional): should we include EXIF data in the output for each image?
607
- skip_image_resizing (bool, optional): whether to skip internal image resizing and rely on external resizing
608
921
  augment (bool, optional): enable image augmentation
609
922
 
610
923
  Returns:
611
924
  dict: dict representing detections on one image,
612
- see the 'images' key in
613
- https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#batch-processing-api-output-format
925
+ see the 'images' key in https://lila.science/megadetector-output-format
614
926
  """
615
927
 
616
928
  if not quiet:
@@ -635,8 +947,9 @@ def process_image(im_file,
635
947
  im_file,
636
948
  detection_threshold=confidence_threshold,
637
949
  image_size=image_size,
638
- skip_image_resizing=skip_image_resizing,
639
- augment=augment)
950
+ augment=augment,
951
+ verbose=verbose)
952
+
640
953
  except Exception as e:
641
954
  if not quiet:
642
955
  print('Image {} cannot be processed. Exception: {}'.format(im_file, e))
@@ -646,6 +959,7 @@ def process_image(im_file,
646
959
  }
647
960
  return result
648
961
 
962
+ # If this image has already been preprocessed
649
963
  if isinstance(image,dict):
650
964
  image = image['img_original_pil']
651
965
 
@@ -661,15 +975,19 @@ def process_image(im_file,
661
975
 
662
976
  return result
663
977
 
664
- # ...def process_image(...)
978
+ # ...def _process_image(...)
665
979
 
666
980
 
667
981
  def _load_custom_class_mapping(class_mapping_filename):
668
982
  """
669
- This is an experimental hack to allow the use of non-MD YOLOv5 models through
670
- the same infrastructure; it disables the code that enforces MDv5-like class lists.
983
+ Allows the use of non-MD models, disables the code that enforces MD-like class lists.
984
+
985
+ Args:
986
+ class_mapping_filename (str): .json file that maps int-strings to strings, or a YOLOv5
987
+ dataset.yaml file.
671
988
 
672
- Should be a .json file that maps int-strings to strings, or a YOLOv5 dataset.yaml file.
989
+ Returns:
990
+ dict: maps class IDs (int-strings) to class names
673
991
  """
674
992
 
675
993
  if class_mapping_filename is None:
@@ -712,7 +1030,8 @@ def load_and_run_detector_batch(model_file,
712
1030
  force_model_download=False,
713
1031
  detector_options=None,
714
1032
  loader_workers=default_loaders,
715
- preprocess_on_image_queue=default_preprocess_on_image_queue):
1033
+ preprocess_on_image_queue=default_preprocess_on_image_queue,
1034
+ batch_size=1):
716
1035
  """
717
1036
  Load a model file and run it on a list of images.
718
1037
 
@@ -748,6 +1067,7 @@ def load_and_run_detector_batch(model_file,
748
1067
  loader_workers (int, optional): number of loaders to use, only relevant when use_image_queue is True
749
1068
  preprocess_on_image_queue (bool, optional): if the image queue is enabled, should it handle
750
1069
  image loading and preprocessing (True), or just image loading (False)?
1070
+ batch_size (int, optional): batch size for GPU processing, automatically set to 1 for CPU processing
751
1071
 
752
1072
  Returns:
753
1073
  results: list of dicts; each dict represents detections on one image
@@ -815,9 +1135,11 @@ def load_and_run_detector_batch(model_file,
815
1135
  force_download=force_model_download,
816
1136
  verbose=verbose)
817
1137
 
818
- print('GPU available: {}'.format(is_gpu_available(model_file)))
1138
+ gpu_available = is_gpu_available(model_file)
1139
+
1140
+ print('GPU available: {}'.format(gpu_available))
819
1141
 
820
- if (n_cores > 1) and is_gpu_available(model_file):
1142
+ if (n_cores > 1) and gpu_available:
821
1143
 
822
1144
  print('Warning: multiple cores requested, but a GPU is available; parallelization across ' + \
823
1145
  'GPUs is not currently supported, defaulting to one GPU')
@@ -836,18 +1158,22 @@ def load_and_run_detector_batch(model_file,
836
1158
  assert len(results) == 0, \
837
1159
  'Using an image queue with results loaded from a checkpoint is not currently supported'
838
1160
  assert n_cores <= 1
839
- results = run_detector_with_image_queue(image_file_names,
840
- model_file,
841
- confidence_threshold,
842
- quiet,
843
- image_size=image_size,
844
- include_image_size=include_image_size,
845
- include_image_timestamp=include_image_timestamp,
846
- include_exif_data=include_exif_data,
847
- augment=augment,
848
- detector_options=detector_options,
849
- loader_workers=loader_workers,
850
- preprocess_on_image_queue=preprocess_on_image_queue)
1161
+
1162
+ # Image queue now supports batch processing
1163
+
1164
+ results = _run_detector_with_image_queue(image_file_names,
1165
+ model_file,
1166
+ confidence_threshold,
1167
+ quiet,
1168
+ image_size=image_size,
1169
+ include_image_size=include_image_size,
1170
+ include_image_timestamp=include_image_timestamp,
1171
+ include_exif_data=include_exif_data,
1172
+ augment=augment,
1173
+ detector_options=detector_options,
1174
+ loader_workers=loader_workers,
1175
+ preprocess_on_image_queue=preprocess_on_image_queue,
1176
+ batch_size=batch_size)
851
1177
 
852
1178
  elif n_cores <= 1:
853
1179
 
@@ -859,38 +1185,81 @@ def load_and_run_detector_batch(model_file,
859
1185
  elapsed = time.time() - start_time
860
1186
  print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
861
1187
 
862
- # This is only used for console reporting, so it's OK that it doesn't
863
- # include images we might have loaded from a previous checkpoint
864
- count = 0
1188
+ if (batch_size > 1) and (not gpu_available):
1189
+ print('Batch size of {} requested, but no GPU is available, using batch size 1'.format(
1190
+ batch_size))
1191
+ batch_size = 1
865
1192
 
866
- for im_file in tqdm(image_file_names):
1193
+ # Filter out already processed images
1194
+ images_to_process = [im_file for im_file in image_file_names
1195
+ if im_file not in already_processed]
867
1196
 
868
- # Will not add additional entries not in the starter checkpoint
869
- if im_file in already_processed:
870
- if not quiet:
871
- print('Bypassing image {}'.format(im_file))
872
- continue
1197
+ if len(images_to_process) != len(image_file_names):
1198
+ print('Bypassing {} images that have already been processed'.format(
1199
+ len(image_file_names) - len(images_to_process)))
873
1200
 
874
- count += 1
1201
+ image_count = 0
875
1202
 
876
- result = process_image(im_file,
877
- detector,
878
- confidence_threshold,
879
- quiet=quiet,
880
- image_size=image_size,
881
- include_image_size=include_image_size,
882
- include_image_timestamp=include_image_timestamp,
883
- include_exif_data=include_exif_data,
884
- augment=augment)
885
- results.append(result)
1203
+ if (batch_size > 1):
886
1204
 
887
- # Write a checkpoint if necessary
888
- if (checkpoint_frequency != -1) and ((count % checkpoint_frequency) == 0):
1205
+ # During testing, randomize the order of images_to_process to help detect
1206
+ # non-deterministic batching issues
1207
+ if randomize_batch_order_during_testing and ('PYTEST_CURRENT_TEST' in os.environ):
1208
+ print('PyTest detected: randomizing batch order')
1209
+ random.seed(int(time.time()))
1210
+ debug_seed = random.randint(0, 2**31 - 1)
1211
+ print('Debug seed: {}'.format(debug_seed))
1212
+ random.seed(debug_seed)
1213
+ random.shuffle(images_to_process)
889
1214
 
890
- print('Writing a new checkpoint after having processed {} images since '
891
- 'last restart'.format(count))
1215
+ # Use batch processing
1216
+ image_batches = _group_into_batches(images_to_process, batch_size)
892
1217
 
893
- _write_checkpoint(checkpoint_path, results)
1218
+ for batch in tqdm(image_batches):
1219
+ batch_results = _process_batch(batch,
1220
+ detector,
1221
+ confidence_threshold,
1222
+ quiet=quiet,
1223
+ image_size=image_size,
1224
+ include_image_size=include_image_size,
1225
+ include_image_timestamp=include_image_timestamp,
1226
+ include_exif_data=include_exif_data,
1227
+ augment=augment)
1228
+
1229
+ results.extend(batch_results)
1230
+ image_count += len(batch)
1231
+
1232
+ # Write a checkpoint if necessary
1233
+ if (checkpoint_frequency != -1) and ((image_count % checkpoint_frequency) == 0):
1234
+ print('Writing a new checkpoint after having processed {} images since '
1235
+ 'last restart'.format(image_count))
1236
+ _write_checkpoint(checkpoint_path, results)
1237
+
1238
+ else:
1239
+
1240
+ # Use non-batch processing
1241
+ for im_file in tqdm(images_to_process):
1242
+
1243
+ image_count += 1
1244
+
1245
+ result = _process_image(im_file,
1246
+ detector,
1247
+ confidence_threshold,
1248
+ quiet=quiet,
1249
+ image_size=image_size,
1250
+ include_image_size=include_image_size,
1251
+ include_image_timestamp=include_image_timestamp,
1252
+ include_exif_data=include_exif_data,
1253
+ augment=augment)
1254
+ results.append(result)
1255
+
1256
+ # Write a checkpoint if necessary
1257
+ if (checkpoint_frequency != -1) and ((image_count % checkpoint_frequency) == 0):
1258
+ print('Writing a new checkpoint after having processed {} images since '
1259
+ 'last restart'.format(image_count))
1260
+ _write_checkpoint(checkpoint_path, results)
1261
+
1262
+ # ...if the batch size is > 1
894
1263
 
895
1264
  else:
896
1265
 
@@ -910,7 +1279,7 @@ def load_and_run_detector_batch(model_file,
910
1279
  len(already_processed),n_images_all))
911
1280
 
912
1281
  # Divide images into chunks; we'll send one chunk to each worker process
913
- image_batches = list(_chunks_by_number_of_chunks(image_file_names, n_cores))
1282
+ image_chunks = list(_chunks_by_number_of_chunks(image_file_names, n_cores))
914
1283
 
915
1284
  pool = None
916
1285
  try:
@@ -930,7 +1299,7 @@ def load_and_run_detector_batch(model_file,
930
1299
  checkpoint_queue, results), daemon=True)
931
1300
  checkpoint_thread.start()
932
1301
 
933
- pool.map(partial(process_images,
1302
+ pool.map(partial(_process_images,
934
1303
  detector=detector,
935
1304
  confidence_threshold=confidence_threshold,
936
1305
  use_image_queue=False,
@@ -942,7 +1311,7 @@ def load_and_run_detector_batch(model_file,
942
1311
  include_exif_data=include_exif_data,
943
1312
  augment=augment,
944
1313
  detector_options=detector_options),
945
- image_batches)
1314
+ image_chunks)
946
1315
 
947
1316
  checkpoint_queue.put(None)
948
1317
 
@@ -950,7 +1319,7 @@ def load_and_run_detector_batch(model_file,
950
1319
 
951
1320
  # Multprocessing is enabled, but checkpointing is not
952
1321
 
953
- new_results = pool.map(partial(process_images,
1322
+ new_results = pool.map(partial(_process_images,
954
1323
  detector=detector,
955
1324
  confidence_threshold=confidence_threshold,
956
1325
  use_image_queue=False,
@@ -962,7 +1331,7 @@ def load_and_run_detector_batch(model_file,
962
1331
  include_exif_data=include_exif_data,
963
1332
  augment=augment,
964
1333
  detector_options=detector_options),
965
- image_batches)
1334
+ image_chunks)
966
1335
 
967
1336
  new_results = list(itertools.chain.from_iterable(new_results))
968
1337
 
@@ -1066,7 +1435,7 @@ def write_results_to_file(results,
1066
1435
  """
1067
1436
  Writes list of detection results to JSON output file. Format matches:
1068
1437
 
1069
- https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#batch-processing-api-output-format
1438
+ https://lila.science/megadetector-output-format
1070
1439
 
1071
1440
  Args:
1072
1441
  results (list): list of dict, each dict represents detections on one image
@@ -1109,7 +1478,7 @@ def write_results_to_file(results,
1109
1478
 
1110
1479
  info = {
1111
1480
  'detection_completion_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
1112
- 'format_version': '1.4'
1481
+ 'format_version': '1.5'
1113
1482
  }
1114
1483
 
1115
1484
  if detector_file is not None:
@@ -1144,9 +1513,16 @@ def write_results_to_file(results,
1144
1513
 
1145
1514
  # Sort detections in descending order by confidence; not required by the format, but
1146
1515
  # convenient for consistency
1147
- for r in results:
1148
- if ('detections' in r) and (r['detections'] is not None):
1149
- r['detections'] = sort_list_of_dicts_by_key(r['detections'], 'conf', reverse=True)
1516
+ for im in results:
1517
+ if ('detections' in im) and (im['detections'] is not None):
1518
+ im['detections'] = sort_list_of_dicts_by_key(im['detections'], 'conf', reverse=True)
1519
+
1520
+ for im in results:
1521
+ if 'failure' in im:
1522
+ if 'detections' in im:
1523
+ assert im['detections'] is None, 'Illegal failure/detection combination'
1524
+ else:
1525
+ im['detections'] = None
1150
1526
 
1151
1527
  final_output = {
1152
1528
  'images': results,
@@ -1414,6 +1790,11 @@ def main(): # noqa
1414
1790
  metavar='KEY=VALUE',
1415
1791
  default='',
1416
1792
  help='Detector-specific options, as a space-separated list of key-value pairs')
1793
+ parser.add_argument(
1794
+ '--batch_size',
1795
+ type=int,
1796
+ default=1,
1797
+ help='Batch size for GPU inference (default 1). CPU inference will ignore this and use batch_size=1.')
1417
1798
 
1418
1799
  if len(sys.argv[1:]) == 0:
1419
1800
  parser.print_help()
@@ -1660,7 +2041,8 @@ def main(): # noqa
1660
2041
  force_model_download=False,
1661
2042
  detector_options=detector_options,
1662
2043
  loader_workers=args.loader_workers,
1663
- preprocess_on_image_queue=args.preprocess_on_image_queue)
2044
+ preprocess_on_image_queue=args.preprocess_on_image_queue,
2045
+ batch_size=args.batch_size)
1664
2046
 
1665
2047
  elapsed = time.time() - start_time
1666
2048
  images_per_second = len(results) / elapsed