megadetector 10.0.4__py3-none-any.whl → 10.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/data_management/cct_json_utils.py +1 -0
- megadetector/data_management/speciesnet_to_md.py +2 -2
- megadetector/detection/process_video.py +15 -6
- megadetector/detection/video_utils.py +132 -21
- megadetector/postprocessing/classification_postprocessing.py +26 -10
- megadetector/postprocessing/generate_csv_report.py +1 -1
- megadetector/postprocessing/load_api_results.py +1 -1
- megadetector/postprocessing/md_to_wi.py +1 -1
- megadetector/postprocessing/postprocess_batch_results.py +1 -1
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1 -1
- megadetector/utils/ct_utils.py +18 -0
- megadetector/utils/wi_platform_utils.py +824 -0
- megadetector/utils/wi_taxonomy_utils.py +1711 -0
- megadetector/visualization/visualize_detector_output.py +1 -1
- megadetector/visualization/visualize_video_output.py +1 -1
- {megadetector-10.0.4.dist-info → megadetector-10.0.6.dist-info}/METADATA +2 -2
- {megadetector-10.0.4.dist-info → megadetector-10.0.6.dist-info}/RECORD +20 -19
- megadetector/utils/wi_utils.py +0 -2674
- {megadetector-10.0.4.dist-info → megadetector-10.0.6.dist-info}/WHEEL +0 -0
- {megadetector-10.0.4.dist-info → megadetector-10.0.6.dist-info}/licenses/LICENSE +0 -0
- {megadetector-10.0.4.dist-info → megadetector-10.0.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,824 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
wi_platform_utils.py
|
|
4
|
+
|
|
5
|
+
Utility functions for working with the Wildlife Insights platform, specifically:
|
|
6
|
+
|
|
7
|
+
* Retrieving images based on .csv downloads
|
|
8
|
+
* Pushing results to the ProcessCVResponse() API (requires an API key)
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
#%% Imports
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
import requests
|
|
16
|
+
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
|
|
23
|
+
from multiprocessing.pool import Pool, ThreadPool
|
|
24
|
+
from functools import partial
|
|
25
|
+
|
|
26
|
+
from megadetector.utils.path_utils import insert_before_extension
|
|
27
|
+
from megadetector.utils.ct_utils import split_list_into_n_chunks
|
|
28
|
+
from megadetector.utils.ct_utils import invert_dictionary
|
|
29
|
+
from megadetector.utils.ct_utils import compare_values_nan_equal
|
|
30
|
+
|
|
31
|
+
from megadetector.utils.wi_taxonomy_utils import is_valid_prediction_string
|
|
32
|
+
from megadetector.utils.wi_taxonomy_utils import no_cv_result_prediction_string
|
|
33
|
+
from megadetector.utils.wi_taxonomy_utils import blank_prediction_string
|
|
34
|
+
|
|
35
|
+
from megadetector.detection.run_detector import DEFAULT_DETECTOR_LABEL_MAP
|
|
36
|
+
|
|
37
|
+
# Only used when pushing results directly to the platform via the API; any detections we want
|
|
38
|
+
# to show in the UI should have at least this confidence value.
|
|
39
|
+
min_md_output_confidence = 0.25
|
|
40
|
+
|
|
41
|
+
md_category_id_to_name = DEFAULT_DETECTOR_LABEL_MAP
|
|
42
|
+
md_category_name_to_id = invert_dictionary(md_category_id_to_name)
|
|
43
|
+
|
|
44
|
+
# Fields expected to be present in a valid WI result
|
|
45
|
+
wi_result_fields = ['wi_taxon_id','class','order','family','genus','species','common_name']
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
#%% Functions for managing WI downloads
|
|
49
|
+
|
|
50
|
+
def read_sequences_from_download_bundle(download_folder):
|
|
51
|
+
"""
|
|
52
|
+
Reads sequences.csv from [download_folder], returning a list of dicts. This is a
|
|
53
|
+
thin wrapper around pd.read_csv, it's just here for future-proofing.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
download_folder (str): a folder containing exactly one file called sequences.csv, typically
|
|
57
|
+
representing a Wildlife Insights download bundle.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
list of dict: a direct conversion of the .csv file to a list of dicts
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
print('Reading sequences from {}'.format(download_folder))
|
|
64
|
+
|
|
65
|
+
sequence_list_files = os.listdir(download_folder)
|
|
66
|
+
sequence_list_files = \
|
|
67
|
+
[fn for fn in sequence_list_files if fn == 'sequences.csv']
|
|
68
|
+
assert len(sequence_list_files) == 1, \
|
|
69
|
+
'Could not find sequences.csv in {}'.format(download_folder)
|
|
70
|
+
|
|
71
|
+
sequence_list_file = os.path.join(download_folder,sequence_list_files[0])
|
|
72
|
+
|
|
73
|
+
df = pd.read_csv(sequence_list_file)
|
|
74
|
+
sequence_records = df.to_dict('records')
|
|
75
|
+
return sequence_records
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def read_images_from_download_bundle(download_folder):
|
|
79
|
+
"""
|
|
80
|
+
Reads all images.csv files from [download_folder], returns a dict mapping image IDs
|
|
81
|
+
to a list of dicts that describe each image. It's a list of dicts rather than a single dict
|
|
82
|
+
because images may appear more than once.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
download_folder (str): a folder containing one or more images.csv files, typically
|
|
86
|
+
representing a Wildlife Insights download bundle.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
dict: Maps image GUIDs to dicts with at least the following fields:
|
|
90
|
+
* project_id (int)
|
|
91
|
+
* deployment_id (str)
|
|
92
|
+
* image_id (str, should match the key)
|
|
93
|
+
* filename (str, the filename without path at the time of upload)
|
|
94
|
+
* location (str, starting with gs://)
|
|
95
|
+
|
|
96
|
+
May also contain classification fields: wi_taxon_id (str), species, etc.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
print('Reading images from {}'.format(download_folder))
|
|
100
|
+
|
|
101
|
+
##%% Find lists of images
|
|
102
|
+
|
|
103
|
+
image_list_files = os.listdir(download_folder)
|
|
104
|
+
image_list_files = \
|
|
105
|
+
[fn for fn in image_list_files if fn.startswith('images_') and fn.endswith('.csv')]
|
|
106
|
+
image_list_files = \
|
|
107
|
+
[os.path.join(download_folder,fn) for fn in image_list_files]
|
|
108
|
+
print('Found {} image list files'.format(len(image_list_files)))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
##%% Read lists of images by deployment
|
|
112
|
+
|
|
113
|
+
image_id_to_image_records = defaultdict(list)
|
|
114
|
+
|
|
115
|
+
# image_list_file = image_list_files[0]
|
|
116
|
+
for image_list_file in image_list_files:
|
|
117
|
+
|
|
118
|
+
print('Reading images from list file {}'.format(
|
|
119
|
+
os.path.basename(image_list_file)))
|
|
120
|
+
|
|
121
|
+
df = pd.read_csv(image_list_file)
|
|
122
|
+
|
|
123
|
+
# i_row = 0; row = df.iloc[i_row]
|
|
124
|
+
for i_row,row in tqdm(df.iterrows(),total=len(df)):
|
|
125
|
+
|
|
126
|
+
row_dict = row.to_dict()
|
|
127
|
+
image_id = row_dict['image_id']
|
|
128
|
+
image_id_to_image_records[image_id].append(row_dict)
|
|
129
|
+
|
|
130
|
+
# ...for each image
|
|
131
|
+
|
|
132
|
+
# ...for each list file
|
|
133
|
+
|
|
134
|
+
deployment_ids = set()
|
|
135
|
+
for image_id in image_id_to_image_records:
|
|
136
|
+
image_records = image_id_to_image_records[image_id]
|
|
137
|
+
for image_record in image_records:
|
|
138
|
+
deployment_ids.add(image_record['deployment_id'])
|
|
139
|
+
|
|
140
|
+
print('Found {} rows in {} deployments'.format(
|
|
141
|
+
len(image_id_to_image_records),
|
|
142
|
+
len(deployment_ids)))
|
|
143
|
+
|
|
144
|
+
return image_id_to_image_records
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def find_images_in_identify_tab(download_folder_with_identify,download_folder_excluding_identify):
|
|
148
|
+
"""
|
|
149
|
+
Based on extracted download packages with and without the "exclude images in 'identify' tab
|
|
150
|
+
checkbox" checked, figure out which images are in the identify tab. Returns a list of dicts (one
|
|
151
|
+
per image).
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
download_folder_with_identify (str): the folder containing the download bundle that
|
|
155
|
+
includes images from the "identify" tab
|
|
156
|
+
download_folder_excluding_identify (str): the folder containing the download bundle that
|
|
157
|
+
excludes images from the "identify" tab
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
list of dict: list of image records that are present in the identify tab
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
##%% Read data (~30 seconds)
|
|
164
|
+
|
|
165
|
+
image_id_to_image_records_with_identify = \
|
|
166
|
+
read_images_from_download_bundle(download_folder_with_identify)
|
|
167
|
+
image_id_to_image_records_excluding_identify = \
|
|
168
|
+
read_images_from_download_bundle(download_folder_excluding_identify)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
##%% Find images that have not been identified
|
|
172
|
+
|
|
173
|
+
all_image_ids_with_identify = set(image_id_to_image_records_with_identify.keys())
|
|
174
|
+
all_image_ids_excluding_identify = set(image_id_to_image_records_excluding_identify.keys())
|
|
175
|
+
|
|
176
|
+
image_ids_in_identify_tab = all_image_ids_with_identify.difference(all_image_ids_excluding_identify)
|
|
177
|
+
|
|
178
|
+
assert len(image_ids_in_identify_tab) == \
|
|
179
|
+
len(all_image_ids_with_identify) - len(all_image_ids_excluding_identify)
|
|
180
|
+
|
|
181
|
+
print('Found {} images with identify, {} in identify tab, {} excluding'.format(
|
|
182
|
+
len(all_image_ids_with_identify),
|
|
183
|
+
len(image_ids_in_identify_tab),
|
|
184
|
+
len(all_image_ids_excluding_identify)))
|
|
185
|
+
|
|
186
|
+
image_records_in_identify_tab = []
|
|
187
|
+
deployment_ids_for_downloaded_images = set()
|
|
188
|
+
|
|
189
|
+
for image_id in image_ids_in_identify_tab:
|
|
190
|
+
image_records_this_image = image_id_to_image_records_with_identify[image_id]
|
|
191
|
+
assert len(image_records_this_image) > 0
|
|
192
|
+
image_records_in_identify_tab.extend(image_records_this_image)
|
|
193
|
+
for image_record in image_records_this_image:
|
|
194
|
+
deployment_ids_for_downloaded_images.add(image_record['deployment_id'])
|
|
195
|
+
|
|
196
|
+
print('Found {} records for {} unique images in {} deployments'.format(
|
|
197
|
+
len(image_records_in_identify_tab),
|
|
198
|
+
len(image_ids_in_identify_tab),
|
|
199
|
+
len(deployment_ids_for_downloaded_images)))
|
|
200
|
+
|
|
201
|
+
return image_records_in_identify_tab
|
|
202
|
+
|
|
203
|
+
# ...def find_images_in_identify_tab(...)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def write_download_commands(image_records_to_download,
|
|
207
|
+
download_dir_base,
|
|
208
|
+
force_download=False,
|
|
209
|
+
n_download_workers=25,
|
|
210
|
+
download_command_file_base=None):
|
|
211
|
+
"""
|
|
212
|
+
Given a list of dicts with at least the field 'location' (a gs:// URL), prepare a set of "gcloud
|
|
213
|
+
storage" commands to download images, and write those to a series of .sh scripts, along with one
|
|
214
|
+
.sh script that runs all the others and blocks.
|
|
215
|
+
|
|
216
|
+
gcloud commands will use relative paths.
|
|
217
|
+
|
|
218
|
+
image_records_to_download can also be a dict mapping IDs to lists of records.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
image_records_to_download (list of dict): list of dicts with at least the field 'location'
|
|
222
|
+
download_dir_base (str): local destination folder
|
|
223
|
+
force_download (bool, optional): include gs commands even if the target file exists
|
|
224
|
+
n_download_workers (int, optional): number of scripts to write (that's our hacky way
|
|
225
|
+
of controlling parallelization)
|
|
226
|
+
download_command_file_base (str, optional): path of the .sh script we should write, defaults
|
|
227
|
+
to "download_wi_images.sh" in the destination folder. Individual worker scripts will
|
|
228
|
+
have a number added, e.g. download_wi_images_00.sh.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
if isinstance(image_records_to_download,dict):
|
|
232
|
+
|
|
233
|
+
all_image_records = []
|
|
234
|
+
for k in image_records_to_download:
|
|
235
|
+
records_this_image = image_records_to_download[k]
|
|
236
|
+
all_image_records.extend(records_this_image)
|
|
237
|
+
return write_download_commands(all_image_records,
|
|
238
|
+
download_dir_base=download_dir_base,
|
|
239
|
+
force_download=force_download,
|
|
240
|
+
n_download_workers=n_download_workers,
|
|
241
|
+
download_command_file_base=download_command_file_base)
|
|
242
|
+
|
|
243
|
+
##%% Make list of gcloud storage commands
|
|
244
|
+
|
|
245
|
+
if download_command_file_base is None:
|
|
246
|
+
download_command_file_base = os.path.join(download_dir_base,'download_wi_images.sh')
|
|
247
|
+
|
|
248
|
+
commands = []
|
|
249
|
+
skipped_urls = []
|
|
250
|
+
downloaded_urls = set()
|
|
251
|
+
|
|
252
|
+
# image_record = image_records_to_download[0]
|
|
253
|
+
for image_record in tqdm(image_records_to_download):
|
|
254
|
+
|
|
255
|
+
url = image_record['location']
|
|
256
|
+
if url in downloaded_urls:
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
assert url.startswith('gs://')
|
|
260
|
+
|
|
261
|
+
relative_path = url.replace('gs://','')
|
|
262
|
+
abs_path = os.path.join(download_dir_base,relative_path)
|
|
263
|
+
|
|
264
|
+
# Skip files that already exist
|
|
265
|
+
if (not force_download) and (os.path.isfile(abs_path)):
|
|
266
|
+
skipped_urls.append(url)
|
|
267
|
+
continue
|
|
268
|
+
|
|
269
|
+
# command = 'gsutil cp "{}" "./{}"'.format(url,relative_path)
|
|
270
|
+
command = 'gcloud storage cp --no-clobber "{}" "./{}"'.format(url,relative_path)
|
|
271
|
+
commands.append(command)
|
|
272
|
+
|
|
273
|
+
print('Generated {} commands for {} image records'.format(
|
|
274
|
+
len(commands),len(image_records_to_download)))
|
|
275
|
+
|
|
276
|
+
print('Skipped {} URLs'.format(len(skipped_urls)))
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
##%% Write those commands out to n .sh files
|
|
280
|
+
|
|
281
|
+
commands_by_script = split_list_into_n_chunks(commands,n_download_workers)
|
|
282
|
+
|
|
283
|
+
local_download_commands = []
|
|
284
|
+
|
|
285
|
+
output_dir = os.path.dirname(download_command_file_base)
|
|
286
|
+
os.makedirs(output_dir,exist_ok=True)
|
|
287
|
+
|
|
288
|
+
# Write out the download script for each chunk
|
|
289
|
+
# i_script = 0
|
|
290
|
+
for i_script in range(0,n_download_workers):
|
|
291
|
+
download_command_file = insert_before_extension(download_command_file_base,str(i_script).zfill(2))
|
|
292
|
+
local_download_commands.append(os.path.basename(download_command_file))
|
|
293
|
+
with open(download_command_file,'w',newline='\n') as f:
|
|
294
|
+
for command in commands_by_script[i_script]:
|
|
295
|
+
f.write(command + '\n')
|
|
296
|
+
|
|
297
|
+
# Write out the main download script
|
|
298
|
+
with open(download_command_file_base,'w',newline='\n') as f:
|
|
299
|
+
for local_download_command in local_download_commands:
|
|
300
|
+
f.write('./' + local_download_command + ' &\n')
|
|
301
|
+
f.write('wait\n')
|
|
302
|
+
f.write('echo done\n')
|
|
303
|
+
|
|
304
|
+
# ...def write_download_commands(...)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
#%% Functions and constants related to pushing results to the DB
|
|
308
|
+
|
|
309
|
+
# Sample payload for validation
|
|
310
|
+
sample_update_payload = {
|
|
311
|
+
|
|
312
|
+
"predictions": [
|
|
313
|
+
{
|
|
314
|
+
"project_id": "1234",
|
|
315
|
+
"ignore_data_file_checks": True,
|
|
316
|
+
"prediction": "f1856211-cfb7-4a5b-9158-c0f72fd09ee6;;;;;;blank",
|
|
317
|
+
"prediction_score": 0.81218224763870239,
|
|
318
|
+
"classifications": {
|
|
319
|
+
"classes": [
|
|
320
|
+
"f1856211-cfb7-4a5b-9158-c0f72fd09ee6;;;;;;blank",
|
|
321
|
+
"b1352069-a39c-4a84-a949-60044271c0c1;aves;;;;;bird",
|
|
322
|
+
"90d950db-2106-4bd9-a4c1-777604c3eada;mammalia;rodentia;;;;rodent",
|
|
323
|
+
"f2d233e3-80e3-433d-9687-e29ecc7a467a;mammalia;;;;;mammal",
|
|
324
|
+
"ac068717-6079-4aec-a5ab-99e8d14da40b;mammalia;rodentia;sciuridae;dremomys;rufigenis;red-cheeked squirrel"
|
|
325
|
+
],
|
|
326
|
+
"scores": [
|
|
327
|
+
0.81218224763870239,
|
|
328
|
+
0.1096673980355263,
|
|
329
|
+
0.02707692421972752,
|
|
330
|
+
0.00771023565903306,
|
|
331
|
+
0.0049269795417785636
|
|
332
|
+
]
|
|
333
|
+
},
|
|
334
|
+
"detections": [
|
|
335
|
+
{
|
|
336
|
+
"category": "1",
|
|
337
|
+
"label": "animal",
|
|
338
|
+
"conf": 0.181,
|
|
339
|
+
"bbox": [
|
|
340
|
+
0.02421,
|
|
341
|
+
0.35823999999999989,
|
|
342
|
+
0.051560000000000009,
|
|
343
|
+
0.070826666666666746
|
|
344
|
+
]
|
|
345
|
+
}
|
|
346
|
+
],
|
|
347
|
+
"model_version": "3.1.2",
|
|
348
|
+
"prediction_source": "manual_update",
|
|
349
|
+
"data_file_id": "2ea1d2b2-7f84-43f9-af1f-8be0e69c7015"
|
|
350
|
+
}
|
|
351
|
+
]
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
process_cv_response_url = 'https://placeholder'
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def prepare_data_update_auth_headers(auth_token_file):
|
|
358
|
+
"""
|
|
359
|
+
Read the authorization token from a text file and prepare http headers.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
auth_token_file (str): a single-line text file containing a write-enabled
|
|
363
|
+
API token.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
dict: http headers, with fields 'Authorization' and 'Content-Type'
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
with open(auth_token_file,'r') as f:
|
|
370
|
+
auth_token = f.read()
|
|
371
|
+
|
|
372
|
+
headers = {
|
|
373
|
+
'Authorization': 'Bearer ' + auth_token,
|
|
374
|
+
'Content-Type': 'application/json'
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
return headers
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def push_results_for_images(payload,
|
|
381
|
+
headers,
|
|
382
|
+
url=process_cv_response_url,
|
|
383
|
+
verbose=False):
|
|
384
|
+
"""
|
|
385
|
+
Push results for one or more images represented in [payload] to the
|
|
386
|
+
process_cv_response API, to write to the WI DB.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
payload (dict): payload to upload to the API
|
|
390
|
+
headers (dict): authorization headers, see prepare_data_update_auth_headers
|
|
391
|
+
url (str, optional): API URL
|
|
392
|
+
verbose (bool, optional): enable additional debug output
|
|
393
|
+
|
|
394
|
+
Return:
|
|
395
|
+
int: response status code
|
|
396
|
+
"""
|
|
397
|
+
|
|
398
|
+
if verbose:
|
|
399
|
+
print('Sending header {} to URL {}'.format(
|
|
400
|
+
headers,url))
|
|
401
|
+
|
|
402
|
+
response = requests.post(url, headers=headers, json=payload)
|
|
403
|
+
|
|
404
|
+
# Check the response status code
|
|
405
|
+
if response.status_code in (200,201):
|
|
406
|
+
if verbose:
|
|
407
|
+
print('Successfully pushed results for {} images'.format(len(payload['predictions'])))
|
|
408
|
+
print(response.headers)
|
|
409
|
+
print(str(response))
|
|
410
|
+
else:
|
|
411
|
+
print(f'Error: {response.status_code} {response.text}')
|
|
412
|
+
|
|
413
|
+
return response.status_code
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def parallel_push_results_for_images(payloads,
|
|
417
|
+
headers,
|
|
418
|
+
url=process_cv_response_url,
|
|
419
|
+
verbose=False,
|
|
420
|
+
pool_type='thread',
|
|
421
|
+
n_workers=10):
|
|
422
|
+
"""
|
|
423
|
+
Push results for the list of payloads in [payloads] to the process_cv_response API,
|
|
424
|
+
parallelized over multiple workers.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
payloads (list of dict): payloads to upload to the API
|
|
428
|
+
headers (dict): authorization headers, see prepare_data_update_auth_headers
|
|
429
|
+
url (str, optional): API URL
|
|
430
|
+
verbose (bool, optional): enable additional debug output
|
|
431
|
+
pool_type (str, optional): 'thread' or 'process'
|
|
432
|
+
n_workers (int, optional): number of parallel workers
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
list of int: list of http response codes, one per payload
|
|
436
|
+
"""
|
|
437
|
+
|
|
438
|
+
if n_workers == 1:
|
|
439
|
+
|
|
440
|
+
results = []
|
|
441
|
+
for payload in payloads:
|
|
442
|
+
results.append(push_results_for_images(payload,
|
|
443
|
+
headers=headers,
|
|
444
|
+
url=url,
|
|
445
|
+
verbose=verbose))
|
|
446
|
+
return results
|
|
447
|
+
|
|
448
|
+
else:
|
|
449
|
+
|
|
450
|
+
assert pool_type in ('thread','process')
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
if pool_type == 'thread':
|
|
454
|
+
pool_string = 'thread'
|
|
455
|
+
pool = ThreadPool(n_workers)
|
|
456
|
+
else:
|
|
457
|
+
pool_string = 'process'
|
|
458
|
+
pool = Pool(n_workers)
|
|
459
|
+
|
|
460
|
+
print('Created a {} pool of {} workers'.format(
|
|
461
|
+
pool_string,n_workers))
|
|
462
|
+
|
|
463
|
+
results = list(tqdm(pool.imap(
|
|
464
|
+
partial(push_results_for_images,headers=headers,url=url,verbose=verbose),payloads),
|
|
465
|
+
total=len(payloads)))
|
|
466
|
+
finally:
|
|
467
|
+
pool.close()
|
|
468
|
+
pool.join()
|
|
469
|
+
print("Pool closed and joined for WI result uploads")
|
|
470
|
+
|
|
471
|
+
assert len(results) == len(payloads)
|
|
472
|
+
return results
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def generate_payload_with_replacement_detections(wi_result,
|
|
476
|
+
detections,
|
|
477
|
+
prediction_score=0.9,
|
|
478
|
+
model_version='3.1.2',
|
|
479
|
+
prediction_source='manual_update'):
|
|
480
|
+
"""
|
|
481
|
+
Generate a payload for a single image that keeps the classifications from
|
|
482
|
+
[wi_result], but replaces the detections with the MD-formatted list [detections].
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
wi_result (dict): dict representing a WI prediction result, with at least the
|
|
486
|
+
fields in the constant wi_result_fields
|
|
487
|
+
detections (list): list of WI-formatted detection dicts (with fields ['conf'] and ['category'])
|
|
488
|
+
prediction_score (float, optional): confidence value to use for the combined prediction
|
|
489
|
+
model_version (str, optional): model version string to include in the payload
|
|
490
|
+
prediction_source (str, optional): prediction source string to include in the payload
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
dict: dictionary suitable for uploading via push_results_for_images
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
payload_detections = []
|
|
497
|
+
|
|
498
|
+
# detection = detections[0]
|
|
499
|
+
for detection in detections:
|
|
500
|
+
detection_out = detection.copy()
|
|
501
|
+
detection_out['label'] = md_category_id_to_name[detection['category']]
|
|
502
|
+
if detection_out['conf'] < min_md_output_confidence:
|
|
503
|
+
detection_out['conf'] = min_md_output_confidence
|
|
504
|
+
payload_detections.append(detection_out)
|
|
505
|
+
|
|
506
|
+
prediction_string = wi_result_to_prediction_string(wi_result)
|
|
507
|
+
|
|
508
|
+
prediction = {}
|
|
509
|
+
prediction['ignore_data_file_checks'] = True
|
|
510
|
+
prediction['prediction'] = prediction_string
|
|
511
|
+
prediction['prediction_score'] = prediction_score
|
|
512
|
+
|
|
513
|
+
classifications = {}
|
|
514
|
+
classifications['classes'] = [prediction_string]
|
|
515
|
+
classifications['scores'] = [prediction_score]
|
|
516
|
+
|
|
517
|
+
prediction['classifications'] = classifications
|
|
518
|
+
prediction['detections'] = payload_detections
|
|
519
|
+
prediction['model_version'] = model_version
|
|
520
|
+
prediction['prediction_source'] = prediction_source
|
|
521
|
+
prediction['data_file_id'] = wi_result['image_id']
|
|
522
|
+
prediction['project_id'] = str(wi_result['project_id'])
|
|
523
|
+
payload = {}
|
|
524
|
+
payload['predictions'] = [prediction]
|
|
525
|
+
|
|
526
|
+
return payload
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def generate_blank_prediction_payload(data_file_id,
|
|
530
|
+
project_id,
|
|
531
|
+
blank_confidence=0.9,
|
|
532
|
+
model_version='3.1.2',
|
|
533
|
+
prediction_source='manual_update'):
|
|
534
|
+
"""
|
|
535
|
+
Generate a payload that will set a single image to the blank classification, with
|
|
536
|
+
no detections. Suitable for upload via push_results_for_images.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
data_file_id (str): unique identifier for this image used in the WI DB
|
|
540
|
+
project_id (int): WI project ID
|
|
541
|
+
blank_confidence (float, optional): confidence value to associate with this
|
|
542
|
+
prediction
|
|
543
|
+
model_version (str, optional): model version string to include in the payload
|
|
544
|
+
prediction_source (str, optional): prediction source string to include in the payload
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
dict: dictionary suitable for uploading via push_results_for_images
|
|
548
|
+
"""
|
|
549
|
+
|
|
550
|
+
prediction = {}
|
|
551
|
+
prediction['ignore_data_file_checks'] = True
|
|
552
|
+
prediction['prediction'] = blank_prediction_string
|
|
553
|
+
prediction['prediction_score'] = blank_confidence
|
|
554
|
+
prediction['classifications'] = {}
|
|
555
|
+
prediction['classifications']['classes'] = [blank_prediction_string]
|
|
556
|
+
prediction['classifications']['scores'] = [blank_confidence]
|
|
557
|
+
prediction['detections'] = []
|
|
558
|
+
prediction['model_version'] = model_version
|
|
559
|
+
prediction['prediction_source'] = prediction_source
|
|
560
|
+
prediction['data_file_id'] = data_file_id
|
|
561
|
+
prediction['project_id'] = project_id
|
|
562
|
+
payload = {}
|
|
563
|
+
payload['predictions'] = [prediction]
|
|
564
|
+
|
|
565
|
+
return payload
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def generate_no_cv_result_payload(data_file_id,
|
|
569
|
+
project_id,
|
|
570
|
+
no_cv_confidence=0.9,
|
|
571
|
+
model_version='3.1.2',
|
|
572
|
+
prediction_source='manual_update'):
|
|
573
|
+
"""
|
|
574
|
+
Generate a payload that will set a single image to the blank classification, with
|
|
575
|
+
no detections. Suitable for uploading via push_results_for_images.
|
|
576
|
+
|
|
577
|
+
Args:
|
|
578
|
+
data_file_id (str): unique identifier for this image used in the WI DB
|
|
579
|
+
project_id (int): WI project ID
|
|
580
|
+
no_cv_confidence (float, optional): confidence value to associate with this
|
|
581
|
+
prediction
|
|
582
|
+
model_version (str, optional): model version string to include in the payload
|
|
583
|
+
prediction_source (str, optional): prediction source string to include in the payload
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
dict: dictionary suitable for uploading via push_results_for_images
|
|
587
|
+
"""
|
|
588
|
+
|
|
589
|
+
prediction = {}
|
|
590
|
+
prediction['ignore_data_file_checks'] = True
|
|
591
|
+
prediction['prediction'] = no_cv_result_prediction_string
|
|
592
|
+
prediction['prediction_score'] = no_cv_confidence
|
|
593
|
+
prediction['classifications'] = {}
|
|
594
|
+
prediction['classifications']['classes'] = [no_cv_result_prediction_string]
|
|
595
|
+
prediction['classifications']['scores'] = [no_cv_confidence]
|
|
596
|
+
prediction['detections'] = []
|
|
597
|
+
prediction['model_version'] = model_version
|
|
598
|
+
prediction['prediction_source'] = prediction_source
|
|
599
|
+
prediction['data_file_id'] = data_file_id
|
|
600
|
+
prediction['project_id'] = project_id
|
|
601
|
+
payload = {}
|
|
602
|
+
payload['predictions'] = [prediction]
|
|
603
|
+
|
|
604
|
+
return payload
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def generate_payload_for_prediction_string(data_file_id,
|
|
608
|
+
project_id,
|
|
609
|
+
prediction_string,
|
|
610
|
+
prediction_confidence=0.8,
|
|
611
|
+
detections=None,
|
|
612
|
+
model_version='3.1.2',
|
|
613
|
+
prediction_source='manual_update'):
|
|
614
|
+
"""
|
|
615
|
+
Generate a payload that will set a single image to a particular prediction, optionally
|
|
616
|
+
including detections. Suitable for uploading via push_results_for_images.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
data_file_id (str): unique identifier for this image used in the WI DB
|
|
620
|
+
project_id (int): WI project ID
|
|
621
|
+
prediction_string (str): WI-formatted prediction string to include in the payload
|
|
622
|
+
prediction_confidence (float, optional): confidence value to associate with this
|
|
623
|
+
prediction
|
|
624
|
+
detections (list, optional): list of MD-formatted detection dicts, with fields
|
|
625
|
+
['category'] and 'conf'
|
|
626
|
+
model_version (str, optional): model version string to include in the payload
|
|
627
|
+
prediction_source (str, optional): prediction source string to include in the payload
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
dict: dictionary suitable for uploading via push_results_for_images
|
|
632
|
+
"""
|
|
633
|
+
|
|
634
|
+
assert is_valid_prediction_string(prediction_string), \
|
|
635
|
+
'Invalid prediction string: {}'.format(prediction_string)
|
|
636
|
+
|
|
637
|
+
payload_detections = []
|
|
638
|
+
|
|
639
|
+
if detections is not None:
|
|
640
|
+
# detection = detections[0]
|
|
641
|
+
for detection in detections:
|
|
642
|
+
detection_out = detection.copy()
|
|
643
|
+
detection_out['label'] = md_category_id_to_name[detection['category']]
|
|
644
|
+
if detection_out['conf'] < min_md_output_confidence:
|
|
645
|
+
detection_out['conf'] = min_md_output_confidence
|
|
646
|
+
payload_detections.append(detection_out)
|
|
647
|
+
|
|
648
|
+
prediction = {}
|
|
649
|
+
prediction['ignore_data_file_checks'] = True
|
|
650
|
+
prediction['prediction'] = prediction_string
|
|
651
|
+
prediction['prediction_score'] = prediction_confidence
|
|
652
|
+
prediction['classifications'] = {}
|
|
653
|
+
prediction['classifications']['classes'] = [prediction_string]
|
|
654
|
+
prediction['classifications']['scores'] = [prediction_confidence]
|
|
655
|
+
prediction['detections'] = payload_detections
|
|
656
|
+
prediction['model_version'] = model_version
|
|
657
|
+
prediction['prediction_source'] = prediction_source
|
|
658
|
+
prediction['data_file_id'] = data_file_id
|
|
659
|
+
prediction['project_id'] = project_id
|
|
660
|
+
|
|
661
|
+
payload = {}
|
|
662
|
+
payload['predictions'] = [prediction]
|
|
663
|
+
|
|
664
|
+
return payload
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def validate_payload(payload):
|
|
668
|
+
"""
|
|
669
|
+
Verifies that the dict [payload] is compatible with the ProcessCVResponse() API. Throws an
|
|
670
|
+
error if [payload] is invalid.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
payload (dict): payload in the format expected by push_results_for_images.
|
|
674
|
+
|
|
675
|
+
Returns:
|
|
676
|
+
bool: successful validation; this is just future-proofing, currently never returns False
|
|
677
|
+
"""
|
|
678
|
+
|
|
679
|
+
assert isinstance(payload,dict)
|
|
680
|
+
assert len(payload.keys()) == 1 and 'predictions' in payload
|
|
681
|
+
|
|
682
|
+
# prediction = payload['predictions'][0]
|
|
683
|
+
for prediction in payload['predictions']:
|
|
684
|
+
|
|
685
|
+
assert 'project_id' in prediction
|
|
686
|
+
if not isinstance(prediction['project_id'],int):
|
|
687
|
+
_ = int(prediction['project_id'])
|
|
688
|
+
assert 'ignore_data_file_checks' in prediction and \
|
|
689
|
+
isinstance(prediction['ignore_data_file_checks'],bool)
|
|
690
|
+
assert 'prediction' in prediction and \
|
|
691
|
+
isinstance(prediction['prediction'],str) and \
|
|
692
|
+
len(prediction['prediction'].split(';')) == 7
|
|
693
|
+
assert 'prediction_score' in prediction and \
|
|
694
|
+
isinstance(prediction['prediction_score'],float)
|
|
695
|
+
assert 'model_version' in prediction and \
|
|
696
|
+
isinstance(prediction['model_version'],str)
|
|
697
|
+
assert 'data_file_id' in prediction and \
|
|
698
|
+
isinstance(prediction['data_file_id'],str) and \
|
|
699
|
+
len(prediction['data_file_id']) == 36
|
|
700
|
+
assert 'classifications' in prediction and \
|
|
701
|
+
isinstance(prediction['classifications'],dict)
|
|
702
|
+
classifications = prediction['classifications']
|
|
703
|
+
assert 'classes' in classifications and isinstance(classifications['classes'],list)
|
|
704
|
+
assert 'scores' in classifications and isinstance(classifications['scores'],list)
|
|
705
|
+
assert len(classifications['classes']) == len(classifications['scores'])
|
|
706
|
+
for c in classifications['classes']:
|
|
707
|
+
assert is_valid_prediction_string(c)
|
|
708
|
+
for score in classifications['scores']:
|
|
709
|
+
assert isinstance(score,float) and score >= 0 and score <= 1.0
|
|
710
|
+
assert 'detections' in prediction and isinstance(prediction['detections'],list)
|
|
711
|
+
|
|
712
|
+
for detection in prediction['detections']:
|
|
713
|
+
|
|
714
|
+
assert isinstance(detection,dict)
|
|
715
|
+
assert 'category' in detection and detection['category'] in ('1','2','3')
|
|
716
|
+
assert 'label' in detection and detection['label'] in ('animal','person','vehicle')
|
|
717
|
+
assert 'conf' in detection and \
|
|
718
|
+
isinstance(detection['conf'],float) and \
|
|
719
|
+
detection['conf'] >= 0 and detection['conf'] <= 1.0
|
|
720
|
+
assert 'bbox' in detection and \
|
|
721
|
+
isinstance(detection['bbox'],list) and \
|
|
722
|
+
len(detection['bbox']) == 4
|
|
723
|
+
|
|
724
|
+
# ...for each detection
|
|
725
|
+
|
|
726
|
+
# ...for each prediction
|
|
727
|
+
|
|
728
|
+
return True
|
|
729
|
+
|
|
730
|
+
# ...def validate_payload(...)
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
#%% Functions for working with WI results (from the API or from download bundles)
|
|
734
|
+
|
|
735
|
+
def wi_result_to_prediction_string(r):
|
|
736
|
+
"""
|
|
737
|
+
Convert the dict [r] - typically loaded from a row in a downloaded .csv file - to
|
|
738
|
+
a valid prediction string, e.g.:
|
|
739
|
+
|
|
740
|
+
1f689929-883d-4dae-958c-3d57ab5b6c16;;;;;;animal
|
|
741
|
+
90d950db-2106-4bd9-a4c1-777604c3eada;mammalia;rodentia;;;;rodent
|
|
742
|
+
|
|
743
|
+
Args:
|
|
744
|
+
r (dict): dict containing WI prediction information, with at least the fields
|
|
745
|
+
specified in wi_result_fields.
|
|
746
|
+
|
|
747
|
+
Returns:
|
|
748
|
+
str: the result in [r], as a semicolon-delimited prediction string
|
|
749
|
+
"""
|
|
750
|
+
|
|
751
|
+
values = []
|
|
752
|
+
for field in wi_result_fields:
|
|
753
|
+
if isinstance(r[field],str):
|
|
754
|
+
values.append(r[field].lower())
|
|
755
|
+
else:
|
|
756
|
+
assert isinstance(r[field],float) and np.isnan(r[field])
|
|
757
|
+
values.append('')
|
|
758
|
+
s = ';'.join(values)
|
|
759
|
+
assert is_valid_prediction_string(s)
|
|
760
|
+
return s
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def record_is_unidentified(record):
|
|
764
|
+
"""
|
|
765
|
+
A record is considered "unidentified" if the "identified by" field is either NaN or "computer vision"
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
record (dict): dict representing a WI result loaded from a .csv file, with at least the
|
|
769
|
+
field "identified_by"
|
|
770
|
+
|
|
771
|
+
Returns:
|
|
772
|
+
bool: True if the "identified_by" field is either NaN or a string indicating that this
|
|
773
|
+
record has not yet been human-reviewed.
|
|
774
|
+
"""
|
|
775
|
+
|
|
776
|
+
identified_by = record['identified_by']
|
|
777
|
+
assert isinstance(identified_by,float) or isinstance(identified_by,str)
|
|
778
|
+
if isinstance(identified_by,float):
|
|
779
|
+
assert np.isnan(identified_by)
|
|
780
|
+
return True
|
|
781
|
+
else:
|
|
782
|
+
return identified_by == 'Computer vision'
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def record_lists_are_identical(records_0,records_1,verbose=False):
|
|
786
|
+
"""
|
|
787
|
+
Takes two lists of records in the form returned by read_images_from_download_bundle and
|
|
788
|
+
determines whether they are the same.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
records_0 (list of dict): the first list of records to compare
|
|
792
|
+
records_1 (list of dict): the second list of records to compare
|
|
793
|
+
verbose (bool, optional): enable additional debug output
|
|
794
|
+
|
|
795
|
+
Returns:
|
|
796
|
+
bool: True if the two lists are identical
|
|
797
|
+
"""
|
|
798
|
+
|
|
799
|
+
if len(records_0) != len(records_1):
|
|
800
|
+
return False
|
|
801
|
+
|
|
802
|
+
# i_record = 0; record_0 = records_0[i_record]
|
|
803
|
+
for i_record,record_0 in enumerate(records_0):
|
|
804
|
+
record_1 = records_1[i_record]
|
|
805
|
+
assert set(record_0.keys()) == set(record_1.keys())
|
|
806
|
+
for k in record_0.keys():
|
|
807
|
+
if not compare_values_nan_equal(record_0[k],record_1[k]):
|
|
808
|
+
if verbose:
|
|
809
|
+
print('Image ID: {} ({})\nRecord 0/{}: {}\nRecord 1/{}: {}'.format(
|
|
810
|
+
record_0['image_id'],record_1['image_id'],
|
|
811
|
+
k,record_0[k],k,record_1[k]))
|
|
812
|
+
return False
|
|
813
|
+
|
|
814
|
+
return True
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
#%% Validate constants
|
|
818
|
+
|
|
819
|
+
# This is executed at the time this module gets imported.
|
|
820
|
+
|
|
821
|
+
blank_payload = generate_blank_prediction_payload('70ede9c6-d056-4dd1-9a0b-3098d8113e0e','1234')
|
|
822
|
+
validate_payload(sample_update_payload)
|
|
823
|
+
validate_payload(blank_payload)
|
|
824
|
+
|