megadetector 10.0.9__py3-none-any.whl → 10.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (84) hide show
  1. megadetector/data_management/animl_to_md.py +5 -2
  2. megadetector/data_management/cct_json_utils.py +4 -2
  3. megadetector/data_management/cct_to_md.py +5 -4
  4. megadetector/data_management/cct_to_wi.py +5 -1
  5. megadetector/data_management/coco_to_yolo.py +3 -2
  6. megadetector/data_management/databases/combine_coco_camera_traps_files.py +4 -4
  7. megadetector/data_management/databases/integrity_check_json_db.py +2 -2
  8. megadetector/data_management/databases/subset_json_db.py +0 -3
  9. megadetector/data_management/generate_crops_from_cct.py +6 -4
  10. megadetector/data_management/get_image_sizes.py +5 -35
  11. megadetector/data_management/labelme_to_coco.py +10 -6
  12. megadetector/data_management/labelme_to_yolo.py +19 -28
  13. megadetector/data_management/lila/create_lila_test_set.py +22 -2
  14. megadetector/data_management/lila/generate_lila_per_image_labels.py +7 -5
  15. megadetector/data_management/lila/lila_common.py +2 -2
  16. megadetector/data_management/lila/test_lila_metadata_urls.py +0 -1
  17. megadetector/data_management/ocr_tools.py +6 -10
  18. megadetector/data_management/read_exif.py +69 -13
  19. megadetector/data_management/remap_coco_categories.py +1 -1
  20. megadetector/data_management/remove_exif.py +10 -5
  21. megadetector/data_management/rename_images.py +20 -13
  22. megadetector/data_management/resize_coco_dataset.py +10 -4
  23. megadetector/data_management/speciesnet_to_md.py +3 -3
  24. megadetector/data_management/yolo_output_to_md_output.py +3 -1
  25. megadetector/data_management/yolo_to_coco.py +28 -19
  26. megadetector/detection/change_detection.py +26 -18
  27. megadetector/detection/process_video.py +1 -1
  28. megadetector/detection/pytorch_detector.py +5 -5
  29. megadetector/detection/run_detector.py +34 -10
  30. megadetector/detection/run_detector_batch.py +60 -42
  31. megadetector/detection/run_inference_with_yolov5_val.py +3 -1
  32. megadetector/detection/run_md_and_speciesnet.py +282 -110
  33. megadetector/detection/run_tiled_inference.py +7 -7
  34. megadetector/detection/tf_detector.py +4 -6
  35. megadetector/detection/video_utils.py +9 -6
  36. megadetector/postprocessing/add_max_conf.py +4 -4
  37. megadetector/postprocessing/categorize_detections_by_size.py +3 -2
  38. megadetector/postprocessing/classification_postprocessing.py +19 -21
  39. megadetector/postprocessing/combine_batch_outputs.py +3 -2
  40. megadetector/postprocessing/compare_batch_results.py +49 -27
  41. megadetector/postprocessing/convert_output_format.py +8 -6
  42. megadetector/postprocessing/create_crop_folder.py +13 -4
  43. megadetector/postprocessing/generate_csv_report.py +22 -8
  44. megadetector/postprocessing/load_api_results.py +8 -4
  45. megadetector/postprocessing/md_to_coco.py +2 -3
  46. megadetector/postprocessing/md_to_labelme.py +12 -8
  47. megadetector/postprocessing/md_to_wi.py +2 -1
  48. megadetector/postprocessing/merge_detections.py +4 -6
  49. megadetector/postprocessing/postprocess_batch_results.py +4 -3
  50. megadetector/postprocessing/remap_detection_categories.py +6 -3
  51. megadetector/postprocessing/render_detection_confusion_matrix.py +18 -10
  52. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  53. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +5 -3
  54. megadetector/postprocessing/separate_detections_into_folders.py +10 -4
  55. megadetector/postprocessing/subset_json_detector_output.py +1 -1
  56. megadetector/postprocessing/top_folders_to_bottom.py +22 -7
  57. megadetector/postprocessing/validate_batch_results.py +1 -1
  58. megadetector/taxonomy_mapping/map_new_lila_datasets.py +59 -3
  59. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +1 -1
  60. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +26 -17
  61. megadetector/taxonomy_mapping/species_lookup.py +51 -2
  62. megadetector/utils/ct_utils.py +9 -4
  63. megadetector/utils/directory_listing.py +3 -0
  64. megadetector/utils/extract_frames_from_video.py +4 -0
  65. megadetector/utils/gpu_test.py +6 -6
  66. megadetector/utils/md_tests.py +21 -21
  67. megadetector/utils/path_utils.py +171 -36
  68. megadetector/utils/split_locations_into_train_val.py +0 -4
  69. megadetector/utils/string_utils.py +21 -0
  70. megadetector/utils/url_utils.py +5 -3
  71. megadetector/utils/wi_platform_utils.py +168 -24
  72. megadetector/utils/wi_taxonomy_utils.py +38 -8
  73. megadetector/utils/write_html_image_list.py +1 -2
  74. megadetector/visualization/plot_utils.py +31 -19
  75. megadetector/visualization/render_images_with_thumbnails.py +3 -0
  76. megadetector/visualization/visualization_utils.py +18 -7
  77. megadetector/visualization/visualize_db.py +9 -26
  78. megadetector/visualization/visualize_detector_output.py +1 -0
  79. megadetector/visualization/visualize_video_output.py +14 -2
  80. {megadetector-10.0.9.dist-info → megadetector-10.0.11.dist-info}/METADATA +1 -1
  81. {megadetector-10.0.9.dist-info → megadetector-10.0.11.dist-info}/RECORD +84 -84
  82. {megadetector-10.0.9.dist-info → megadetector-10.0.11.dist-info}/WHEEL +0 -0
  83. {megadetector-10.0.9.dist-info → megadetector-10.0.11.dist-info}/licenses/LICENSE +0 -0
  84. {megadetector-10.0.9.dist-info → megadetector-10.0.11.dist-info}/top_level.txt +0 -0
