megadetector 10.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +701 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +563 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +192 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +665 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +984 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2172 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1604 -0
  81. megadetector/detection/run_tiled_inference.py +1044 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1943 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2140 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +231 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2872 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1766 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1973 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +498 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.15.dist-info/METADATA +115 -0
  144. megadetector-10.0.15.dist-info/RECORD +147 -0
  145. megadetector-10.0.15.dist-info/WHEEL +5 -0
  146. megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,853 @@
1
+ """
2
+
3
+ detect_and_crop.py
4
+
5
+ Run MegaDetector on images via Batch API, then save crops of the detected
6
+ bounding boxes.
7
+
8
+ The input to this script is a "queried images" JSON file, whose keys are paths
9
+ to images and values are dicts containing information relevant for training
10
+ a classifier, including labels and (optionally) ground-truth bounding boxes.
11
+ The image paths are in the format `<dataset-name>/<blob-name>` where we assume
12
+ that the dataset name does not contain '/'.
13
+
14
+ {
15
+ "caltech/cct_images/59f79901-23d2-11e8-a6a3-ec086b02610b.jpg": {
16
+ "dataset": "caltech",
17
+ "location": 13,
18
+ "class": "mountain_lion", # class from dataset
19
+ "bbox": [{"category": "animal",
20
+ "bbox": [0, 0.347, 0.237, 0.257]}], # ground-truth bbox
21
+ "label": ["monutain_lion"] # labels to use in classifier
22
+ },
23
+ "caltech/cct_images/59f5fe2b-23d2-11e8-a6a3-ec086b02610b.jpg": {
24
+ "dataset": "caltech",
25
+ "location": 13,
26
+ "class": "mountain_lion", # class from dataset
27
+ "label": ["monutain_lion"] # labels to use in classifier
28
+ },
29
+ ...
30
+ }
31
+
32
+ We assume that no image contains over 100 bounding boxes, and we always save
33
+ crops as RGB .jpg files for consistency. For each image, each bounding box is
34
+ cropped and saved to a file with a suffix "___cropXX.jpg" (ground truth bbox) or
35
+ "___cropXX_mdvY.Y.jpg" (detected bbox) added to the filename of the original
36
+ image. "XX" ranges from "00" to "99" and "Y.Y" indicates the MegaDetector
37
+ version. If an image has ground truth bounding boxes, we assume that they are
38
+ exhaustive--i.e., there are no other objects of interest, so we don't need to
39
+ run MegaDetector on the image. If an image does not have ground truth bounding
40
+ boxes, we run MegaDetector on the image and label the detected boxes in order
41
+ from 00 up to 99. Based on the given confidence threshold, we may skip saving
42
+ certain bounding box crops, but we still increment the bounding box number for
43
+ skipped boxes.
44
+
45
+ Example cropped image path (with ground truth bbox from MegaDB)
46
+
47
+ "path/to/crops/image.jpg___crop00.jpg"
48
+
49
+ Example cropped image path (with MegaDetector bbox)
50
+
51
+ "path/to/crops/image.jpg___crop00_mdv4.1.jpg"
52
+
53
+ By default, the images are cropped exactly per the given bounding box
54
+ coordinates. However, if square crops are desired, pass the --square-crops
55
+ flag. This will always generate a square crop whose size is the larger of the
56
+ bounding box width or height. In the case that the square crop boundaries exceed
57
+ the original image size, the crop is padded with 0s.
58
+
59
+ This script currently only supports running MegaDetector via the Batch Detection
60
+ API. See the classification README for instructions on running MegaDetector
61
+ locally. If running the Batch Detection API, set the following environment
62
+ variables for the Azure Blob Storage container in which we save the intermediate
63
+ task lists:
64
+
65
+ BATCH_DETECTION_API_URL # API URL
66
+ CLASSIFICATION_BLOB_STORAGE_ACCOUNT # storage account name
67
+ CLASSIFICATION_BLOB_CONTAINER # container name
68
+ CLASSIFICATION_BLOB_CONTAINER_WRITE_SAS # SAS token, without leading '?'
69
+ DETECTION_API_CALLER # allow-listed API caller
70
+
71
+ This script allows specifying a directory where MegaDetector outputs are cached
72
+ via the --detector-output-cache-dir argument. This directory must be
73
+ organized as:
74
+
75
+ <cache-dir>/<MegaDetector-version>/<dataset-name>.json
76
+
77
+ Example: If the `cameratrapssc/classifier-training` Azure blob storage
78
+ container is mounted to the local machine via blobfuse, it may be used as
79
+ a MegaDetector output cache directory by passing
80
+ "cameratrapssc/classifier-training/mdcache/"
81
+ as the value for --detector-output-cache-dir.
82
+
83
+ This script outputs either 1 or 3 files, depending on whether the Batch Detection API
84
+ is run:
85
+
86
+ - <output_dir>/detect_and_crop_log_{timestamp}.json
87
+ log of images missing detections and images that failed to properly
88
+ download and crop
89
+ - <output_dir>/batchapi_tasklists/{task_id}.json
90
+ (if --run-detector) task lists uploaded to the Batch Detection API
91
+ - <output_dir>/batchapi_response/{task_id}.json
92
+ (if --run-detector) task status responses for completed tasks
93
+
94
+ """
95
+
96
+ #%% Imports
97
+
98
+ from __future__ import annotations
99
+
100
+ import argparse
101
+ from collections.abc import Collection, Iterable, Mapping, Sequence
102
+ from concurrent import futures
103
+ from datetime import datetime
104
+ import json
105
+ import os
106
+ import pprint
107
+ import time
108
+ from typing import Any, Optional
109
+
110
+ from azure.storage.blob import ContainerClient
111
+ import requests
112
+ from tqdm import tqdm
113
+
114
+ from api.batch_processing.data_preparation.prepare_api_submission import (
115
+ BatchAPIResponseError, Task, TaskStatus, divide_list_into_tasks)
116
+ from megadetector.classification.cache_batchapi_outputs import cache_detections
117
+ from megadetector.classification.crop_detections import load_and_crop
118
+ from megadetector.data_management.megadb import megadb_utils
119
+ from megadetector.utils import path_utils
120
+ from megadetector.utils import sas_blob_utils
121
+ from megadetector.utils import ct_utils
122
+
123
+
124
+ #%% Example usage
125
+
126
+ """
127
+ python detect_and_crop.py \
128
+ base_logdir/queried_images.json \
129
+ base_logdir \
130
+ --detector-output-cache-dir /path/to/classifier-training/mdcache \
131
+ --detector-version 4.1 \
132
+ --run-detector --resume-file base_logdir/resume.json \
133
+ --cropped-images-dir /path/to/crops --square-crops --threshold 0.9 \
134
+ --save-full-images --images-dir /path/to/images --threads 50
135
+ """
136
+
137
+
138
+ #%% Main function
139
+
140
+ def main(queried_images_json_path: str,
141
+ output_dir: str,
142
+ detector_version: str,
143
+ detector_output_cache_base_dir: str,
144
+ run_detector: bool,
145
+ resume_file_path: Optional[str],
146
+ cropped_images_dir: Optional[str],
147
+ save_full_images: bool,
148
+ square_crops: bool,
149
+ check_crops_valid: bool,
150
+ confidence_threshold: float,
151
+ images_dir: Optional[str],
152
+ threads: int) -> None:
153
+ """
154
+ Args:
155
+ queried_images_json_path: str, path to output of json_validator.py
156
+ detector_version: str, detector version string, e.g., '4.1',
157
+ see {batch_detection_api_url}/supported_model_versions,
158
+ determines the subfolder of detector_output_cache_base_dir in
159
+ which to find and save detector outputs
160
+ detector_output_cache_base_dir: str, path to local directory
161
+ where detector outputs are cached, 1 JSON file per dataset
162
+ cropped_images_dir: str, path to local directory for saving crops of
163
+ bounding boxes
164
+ run_detector: bool, whether to run Batch Detection API, or to skip
165
+ running the detector entirely
166
+ output_dir: str, path to directory to save outputs, see module docstring
167
+ save_full_images: bool, whether to save downloaded images to images_dir,
168
+ images_dir must be given if save_full_images=True
169
+ square_crops: bool, whether to crop bounding boxes as squares
170
+ check_crops_valid: bool, whether to load each crop to ensure the file is
171
+ valid (i.e., not truncated)
172
+ confidence_threshold: float, only crop bounding boxes above this value
173
+ images_dir: optional str, path to local directory where images are saved
174
+ threads: int, number of threads to use for downloading images
175
+ resume_file_path: optional str, path to save JSON file with list of info
176
+ dicts on running tasks, or to resume from running tasks, only used
177
+ if run_detector=True
178
+ """
179
+
180
+ # This dictionary will get written out at the end of this process; store
181
+ # diagnostic variables here
182
+ log: dict[str, Any] = {}
183
+
184
+ # error checking
185
+ assert 0 <= confidence_threshold <= 1
186
+ if save_full_images:
187
+ assert images_dir is not None
188
+ if not os.path.exists(images_dir):
189
+ os.makedirs(images_dir, exist_ok=True)
190
+ print(f'Created images_dir at {images_dir}')
191
+
192
+ with open(queried_images_json_path, 'r') as f:
193
+ js = json.load(f)
194
+ detector_output_cache_dir = os.path.join(
195
+ detector_output_cache_base_dir, f'v{detector_version}')
196
+ if not os.path.exists(detector_output_cache_dir):
197
+ os.makedirs(detector_output_cache_dir)
198
+ print(f'Created directory at {detector_output_cache_dir}')
199
+ images_without_ground_truth_bbox = [k for k in js if 'bbox' not in js[k]]
200
+ images_to_detect, detection_cache, categories = filter_detected_images(
201
+ potential_images_to_detect=images_without_ground_truth_bbox,
202
+ detector_output_cache_dir=detector_output_cache_dir)
203
+ print(f'{len(images_to_detect)} images not in detection cache')
204
+
205
+ if run_detector:
206
+ log['images_submitted_for_detection'] = images_to_detect
207
+
208
+ assert resume_file_path is not None
209
+ assert not os.path.isdir(resume_file_path)
210
+ batch_detection_api_url = os.environ['BATCH_DETECTION_API_URL']
211
+
212
+ if os.path.exists(resume_file_path):
213
+ tasks_by_dataset = resume_tasks(
214
+ resume_file_path,
215
+ batch_detection_api_url=batch_detection_api_url)
216
+ else:
217
+ task_lists_dir = os.path.join(output_dir, 'batchapi_tasklists')
218
+ tasks_by_dataset = submit_batch_detection_api(
219
+ images_to_detect=images_to_detect,
220
+ task_lists_dir=task_lists_dir,
221
+ detector_version=detector_version,
222
+ account=os.environ['CLASSIFICATION_BLOB_STORAGE_ACCOUNT'],
223
+ container=os.environ['CLASSIFICATION_BLOB_CONTAINER'],
224
+ sas_token=os.environ['CLASSIFICATION_BLOB_CONTAINER_WRITE_SAS'],
225
+ caller=os.environ['DETECTION_API_CALLER'],
226
+ batch_detection_api_url=batch_detection_api_url,
227
+ resume_file_path=resume_file_path)
228
+
229
+ wait_for_tasks(tasks_by_dataset, detector_output_cache_dir,
230
+ output_dir=output_dir)
231
+
232
+ # refresh detection cache
233
+ print('Refreshing detection cache...')
234
+ images_to_detect, detection_cache, categories = filter_detected_images(
235
+ potential_images_to_detect=images_without_ground_truth_bbox,
236
+ detector_output_cache_dir=detector_output_cache_dir)
237
+ print(f'{len(images_to_detect)} images not in detection cache')
238
+
239
+ log['images_missing_detections'] = images_to_detect
240
+
241
+ if cropped_images_dir is not None:
242
+
243
+ images_failed_dload_crop, num_downloads, num_crops = download_and_crop(
244
+ queried_images_json=js,
245
+ detection_cache=detection_cache,
246
+ detection_categories=categories,
247
+ detector_version=detector_version,
248
+ cropped_images_dir=cropped_images_dir,
249
+ confidence_threshold=confidence_threshold,
250
+ save_full_images=save_full_images,
251
+ square_crops=square_crops,
252
+ check_crops_valid=check_crops_valid,
253
+ images_dir=images_dir,
254
+ threads=threads,
255
+ images_missing_detections=images_to_detect)
256
+ log['images_failed_download_or_crop'] = images_failed_dload_crop
257
+ log['num_new_downloads'] = num_downloads
258
+ log['num_new_crops'] = num_crops
259
+
260
+ print(f'{len(images_to_detect)} images with missing detections.')
261
+ if cropped_images_dir is not None:
262
+ print(f'{len(images_failed_dload_crop)} images failed to download or '
263
+ 'crop.')
264
+
265
+ # save log of bad images
266
+ date = datetime.now().strftime('%Y%m%d_%H%M%S') # e.g., '20200722_110816'
267
+ log_path = os.path.join(output_dir, f'detect_and_crop_log_{date}.json')
268
+ ct_utils.write_json(log_path, log)
269
+
270
+
271
+ #%% Support functions
272
+
273
+ def load_detection_cache(detector_output_cache_dir: str,
274
+ datasets: Collection[str]) -> tuple[
275
+ dict[str, dict[str, dict[str, Any]]],
276
+ dict[str, str]
277
+ ]:
278
+ """
279
+ Loads detection cache for a given dataset. Returns empty dictionaries
280
+ if the cache does not exist.
281
+
282
+ Args:
283
+ detector_output_cache_dir: str, path to local directory where detector
284
+ outputs are cached, 1 JSON file per dataset
285
+ datasets: list of str, names of datasets
286
+
287
+ Returns:
288
+ detection_cache: dict, maps dataset name to dict, which maps
289
+ image file to corresponding entry in 'images' list from the
290
+ Batch Detection API output. detection_cache[ds] is an empty dict
291
+ if no cached detections were found for the given dataset ds.
292
+ detection_categories: dict, maps str category ID to str category name
293
+ """
294
+
295
+ # cache of Detector outputs: dataset name => {img_path => detection_dict}
296
+ detection_cache = {}
297
+ detection_categories: dict[str, str] = {}
298
+
299
+ pbar = tqdm(datasets)
300
+ for ds in pbar:
301
+ pbar.set_description(f'Loading dataset {ds} into detection cache')
302
+ cache_path = os.path.join(detector_output_cache_dir, f'{ds}.json')
303
+ if os.path.exists(cache_path):
304
+ with open(cache_path, 'r') as f:
305
+ js = json.load(f)
306
+ detection_cache[ds] = {img['file']: img for img in js['images']}
307
+ if len(detection_categories) == 0:
308
+ detection_categories = js['detection_categories']
309
+ assert detection_categories == js['detection_categories']
310
+ else:
311
+ tqdm.write(f'No detection cache found for dataset {ds}')
312
+ detection_cache[ds] = {}
313
+ return detection_cache, detection_categories
314
+
315
+
316
+ def filter_detected_images(
317
+ potential_images_to_detect: Iterable[str],
318
+ detector_output_cache_dir: str
319
+ ) -> tuple[list[str],
320
+ dict[str, dict[str, dict[str, Any]]],
321
+ dict[str, str]]:
322
+ """
323
+ Checks image paths against cached Detector outputs, and prepares
324
+ the SAS URIs for each image not in the cache.
325
+
326
+ Args:
327
+ potential_images_to_detect: list of str, paths to images that do not
328
+ have ground truth bounding boxes, each path has format
329
+ <dataset-name>/<img-filename>, where <img-filename> is the blob name
330
+ detector_output_cache_dir: str, path to local directory where detector
331
+ outputs are cached, 1 JSON file per dataset
332
+
333
+ Returns:
334
+ images_to_detect: list of str, paths to images not in the detector
335
+ output cache, with the format <dataset-name>/<img-filename>
336
+ detection_cache: dict, maps str dataset name to dict,
337
+ detection_cache[dataset_name] is the 'detections' list from the
338
+ Batch Detection API output
339
+ detection_categories: dict, maps str category ID to str category name,
340
+ empty dict if no cached detections are found
341
+ """
342
+
343
+ datasets = set(img_path[:img_path.find('/')]
344
+ for img_path in potential_images_to_detect)
345
+ detection_cache, detection_categories = load_detection_cache(
346
+ detector_output_cache_dir, datasets)
347
+
348
+ images_to_detect = []
349
+ for img_path in potential_images_to_detect:
350
+ # img_path: <dataset-name>/<img-filename>
351
+ ds, img_file = img_path.split('/', maxsplit=1)
352
+ if img_file not in detection_cache[ds]:
353
+ images_to_detect.append(img_path)
354
+
355
+ return images_to_detect, detection_cache, detection_categories
356
+
357
+
358
+ def split_images_list_by_dataset(images_to_detect: Iterable[str]
359
+ ) -> dict[str, list[str]]:
360
+ """
361
+ Args:
362
+ images_to_detect: list of str, image paths with the format
363
+ <dataset-name>/<image-filename>
364
+
365
+ Returns: dict, maps dataset name to a list of image paths
366
+ """
367
+
368
+ images_by_dataset: dict[str, list[str]] = {}
369
+ for img_path in images_to_detect:
370
+ dataset = img_path[:img_path.find('/')]
371
+ if dataset not in images_by_dataset:
372
+ images_by_dataset[dataset] = []
373
+ images_by_dataset[dataset].append(img_path)
374
+ return images_by_dataset
375
+
376
+
377
+ def submit_batch_detection_api(images_to_detect: Iterable[str],
378
+ task_lists_dir: str,
379
+ detector_version: str,
380
+ account: str,
381
+ container: str,
382
+ sas_token: str,
383
+ caller: str,
384
+ batch_detection_api_url: str,
385
+ resume_file_path: str
386
+ ) -> dict[str, list[Task]]:
387
+ """
388
+ Args:
389
+ images_to_detect: list of str, list of str, image paths with the format
390
+ <dataset-name>/<image-filename>
391
+ task_lists_dir: str, path to local directory for saving JSON files
392
+ each containing a list of image URLs corresponding to an API task
393
+ detector_version: str, MegaDetector version string, e.g., '4.1',
394
+ see {batch_detection_api_url}/supported_model_versions
395
+ account: str, Azure Storage account name
396
+ container: str, Azure Blob Storage container name, where the task lists
397
+ will be uploaded
398
+ sas_token: str, SAS token with write permissions for the container
399
+ caller: str, allow-listed caller
400
+ batch_detection_api_url: str, URL to batch detection API
401
+ resume_file_path: str, path to save resume file
402
+
403
+ Returns: dict, maps str dataset name to list of Task objects
404
+ """
405
+
406
+ filtered_images_to_detect = [
407
+ x for x in images_to_detect if path_utils.is_image_file(x)]
408
+ not_images = set(images_to_detect) - set(filtered_images_to_detect)
409
+ if len(not_images) == 0:
410
+ print('Good! All image files have valid file extensions.')
411
+ else:
412
+ print(f'Skipping {len(not_images)} files with non-image extensions:')
413
+ pprint.pprint(sorted(not_images))
414
+ images_to_detect = filtered_images_to_detect
415
+
416
+ datasets_table = megadb_utils.MegadbUtils().get_datasets_table()
417
+
418
+ images_by_dataset = split_images_list_by_dataset(images_to_detect)
419
+ tasks_by_dataset = {}
420
+ for dataset, image_paths in images_by_dataset.items():
421
+ # get SAS URL for images container
422
+ images_sas_token = datasets_table[dataset]['container_sas_key']
423
+ if images_sas_token[0] == '?':
424
+ images_sas_token = images_sas_token[1:]
425
+ images_container_url = sas_blob_utils.build_azure_storage_uri(
426
+ account=datasets_table[dataset]['storage_account'],
427
+ container=datasets_table[dataset]['container'],
428
+ sas_token=images_sas_token)
429
+
430
+ # strip image paths of dataset name
431
+ image_blob_names = [path[path.find('/') + 1:] for path in image_paths]
432
+
433
+ tasks_by_dataset[dataset] = submit_batch_detection_api_by_dataset(
434
+ dataset=dataset,
435
+ image_blob_names=image_blob_names,
436
+ images_container_url=images_container_url,
437
+ task_lists_dir=task_lists_dir,
438
+ detector_version=detector_version,
439
+ account=account, container=container, sas_token=sas_token,
440
+ caller=caller, batch_detection_api_url=batch_detection_api_url)
441
+
442
+ # save list of dataset names and task IDs for resuming
443
+ resume_json = [
444
+ {
445
+ 'dataset': dataset,
446
+ 'task_name': task.name,
447
+ 'task_id': task.id,
448
+ 'local_images_list_path': task.local_images_list_path
449
+ }
450
+ for dataset in tasks_by_dataset
451
+ for task in tasks_by_dataset[dataset]
452
+ ]
453
+ ct_utils.write_json(resume_file_path, resume_json)
454
+ return tasks_by_dataset
455
+
456
+
457
+ def submit_batch_detection_api_by_dataset(
458
+ dataset: str,
459
+ image_blob_names: Sequence[str],
460
+ images_container_url: str,
461
+ task_lists_dir: str,
462
+ detector_version: str,
463
+ account: str,
464
+ container: str,
465
+ sas_token: str,
466
+ caller: str,
467
+ batch_detection_api_url: str
468
+ ) -> list[Task]:
469
+ """
470
+ Args:
471
+ dataset: str, name of dataset
472
+ image_blob_names: list of str, image blob names from the same dataset
473
+ images_container_url: str, URL to blob storage container where images
474
+ from this dataset are stored, including SAS token with read
475
+ permissions if container is not public
476
+ **see submit_batch_detection_api() for description of other args
477
+
478
+ Returns: list of Task objects
479
+ """
480
+
481
+ os.makedirs(task_lists_dir, exist_ok=True)
482
+
483
+ date = datetime.now().strftime('%Y%m%d_%H%M%S') # e.g., '20200722_110816'
484
+ task_list_base_filename = f'task_list_{dataset}_{date}.json'
485
+
486
+ task_list_paths, _ = divide_list_into_tasks(
487
+ file_list=image_blob_names,
488
+ save_path=os.path.join(task_lists_dir, task_list_base_filename))
489
+
490
+ # complete task name: 'detect_for_classifier_caltech_20200722_110816_task01'
491
+ task_name_template = 'detect_for_classifier_{dataset}_{date}_task{n:>02d}'
492
+ tasks: list[Task] = []
493
+ for i, task_list_path in enumerate(task_list_paths):
494
+ task = Task(
495
+ name=task_name_template.format(dataset=dataset, date=date, n=i),
496
+ images_list_path=task_list_path, api_url=batch_detection_api_url)
497
+ task.upload_images_list(
498
+ account=account, container=container, sas_token=sas_token)
499
+ task.generate_api_request(
500
+ caller=caller,
501
+ input_container_url=images_container_url,
502
+ model_version=detector_version)
503
+ print(f'Submitting task for: {task_list_path}')
504
+ task.submit()
505
+ print(f'- task ID: {task.id}')
506
+ tasks.append(task)
507
+
508
+ # HACK! Sleep for 10s between task submissions in the hopes that it
509
+ # decreases the chance of backend JSON "database" corruption
510
+ time.sleep(10)
511
+ return tasks
512
+
513
+
514
+ def resume_tasks(resume_file_path: str, batch_detection_api_url: str
515
+ ) -> dict[str, list[Task]]:
516
+ """
517
+ Args:
518
+ resume_file_path: str, path to resume file with list of info dicts on
519
+ running tasks
520
+ batch_detection_api_url: str, URL to batch detection API
521
+
522
+ Returns: dict, maps str dataset name to list of Task objects
523
+ """
524
+
525
+ with open(resume_file_path, 'r') as f:
526
+ resume_json = json.load(f)
527
+
528
+ tasks_by_dataset: dict[str, list[Task]] = {}
529
+ for info_dict in resume_json:
530
+ dataset = info_dict['dataset']
531
+ if dataset not in tasks_by_dataset:
532
+ tasks_by_dataset[dataset] = []
533
+ task = Task(name=info_dict['task_name'],
534
+ task_id=info_dict['task_id'],
535
+ images_list_path=info_dict['local_images_list_path'],
536
+ validate=False,
537
+ api_url=batch_detection_api_url)
538
+ tasks_by_dataset[dataset].append(task)
539
+ return tasks_by_dataset
540
+
541
+
542
+ def wait_for_tasks(tasks_by_dataset: Mapping[str, Iterable[Task]],
543
+ detector_output_cache_dir: str,
544
+ output_dir: Optional[str] = None,
545
+ poll_interval: int = 120) -> None:
546
+ """
547
+ Waits for the Batch Detection API tasks to finish running.
548
+
549
+ For jobs that finish successfully, merges the output with cached detector
550
+ outputs.
551
+
552
+ Args:
553
+ tasks_by_dataset: dict, maps str dataset name to list of Task objects
554
+ detector_output_cache_dir: str, path to local directory where detector
555
+ outputs are cached, 1 JSON file per dataset, directory must
556
+ already exist
557
+ output_dir: optional str, task status responses for completed tasks are
558
+ saved to <output_dir>/batchapi_response/{task_id}.json
559
+ poll_interval: int, # of seconds between pinging the task status API
560
+ """
561
+
562
+ remaining_tasks: list[tuple[str, Task]] = [
563
+ (dataset, task) for dataset, tasks in tasks_by_dataset.items()
564
+ for task in tasks]
565
+
566
+ progbar = tqdm(total=len(remaining_tasks))
567
+ while True:
568
+ new_remaining_tasks = []
569
+ for dataset, task in remaining_tasks:
570
+ try:
571
+ task.check_status()
572
+ except (BatchAPIResponseError, requests.HTTPError) as e:
573
+ exception_type = type(e).__name__
574
+ tqdm.write(f'Error in checking status of task {task.id}: '
575
+ f'({exception_type}) {e}')
576
+ tqdm.write(f'Skipping task {task.id}.')
577
+ continue
578
+
579
+ # task still running => continue
580
+ if task.status == TaskStatus.RUNNING:
581
+ new_remaining_tasks.append((dataset, task))
582
+ continue
583
+
584
+ progbar.update(1)
585
+ tqdm.write(f'Task {task.id} stopped with status {task.status}')
586
+
587
+ if task.status in [TaskStatus.PROBLEM, TaskStatus.FAILED]:
588
+ tqdm.write('API response:')
589
+ tqdm.write(str(task.response))
590
+ continue
591
+
592
+ # task finished successfully, save response to disk
593
+ assert task.status == TaskStatus.COMPLETED
594
+ if output_dir is not None:
595
+ save_dir = os.path.join(output_dir, 'batchapi_response')
596
+ if not os.path.exists(save_dir):
597
+ tqdm.write(f'Creating API output dir: {save_dir}')
598
+ os.makedirs(save_dir)
599
+ ct_utils.write_json(os.path.join(save_dir, f'{task.id}.json'), task.response)
600
+ message = task.response['Status']['message']
601
+ num_failed_shards = message['num_failed_shards']
602
+ if num_failed_shards != 0:
603
+ tqdm.write(f'Task {task.id} completed with {num_failed_shards} '
604
+ 'failed shards.')
605
+
606
+ detections_url = message['output_file_urls']['detections']
607
+ if task.id not in detections_url:
608
+ tqdm.write('Invalid detections URL in response. Skipping task.')
609
+ continue
610
+
611
+ detections = requests.get(detections_url).json()
612
+ msg = cache_detections(
613
+ detections=detections, dataset=dataset,
614
+ detector_output_cache_dir=detector_output_cache_dir)
615
+ tqdm.write(msg)
616
+
617
+ remaining_tasks = new_remaining_tasks
618
+ if len(remaining_tasks) == 0:
619
+ break
620
+ tqdm.write(f'Sleeping for {poll_interval} seconds...')
621
+ time.sleep(poll_interval)
622
+
623
+ progbar.close()
624
+
625
+
626
+ def download_and_crop(
627
+ queried_images_json: Mapping[str, Mapping[str, Any]],
628
+ detection_cache: Mapping[str, Mapping[str, Mapping[str, Any]]],
629
+ detection_categories: Mapping[str, str],
630
+ detector_version: str,
631
+ cropped_images_dir: str,
632
+ confidence_threshold: float,
633
+ save_full_images: bool,
634
+ square_crops: bool,
635
+ check_crops_valid: bool,
636
+ images_dir: Optional[str] = None,
637
+ threads: int = 1,
638
+ images_missing_detections: Optional[Iterable[str]] = None
639
+ ) -> tuple[list[str], int, int]:
640
+ """
641
+ Saves crops to a file with the same name as the original image with an
642
+ additional suffix appended, starting with 3 underscores:
643
+ - if image has ground truth bboxes: "___cropXX.jpg", where "XX" indicates
644
+ the bounding box index
645
+ - if image has bboxes from MegaDetector: "___cropXX_mdvY.Y.jpg", where
646
+ "Y.Y" indicates the MegaDetector version
647
+ See module docstring for more info and examples.
648
+
649
+ Note: this function is very similar to the "download_and_crop()" function in
650
+ crop_detections.py. The main difference is that this function uses
651
+ MegaDB to look up Azure Storage container information for images based
652
+ on the dataset, whereas the crop_detections.py version has no concept
653
+ of a "dataset" and "ground-truth" bounding boxes from MegaDB.
654
+
655
+ Args:
656
+ queried_images_json: dict, represents JSON output of json_validator.py,
657
+ all images in queried_images_json are assumed to have either ground
658
+ truth or cached detected bounding boxes unless
659
+ images_missing_detections is given
660
+ detection_cache: dict, dataset_name => {img_path => detection_dict}
661
+ detector_version: str, detector version string, e.g., '4.1'
662
+ cropped_images_dir: str, path to folder where cropped images are saved
663
+ confidence_threshold: float, only crop bounding boxes above this value
664
+ save_full_images: bool, whether to save downloaded images to images_dir,
665
+ images_dir must be given and must exist if save_full_images=True
666
+ square_crops: bool, whether to crop bounding boxes as squares
667
+ check_crops_valid: bool, whether to load each crop to ensure the file is
668
+ valid (i.e., not truncated)
669
+ images_dir: optional str, path to folder where full images are saved
670
+ threads: int, number of threads to use for downloading images
671
+ images_missing_detections: optional list of str, image files to skip
672
+ because they have no ground truth or cached detected bounding boxes
673
+
674
+ Returns: list of str, images with bounding boxes that failed to download or
675
+ crop properly
676
+ """
677
+
678
+ # error checking before we download and crop any images
679
+ valid_img_paths = set(queried_images_json.keys())
680
+ if images_missing_detections is not None:
681
+ valid_img_paths -= set(images_missing_detections)
682
+ for img_path in valid_img_paths:
683
+ info_dict = queried_images_json[img_path]
684
+ ds, img_file = img_path.split('/', maxsplit=1)
685
+ assert ds == info_dict['dataset']
686
+
687
+ if 'bbox' in info_dict: # ground-truth bounding boxes
688
+ pass
689
+ elif img_file in detection_cache[ds]: # detected bounding boxes
690
+ bbox_dicts = detection_cache[ds][img_file]['detections']
691
+ assert all('conf' in bbox_dict for bbox_dict in bbox_dicts)
692
+ # convert from category ID to category name
693
+ for d in bbox_dicts:
694
+ d['category'] = detection_categories[d['category']]
695
+ else:
696
+ raise ValueError(f'{img_path} has no ground truth bounding boxes '
697
+ 'and was not found in the detection cache. Please '
698
+ 'include it in images_missing_detections.')
699
+
700
+ # we need the datasets table for getting SAS keys
701
+ datasets_table = megadb_utils.MegadbUtils().get_datasets_table()
702
+ container_clients = {} # dataset name => ContainerClient
703
+
704
+ pool = futures.ThreadPoolExecutor(max_workers=threads)
705
+ future_to_img_path = {}
706
+ images_failed_download = []
707
+
708
+ print(f'Getting bbox info for {len(valid_img_paths)} images...')
709
+ for img_path in tqdm(sorted(valid_img_paths)):
710
+ # we already did all error checking above, so we don't do any here
711
+ info_dict = queried_images_json[img_path]
712
+ ds, img_file = img_path.split('/', maxsplit=1)
713
+
714
+ # get ContainerClient
715
+ if ds not in container_clients:
716
+ sas_token = datasets_table[ds]['container_sas_key']
717
+ if sas_token[0] == '?':
718
+ sas_token = sas_token[1:]
719
+ url = sas_blob_utils.build_azure_storage_uri(
720
+ account=datasets_table[ds]['storage_account'],
721
+ container=datasets_table[ds]['container'],
722
+ sas_token=sas_token)
723
+ container_clients[ds] = ContainerClient.from_container_url(url)
724
+ container_client = container_clients[ds]
725
+
726
+ # get bounding boxes
727
+ # we must include the dataset <ds> in <crop_path_template> because
728
+ # '{img_path}' actually gets populated with <img_file> in
729
+ # load_and_crop()
730
+ is_ground_truth = ('bbox' in info_dict)
731
+ if is_ground_truth: # ground-truth bounding boxes
732
+ bbox_dicts = info_dict['bbox']
733
+ crop_path_template = os.path.join(
734
+ cropped_images_dir, ds, '{img_path}___crop{n:>02d}.jpg')
735
+ else: # detected bounding boxes
736
+ bbox_dicts = detection_cache[ds][img_file]['detections']
737
+ crop_path_template = os.path.join(
738
+ cropped_images_dir, ds,
739
+ '{img_path}___crop{n:>02d}_' + f'mdv{detector_version}.jpg')
740
+
741
+ ds_dir = None if images_dir is None else os.path.join(images_dir, ds)
742
+
743
+ # get the image, either from disk or from Blob Storage
744
+ future = pool.submit(
745
+ load_and_crop, img_file, ds_dir, container_client, bbox_dicts,
746
+ confidence_threshold, crop_path_template, save_full_images,
747
+ square_crops, check_crops_valid)
748
+ future_to_img_path[future] = img_path
749
+
750
+ total = len(future_to_img_path)
751
+ total_downloads = 0
752
+ total_new_crops = 0
753
+ print(f'Reading/downloading {total} images and cropping...')
754
+ for future in tqdm(futures.as_completed(future_to_img_path), total=total):
755
+ img_path = future_to_img_path[future]
756
+ try:
757
+ did_download, num_new_crops = future.result()
758
+ total_downloads += did_download
759
+ total_new_crops += num_new_crops
760
+ except Exception as e: # pylint: disable=broad-except
761
+ exception_type = type(e).__name__
762
+ tqdm.write(f'{img_path} - generated {exception_type}: {e}')
763
+ images_failed_download.append(img_path)
764
+
765
+ pool.shutdown()
766
+ for container_client in container_clients.values():
767
+ # inelegant way to close the container_clients
768
+ with container_client:
769
+ pass
770
+
771
+ print(f'Downloaded {total_downloads} images.')
772
+ print(f'Made {total_new_crops} new crops.')
773
+ return images_failed_download, total_downloads, total_new_crops
774
+
775
+
776
+ #%% Command-line driver
777
+
778
+ def _parse_args() -> argparse.Namespace:
779
+
780
+ parser = argparse.ArgumentParser(
781
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
782
+ description='Detects and crops images.')
783
+ parser.add_argument(
784
+ 'queried_images_json',
785
+ help='path to JSON file mapping image paths and classification info')
786
+ parser.add_argument(
787
+ 'output_dir',
788
+ help='path to directory to save log file. If --run-detector, then '
789
+ 'task lists and status responses are also saved here.')
790
+ parser.add_argument(
791
+ '-c', '--detector-output-cache-dir', required=True,
792
+ help='(required) path to directory where detector outputs are cached')
793
+ parser.add_argument(
794
+ '-v', '--detector-version', required=True,
795
+ help='(required) detector version string, e.g., "4.1"')
796
+ parser.add_argument(
797
+ '-d', '--run-detector', action='store_true',
798
+ help='Run the Batch Detection API. If not given, skips running the '
799
+ 'detector (and only use ground truth and cached bounding boxes).')
800
+ parser.add_argument(
801
+ '-r', '--resume-file',
802
+ help='path to save JSON file with list of info dicts on running tasks, '
803
+ 'or to resume from running tasks. Only used if --run-detector is '
804
+ 'set. Each dict has keys '
805
+ '["dataset", "task_id", "task_name", "local_images_list_path", '
806
+ '"remote_images_list_url"]')
807
+ parser.add_argument(
808
+ '-p', '--cropped-images-dir',
809
+ help='path to local directory for saving crops of bounding boxes. No '
810
+ 'images are downloaded or cropped if this argument is not given.')
811
+ parser.add_argument(
812
+ '--save-full-images', action='store_true',
813
+ help='if downloading an image, save the full image to --images-dir, '
814
+ 'only used if <cropped_images_dir> is not None')
815
+ parser.add_argument(
816
+ '--square-crops', action='store_true',
817
+ help='crop bounding boxes as squares, '
818
+ 'only used if <cropped_images_dir> is not None')
819
+ parser.add_argument(
820
+ '--check-crops-valid', action='store_true',
821
+ help='load each crop to ensure file is valid (i.e., not truncated), '
822
+ 'only used if <cropped_images_dir> is not None')
823
+ parser.add_argument(
824
+ '-t', '--threshold', type=float, default=0.0,
825
+ help='confidence threshold above which to crop bounding boxes, '
826
+ 'only used if <cropped_images_dir> is not None')
827
+ parser.add_argument(
828
+ '-i', '--images-dir',
829
+ help='path to local directory where images are saved, '
830
+ 'only used if <cropped_images_dir> is not None')
831
+ parser.add_argument(
832
+ '-n', '--threads', type=int, default=1,
833
+ help='number of threads to use for downloading images, '
834
+ 'only used if <cropped_images_dir> is not None')
835
+ return parser.parse_args()
836
+
837
+
838
+ if __name__ == '__main__':
839
+
840
+ args = _parse_args()
841
+ main(queried_images_json_path=args.queried_images_json,
842
+ output_dir=args.output_dir,
843
+ detector_version=args.detector_version,
844
+ detector_output_cache_base_dir=args.detector_output_cache_dir,
845
+ run_detector=args.run_detector,
846
+ resume_file_path=args.resume_file,
847
+ cropped_images_dir=args.cropped_images_dir,
848
+ save_full_images=args.save_full_images,
849
+ square_crops=args.square_crops,
850
+ check_crops_valid=args.check_crops_valid,
851
+ confidence_threshold=args.threshold,
852
+ images_dir=args.images_dir,
853
+ threads=args.threads)