megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,968 @@
1
+ """
2
+
3
+ wi_platform_utils.py
4
+
5
+ Utility functions for working with the Wildlife Insights platform, specifically:
6
+
7
+ * Retrieving images based on .csv downloads
8
+ * Pushing results to the ProcessCVResponse() API (requires an API key)
9
+
10
+ """
11
+
12
+ #%% Imports
13
+
14
+ import os
15
+ import requests
16
+
17
+ import pandas as pd
18
+ import numpy as np
19
+
20
+ from tqdm import tqdm
21
+ from collections import defaultdict
22
+
23
+ from multiprocessing.pool import Pool, ThreadPool
24
+ from functools import partial
25
+
26
+ from megadetector.utils.path_utils import insert_before_extension
27
+ from megadetector.utils.path_utils import path_join
28
+
29
+ from megadetector.utils.ct_utils import split_list_into_n_chunks
30
+ from megadetector.utils.ct_utils import invert_dictionary
31
+ from megadetector.utils.ct_utils import compare_values_nan_equal
32
+
33
+ from megadetector.utils.string_utils import is_int
34
+
35
+ from megadetector.utils.wi_taxonomy_utils import is_valid_prediction_string
36
+ from megadetector.utils.wi_taxonomy_utils import no_cv_result_prediction_string
37
+ from megadetector.utils.wi_taxonomy_utils import blank_prediction_string
38
+
39
+ from megadetector.detection.run_detector import DEFAULT_DETECTOR_LABEL_MAP
40
+
41
+ # Only used when pushing results directly to the platform via the API; any detections we want
42
+ # to show in the UI should have at least this confidence value.
43
+ min_md_output_confidence = 0.25
44
+
45
+ md_category_id_to_name = DEFAULT_DETECTOR_LABEL_MAP
46
+ md_category_name_to_id = invert_dictionary(md_category_id_to_name)
47
+
48
+ # Fields expected to be present in a valid WI result
49
+ wi_result_fields = ['wi_taxon_id','class','order','family','genus','species','common_name']
50
+
51
+
52
+ #%% Functions for managing WI downloads
53
+
54
+ def read_sequences_from_download_bundle(download_folder):
55
+ """
56
+ Reads sequences.csv from [download_folder], returning a list of dicts. This is a
57
+ thin wrapper around pd.read_csv, it's just here for future-proofing.
58
+
59
+ Args:
60
+ download_folder (str): a folder containing exactly one file called sequences.csv, typically
61
+ representing a Wildlife Insights download bundle.
62
+
63
+ Returns:
64
+ list of dict: a direct conversion of the .csv file to a list of dicts
65
+ """
66
+
67
+ print('Reading sequences from {}'.format(download_folder))
68
+
69
+ sequence_list_files = os.listdir(download_folder)
70
+ sequence_list_files = \
71
+ [fn for fn in sequence_list_files if fn == 'sequences.csv']
72
+ assert len(sequence_list_files) == 1, \
73
+ 'Could not find sequences.csv in {}'.format(download_folder)
74
+
75
+ sequence_list_file = path_join(download_folder,sequence_list_files[0])
76
+
77
+ df = pd.read_csv(sequence_list_file)
78
+ sequence_records = df.to_dict('records')
79
+ return sequence_records
80
+
81
+
82
+ def read_images_from_download_bundle(download_folder):
83
+ """
84
+ Reads all images.csv files from [download_folder], returns a dict mapping image IDs
85
+ to a list of dicts that describe each image. It's a list of dicts rather than a single dict
86
+ because images may appear more than once.
87
+
88
+ Args:
89
+ download_folder (str): a folder containing one or more images.csv files, typically
90
+ representing a Wildlife Insights download bundle.
91
+
92
+ Returns:
93
+ dict: Maps image GUIDs to dicts with at least the following fields:
94
+ * project_id (int)
95
+ * deployment_id (str)
96
+ * image_id (str, should match the key)
97
+ * filename (str, the filename without path at the time of upload)
98
+ * location (str, starting with gs://)
99
+
100
+ May also contain classification fields: wi_taxon_id (str), species, etc.
101
+ """
102
+
103
+ print('Reading images from {}'.format(download_folder))
104
+
105
+ ##%% Find lists of images
106
+
107
+ image_list_files = os.listdir(download_folder)
108
+ image_list_files = \
109
+ [fn for fn in image_list_files if fn.startswith('images_') and fn.endswith('.csv')]
110
+ image_list_files = \
111
+ [path_join(download_folder,fn) for fn in image_list_files]
112
+ print('Found {} image list files'.format(len(image_list_files)))
113
+
114
+
115
+ ##%% Read lists of images by deployment
116
+
117
+ image_id_to_image_records = defaultdict(list)
118
+
119
+ # image_list_file = image_list_files[0]
120
+ for image_list_file in image_list_files:
121
+
122
+ print('Reading images from list file {}'.format(
123
+ os.path.basename(image_list_file)))
124
+
125
+ df = pd.read_csv(image_list_file,low_memory=False)
126
+
127
+ # i_row = 0; row = df.iloc[i_row]
128
+ for i_row,row in tqdm(df.iterrows(),total=len(df)):
129
+
130
+ row_dict = row.to_dict()
131
+ image_id = row_dict['image_id']
132
+ image_id_to_image_records[image_id].append(row_dict)
133
+
134
+ # ...for each image
135
+
136
+ # ...for each list file
137
+
138
+ deployment_ids = set()
139
+ for image_id in image_id_to_image_records:
140
+ image_records = image_id_to_image_records[image_id]
141
+ for image_record in image_records:
142
+ deployment_ids.add(image_record['deployment_id'])
143
+
144
+ print('Found {} rows in {} deployments'.format(
145
+ len(image_id_to_image_records),
146
+ len(deployment_ids)))
147
+
148
+ return image_id_to_image_records
149
+
150
+
151
+ def find_images_in_identify_tab(download_folder_with_identify,download_folder_excluding_identify):
152
+ """
153
+ Based on extracted download packages with and without the "exclude images in 'identify' tab
154
+ checkbox" checked, figure out which images are in the identify tab. Returns a list of dicts (one
155
+ per image).
156
+
157
+ Args:
158
+ download_folder_with_identify (str): the folder containing the download bundle that
159
+ includes images from the "identify" tab
160
+ download_folder_excluding_identify (str): the folder containing the download bundle that
161
+ excludes images from the "identify" tab
162
+
163
+ Returns:
164
+ list of dict: list of image records that are present in the identify tab
165
+ """
166
+
167
+ ##%% Read data (~30 seconds)
168
+
169
+ image_id_to_image_records_with_identify = \
170
+ read_images_from_download_bundle(download_folder_with_identify)
171
+ image_id_to_image_records_excluding_identify = \
172
+ read_images_from_download_bundle(download_folder_excluding_identify)
173
+
174
+
175
+ ##%% Find images that have not been identified
176
+
177
+ all_image_ids_with_identify = set(image_id_to_image_records_with_identify.keys())
178
+ all_image_ids_excluding_identify = set(image_id_to_image_records_excluding_identify.keys())
179
+
180
+ image_ids_in_identify_tab = all_image_ids_with_identify.difference(all_image_ids_excluding_identify)
181
+
182
+ assert len(image_ids_in_identify_tab) == \
183
+ len(all_image_ids_with_identify) - len(all_image_ids_excluding_identify)
184
+
185
+ print('Found {} images with identify, {} in identify tab, {} excluding'.format(
186
+ len(all_image_ids_with_identify),
187
+ len(image_ids_in_identify_tab),
188
+ len(all_image_ids_excluding_identify)))
189
+
190
+ image_records_in_identify_tab = []
191
+ deployment_ids_for_downloaded_images = set()
192
+
193
+ for image_id in image_ids_in_identify_tab:
194
+ image_records_this_image = image_id_to_image_records_with_identify[image_id]
195
+ assert len(image_records_this_image) > 0
196
+ image_records_in_identify_tab.extend(image_records_this_image)
197
+ for image_record in image_records_this_image:
198
+ deployment_ids_for_downloaded_images.add(image_record['deployment_id'])
199
+
200
+ print('Found {} records for {} unique images in {} deployments'.format(
201
+ len(image_records_in_identify_tab),
202
+ len(image_ids_in_identify_tab),
203
+ len(deployment_ids_for_downloaded_images)))
204
+
205
+ return image_records_in_identify_tab
206
+
207
+ # ...def find_images_in_identify_tab(...)
208
+
209
+
210
+ def write_prefix_download_command(image_records,
211
+ download_dir_base,
212
+ force_download=False,
213
+ download_command_file=None):
214
+ """
215
+ Write a .sh script to download all images (using gcloud) from the longest common URL
216
+ prefix in the images represented in [image_records].
217
+
218
+ Args:
219
+ image_records (list of dict): list of dicts with at least the field 'location'.
220
+ Can also be a dict whose values are lists of record dicts.
221
+ download_dir_base (str): local destination folder
222
+ force_download (bool, optional): overwrite existing files
223
+ download_command_file (str, optional): path of the .sh script we should write, defaults
224
+ to "download_wi_images_with_prefix.sh" in the destination folder.
225
+ """
226
+
227
+ ##%% Input validation
228
+
229
+ # If a dict is provided, assume it maps image GUIDs to lists of records, flatten to a list
230
+ if isinstance(image_records,dict):
231
+ all_image_records = []
232
+ for k in image_records:
233
+ records_this_image = image_records[k]
234
+ all_image_records.extend(records_this_image)
235
+ image_records = all_image_records
236
+
237
+ assert isinstance(image_records,list), \
238
+ 'Illegal image record list format {}'.format(type(image_records))
239
+ assert isinstance(image_records[0],dict), \
240
+ 'Illegal image record format {}'.format(type(image_records[0]))
241
+
242
+ urls = [r['location'] for r in image_records]
243
+
244
+ # "urls" is a list of URLs starting with gs://. Find the highest-level folder
245
+ # that is common to all URLs in the list. For example, if the list is:
246
+ #
247
+ # gs://a/b/c
248
+ # gs://a/b/d
249
+ #
250
+ # The result should be:
251
+ #
252
+ # gs://a/b
253
+ common_prefix = os.path.commonprefix(urls)
254
+
255
+ # Remove the gs:// prefix if it's still there
256
+ if common_prefix.startswith('gs://'):
257
+ common_prefix = common_prefix[len('gs://'):]
258
+
259
+ # Ensure the common prefix ends with a '/' if it's not empty
260
+ if (len(common_prefix) > 0) and (not common_prefix.endswith('/')):
261
+ common_prefix = os.path.dirname(common_prefix) + '/'
262
+
263
+ print('Longest common prefix: {}'.format(common_prefix))
264
+
265
+ if download_command_file is None:
266
+ download_command_file = \
267
+ path_join(download_dir_base,'download_wi_images_with_prefix.sh')
268
+
269
+ os.makedirs(download_dir_base,exist_ok=True)
270
+
271
+ with open(download_command_file,'w',newline='\n') as f:
272
+ # The --no-clobber flag prevents overwriting existing files
273
+ # The -r flag is for recursive download
274
+ # The gs:// prefix is added back for the gcloud command
275
+ no_clobber_string = ''
276
+ if not force_download:
277
+ no_clobber_string = '--no-clobber'
278
+
279
+ cmd = 'gcloud storage cp -r {} "gs://{}" "{}"'.format(
280
+ no_clobber_string,common_prefix,download_dir_base)
281
+ print('Writing download command:\n{}'.format(cmd))
282
+ f.write(cmd + '\n')
283
+
284
+ print('Download script written to {}'.format(download_command_file))
285
+
286
+ # ...def write_prefix_download_command(...)
287
+
288
+
289
+ def write_download_commands(image_records,
290
+ download_dir_base,
291
+ force_download=False,
292
+ n_download_workers=25,
293
+ download_command_file_base=None,
294
+ image_flattening='deployment'):
295
+ """
296
+ Given a list of dicts with at least the field 'location' (a gs:// URL), prepare a set of "gcloud
297
+ storage" commands to download images, and write those to a series of .sh scripts, along with one
298
+ .sh script that runs all the others and blocks.
299
+
300
+ gcloud commands will use relative paths.
301
+
302
+ Args:
303
+ image_records (list of dict): list of dicts with at least the field 'location'.
304
+ Can also be a dict whose values are lists of record dicts.
305
+ download_dir_base (str): local destination folder
306
+ force_download (bool, optional): include gs commands even if the target file exists
307
+ n_download_workers (int, optional): number of scripts to write (that's our hacky way
308
+ of controlling parallelization)
309
+ download_command_file_base (str, optional): path of the .sh script we should write, defaults
310
+ to "download_wi_images.sh" in the destination folder. Individual worker scripts will
311
+ have a number added, e.g. download_wi_images_00.sh.
312
+ image_flattening (str, optional): if 'none', relative paths will be preserved
313
+ representing the entire URL for each image. Can be 'guid' (just download to
314
+ [GUID].JPG) or 'deployment' (download to [deployment]/[GUID].JPG).
315
+ """
316
+
317
+ ##%% Input validation
318
+
319
+ # If a dict is provided, assume it maps image GUIDs to lists of records, flatten to a list
320
+ if isinstance(image_records,dict):
321
+ all_image_records = []
322
+ for k in image_records:
323
+ records_this_image = image_records[k]
324
+ all_image_records.extend(records_this_image)
325
+ image_records = all_image_records
326
+
327
+ assert isinstance(image_records,list), \
328
+ 'Illegal image record list format {}'.format(type(image_records))
329
+ assert isinstance(image_records[0],dict), \
330
+ 'Illegal image record format {}'.format(type(image_records[0]))
331
+
332
+
333
+ ##%% Map URLs to relative paths
334
+
335
+ # URLs look like:
336
+ #
337
+ # gs://145625555_2004881_2323_name__main/deployment/2241000/prod/directUpload/5fda0ddd-511e-46ca-95c1-302b3c71f8ea.JPG
338
+ if image_flattening is None:
339
+ image_flattening = 'none'
340
+ image_flattening = image_flattening.lower().strip()
341
+
342
+ assert image_flattening in ('none','guid','deployment'), \
343
+ 'Illegal image flattening strategy {}'.format(image_flattening)
344
+
345
+ url_to_relative_path = {}
346
+
347
+ for image_record in image_records:
348
+
349
+ url = image_record['location']
350
+ assert url.startswith('gs://'), 'Illegal URL {}'.format(url)
351
+
352
+ relative_path = None
353
+
354
+ if image_flattening == 'none':
355
+ relative_path = url.replace('gs://','')
356
+ elif image_flattening == 'guid':
357
+ relative_path = url.split('/')[-1]
358
+ else:
359
+ assert image_flattening == 'deployment'
360
+ tokens = url.split('/')
361
+ found_deployment_id = False
362
+ for i_token,token in enumerate(tokens):
363
+ if token == 'deployment':
364
+ assert i_token < (len(tokens)-1)
365
+ deployment_id_string = tokens[i_token + 1]
366
+ deployment_id_string = deployment_id_string.replace('_thumb','')
367
+ assert is_int(deployment_id_string), \
368
+ 'Illegal deployment ID {}'.format(deployment_id_string)
369
+ image_id = url.split('/')[-1]
370
+ relative_path = deployment_id_string + '/' + image_id
371
+ found_deployment_id = True
372
+ break
373
+ assert found_deployment_id, \
374
+ 'Could not find deployment ID in record {}'.format(str(image_record))
375
+
376
+ assert relative_path is not None
377
+
378
+ if url in url_to_relative_path:
379
+ assert url_to_relative_path[url] == relative_path, \
380
+ 'URL path mapping error'
381
+ else:
382
+ url_to_relative_path[url] = relative_path
383
+
384
+ # ...for each image record
385
+
386
+
387
+ ##%% Make list of gcloud storage commands
388
+
389
+ if download_command_file_base is None:
390
+ download_command_file_base = path_join(download_dir_base,'download_wi_images.sh')
391
+
392
+ commands = []
393
+ skipped_urls = []
394
+ downloaded_urls = set()
395
+
396
+ # image_record = image_records[0]
397
+ for image_record in tqdm(image_records):
398
+
399
+ url = image_record['location']
400
+ if url in downloaded_urls:
401
+ continue
402
+
403
+ assert url.startswith('gs://'), 'Illegal URL {}'.format(url)
404
+
405
+ relative_path = url_to_relative_path[url]
406
+ abs_path = path_join(download_dir_base,relative_path)
407
+
408
+ # Optionally skip files that already exist
409
+ if (not force_download) and (os.path.isfile(abs_path)):
410
+ skipped_urls.append(url)
411
+ continue
412
+
413
+ # command = 'gsutil cp "{}" "./{}"'.format(url,relative_path)
414
+ command = 'gcloud storage cp --no-clobber "{}" "./{}"'.format(url,relative_path)
415
+ commands.append(command)
416
+
417
+ print('Generated {} commands for {} image records'.format(
418
+ len(commands),len(image_records)))
419
+
420
+ print('Skipped {} URLs'.format(len(skipped_urls)))
421
+
422
+
423
+ ##%% Write those commands out to n .sh files
424
+
425
+ commands_by_script = split_list_into_n_chunks(commands,n_download_workers)
426
+
427
+ local_download_commands = []
428
+
429
+ output_dir = os.path.dirname(download_command_file_base)
430
+ os.makedirs(output_dir,exist_ok=True)
431
+
432
+ # Write out the download script for each chunk
433
+ # i_script = 0
434
+ for i_script in range(0,n_download_workers):
435
+ download_command_file = insert_before_extension(download_command_file_base,str(i_script).zfill(2))
436
+ local_download_commands.append(os.path.basename(download_command_file))
437
+ with open(download_command_file,'w',newline='\n') as f:
438
+ for command in commands_by_script[i_script]:
439
+ f.write(command + '\n')
440
+
441
+ # Write out the main download script
442
+ with open(download_command_file_base,'w',newline='\n') as f:
443
+ for local_download_command in local_download_commands:
444
+ f.write('./' + local_download_command + ' &\n')
445
+ f.write('wait\n')
446
+ f.write('echo done\n')
447
+
448
+ # ...def write_download_commands(...)
449
+
450
+
451
+ #%% Functions and constants related to pushing results to the DB
452
+
453
+ # Sample payload for validation
454
+ sample_update_payload = {
455
+
456
+ "predictions": [
457
+ {
458
+ "project_id": "1234",
459
+ "ignore_data_file_checks": True,
460
+ "prediction": "f1856211-cfb7-4a5b-9158-c0f72fd09ee6;;;;;;blank",
461
+ "prediction_score": 0.81218224763870239,
462
+ "classifications": {
463
+ "classes": [
464
+ "f1856211-cfb7-4a5b-9158-c0f72fd09ee6;;;;;;blank",
465
+ "b1352069-a39c-4a84-a949-60044271c0c1;aves;;;;;bird",
466
+ "90d950db-2106-4bd9-a4c1-777604c3eada;mammalia;rodentia;;;;rodent",
467
+ "f2d233e3-80e3-433d-9687-e29ecc7a467a;mammalia;;;;;mammal",
468
+ "ac068717-6079-4aec-a5ab-99e8d14da40b;mammalia;rodentia;sciuridae;dremomys;rufigenis;red-cheeked squirrel"
469
+ ],
470
+ "scores": [
471
+ 0.81218224763870239,
472
+ 0.1096673980355263,
473
+ 0.02707692421972752,
474
+ 0.00771023565903306,
475
+ 0.0049269795417785636
476
+ ]
477
+ },
478
+ "detections": [
479
+ {
480
+ "category": "1",
481
+ "label": "animal",
482
+ "conf": 0.181,
483
+ "bbox": [
484
+ 0.02421,
485
+ 0.35823999999999989,
486
+ 0.051560000000000009,
487
+ 0.070826666666666746
488
+ ]
489
+ }
490
+ ],
491
+ "model_version": "3.1.2",
492
+ "prediction_source": "manual_update",
493
+ "data_file_id": "2ea1d2b2-7f84-43f9-af1f-8be0e69c7015"
494
+ }
495
+ ]
496
+ }
497
+
498
+ process_cv_response_url = 'https://placeholder'
499
+
500
+
501
+ def prepare_data_update_auth_headers(auth_token_file):
502
+ """
503
+ Read the authorization token from a text file and prepare http headers.
504
+
505
+ Args:
506
+ auth_token_file (str): a single-line text file containing a write-enabled
507
+ API token.
508
+
509
+ Returns:
510
+ dict: http headers, with fields 'Authorization' and 'Content-Type'
511
+ """
512
+
513
+ with open(auth_token_file,'r') as f:
514
+ auth_token = f.read()
515
+
516
+ headers = {
517
+ 'Authorization': 'Bearer ' + auth_token,
518
+ 'Content-Type': 'application/json'
519
+ }
520
+
521
+ return headers
522
+
523
+
524
+ def push_results_for_images(payload,
525
+ headers,
526
+ url=process_cv_response_url,
527
+ verbose=False):
528
+ """
529
+ Push results for one or more images represented in [payload] to the
530
+ process_cv_response API, to write to the WI DB.
531
+
532
+ Args:
533
+ payload (dict): payload to upload to the API
534
+ headers (dict): authorization headers, see prepare_data_update_auth_headers
535
+ url (str, optional): API URL
536
+ verbose (bool, optional): enable additional debug output
537
+
538
+ Return:
539
+ int: response status code
540
+ """
541
+
542
+ if verbose:
543
+ print('Sending header {} to URL {}'.format(
544
+ headers,url))
545
+
546
+ response = requests.post(url, headers=headers, json=payload)
547
+
548
+ # Check the response status code
549
+ if response.status_code in (200,201):
550
+ if verbose:
551
+ print('Successfully pushed results for {} images'.format(len(payload['predictions'])))
552
+ print(response.headers)
553
+ print(str(response))
554
+ else:
555
+ print(f'Error: {response.status_code} {response.text}')
556
+
557
+ return response.status_code
558
+
559
+
560
+ def parallel_push_results_for_images(payloads,
561
+ headers,
562
+ url=process_cv_response_url,
563
+ verbose=False,
564
+ pool_type='thread',
565
+ n_workers=10):
566
+ """
567
+ Push results for the list of payloads in [payloads] to the process_cv_response API,
568
+ parallelized over multiple workers.
569
+
570
+ Args:
571
+ payloads (list of dict): payloads to upload to the API
572
+ headers (dict): authorization headers, see prepare_data_update_auth_headers
573
+ url (str, optional): API URL
574
+ verbose (bool, optional): enable additional debug output
575
+ pool_type (str, optional): 'thread' or 'process'
576
+ n_workers (int, optional): number of parallel workers
577
+
578
+ Returns:
579
+ list of int: list of http response codes, one per payload
580
+ """
581
+
582
+ if n_workers == 1:
583
+
584
+ results = []
585
+ for payload in payloads:
586
+ results.append(push_results_for_images(payload,
587
+ headers=headers,
588
+ url=url,
589
+ verbose=verbose))
590
+ return results
591
+
592
+ else:
593
+
594
+ assert pool_type in ('thread','process')
595
+
596
+ try:
597
+ if pool_type == 'thread':
598
+ pool_string = 'thread'
599
+ pool = ThreadPool(n_workers)
600
+ else:
601
+ pool_string = 'process'
602
+ pool = Pool(n_workers)
603
+
604
+ print('Created a {} pool of {} workers'.format(
605
+ pool_string,n_workers))
606
+
607
+ results = list(tqdm(pool.imap(
608
+ partial(push_results_for_images,headers=headers,url=url,verbose=verbose),payloads),
609
+ total=len(payloads)))
610
+ finally:
611
+ pool.close()
612
+ pool.join()
613
+ print('Pool closed and joined for WI result uploads')
614
+
615
+ assert len(results) == len(payloads)
616
+ return results
617
+
618
+
619
+ def generate_payload_with_replacement_detections(wi_result,
620
+ detections,
621
+ prediction_score=0.9,
622
+ model_version='3.1.2',
623
+ prediction_source='manual_update'):
624
+ """
625
+ Generate a payload for a single image that keeps the classifications from
626
+ [wi_result], but replaces the detections with the MD-formatted list [detections].
627
+
628
+ Args:
629
+ wi_result (dict): dict representing a WI prediction result, with at least the
630
+ fields in the constant wi_result_fields
631
+ detections (list): list of WI-formatted detection dicts (with fields ['conf'] and ['category'])
632
+ prediction_score (float, optional): confidence value to use for the combined prediction
633
+ model_version (str, optional): model version string to include in the payload
634
+ prediction_source (str, optional): prediction source string to include in the payload
635
+
636
+ Returns:
637
+ dict: dictionary suitable for uploading via push_results_for_images
638
+ """
639
+
640
+ payload_detections = []
641
+
642
+ # detection = detections[0]
643
+ for detection in detections:
644
+ detection_out = detection.copy()
645
+ detection_out['label'] = md_category_id_to_name[detection['category']]
646
+ if detection_out['conf'] < min_md_output_confidence:
647
+ detection_out['conf'] = min_md_output_confidence
648
+ payload_detections.append(detection_out)
649
+
650
+ prediction_string = wi_result_to_prediction_string(wi_result)
651
+
652
+ prediction = {}
653
+ prediction['ignore_data_file_checks'] = True
654
+ prediction['prediction'] = prediction_string
655
+ prediction['prediction_score'] = prediction_score
656
+
657
+ classifications = {}
658
+ classifications['classes'] = [prediction_string]
659
+ classifications['scores'] = [prediction_score]
660
+
661
+ prediction['classifications'] = classifications
662
+ prediction['detections'] = payload_detections
663
+ prediction['model_version'] = model_version
664
+ prediction['prediction_source'] = prediction_source
665
+ prediction['data_file_id'] = wi_result['image_id']
666
+ prediction['project_id'] = str(wi_result['project_id'])
667
+ payload = {}
668
+ payload['predictions'] = [prediction]
669
+
670
+ return payload
671
+
672
+
673
+ def generate_blank_prediction_payload(data_file_id,
674
+ project_id,
675
+ blank_confidence=0.9,
676
+ model_version='3.1.2',
677
+ prediction_source='manual_update'):
678
+ """
679
+ Generate a payload that will set a single image to the blank classification, with
680
+ no detections. Suitable for upload via push_results_for_images.
681
+
682
+ Args:
683
+ data_file_id (str): unique identifier for this image used in the WI DB
684
+ project_id (int): WI project ID
685
+ blank_confidence (float, optional): confidence value to associate with this
686
+ prediction
687
+ model_version (str, optional): model version string to include in the payload
688
+ prediction_source (str, optional): prediction source string to include in the payload
689
+
690
+ Returns:
691
+ dict: dictionary suitable for uploading via push_results_for_images
692
+ """
693
+
694
+ prediction = {}
695
+ prediction['ignore_data_file_checks'] = True
696
+ prediction['prediction'] = blank_prediction_string
697
+ prediction['prediction_score'] = blank_confidence
698
+ prediction['classifications'] = {}
699
+ prediction['classifications']['classes'] = [blank_prediction_string]
700
+ prediction['classifications']['scores'] = [blank_confidence]
701
+ prediction['detections'] = []
702
+ prediction['model_version'] = model_version
703
+ prediction['prediction_source'] = prediction_source
704
+ prediction['data_file_id'] = data_file_id
705
+ prediction['project_id'] = project_id
706
+ payload = {}
707
+ payload['predictions'] = [prediction]
708
+
709
+ return payload
710
+
711
+
712
+ def generate_no_cv_result_payload(data_file_id,
713
+ project_id,
714
+ no_cv_confidence=0.9,
715
+ model_version='3.1.2',
716
+ prediction_source='manual_update'):
717
+ """
718
+ Generate a payload that will set a single image to the blank classification, with
719
+ no detections. Suitable for uploading via push_results_for_images.
720
+
721
+ Args:
722
+ data_file_id (str): unique identifier for this image used in the WI DB
723
+ project_id (int): WI project ID
724
+ no_cv_confidence (float, optional): confidence value to associate with this
725
+ prediction
726
+ model_version (str, optional): model version string to include in the payload
727
+ prediction_source (str, optional): prediction source string to include in the payload
728
+
729
+ Returns:
730
+ dict: dictionary suitable for uploading via push_results_for_images
731
+ """
732
+
733
+ prediction = {}
734
+ prediction['ignore_data_file_checks'] = True
735
+ prediction['prediction'] = no_cv_result_prediction_string
736
+ prediction['prediction_score'] = no_cv_confidence
737
+ prediction['classifications'] = {}
738
+ prediction['classifications']['classes'] = [no_cv_result_prediction_string]
739
+ prediction['classifications']['scores'] = [no_cv_confidence]
740
+ prediction['detections'] = []
741
+ prediction['model_version'] = model_version
742
+ prediction['prediction_source'] = prediction_source
743
+ prediction['data_file_id'] = data_file_id
744
+ prediction['project_id'] = project_id
745
+ payload = {}
746
+ payload['predictions'] = [prediction]
747
+
748
+ return payload
749
+
750
+
751
+ def generate_payload_for_prediction_string(data_file_id,
752
+ project_id,
753
+ prediction_string,
754
+ prediction_confidence=0.8,
755
+ detections=None,
756
+ model_version='3.1.2',
757
+ prediction_source='manual_update'):
758
+ """
759
+ Generate a payload that will set a single image to a particular prediction, optionally
760
+ including detections. Suitable for uploading via push_results_for_images.
761
+
762
+ Args:
763
+ data_file_id (str): unique identifier for this image used in the WI DB
764
+ project_id (int): WI project ID
765
+ prediction_string (str): WI-formatted prediction string to include in the payload
766
+ prediction_confidence (float, optional): confidence value to associate with this
767
+ prediction
768
+ detections (list, optional): list of MD-formatted detection dicts, with fields
769
+ ['category'] and 'conf'
770
+ model_version (str, optional): model version string to include in the payload
771
+ prediction_source (str, optional): prediction source string to include in the payload
772
+
773
+
774
+ Returns:
775
+ dict: dictionary suitable for uploading via push_results_for_images
776
+ """
777
+
778
+ assert is_valid_prediction_string(prediction_string), \
779
+ 'Invalid prediction string: {}'.format(prediction_string)
780
+
781
+ payload_detections = []
782
+
783
+ if detections is not None:
784
+ # detection = detections[0]
785
+ for detection in detections:
786
+ detection_out = detection.copy()
787
+ detection_out['label'] = md_category_id_to_name[detection['category']]
788
+ if detection_out['conf'] < min_md_output_confidence:
789
+ detection_out['conf'] = min_md_output_confidence
790
+ payload_detections.append(detection_out)
791
+
792
+ prediction = {}
793
+ prediction['ignore_data_file_checks'] = True
794
+ prediction['prediction'] = prediction_string
795
+ prediction['prediction_score'] = prediction_confidence
796
+ prediction['classifications'] = {}
797
+ prediction['classifications']['classes'] = [prediction_string]
798
+ prediction['classifications']['scores'] = [prediction_confidence]
799
+ prediction['detections'] = payload_detections
800
+ prediction['model_version'] = model_version
801
+ prediction['prediction_source'] = prediction_source
802
+ prediction['data_file_id'] = data_file_id
803
+ prediction['project_id'] = project_id
804
+
805
+ payload = {}
806
+ payload['predictions'] = [prediction]
807
+
808
+ return payload
809
+
810
+
811
+ def validate_payload(payload):
812
+ """
813
+ Verifies that the dict [payload] is compatible with the ProcessCVResponse() API. Throws an
814
+ error if [payload] is invalid.
815
+
816
+ Args:
817
+ payload (dict): payload in the format expected by push_results_for_images.
818
+
819
+ Returns:
820
+ bool: successful validation; this is just future-proofing, currently never returns False
821
+ """
822
+
823
+ assert isinstance(payload,dict)
824
+ assert len(payload.keys()) == 1 and 'predictions' in payload
825
+
826
+ # prediction = payload['predictions'][0]
827
+ for prediction in payload['predictions']:
828
+
829
+ assert 'project_id' in prediction
830
+ if not isinstance(prediction['project_id'],int):
831
+ _ = int(prediction['project_id'])
832
+ assert 'ignore_data_file_checks' in prediction and \
833
+ isinstance(prediction['ignore_data_file_checks'],bool)
834
+ assert 'prediction' in prediction and \
835
+ isinstance(prediction['prediction'],str) and \
836
+ len(prediction['prediction'].split(';')) == 7
837
+ assert 'prediction_score' in prediction and \
838
+ isinstance(prediction['prediction_score'],float)
839
+ assert 'model_version' in prediction and \
840
+ isinstance(prediction['model_version'],str)
841
+ assert 'data_file_id' in prediction and \
842
+ isinstance(prediction['data_file_id'],str) and \
843
+ len(prediction['data_file_id']) == 36
844
+ assert 'classifications' in prediction and \
845
+ isinstance(prediction['classifications'],dict)
846
+ classifications = prediction['classifications']
847
+ assert 'classes' in classifications and isinstance(classifications['classes'],list)
848
+ assert 'scores' in classifications and isinstance(classifications['scores'],list)
849
+ assert len(classifications['classes']) == len(classifications['scores'])
850
+ for c in classifications['classes']:
851
+ assert is_valid_prediction_string(c)
852
+ for score in classifications['scores']:
853
+ assert isinstance(score,float) and score >= 0 and score <= 1.0
854
+ assert 'detections' in prediction and isinstance(prediction['detections'],list)
855
+
856
+ for detection in prediction['detections']:
857
+
858
+ assert isinstance(detection,dict)
859
+ assert 'category' in detection and detection['category'] in ('1','2','3')
860
+ assert 'label' in detection and detection['label'] in ('animal','person','vehicle')
861
+ assert 'conf' in detection and \
862
+ isinstance(detection['conf'],float) and \
863
+ detection['conf'] >= 0 and detection['conf'] <= 1.0
864
+ assert 'bbox' in detection and \
865
+ isinstance(detection['bbox'],list) and \
866
+ len(detection['bbox']) == 4
867
+
868
+ # ...for each detection
869
+
870
+ # ...for each prediction
871
+
872
+ return True
873
+
874
+ # ...def validate_payload(...)
875
+
876
+
877
+ #%% Functions for working with WI results (from the API or from download bundles)
878
+
879
+ def wi_result_to_prediction_string(r):
880
+ """
881
+ Convert the dict [r] - typically loaded from a row in a downloaded .csv file - to
882
+ a valid prediction string, e.g.:
883
+
884
+ 1f689929-883d-4dae-958c-3d57ab5b6c16;;;;;;animal
885
+ 90d950db-2106-4bd9-a4c1-777604c3eada;mammalia;rodentia;;;;rodent
886
+
887
+ Args:
888
+ r (dict): dict containing WI prediction information, with at least the fields
889
+ specified in wi_result_fields.
890
+
891
+ Returns:
892
+ str: the result in [r], as a semicolon-delimited prediction string
893
+ """
894
+
895
+ values = []
896
+ for field in wi_result_fields:
897
+ if isinstance(r[field],str):
898
+ values.append(r[field].lower())
899
+ else:
900
+ assert isinstance(r[field],float) and np.isnan(r[field])
901
+ values.append('')
902
+ s = ';'.join(values)
903
+ assert is_valid_prediction_string(s)
904
+ return s
905
+
906
+
907
+ def record_is_unidentified(record):
908
+ """
909
+ A record is considered "unidentified" if the "identified by" field is either NaN or "computer vision"
910
+
911
+ Args:
912
+ record (dict): dict representing a WI result loaded from a .csv file, with at least the
913
+ field "identified_by"
914
+
915
+ Returns:
916
+ bool: True if the "identified_by" field is either NaN or a string indicating that this
917
+ record has not yet been human-reviewed.
918
+ """
919
+
920
+ identified_by = record['identified_by']
921
+ assert isinstance(identified_by,float) or isinstance(identified_by,str)
922
+ if isinstance(identified_by,float):
923
+ assert np.isnan(identified_by)
924
+ return True
925
+ else:
926
+ return identified_by == 'Computer vision'
927
+
928
+
929
+ def record_lists_are_identical(records_0,records_1,verbose=False):
930
+ """
931
+ Takes two lists of records in the form returned by read_images_from_download_bundle and
932
+ determines whether they are the same.
933
+
934
+ Args:
935
+ records_0 (list of dict): the first list of records to compare
936
+ records_1 (list of dict): the second list of records to compare
937
+ verbose (bool, optional): enable additional debug output
938
+
939
+ Returns:
940
+ bool: True if the two lists are identical
941
+ """
942
+
943
+ if len(records_0) != len(records_1):
944
+ return False
945
+
946
+ # i_record = 0; record_0 = records_0[i_record]
947
+ for i_record,record_0 in enumerate(records_0):
948
+ record_1 = records_1[i_record]
949
+ assert set(record_0.keys()) == set(record_1.keys())
950
+ for k in record_0.keys():
951
+ if not compare_values_nan_equal(record_0[k],record_1[k]):
952
+ if verbose:
953
+ print('Image ID: {} ({})\nRecord 0/{}: {}\nRecord 1/{}: {}'.format(
954
+ record_0['image_id'],record_1['image_id'],
955
+ k,record_0[k],k,record_1[k]))
956
+ return False
957
+
958
+ return True
959
+
960
+
961
+ #%% Validate constants
962
+
963
+ # This is executed at the time this module gets imported.
964
+
965
+ blank_payload = generate_blank_prediction_payload('70ede9c6-d056-4dd1-9a0b-3098d8113e0e','1234')
966
+ validate_payload(sample_update_payload)
967
+ validate_payload(blank_payload)
968
+