@@ -24,10 +24,14 @@ from multiprocessing.pool import Pool, ThreadPool
24
24
  from functools import partial
25
25
 
26
26
  from megadetector.utils.path_utils import insert_before_extension
27
+ from megadetector.utils.path_utils import path_join
28
+
27
29
  from megadetector.utils.ct_utils import split_list_into_n_chunks
28
30
  from megadetector.utils.ct_utils import invert_dictionary
29
31
  from megadetector.utils.ct_utils import compare_values_nan_equal
30
32
 
33
+ from megadetector.utils.string_utils import is_int
34
+
31
35
  from megadetector.utils.wi_taxonomy_utils import is_valid_prediction_string
32
36
  from megadetector.utils.wi_taxonomy_utils import no_cv_result_prediction_string
33
37
  from megadetector.utils.wi_taxonomy_utils import blank_prediction_string
@@ -68,7 +72,7 @@ def read_sequences_from_download_bundle(download_folder):
68
72
  assert len(sequence_list_files) == 1, \
69
73
  'Could not find sequences.csv in {}'.format(download_folder)
70
74
 
71
- sequence_list_file = os.path.join(download_folder,sequence_list_files[0])
75
+ sequence_list_file = path_join(download_folder,sequence_list_files[0])
72
76
 
73
77
  df = pd.read_csv(sequence_list_file)
74
78
  sequence_records = df.to_dict('records')
@@ -104,7 +108,7 @@ def read_images_from_download_bundle(download_folder):
104
108
  image_list_files = \
105
109
  [fn for fn in image_list_files if fn.startswith('images_') and fn.endswith('.csv')]
106
110
  image_list_files = \
107
- [os.path.join(download_folder,fn) for fn in image_list_files]
111
+ [path_join(download_folder,fn) for fn in image_list_files]
108
112
  print('Found {} image list files'.format(len(image_list_files)))
109
113
 
110
114
 
@@ -118,7 +122,7 @@ def read_images_from_download_bundle(download_folder):
118
122
  print('Reading images from list file {}'.format(
119
123
  os.path.basename(image_list_file)))
120
124
 
121
- df = pd.read_csv(image_list_file)
125
+ df = pd.read_csv(image_list_file,low_memory=False)
122
126
 
123
127
  # i_row = 0; row = df.iloc[i_row]
124
128
  for i_row,row in tqdm(df.iterrows(),total=len(df)):
