megadetector 10.0.3__py3-none-any.whl → 10.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (32) hide show
  1. megadetector/data_management/animl_to_md.py +158 -0
  2. megadetector/data_management/cct_json_utils.py +1 -0
  3. megadetector/data_management/speciesnet_to_md.py +2 -2
  4. megadetector/data_management/zamba_to_md.py +188 -0
  5. megadetector/detection/process_video.py +52 -40
  6. megadetector/detection/pytorch_detector.py +24 -34
  7. megadetector/detection/run_detector_batch.py +138 -93
  8. megadetector/detection/run_md_and_speciesnet.py +22 -4
  9. megadetector/detection/video_utils.py +5 -4
  10. megadetector/postprocessing/classification_postprocessing.py +26 -10
  11. megadetector/postprocessing/combine_batch_outputs.py +2 -0
  12. megadetector/postprocessing/generate_csv_report.py +1 -1
  13. megadetector/postprocessing/load_api_results.py +1 -1
  14. megadetector/postprocessing/md_to_wi.py +1 -1
  15. megadetector/postprocessing/postprocess_batch_results.py +1 -1
  16. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1 -1
  17. megadetector/postprocessing/separate_detections_into_folders.py +1 -1
  18. megadetector/postprocessing/subset_json_detector_output.py +1 -3
  19. megadetector/utils/ct_utils.py +71 -0
  20. megadetector/utils/md_tests.py +8 -7
  21. megadetector/utils/path_utils.py +4 -15
  22. megadetector/utils/wi_platform_utils.py +824 -0
  23. megadetector/utils/wi_taxonomy_utils.py +1711 -0
  24. megadetector/visualization/visualization_utils.py +1 -1
  25. megadetector/visualization/visualize_detector_output.py +7 -5
  26. megadetector/visualization/visualize_video_output.py +1 -1
  27. {megadetector-10.0.3.dist-info → megadetector-10.0.5.dist-info}/METADATA +2 -2
  28. {megadetector-10.0.3.dist-info → megadetector-10.0.5.dist-info}/RECORD +31 -28
  29. megadetector/utils/wi_utils.py +0 -2674
  30. {megadetector-10.0.3.dist-info → megadetector-10.0.5.dist-info}/WHEEL +0 -0
  31. {megadetector-10.0.3.dist-info → megadetector-10.0.5.dist-info}/licenses/LICENSE +0 -0
  32. {megadetector-10.0.3.dist-info → megadetector-10.0.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,824 @@
1
+ """
2
+
3
+ wi_platform_utils.py
4
+
5
+ Utility functions for working with the Wildlife Insights platform, specifically:
6
+
7
+ * Retrieving images based on .csv downloads
8
+ * Pushing results to the ProcessCVResponse() API (requires an API key)
9
+
10
+ """
11
+
12
+ #%% Imports
13
+
14
+ import os
15
+ import requests
16
+
17
+ import pandas as pd
18
+ import numpy as np
19
+
20
+ from tqdm import tqdm
21
+ from collections import defaultdict
22
+
23
+ from multiprocessing.pool import Pool, ThreadPool
24
+ from functools import partial
25
+
26
+ from megadetector.utils.path_utils import insert_before_extension
27
+ from megadetector.utils.ct_utils import split_list_into_n_chunks
28
+ from megadetector.utils.ct_utils import invert_dictionary
29
+ from megadetector.utils.ct_utils import compare_values_nan_equal
30
+
31
+ from megadetector.utils.wi_taxonomy_utils import is_valid_prediction_string
32
+ from megadetector.utils.wi_taxonomy_utils import no_cv_result_prediction_string
33
+ from megadetector.utils.wi_taxonomy_utils import blank_prediction_string
34
+
35
+ from megadetector.detection.run_detector import DEFAULT_DETECTOR_LABEL_MAP
36
+
37
+ # Only used when pushing results directly to the platform via the API; any detections we want
38
+ # to show in the UI should have at least this confidence value.
39
+ min_md_output_confidence = 0.25
40
+
41
+ md_category_id_to_name = DEFAULT_DETECTOR_LABEL_MAP
42
+ md_category_name_to_id = invert_dictionary(md_category_id_to_name)
43
+
44
+ # Fields expected to be present in a valid WI result
45
+ wi_result_fields = ['wi_taxon_id','class','order','family','genus','species','common_name']
46
+
47
+
48
+ #%% Functions for managing WI downloads
49
+
50
+ def read_sequences_from_download_bundle(download_folder):
51
+ """
52
+ Reads sequences.csv from [download_folder], returning a list of dicts. This is a
53
+ thin wrapper around pd.read_csv, it's just here for future-proofing.
54
+
55
+ Args:
56
+ download_folder (str): a folder containing exactly one file called sequences.csv, typically
57
+ representing a Wildlife Insights download bundle.
58
+
59
+ Returns:
60
+ list of dict: a direct conversion of the .csv file to a list of dicts
61
+ """
62
+
63
+ print('Reading sequences from {}'.format(download_folder))
64
+
65
+ sequence_list_files = os.listdir(download_folder)
66
+ sequence_list_files = \
67
+ [fn for fn in sequence_list_files if fn == 'sequences.csv']
68
+ assert len(sequence_list_files) == 1, \
69
+ 'Could not find sequences.csv in {}'.format(download_folder)
70
+
71
+ sequence_list_file = os.path.join(download_folder,sequence_list_files[0])
72
+
73
+ df = pd.read_csv(sequence_list_file)
74
+ sequence_records = df.to_dict('records')
75
+ return sequence_records
76
+
77
+
78
+ def read_images_from_download_bundle(download_folder):
79
+ """
80
+ Reads all images.csv files from [download_folder], returns a dict mapping image IDs
81
+ to a list of dicts that describe each image. It's a list of dicts rather than a single dict
82
+ because images may appear more than once.
83
+
84
+ Args:
85
+ download_folder (str): a folder containing one or more images.csv files, typically
86
+ representing a Wildlife Insights download bundle.
87
+
88
+ Returns:
89
+ dict: Maps image GUIDs to dicts with at least the following fields:
90
+ * project_id (int)
91
+ * deployment_id (str)
92
+ * image_id (str, should match the key)
93
+ * filename (str, the filename without path at the time of upload)
94
+ * location (str, starting with gs://)
95
+
96
+ May also contain classification fields: wi_taxon_id (str), species, etc.
97
+ """
98
+
99
+ print('Reading images from {}'.format(download_folder))
100
+
101
+ ##%% Find lists of images
102
+
103
+ image_list_files = os.listdir(download_folder)
104
+ image_list_files = \
105
+ [fn for fn in image_list_files if fn.startswith('images_') and fn.endswith('.csv')]
106
+ image_list_files = \
107
+ [os.path.join(download_folder,fn) for fn in image_list_files]
108
+ print('Found {} image list files'.format(len(image_list_files)))
109
+
110
+
111
+ ##%% Read lists of images by deployment
112
+
113
+ image_id_to_image_records = defaultdict(list)
114
+
115
+ # image_list_file = image_list_files[0]
116
+ for image_list_file in image_list_files:
117
+
118
+ print('Reading images from list file {}'.format(
119
+ os.path.basename(image_list_file)))
120
+
121
+ df = pd.read_csv(image_list_file)
122
+
123
+ # i_row = 0; row = df.iloc[i_row]
124
+ for i_row,row in tqdm(df.iterrows(),total=len(df)):
125
+
126
+ row_dict = row.to_dict()
127
+ image_id = row_dict['image_id']
128
+ image_id_to_image_records[image_id].append(row_dict)
129
+
130
+ # ...for each image
131
+
132
+ # ...for each list file
133
+
134
+ deployment_ids = set()
135
+ for image_id in image_id_to_image_records:
136
+ image_records = image_id_to_image_records[image_id]
137
+ for image_record in image_records:
138
+ deployment_ids.add(image_record['deployment_id'])
139
+
140
+ print('Found {} rows in {} deployments'.format(
141
+ len(image_id_to_image_records),
142
+ len(deployment_ids)))
143
+
144
+ return image_id_to_image_records
145
+
146
+
147
+ def find_images_in_identify_tab(download_folder_with_identify,download_folder_excluding_identify):
148
+ """
149
+ Based on extracted download packages with and without the "exclude images in 'identify' tab
150
+ checkbox" checked, figure out which images are in the identify tab. Returns a list of dicts (one
151
+ per image).
152
+
153
+ Args:
154
+ download_folder_with_identify (str): the folder containing the download bundle that
155
+ includes images from the "identify" tab
156
+ download_folder_excluding_identify (str): the folder containing the download bundle that
157
+ excludes images from the "identify" tab
158
+
159
+ Returns:
160
+ list of dict: list of image records that are present in the identify tab
161
+ """
162
+
163
+ ##%% Read data (~30 seconds)
164
+
165
+ image_id_to_image_records_with_identify = \
166
+ read_images_from_download_bundle(download_folder_with_identify)
167
+ image_id_to_image_records_excluding_identify = \
168
+ read_images_from_download_bundle(download_folder_excluding_identify)
169
+
170
+
171
+ ##%% Find images that have not been identified
172
+
173
+ all_image_ids_with_identify = set(image_id_to_image_records_with_identify.keys())
174
+ all_image_ids_excluding_identify = set(image_id_to_image_records_excluding_identify.keys())
175
+
176
+ image_ids_in_identify_tab = all_image_ids_with_identify.difference(all_image_ids_excluding_identify)
177
+
178
+ assert len(image_ids_in_identify_tab) == \
179
+ len(all_image_ids_with_identify) - len(all_image_ids_excluding_identify)
180
+
181
+ print('Found {} images with identify, {} in identify tab, {} excluding'.format(
182
+ len(all_image_ids_with_identify),
183
+ len(image_ids_in_identify_tab),
184
+ len(all_image_ids_excluding_identify)))
185
+
186
+ image_records_in_identify_tab = []
187
+ deployment_ids_for_downloaded_images = set()
188
+
189
+ for image_id in image_ids_in_identify_tab:
190
+ image_records_this_image = image_id_to_image_records_with_identify[image_id]
191
+ assert len(image_records_this_image) > 0
192
+ image_records_in_identify_tab.extend(image_records_this_image)
193
+ for image_record in image_records_this_image:
194
+ deployment_ids_for_downloaded_images.add(image_record['deployment_id'])
195
+
196
+ print('Found {} records for {} unique images in {} deployments'.format(
197
+ len(image_records_in_identify_tab),
198
+ len(image_ids_in_identify_tab),
199
+ len(deployment_ids_for_downloaded_images)))
200
+
201
+ return image_records_in_identify_tab
202
+
203
+ # ...def find_images_in_identify_tab(...)
204
+
205
+
206
+ def write_download_commands(image_records_to_download,
207
+ download_dir_base,
208
+ force_download=False,
209
+ n_download_workers=25,
210
+ download_command_file_base=None):
211
+ """
212
+ Given a list of dicts with at least the field 'location' (a gs:// URL), prepare a set of "gcloud
213
+ storage" commands to download images, and write those to a series of .sh scripts, along with one
214
+ .sh script that runs all the others and blocks.
215
+
216
+ gcloud commands will use relative paths.
217
+
218
+ image_records_to_download can also be a dict mapping IDs to lists of records.
219
+
220
+ Args:
221
+ image_records_to_download (list of dict): list of dicts with at least the field 'location'
222
+ download_dir_base (str): local destination folder
223
+ force_download (bool, optional): include gs commands even if the target file exists
224
+ n_download_workers (int, optional): number of scripts to write (that's our hacky way
225
+ of controlling parallelization)
226
+ download_command_file_base (str, optional): path of the .sh script we should write, defaults
227
+ to "download_wi_images.sh" in the destination folder. Individual worker scripts will
228
+ have a number added, e.g. download_wi_images_00.sh.
229
+ """
230
+
231
+ if isinstance(image_records_to_download,dict):
232
+
233
+ all_image_records = []
234
+ for k in image_records_to_download:
235
+ records_this_image = image_records_to_download[k]
236
+ all_image_records.extend(records_this_image)
237
+ return write_download_commands(all_image_records,
238
+ download_dir_base=download_dir_base,
239
+ force_download=force_download,
240
+ n_download_workers=n_download_workers,
241
+ download_command_file_base=download_command_file_base)
242
+
243
+ ##%% Make list of gcloud storage commands
244
+
245
+ if download_command_file_base is None:
246
+ download_command_file_base = os.path.join(download_dir_base,'download_wi_images.sh')
247
+
248
+ commands = []
249
+ skipped_urls = []
250
+ downloaded_urls = set()
251
+
252
+ # image_record = image_records_to_download[0]
253
+ for image_record in tqdm(image_records_to_download):
254
+
255
+ url = image_record['location']
256
+ if url in downloaded_urls:
257
+ continue
258
+
259
+ assert url.startswith('gs://')
260
+
261
+ relative_path = url.replace('gs://','')
262
+ abs_path = os.path.join(download_dir_base,relative_path)
263
+
264
+ # Skip files that already exist
265
+ if (not force_download) and (os.path.isfile(abs_path)):
266
+ skipped_urls.append(url)
267
+ continue
268
+
269
+ # command = 'gsutil cp "{}" "./{}"'.format(url,relative_path)
270
+ command = 'gcloud storage cp --no-clobber "{}" "./{}"'.format(url,relative_path)
271
+ commands.append(command)
272
+
273
+ print('Generated {} commands for {} image records'.format(
274
+ len(commands),len(image_records_to_download)))
275
+
276
+ print('Skipped {} URLs'.format(len(skipped_urls)))
277
+
278
+
279
+ ##%% Write those commands out to n .sh files
280
+
281
+ commands_by_script = split_list_into_n_chunks(commands,n_download_workers)
282
+
283
+ local_download_commands = []
284
+
285
+ output_dir = os.path.dirname(download_command_file_base)
286
+ os.makedirs(output_dir,exist_ok=True)
287
+
288
+ # Write out the download script for each chunk
289
+ # i_script = 0
290
+ for i_script in range(0,n_download_workers):
291
+ download_command_file = insert_before_extension(download_command_file_base,str(i_script).zfill(2))
292
+ local_download_commands.append(os.path.basename(download_command_file))
293
+ with open(download_command_file,'w',newline='\n') as f:
294
+ for command in commands_by_script[i_script]:
295
+ f.write(command + '\n')
296
+
297
+ # Write out the main download script
298
+ with open(download_command_file_base,'w',newline='\n') as f:
299
+ for local_download_command in local_download_commands:
300
+ f.write('./' + local_download_command + ' &\n')
301
+ f.write('wait\n')
302
+ f.write('echo done\n')
303
+
304
+ # ...def write_download_commands(...)
305
+
306
+
307
+ #%% Functions and constants related to pushing results to the DB
308
+
309
+ # Sample payload for validation
310
+ sample_update_payload = {
311
+
312
+ "predictions": [
313
+ {
314
+ "project_id": "1234",
315
+ "ignore_data_file_checks": True,
316
+ "prediction": "f1856211-cfb7-4a5b-9158-c0f72fd09ee6;;;;;;blank",
317
+ "prediction_score": 0.81218224763870239,
318
+ "classifications": {
319
+ "classes": [
320
+ "f1856211-cfb7-4a5b-9158-c0f72fd09ee6;;;;;;blank",
321
+ "b1352069-a39c-4a84-a949-60044271c0c1;aves;;;;;bird",
322
+ "90d950db-2106-4bd9-a4c1-777604c3eada;mammalia;rodentia;;;;rodent",
323
+ "f2d233e3-80e3-433d-9687-e29ecc7a467a;mammalia;;;;;mammal",
324
+ "ac068717-6079-4aec-a5ab-99e8d14da40b;mammalia;rodentia;sciuridae;dremomys;rufigenis;red-cheeked squirrel"
325
+ ],
326
+ "scores": [
327
+ 0.81218224763870239,
328
+ 0.1096673980355263,
329
+ 0.02707692421972752,
330
+ 0.00771023565903306,
331
+ 0.0049269795417785636
332
+ ]
333
+ },
334
+ "detections": [
335
+ {
336
+ "category": "1",
337
+ "label": "animal",
338
+ "conf": 0.181,
339
+ "bbox": [
340
+ 0.02421,
341
+ 0.35823999999999989,
342
+ 0.051560000000000009,
343
+ 0.070826666666666746
344
+ ]
345
+ }
346
+ ],
347
+ "model_version": "3.1.2",
348
+ "prediction_source": "manual_update",
349
+ "data_file_id": "2ea1d2b2-7f84-43f9-af1f-8be0e69c7015"
350
+ }
351
+ ]
352
+ }
353
+
354
+ process_cv_response_url = 'https://placeholder'
355
+
356
+
357
+ def prepare_data_update_auth_headers(auth_token_file):
358
+ """
359
+ Read the authorization token from a text file and prepare http headers.
360
+
361
+ Args:
362
+ auth_token_file (str): a single-line text file containing a write-enabled
363
+ API token.
364
+
365
+ Returns:
366
+ dict: http headers, with fields 'Authorization' and 'Content-Type'
367
+ """
368
+
369
+ with open(auth_token_file,'r') as f:
370
+ auth_token = f.read()
371
+
372
+ headers = {
373
+ 'Authorization': 'Bearer ' + auth_token,
374
+ 'Content-Type': 'application/json'
375
+ }
376
+
377
+ return headers
378
+
379
+
380
+ def push_results_for_images(payload,
381
+ headers,
382
+ url=process_cv_response_url,
383
+ verbose=False):
384
+ """
385
+ Push results for one or more images represented in [payload] to the
386
+ process_cv_response API, to write to the WI DB.
387
+
388
+ Args:
389
+ payload (dict): payload to upload to the API
390
+ headers (dict): authorization headers, see prepare_data_update_auth_headers
391
+ url (str, optional): API URL
392
+ verbose (bool, optional): enable additional debug output
393
+
394
+ Return:
395
+ int: response status code
396
+ """
397
+
398
+ if verbose:
399
+ print('Sending header {} to URL {}'.format(
400
+ headers,url))
401
+
402
+ response = requests.post(url, headers=headers, json=payload)
403
+
404
+ # Check the response status code
405
+ if response.status_code in (200,201):
406
+ if verbose:
407
+ print('Successfully pushed results for {} images'.format(len(payload['predictions'])))
408
+ print(response.headers)
409
+ print(str(response))
410
+ else:
411
+ print(f'Error: {response.status_code} {response.text}')
412
+
413
+ return response.status_code
414
+
415
+
416
+ def parallel_push_results_for_images(payloads,
417
+ headers,
418
+ url=process_cv_response_url,
419
+ verbose=False,
420
+ pool_type='thread',
421
+ n_workers=10):
422
+ """
423
+ Push results for the list of payloads in [payloads] to the process_cv_response API,
424
+ parallelized over multiple workers.
425
+
426
+ Args:
427
+ payloads (list of dict): payloads to upload to the API
428
+ headers (dict): authorization headers, see prepare_data_update_auth_headers
429
+ url (str, optional): API URL
430
+ verbose (bool, optional): enable additional debug output
431
+ pool_type (str, optional): 'thread' or 'process'
432
+ n_workers (int, optional): number of parallel workers
433
+
434
+ Returns:
435
+ list of int: list of http response codes, one per payload
436
+ """
437
+
438
+ if n_workers == 1:
439
+
440
+ results = []
441
+ for payload in payloads:
442
+ results.append(push_results_for_images(payload,
443
+ headers=headers,
444
+ url=url,
445
+ verbose=verbose))
446
+ return results
447
+
448
+ else:
449
+
450
+ assert pool_type in ('thread','process')
451
+
452
+ try:
453
+ if pool_type == 'thread':
454
+ pool_string = 'thread'
455
+ pool = ThreadPool(n_workers)
456
+ else:
457
+ pool_string = 'process'
458
+ pool = Pool(n_workers)
459
+
460
+ print('Created a {} pool of {} workers'.format(
461
+ pool_string,n_workers))
462
+
463
+ results = list(tqdm(pool.imap(
464
+ partial(push_results_for_images,headers=headers,url=url,verbose=verbose),payloads),
465
+ total=len(payloads)))
466
+ finally:
467
+ pool.close()
468
+ pool.join()
469
+ print("Pool closed and joined for WI result uploads")
470
+
471
+ assert len(results) == len(payloads)
472
+ return results
473
+
474
+
475
+ def generate_payload_with_replacement_detections(wi_result,
476
+ detections,
477
+ prediction_score=0.9,
478
+ model_version='3.1.2',
479
+ prediction_source='manual_update'):
480
+ """
481
+ Generate a payload for a single image that keeps the classifications from
482
+ [wi_result], but replaces the detections with the MD-formatted list [detections].
483
+
484
+ Args:
485
+ wi_result (dict): dict representing a WI prediction result, with at least the
486
+ fields in the constant wi_result_fields
487
+ detections (list): list of WI-formatted detection dicts (with fields ['conf'] and ['category'])
488
+ prediction_score (float, optional): confidence value to use for the combined prediction
489
+ model_version (str, optional): model version string to include in the payload
490
+ prediction_source (str, optional): prediction source string to include in the payload
491
+
492
+ Returns:
493
+ dict: dictionary suitable for uploading via push_results_for_images
494
+ """
495
+
496
+ payload_detections = []
497
+
498
+ # detection = detections[0]
499
+ for detection in detections:
500
+ detection_out = detection.copy()
501
+ detection_out['label'] = md_category_id_to_name[detection['category']]
502
+ if detection_out['conf'] < min_md_output_confidence:
503
+ detection_out['conf'] = min_md_output_confidence
504
+ payload_detections.append(detection_out)
505
+
506
+ prediction_string = wi_result_to_prediction_string(wi_result)
507
+
508
+ prediction = {}
509
+ prediction['ignore_data_file_checks'] = True
510
+ prediction['prediction'] = prediction_string
511
+ prediction['prediction_score'] = prediction_score
512
+
513
+ classifications = {}
514
+ classifications['classes'] = [prediction_string]
515
+ classifications['scores'] = [prediction_score]
516
+
517
+ prediction['classifications'] = classifications
518
+ prediction['detections'] = payload_detections
519
+ prediction['model_version'] = model_version
520
+ prediction['prediction_source'] = prediction_source
521
+ prediction['data_file_id'] = wi_result['image_id']
522
+ prediction['project_id'] = str(wi_result['project_id'])
523
+ payload = {}
524
+ payload['predictions'] = [prediction]
525
+
526
+ return payload
527
+
528
+
529
+ def generate_blank_prediction_payload(data_file_id,
530
+ project_id,
531
+ blank_confidence=0.9,
532
+ model_version='3.1.2',
533
+ prediction_source='manual_update'):
534
+ """
535
+ Generate a payload that will set a single image to the blank classification, with
536
+ no detections. Suitable for upload via push_results_for_images.
537
+
538
+ Args:
539
+ data_file_id (str): unique identifier for this image used in the WI DB
540
+ project_id (int): WI project ID
541
+ blank_confidence (float, optional): confidence value to associate with this
542
+ prediction
543
+ model_version (str, optional): model version string to include in the payload
544
+ prediction_source (str, optional): prediction source string to include in the payload
545
+
546
+ Returns:
547
+ dict: dictionary suitable for uploading via push_results_for_images
548
+ """
549
+
550
+ prediction = {}
551
+ prediction['ignore_data_file_checks'] = True
552
+ prediction['prediction'] = blank_prediction_string
553
+ prediction['prediction_score'] = blank_confidence
554
+ prediction['classifications'] = {}
555
+ prediction['classifications']['classes'] = [blank_prediction_string]
556
+ prediction['classifications']['scores'] = [blank_confidence]
557
+ prediction['detections'] = []
558
+ prediction['model_version'] = model_version
559
+ prediction['prediction_source'] = prediction_source
560
+ prediction['data_file_id'] = data_file_id
561
+ prediction['project_id'] = project_id
562
+ payload = {}
563
+ payload['predictions'] = [prediction]
564
+
565
+ return payload
566
+
567
+
568
+ def generate_no_cv_result_payload(data_file_id,
569
+ project_id,
570
+ no_cv_confidence=0.9,
571
+ model_version='3.1.2',
572
+ prediction_source='manual_update'):
573
+ """
574
+ Generate a payload that will set a single image to the blank classification, with
575
+ no detections. Suitable for uploading via push_results_for_images.
576
+
577
+ Args:
578
+ data_file_id (str): unique identifier for this image used in the WI DB
579
+ project_id (int): WI project ID
580
+ no_cv_confidence (float, optional): confidence value to associate with this
581
+ prediction
582
+ model_version (str, optional): model version string to include in the payload
583
+ prediction_source (str, optional): prediction source string to include in the payload
584
+
585
+ Returns:
586
+ dict: dictionary suitable for uploading via push_results_for_images
587
+ """
588
+
589
+ prediction = {}
590
+ prediction['ignore_data_file_checks'] = True
591
+ prediction['prediction'] = no_cv_result_prediction_string
592
+ prediction['prediction_score'] = no_cv_confidence
593
+ prediction['classifications'] = {}
594
+ prediction['classifications']['classes'] = [no_cv_result_prediction_string]
595
+ prediction['classifications']['scores'] = [no_cv_confidence]
596
+ prediction['detections'] = []
597
+ prediction['model_version'] = model_version
598
+ prediction['prediction_source'] = prediction_source
599
+ prediction['data_file_id'] = data_file_id
600
+ prediction['project_id'] = project_id
601
+ payload = {}
602
+ payload['predictions'] = [prediction]
603
+
604
+ return payload
605
+
606
+
607
+ def generate_payload_for_prediction_string(data_file_id,
608
+ project_id,
609
+ prediction_string,
610
+ prediction_confidence=0.8,
611
+ detections=None,
612
+ model_version='3.1.2',
613
+ prediction_source='manual_update'):
614
+ """
615
+ Generate a payload that will set a single image to a particular prediction, optionally
616
+ including detections. Suitable for uploading via push_results_for_images.
617
+
618
+ Args:
619
+ data_file_id (str): unique identifier for this image used in the WI DB
620
+ project_id (int): WI project ID
621
+ prediction_string (str): WI-formatted prediction string to include in the payload
622
+ prediction_confidence (float, optional): confidence value to associate with this
623
+ prediction
624
+ detections (list, optional): list of MD-formatted detection dicts, with fields
625
+ ['category'] and 'conf'
626
+ model_version (str, optional): model version string to include in the payload
627
+ prediction_source (str, optional): prediction source string to include in the payload
628
+
629
+
630
+ Returns:
631
+ dict: dictionary suitable for uploading via push_results_for_images
632
+ """
633
+
634
+ assert is_valid_prediction_string(prediction_string), \
635
+ 'Invalid prediction string: {}'.format(prediction_string)
636
+
637
+ payload_detections = []
638
+
639
+ if detections is not None:
640
+ # detection = detections[0]
641
+ for detection in detections:
642
+ detection_out = detection.copy()
643
+ detection_out['label'] = md_category_id_to_name[detection['category']]
644
+ if detection_out['conf'] < min_md_output_confidence:
645
+ detection_out['conf'] = min_md_output_confidence
646
+ payload_detections.append(detection_out)
647
+
648
+ prediction = {}
649
+ prediction['ignore_data_file_checks'] = True
650
+ prediction['prediction'] = prediction_string
651
+ prediction['prediction_score'] = prediction_confidence
652
+ prediction['classifications'] = {}
653
+ prediction['classifications']['classes'] = [prediction_string]
654
+ prediction['classifications']['scores'] = [prediction_confidence]
655
+ prediction['detections'] = payload_detections
656
+ prediction['model_version'] = model_version
657
+ prediction['prediction_source'] = prediction_source
658
+ prediction['data_file_id'] = data_file_id
659
+ prediction['project_id'] = project_id
660
+
661
+ payload = {}
662
+ payload['predictions'] = [prediction]
663
+
664
+ return payload
665
+
666
+
667
+ def validate_payload(payload):
668
+ """
669
+ Verifies that the dict [payload] is compatible with the ProcessCVResponse() API. Throws an
670
+ error if [payload] is invalid.
671
+
672
+ Args:
673
+ payload (dict): payload in the format expected by push_results_for_images.
674
+
675
+ Returns:
676
+ bool: successful validation; this is just future-proofing, currently never returns False
677
+ """
678
+
679
+ assert isinstance(payload,dict)
680
+ assert len(payload.keys()) == 1 and 'predictions' in payload
681
+
682
+ # prediction = payload['predictions'][0]
683
+ for prediction in payload['predictions']:
684
+
685
+ assert 'project_id' in prediction
686
+ if not isinstance(prediction['project_id'],int):
687
+ _ = int(prediction['project_id'])
688
+ assert 'ignore_data_file_checks' in prediction and \
689
+ isinstance(prediction['ignore_data_file_checks'],bool)
690
+ assert 'prediction' in prediction and \
691
+ isinstance(prediction['prediction'],str) and \
692
+ len(prediction['prediction'].split(';')) == 7
693
+ assert 'prediction_score' in prediction and \
694
+ isinstance(prediction['prediction_score'],float)
695
+ assert 'model_version' in prediction and \
696
+ isinstance(prediction['model_version'],str)
697
+ assert 'data_file_id' in prediction and \
698
+ isinstance(prediction['data_file_id'],str) and \
699
+ len(prediction['data_file_id']) == 36
700
+ assert 'classifications' in prediction and \
701
+ isinstance(prediction['classifications'],dict)
702
+ classifications = prediction['classifications']
703
+ assert 'classes' in classifications and isinstance(classifications['classes'],list)
704
+ assert 'scores' in classifications and isinstance(classifications['scores'],list)
705
+ assert len(classifications['classes']) == len(classifications['scores'])
706
+ for c in classifications['classes']:
707
+ assert is_valid_prediction_string(c)
708
+ for score in classifications['scores']:
709
+ assert isinstance(score,float) and score >= 0 and score <= 1.0
710
+ assert 'detections' in prediction and isinstance(prediction['detections'],list)
711
+
712
+ for detection in prediction['detections']:
713
+
714
+ assert isinstance(detection,dict)
715
+ assert 'category' in detection and detection['category'] in ('1','2','3')
716
+ assert 'label' in detection and detection['label'] in ('animal','person','vehicle')
717
+ assert 'conf' in detection and \
718
+ isinstance(detection['conf'],float) and \
719
+ detection['conf'] >= 0 and detection['conf'] <= 1.0
720
+ assert 'bbox' in detection and \
721
+ isinstance(detection['bbox'],list) and \
722
+ len(detection['bbox']) == 4
723
+
724
+ # ...for each detection
725
+
726
+ # ...for each prediction
727
+
728
+ return True
729
+
730
+ # ...def validate_payload(...)
731
+
732
+
733
+ #%% Functions for working with WI results (from the API or from download bundles)
734
+
735
+ def wi_result_to_prediction_string(r):
736
+ """
737
+ Convert the dict [r] - typically loaded from a row in a downloaded .csv file - to
738
+ a valid prediction string, e.g.:
739
+
740
+ 1f689929-883d-4dae-958c-3d57ab5b6c16;;;;;;animal
741
+ 90d950db-2106-4bd9-a4c1-777604c3eada;mammalia;rodentia;;;;rodent
742
+
743
+ Args:
744
+ r (dict): dict containing WI prediction information, with at least the fields
745
+ specified in wi_result_fields.
746
+
747
+ Returns:
748
+ str: the result in [r], as a semicolon-delimited prediction string
749
+ """
750
+
751
+ values = []
752
+ for field in wi_result_fields:
753
+ if isinstance(r[field],str):
754
+ values.append(r[field].lower())
755
+ else:
756
+ assert isinstance(r[field],float) and np.isnan(r[field])
757
+ values.append('')
758
+ s = ';'.join(values)
759
+ assert is_valid_prediction_string(s)
760
+ return s
761
+
762
+
763
+ def record_is_unidentified(record):
764
+ """
765
+ A record is considered "unidentified" if the "identified by" field is either NaN or "computer vision"
766
+
767
+ Args:
768
+ record (dict): dict representing a WI result loaded from a .csv file, with at least the
769
+ field "identified_by"
770
+
771
+ Returns:
772
+ bool: True if the "identified_by" field is either NaN or a string indicating that this
773
+ record has not yet been human-reviewed.
774
+ """
775
+
776
+ identified_by = record['identified_by']
777
+ assert isinstance(identified_by,float) or isinstance(identified_by,str)
778
+ if isinstance(identified_by,float):
779
+ assert np.isnan(identified_by)
780
+ return True
781
+ else:
782
+ return identified_by == 'Computer vision'
783
+
784
+
785
+ def record_lists_are_identical(records_0,records_1,verbose=False):
786
+ """
787
+ Takes two lists of records in the form returned by read_images_from_download_bundle and
788
+ determines whether they are the same.
789
+
790
+ Args:
791
+ records_0 (list of dict): the first list of records to compare
792
+ records_1 (list of dict): the second list of records to compare
793
+ verbose (bool, optional): enable additional debug output
794
+
795
+ Returns:
796
+ bool: True if the two lists are identical
797
+ """
798
+
799
+ if len(records_0) != len(records_1):
800
+ return False
801
+
802
+ # i_record = 0; record_0 = records_0[i_record]
803
+ for i_record,record_0 in enumerate(records_0):
804
+ record_1 = records_1[i_record]
805
+ assert set(record_0.keys()) == set(record_1.keys())
806
+ for k in record_0.keys():
807
+ if not compare_values_nan_equal(record_0[k],record_1[k]):
808
+ if verbose:
809
+ print('Image ID: {} ({})\nRecord 0/{}: {}\nRecord 1/{}: {}'.format(
810
+ record_0['image_id'],record_1['image_id'],
811
+ k,record_0[k],k,record_1[k]))
812
+ return False
813
+
814
+ return True
815
+
816
+
817
+ #%% Validate constants
818
+
819
+ # This is executed at the time this module gets imported.
820
+
821
+ blank_payload = generate_blank_prediction_payload('70ede9c6-d056-4dd1-9a0b-3098d8113e0e','1234')
822
+ validate_payload(sample_update_payload)
823
+ validate_payload(blank_payload)
824
+