@@ -203,11 +207,91 @@ def find_images_in_identify_tab(download_folder_with_identify,download_folder_ex
203
207
  # ...def find_images_in_identify_tab(...)
204
208
 
205
209
 
206
- def write_download_commands(image_records_to_download,
210
+ def write_prefix_download_command(image_records,
211
+ download_dir_base,
212
+ force_download=False,
213
+ download_command_file=None):
214
+ """
215
+ Write a .sh script to download all images (using gcloud) from the longest common URL
216
+ prefix in the images represented in [image_records].
217
+
218
+ Args:
219
+ image_records (list of dict): list of dicts with at least the field 'location'.
220
+ Can also be a dict whose values are lists of record dicts.
221
+ download_dir_base (str): local destination folder
222
+ force_download (bool, optional): overwrite existing files
223
+ download_command_file (str, optional): path of the .sh script we should write, defaults
224
+ to "download_wi_images_with_prefix.sh" in the destination folder.
225
+ """
226
+
227
+ ##%% Input validation
228
+
229
+ # If a dict is provided, assume it maps image GUIDs to lists of records, flatten to a list
230
+ if isinstance(image_records,dict):
231
+ all_image_records = []
232
+ for k in image_records:
233
+ records_this_image = image_records[k]
234
+ all_image_records.extend(records_this_image)
235
+ image_records = all_image_records
236
+
237
+ assert isinstance(image_records,list), \
238
+ 'Illegal image record list format {}'.format(type(image_records))
239
+ assert isinstance(image_records[0],dict), \
240
+ 'Illegal image record format {}'.format(type(image_records[0]))
241
+
242
+ urls = [r['location'] for r in image_records]
243
+
244
+ # "urls" is a list of URLs starting with gs://. Find the highest-level folder
245
+ # that is common to all URLs in the list. For example, if the list is:
246
+ #
247
+ # gs://a/b/c
248
+ # gs://a/b/d
249
+ #
250
+ # The result should be:
251
+ #
252
+ # gs://a/b
253
+ common_prefix = os.path.commonprefix(urls)
254
+
255
+ # Remove the gs:// prefix if it's still there
256
+ if common_prefix.startswith('gs://'):
257
+ common_prefix = common_prefix[len('gs://'):]
258
+
259
+ # Ensure the common prefix ends with a '/' if it's not empty
260
+ if (len(common_prefix) > 0) and (not common_prefix.endswith('/')):
261
+ common_prefix = os.path.dirname(common_prefix) + '/'
262
+
263
+ print('Longest common prefix: {}'.format(common_prefix))
264
+
265
+ if download_command_file is None:
266
+ download_command_file = \
267
+ path_join(download_dir_base,'download_wi_images_with_prefix.sh')
268
+
269
+ os.makedirs(download_dir_base,exist_ok=True)
270
+
271
+ with open(download_command_file,'w',newline='\n') as f:
272
+ # The --no-clobber flag prevents overwriting existing files
273
+ # The -r flag is for recursive download
274
+ # The gs:// prefix is added back for the gcloud command
275
+ no_clobber_string = ''
276
+ if not force_download:
277
+ no_clobber_string = '--no-clobber'
278
+
279
+ cmd = 'gcloud storage cp -r {} "gs://{}" "{}"'.format(
280
+ no_clobber_string,common_prefix,download_dir_base)
281
+ print('Writing download command:\n{}'.format(cmd))
282
+ f.write(cmd + '\n')
283
+
284
+ print('Download script written to {}'.format(download_command_file))
285
+
286
+ # ...def write_prefix_download_command(...)
287
+
288
+
289
+ def write_download_commands(image_records,
207
290
  download_dir_base,
208
291
  force_download=False,
209
292
  n_download_workers=25,
210
- download_command_file_base=None):
293
+ download_command_file_base=None,
294
+ image_flattening='deployment'):
211
295
  """
212
296
  Given a list of dicts with at least the field 'location' (a gs:// URL), prepare a set of "gcloud
213
297
  storage" commands to download images, and write those to a series of .sh scripts, along with one
@@ -215,10 +299,9 @@ def write_download_commands(image_records_to_download,
215
299
 
216
300
  gcloud commands will use relative paths.
217
301
 
218
- image_records_to_download can also be a dict mapping IDs to lists of records.
219
-
220
302
  Args:
221
- image_records_to_download (list of dict): list of dicts with at least the field 'location'
303
+ image_records (list of dict): list of dicts with at least the field 'location'.
304
+ Can also be a dict whose values are lists of record dicts.
222
305
  download_dir_base (str): local destination folder
223
306
  force_download (bool, optional): include gs commands even if the target file exists
224
307
  n_download_workers (int, optional): number of scripts to write (that's our hacky way
@@ -226,42 +309,103 @@ def write_download_commands(image_records_to_download,
226
309
  download_command_file_base (str, optional): path of the .sh script we should write, defaults
227
310
  to "download_wi_images.sh" in the destination folder. Individual worker scripts will
228
311
  have a number added, e.g. download_wi_images_00.sh.
312
+ image_flattening (str, optional): if 'none', relative paths will be preserved
313
+ representing the entire URL for each image. Can be 'guid' (just download to
314
+ [GUID].JPG) or 'deployment' (download to [deployment]/[GUID].JPG).
229
315
  """
230
316
 
231
- if isinstance(image_records_to_download,dict):
317
+ ##%% Input validation
232
318
 
319
+ # If a dict is provided, assume it maps image GUIDs to lists of records, flatten to a list
320
+ if isinstance(image_records,dict):
233
321
  all_image_records = []
234
- for k in image_records_to_download:
235
- records_this_image = image_records_to_download[k]
322
+ for k in image_records:
323
+ records_this_image = image_records[k]
236
324
  all_image_records.extend(records_this_image)
237
- return write_download_commands(all_image_records,
238
- download_dir_base=download_dir_base,
239
- force_download=force_download,
240
- n_download_workers=n_download_workers,
241
- download_command_file_base=download_command_file_base)
325
+ image_records = all_image_records
326
+
327
+ assert isinstance(image_records,list), \
328
+ 'Illegal image record list format {}'.format(type(image_records))
329
+ assert isinstance(image_records[0],dict), \
330
+ 'Illegal image record format {}'.format(type(image_records[0]))
331
+
332
+
333
+ ##%% Map URLs to relative paths
334
+
335
+ # URLs look like:
336
+ #
337
+ # gs://145625555_2004881_2323_name__main/deployment/2241000/prod/directUpload/5fda0ddd-511e-46ca-95c1-302b3c71f8ea.JPG
338
+ if image_flattening is None:
339
+ image_flattening = 'none'
340
+ image_flattening = image_flattening.lower().strip()
341
+
342
+ assert image_flattening in ('none','guid','deployment'), \
343
+ 'Illegal image flattening strategy {}'.format(image_flattening)
344
+
345
+ url_to_relative_path = {}
346
+
347
+ for image_record in image_records:
348
+
349
+ url = image_record['location']
350
+ assert url.startswith('gs://'), 'Illegal URL {}'.format(url)
351
+
352
+ relative_path = None
353
+
354
+ if image_flattening == 'none':
355
+ relative_path = url.replace('gs://','')
356
+ elif image_flattening == 'guid':
357
+ relative_path = url.split('/')[-1]
358
+ else:
359
+ assert image_flattening == 'deployment'
360
+ tokens = url.split('/')
361
+ found_deployment_id = False
362
+ for i_token,token in enumerate(tokens):
363
+ if token == 'deployment':
364
+ assert i_token < (len(tokens)-1)
365
+ deployment_id_string = tokens[i_token + 1]
366
+ deployment_id_string = deployment_id_string.replace('_thumb','')
367
+ assert is_int(deployment_id_string), \
368
+ 'Illegal deployment ID {}'.format(deployment_id_string)
369
+ image_id = url.split('/')[-1]
370
+ relative_path = deployment_id_string + '/' + image_id
371
+ found_deployment_id = True
372
+ break
373
+ assert found_deployment_id, \
374
+ 'Could not find deployment ID in record {}'.format(str(image_record))
375
+
376
+ assert relative_path is not None
377
+
378
+ if url in url_to_relative_path:
379
+ assert url_to_relative_path[url] == relative_path, \
380
+ 'URL path mapping error'
381
+ else:
382
+ url_to_relative_path[url] = relative_path
383
+
384
+ # ...for each image record
385
+
242
386
 
243
387
  ##%% Make list of gcloud storage commands
244
388
 
245
389
  if download_command_file_base is None:
246
- download_command_file_base = os.path.join(download_dir_base,'download_wi_images.sh')
390
+ download_command_file_base = path_join(download_dir_base,'download_wi_images.sh')
247
391
 
248
392
  commands = []
249
393
  skipped_urls = []
250
394
  downloaded_urls = set()
251
395
 
252
- # image_record = image_records_to_download[0]
253
- for image_record in tqdm(image_records_to_download):
396
+ # image_record = image_records[0]
397
+ for image_record in tqdm(image_records):
254
398
 
255
399
  url = image_record['location']
256
400
  if url in downloaded_urls:
257
401
  continue
258
402
 
259
- assert url.startswith('gs://')
403
+ assert url.startswith('gs://'), 'Illegal URL {}'.format(url)
260
404
 
261
- relative_path = url.replace('gs://','')
262
- abs_path = os.path.join(download_dir_base,relative_path)
405
+ relative_path = url_to_relative_path[url]
406
+ abs_path = path_join(download_dir_base,relative_path)
263
407
 
264
- # Skip files that already exist
408
+ # Optionally skip files that already exist
265
409
  if (not force_download) and (os.path.isfile(abs_path)):
266
410
  skipped_urls.append(url)
267
411
  continue
@@ -271,7 +415,7 @@ def write_download_commands(image_records_to_download,
271
415
  commands.append(command)
272
416
 
273
417
  print('Generated {} commands for {} image records'.format(
274
- len(commands),len(image_records_to_download)))
418
+ len(commands),len(image_records)))
275
419
 
276
420
  print('Skipped {} URLs'.format(len(skipped_urls)))
277
421
 
@@ -311,7 +311,8 @@ def taxonomy_info_to_taxonomy_string(taxonomy_info, include_taxon_id_and_common_
311
311
  def generate_whole_image_detections_for_classifications(classifications_json_file,
312
312
  detections_json_file,
313
313
  ensemble_json_file=None,
314
- ignore_blank_classifications=True):
314
+ ignore_blank_classifications=True,
315
+ verbose=True):
315
316
  """
316
317
  Given a set of classification results in SpeciesNet format that were likely run on
317
318
  already-cropped images, generate a file of [fake] detections in SpeciesNet format in which each
@@ -324,6 +325,7 @@ def generate_whole_image_detections_for_classifications(classifications_json_fil
324
325
  and classfications
325
326
  ignore_blank_classifications (bool, optional): use non-top classifications when
326
327
  the top classification is "blank" or "no CV result"
328
+ verbose (bool, optional): enable additional debug output
327
329
 
328
330
  Returns:
329
331
  dict: the contents of [detections_json_file]
@@ -336,16 +338,37 @@ def generate_whole_image_detections_for_classifications(classifications_json_fil
336
338
  output_predictions = []
337
339
  ensemble_predictions = []
338
340
 
339
- # prediction = predictions[0]
340
- for prediction in predictions:
341
+ # i_prediction = 0; prediction = predictions[i_prediction]
342
+ for i_prediction,prediction in enumerate(predictions):
341
343
 
342
344
  output_prediction = {}
343
345
  output_prediction['filepath'] = prediction['filepath']
344
346
  i_score = 0
347
+
345
348
  if ignore_blank_classifications:
349
+
346
350
  while (prediction['classifications']['classes'][i_score] in \
347
351
  (blank_prediction_string,no_cv_result_prediction_string)):
352
+
348
353
  i_score += 1
354
+ if (i_score >= len(prediction['classifications']['classes'])):
355
+
356
+ if verbose:
357
+
358
+ print('Ignoring blank classifications, but ' + \
359
+ 'image {} has no non-blank values'.format(
360
+ i_prediction))
361
+
362
+ # Just use the first one
363
+ i_score = 0
364
+ break
365
+
366
+ # ...if we passed the last prediction
367
+
368
+ # ...iterate over classes within this prediction
369
+
370
+ # ...if we're supposed to ignore blank classifications
371
+
349
372
  top_classification = prediction['classifications']['classes'][i_score]
350
373
  top_classification_score = prediction['classifications']['scores'][i_score]
351
374
  if is_animal_classification(top_classification):
@@ -450,8 +473,8 @@ def generate_md_results_from_predictions_json(predictions_json_file,
450
473
 
451
474
  # Round floating-point values (confidence scores, coordinates) to a
452
475
  # reasonable number of decimal places
453
- if max_decimals is not None and max_decimals > 0:
454
- round_floats_in_nested_dict(predictions)
476
+ if (max_decimals is not None) and (max_decimals > 0):
477
+ round_floats_in_nested_dict(predictions, decimal_places=max_decimals)
455
478
 
456
479
  predictions = predictions['predictions']
457
480
  assert isinstance(predictions,list)
@@ -714,7 +737,9 @@ def generate_predictions_json_from_md_results(md_results_file,
714
737
 
715
738
  # ...for each image
716
739
 
717
- os.makedirs(os.path.dirname(predictions_json_file),exist_ok=True)
740
+ output_dir = os.path.dirname(predictions_json_file)
741
+ if len(output_dir) > 0:
742
+ os.makedirs(output_dir,exist_ok=True)
718
743
  with open(predictions_json_file,'w') as f:
719
744
  json.dump(output_dict,f,indent=1)
720
745
 
@@ -754,6 +779,7 @@ def generate_instances_json_from_folder(folder,
754
779
 
755
780
  assert os.path.isdir(folder)
756
781
 
782
+ print('Enumerating images in {}'.format(folder))
757
783
  image_files_abs = find_images(folder,recursive=True,return_relative_paths=False)
758
784
 
759
785
  if tokens_to_ignore is not None:
@@ -787,7 +813,9 @@ def generate_instances_json_from_folder(folder,
787
813
  to_return = {'instances':instances}
788
814
 
789
815
  if output_file is not None:
790
- os.makedirs(os.path.dirname(output_file),exist_ok=True)
816
+ output_dir = os.path.dirname(output_file)
817
+ if len(output_dir) > 0:
818
+ os.makedirs(output_dir,exist_ok=True)
791
819
  with open(output_file,'w') as f:
792
820
  json.dump(to_return,f,indent=1)
793
821
 
@@ -869,7 +897,9 @@ def merge_prediction_json_files(input_prediction_files,output_prediction_file):
869
897
 
870
898
  output_dict = {'predictions':predictions}
871
899
 
872
- os.makedirs(os.path.dirname(output_prediction_file),exist_ok=True)
900
+ output_dir = os.path.dirname(output_prediction_file)
901
+ if len(output_dir) > 0:
902
+ os.makedirs(output_dir,exist_ok=True)
873
903
  with open(output_prediction_file,'w') as f:
874
904
  json.dump(output_dict,f,indent=1)
875
905
 
@@ -110,7 +110,6 @@ def write_html_image_list(filename=None,images=None,options=None):
110
110
  if 'linkTarget' not in image_info:
111
111
  image_info['linkTarget'] = ''
112
112
  if 'textStyle' not in image_info:
113
- text_style = options['defaultTextStyle']
114
113
  image_info['textStyle'] = options['defaultTextStyle']
115
114
  images[i_image] = image_info
116
115
 
@@ -185,7 +184,7 @@ def write_html_image_list(filename=None,images=None,options=None):
185
184
  if len(options['pageTitle']) > 0:
186
185
  title_string = '<title>{}</title>'.format(options['pageTitle'])
187
186
 
188
- f_html.write('<html>{}<body>\n'.format(title_string))
187
+ f_html.write('<html><head>{}</head><body>\n'.format(title_string))
189
188
 
190
189
  f_html.write(options['headerHtml'])
191
190
 
@@ -126,19 +126,16 @@ def plot_precision_recall_curve(precisions,
126
126
  ax.step(recalls, precisions, color='b', alpha=0.2, where='post')
127
127
  ax.fill_between(recalls, precisions, alpha=0.2, color='b', step='post')
128
128
 
129
- try:
130
- ax.set(x_label='Recall', y_label='Precision', title=title)
131
- ax.set(x_lim=xlim, y_lim=ylim)
132
- #
133
- except Exception:
134
- ax.set_xlabel('Recall')
135
- ax.set_ylabel('Precision')
136
- ax.set_title(title)
137
- ax.set_xlim(xlim[0],xlim[1])
138
- ax.set_ylim(ylim[0],ylim[1])
129
+ ax.set_xlabel('Recall')
130
+ ax.set_ylabel('Precision')
131
+ ax.set_title(title)
132
+ ax.set_xlim(xlim[0],xlim[1])
133
+ ax.set_ylim(ylim[0],ylim[1])
139
134
 
140
135
  return fig
141
136
 
137
+ # ...def plot_precision_recall_curve(...)
138
+
142
139
 
143
140
  def plot_stacked_bar_chart(data,
144
141
  series_labels=None,
@@ -174,17 +171,21 @@ def plot_stacked_bar_chart(data,
174
171
 
175
172
  # stacked bar charts are made with each segment starting from a y position
176
173
  cumulative_size = np.zeros(num_columns)
177
- for i, row_data in enumerate(data):
178
- ax.bar(ind, row_data, bottom=cumulative_size, label=series_labels[i],
179
- color=colors[i])
174
+ for i_row, row_data in enumerate(data):
175
+ if series_labels is None:
176
+ label = 'series_{}'.format(str(i_row).zfill(2))
177
+ else:
178
+ label = series_labels[i_row]
179
+ ax.bar(ind, row_data, bottom=cumulative_size, label=label,
180
+ color=colors[i_row])
180
181
  cumulative_size += row_data
181
182
 
182
- if col_labels and len(col_labels) < 25:
183
+ if (col_labels is not None) and (len(col_labels) < 25):
183
184
  ax.set_xticks(ind)
184
185
  ax.set_xticklabels(col_labels, rotation=90)
185
- elif col_labels:
186
+ elif (col_labels is not None):
186
187
  ax.set_xticks(list(range(0, len(col_labels), 20)))
187
- ax.set_xticklabels(col_labels, rotation=90)
188
+ ax.set_xticklabels(col_labels[::20], rotation=90)
188
189
 
189
190
  if x_label is not None:
190
191
  ax.set_xlabel(x_label)
@@ -202,6 +203,8 @@ def plot_stacked_bar_chart(data,
202
203
 
203
204
  return fig
204
205
 
206
+ # ...def plot_stacked_bar_chart(...)
207
+
205
208
 
206
209
  def calibration_ece(true_scores, pred_scores, num_bins):
207
210
  r"""
@@ -245,10 +248,17 @@ def calibration_ece(true_scores, pred_scores, num_bins):
245
248
  ece = np.abs(accs - confs) @ weights
246
249
  return accs, confs, ece
247
250
 
251
+ # ...def calibration_ece(...)
252
+
248
253
 
249
- def plot_calibration_curve(true_scores, pred_scores, num_bins,
250
- name='calibration', plot_perf=True, plot_hist=True,
251
- ax=None, **fig_kwargs):
254
+ def plot_calibration_curve(true_scores,
255
+ pred_scores,
256
+ num_bins,
257
+ name='calibration',
258
+ plot_perf=True,
259
+ plot_hist=True,
260
+ ax=None,
261
+ **fig_kwargs):
252
262
  """
253
263
  Plots a calibration curve.
254
264
 
@@ -295,3 +305,5 @@ def plot_calibration_curve(true_scores, pred_scores, num_bins,
295
305
  fig.legend(loc='upper left', bbox_to_anchor=(0.15, 0.85))
296
306
 
297
307
  return ax.figure
308
+
309
+ # ...def plot_calibration_curve(...)
@@ -185,6 +185,9 @@ def render_images_with_thumbnails(
185
185
  # ...for each crop
186
186
 
187
187
  # Write output image to disk
188
+ parent_dir = os.path.dirname(output_image_filename)
189
+ if len(parent_dir) > 0:
190
+ os.makedirs(parent_dir,exist_ok=True)
188
191
  output_image.save(output_image_filename)
189
192
 
190
193
  # ...def render_images_with_thumbnails(...)
@@ -1272,7 +1272,7 @@ def gray_scale_fraction(image,crop_size=(0.1,0.1)):
1272
1272
  if crop_size[0] > 0 or crop_size[1] > 0:
1273
1273
 
1274
1274
  assert (crop_size[0] + crop_size[1]) < 1.0, \
1275
- print('Illegal crop size: {}'.format(str(crop_size)))
1275
+ 'Illegal crop size: {}'.format(str(crop_size))
1276
1276
 
1277
1277
  top_crop_pixels = int(image.height * crop_size[0])
1278
1278
  bottom_crop_pixels = int(image.height * crop_size[1])
@@ -1391,7 +1391,9 @@ def _resize_absolute_image(input_output_files,
1391
1391
  status = 'error'
1392
1392
  error = str(e)
1393
1393
 
1394
- return {'input_fn':input_fn_abs,'output_fn':output_fn_abs,status:'status',
1394
+ return {'input_fn':input_fn_abs,
1395
+ 'output_fn':output_fn_abs,
1396
+ 'status':status,
1395
1397
  'error':error}
1396
1398
 
1397
1399
  # ..._resize_absolute_image(...)
@@ -1460,6 +1462,7 @@ def resize_images(input_file_to_output_file,
1460
1462
  pool = None
1461
1463
 
1462
1464
  try:
1465
+
1463
1466
  if pool_type == 'thread':
1464
1467
  pool = ThreadPool(n_workers); poolstring = 'threads'
1465
1468
  else:
@@ -1477,10 +1480,13 @@ def resize_images(input_file_to_output_file,
1477
1480
  quality=quality)
1478
1481
 
1479
1482
  results = list(tqdm(pool.imap(p, input_output_file_pairs),total=len(input_output_file_pairs)))
1483
+
1480
1484
  finally:
1481
- pool.close()
1482
- pool.join()
1483
- print("Pool closed and joined for image resizing")
1485
+
1486
+ if pool is not None:
1487
+ pool.close()
1488
+ pool.join()
1489
+ print('Pool closed and joined for image resizing')
1484
1490
 
1485
1491
  return results
1486
1492
 
@@ -1680,8 +1686,13 @@ def parallel_get_image_sizes(filenames,
1680
1686
  else:
1681
1687
  pool = Pool(n_workers)
1682
1688
 
1683
- results = list(tqdm(pool.imap(
1684
- partial(get_image_size,verbose=verbose),filenames), total=len(filenames)))
1689
+ try:
1690
+ results = list(tqdm(pool.imap(
1691
+ partial(get_image_size,verbose=verbose),filenames), total=len(filenames)))
1692
+ finally:
1693
+ pool.close()
1694
+ pool.join()
1695
+ print('Pool closed and joined for image size retrieval')
1685
1696
 
1686
1697
  assert len(filenames) == len(results), 'Internal error in parallel_get_image_sizes'
1687
1698
 
@@ -102,10 +102,6 @@ class DbVizOptions:
102
102
  #: :meta private:
103
103
  self.multiple_categories_tag = '*multiple*'
104
104
 
105
- #: We sometimes flatten image directories by replacing a path separator with
106
- #: another character. Leave blank for the typical case where this isn't necessary.
107
- self.pathsep_replacement = '' # '~'
108
-
109
105
  #: Parallelize rendering across multiple workers
110
106
  self.parallelize_rendering = False
111
107
 
@@ -141,24 +137,12 @@ class DbVizOptions:
141
137
  self.confidence_threshold = None
142
138
 
143
139
 
144
- #%% Helper functions
145
-
146
- def _image_filename_to_path(image_file_name, image_base_dir, pathsep_replacement=''):
147
- """
148
- Translates the file name in an image entry in the json database to a path, possibly doing
149
- some manipulation of path separators.
150
- """
151
-
152
- if len(pathsep_replacement) > 0:
153
- image_file_name = os.path.normpath(image_file_name).replace(os.pathsep,pathsep_replacement)
154
- return os.path.join(image_base_dir, image_file_name)
155
-
156
-
157
140
  #%% Core functions
158
141
 
159
142
  def visualize_db(db_path, output_dir, image_base_dir, options=None):
160
143
  """
161
- Writes images and html to output_dir to visualize the annotations in a .json file.
144
+ Writes images and html to output_dir to visualize the images and annotations in a
145
+ COCO-formatted .json file.
162
146
 
163
147
  Args:
164
148
  db_path (str or dict): the .json filename to load, or a previously-loaded database
@@ -176,9 +160,11 @@ def visualize_db(db_path, output_dir, image_base_dir, options=None):
176
160
 
177
161
  # Consistency checking for fields with specific format requirements
178
162
 
179
- # This should be a list, but if someone specifies a string, do a reasonable thing
163
+ # These should be a lists, but if someone specifies a string, do a reasonable thing
180
164
  if isinstance(options.extra_image_fields_to_print,str):
181
165
  options.extra_image_fields_to_print = [options.extra_image_fields_to_print]
166
+ if isinstance(options.extra_annotation_fields_to_print,str):
167
+ options.extra_annotation_fields_to_print = [options.extra_annotation_fields_to_print]
182
168
 
183
169
  if not options.parallelize_rendering_with_threads:
184
170
  print('Warning: process-based parallelization is not yet supported by visualize_db')
@@ -196,7 +182,7 @@ def visualize_db(db_path, output_dir, image_base_dir, options=None):
196
182
  assert(os.path.isfile(db_path))
197
183
  print('Loading database from {}...'.format(db_path))
198
184
  image_db = json.load(open(db_path))
199
- print('...done')
185
+ print('...done, loaded {} images'.format(len(image_db['images'])))
200
186
  elif isinstance(db_path,dict):
201
187
  print('Using previously-loaded DB')
202
188
  image_db = db_path
@@ -312,8 +298,7 @@ def visualize_db(db_path, output_dir, image_base_dir, options=None):
312
298
  if image_base_dir.startswith('http'):
313
299
  img_path = image_base_dir + img_relative_path
314
300
  else:
315
- img_path = os.path.join(image_base_dir,
316
- _image_filename_to_path(img_relative_path, image_base_dir))
301
+ img_path = os.path.join(image_base_dir,img_relative_path).replace('\\','/')
317
302
 
318
303
  annos_i = df_anno.loc[df_anno['image_id'] == img_id, :] # all annotations on this image
319
304
 
@@ -407,7 +392,8 @@ def visualize_db(db_path, output_dir, image_base_dir, options=None):
407
392
  img_id_string = str(img_id).lower()
408
393
  file_name = '{}_gt.jpg'.format(os.path.splitext(img_id_string)[0])
409
394
 
410
- # Replace characters that muck up image links
395
+ # Replace characters that muck up image links, including flattening file
396
+ # separators.
411
397
  illegal_characters = ['/','\\',':','\t','#',' ','%']
412
398
  for c in illegal_characters:
413
399
  file_name = file_name.replace(c,'~')
@@ -625,9 +611,6 @@ def main():
625
611
  help='Only include images with bounding boxes (defaults to false)')
626
612
  parser.add_argument('--random_seed', action='store', type=int, default=None,
627
613
  help='Random seed for image selection')
628
- parser.add_argument('--pathsep_replacement', action='store', type=str, default='',
629
- help='Replace path separators in relative filenames with another ' + \
630
- 'character (frequently ~)')
631
614
 
632
615
  if len(sys.argv[1:]) == 0:
633
616
  parser.print_help()
@@ -428,6 +428,7 @@ def main(): # noqa
428
428
  category_names_to_blur=category_names_to_blur)
429
429
 
430
430
  if (args.html_output_file is not None) and args.open_html_output_file:
431
+ print('Opening output file {}'.format(args.html_output_file))
431
432
  open_file(args.html_output_file)
432
433
 
433
434
  if __name__ == '__main__':