megadetector 10.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +702 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +528 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +187 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +663 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +876 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2159 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1494 -0
- megadetector/detection/run_tiled_inference.py +1038 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1752 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2077 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +224 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2832 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1759 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1940 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +479 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.13.dist-info/METADATA +134 -0
- megadetector-10.0.13.dist-info/RECORD +147 -0
- megadetector-10.0.13.dist-info/WHEEL +5 -0
- megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,853 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
detect_and_crop.py
|
|
4
|
+
|
|
5
|
+
Run MegaDetector on images via Batch API, then save crops of the detected
|
|
6
|
+
bounding boxes.
|
|
7
|
+
|
|
8
|
+
The input to this script is a "queried images" JSON file, whose keys are paths
|
|
9
|
+
to images and values are dicts containing information relevant for training
|
|
10
|
+
a classifier, including labels and (optionally) ground-truth bounding boxes.
|
|
11
|
+
The image paths are in the format `<dataset-name>/<blob-name>` where we assume
|
|
12
|
+
that the dataset name does not contain '/'.
|
|
13
|
+
|
|
14
|
+
{
|
|
15
|
+
"caltech/cct_images/59f79901-23d2-11e8-a6a3-ec086b02610b.jpg": {
|
|
16
|
+
"dataset": "caltech",
|
|
17
|
+
"location": 13,
|
|
18
|
+
"class": "mountain_lion", # class from dataset
|
|
19
|
+
"bbox": [{"category": "animal",
|
|
20
|
+
"bbox": [0, 0.347, 0.237, 0.257]}], # ground-truth bbox
|
|
21
|
+
"label": ["monutain_lion"] # labels to use in classifier
|
|
22
|
+
},
|
|
23
|
+
"caltech/cct_images/59f5fe2b-23d2-11e8-a6a3-ec086b02610b.jpg": {
|
|
24
|
+
"dataset": "caltech",
|
|
25
|
+
"location": 13,
|
|
26
|
+
"class": "mountain_lion", # class from dataset
|
|
27
|
+
"label": ["monutain_lion"] # labels to use in classifier
|
|
28
|
+
},
|
|
29
|
+
...
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
We assume that no image contains over 100 bounding boxes, and we always save
|
|
33
|
+
crops as RGB .jpg files for consistency. For each image, each bounding box is
|
|
34
|
+
cropped and saved to a file with a suffix "___cropXX.jpg" (ground truth bbox) or
|
|
35
|
+
"___cropXX_mdvY.Y.jpg" (detected bbox) added to the filename of the original
|
|
36
|
+
image. "XX" ranges from "00" to "99" and "Y.Y" indicates the MegaDetector
|
|
37
|
+
version. If an image has ground truth bounding boxes, we assume that they are
|
|
38
|
+
exhaustive--i.e., there are no other objects of interest, so we don't need to
|
|
39
|
+
run MegaDetector on the image. If an image does not have ground truth bounding
|
|
40
|
+
boxes, we run MegaDetector on the image and label the detected boxes in order
|
|
41
|
+
from 00 up to 99. Based on the given confidence threshold, we may skip saving
|
|
42
|
+
certain bounding box crops, but we still increment the bounding box number for
|
|
43
|
+
skipped boxes.
|
|
44
|
+
|
|
45
|
+
Example cropped image path (with ground truth bbox from MegaDB)
|
|
46
|
+
|
|
47
|
+
"path/to/crops/image.jpg___crop00.jpg"
|
|
48
|
+
|
|
49
|
+
Example cropped image path (with MegaDetector bbox)
|
|
50
|
+
|
|
51
|
+
"path/to/crops/image.jpg___crop00_mdv4.1.jpg"
|
|
52
|
+
|
|
53
|
+
By default, the images are cropped exactly per the given bounding box
|
|
54
|
+
coordinates. However, if square crops are desired, pass the --square-crops
|
|
55
|
+
flag. This will always generate a square crop whose size is the larger of the
|
|
56
|
+
bounding box width or height. In the case that the square crop boundaries exceed
|
|
57
|
+
the original image size, the crop is padded with 0s.
|
|
58
|
+
|
|
59
|
+
This script currently only supports running MegaDetector via the Batch Detection
|
|
60
|
+
API. See the classification README for instructions on running MegaDetector
|
|
61
|
+
locally. If running the Batch Detection API, set the following environment
|
|
62
|
+
variables for the Azure Blob Storage container in which we save the intermediate
|
|
63
|
+
task lists:
|
|
64
|
+
|
|
65
|
+
BATCH_DETECTION_API_URL # API URL
|
|
66
|
+
CLASSIFICATION_BLOB_STORAGE_ACCOUNT # storage account name
|
|
67
|
+
CLASSIFICATION_BLOB_CONTAINER # container name
|
|
68
|
+
CLASSIFICATION_BLOB_CONTAINER_WRITE_SAS # SAS token, without leading '?'
|
|
69
|
+
DETECTION_API_CALLER # allow-listed API caller
|
|
70
|
+
|
|
71
|
+
This script allows specifying a directory where MegaDetector outputs are cached
|
|
72
|
+
via the --detector-output-cache-dir argument. This directory must be
|
|
73
|
+
organized as:
|
|
74
|
+
|
|
75
|
+
<cache-dir>/<MegaDetector-version>/<dataset-name>.json
|
|
76
|
+
|
|
77
|
+
Example: If the `cameratrapssc/classifier-training` Azure blob storage
|
|
78
|
+
container is mounted to the local machine via blobfuse, it may be used as
|
|
79
|
+
a MegaDetector output cache directory by passing
|
|
80
|
+
"cameratrapssc/classifier-training/mdcache/"
|
|
81
|
+
as the value for --detector-output-cache-dir.
|
|
82
|
+
|
|
83
|
+
This script outputs either 1 or 3 files, depending on whether the Batch Detection API
|
|
84
|
+
is run:
|
|
85
|
+
|
|
86
|
+
- <output_dir>/detect_and_crop_log_{timestamp}.json
|
|
87
|
+
log of images missing detections and images that failed to properly
|
|
88
|
+
download and crop
|
|
89
|
+
- <output_dir>/batchapi_tasklists/{task_id}.json
|
|
90
|
+
(if --run-detector) task lists uploaded to the Batch Detection API
|
|
91
|
+
- <output_dir>/batchapi_response/{task_id}.json
|
|
92
|
+
(if --run-detector) task status responses for completed tasks
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
#%% Imports
|
|
97
|
+
|
|
98
|
+
from __future__ import annotations
|
|
99
|
+
|
|
100
|
+
import argparse
|
|
101
|
+
from collections.abc import Collection, Iterable, Mapping, Sequence
|
|
102
|
+
from concurrent import futures
|
|
103
|
+
from datetime import datetime
|
|
104
|
+
import json
|
|
105
|
+
import os
|
|
106
|
+
import pprint
|
|
107
|
+
import time
|
|
108
|
+
from typing import Any, Optional
|
|
109
|
+
|
|
110
|
+
from azure.storage.blob import ContainerClient
|
|
111
|
+
import requests
|
|
112
|
+
from tqdm import tqdm
|
|
113
|
+
|
|
114
|
+
from api.batch_processing.data_preparation.prepare_api_submission import (
|
|
115
|
+
BatchAPIResponseError, Task, TaskStatus, divide_list_into_tasks)
|
|
116
|
+
from megadetector.classification.cache_batchapi_outputs import cache_detections
|
|
117
|
+
from megadetector.classification.crop_detections import load_and_crop
|
|
118
|
+
from megadetector.data_management.megadb import megadb_utils
|
|
119
|
+
from megadetector.utils import path_utils
|
|
120
|
+
from megadetector.utils import sas_blob_utils
|
|
121
|
+
from megadetector.utils import ct_utils
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
#%% Example usage
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
python detect_and_crop.py \
|
|
128
|
+
base_logdir/queried_images.json \
|
|
129
|
+
base_logdir \
|
|
130
|
+
--detector-output-cache-dir /path/to/classifier-training/mdcache \
|
|
131
|
+
--detector-version 4.1 \
|
|
132
|
+
--run-detector --resume-file base_logdir/resume.json \
|
|
133
|
+
--cropped-images-dir /path/to/crops --square-crops --threshold 0.9 \
|
|
134
|
+
--save-full-images --images-dir /path/to/images --threads 50
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
#%% Main function
|
|
139
|
+
|
|
140
|
+
def main(queried_images_json_path: str,
|
|
141
|
+
output_dir: str,
|
|
142
|
+
detector_version: str,
|
|
143
|
+
detector_output_cache_base_dir: str,
|
|
144
|
+
run_detector: bool,
|
|
145
|
+
resume_file_path: Optional[str],
|
|
146
|
+
cropped_images_dir: Optional[str],
|
|
147
|
+
save_full_images: bool,
|
|
148
|
+
square_crops: bool,
|
|
149
|
+
check_crops_valid: bool,
|
|
150
|
+
confidence_threshold: float,
|
|
151
|
+
images_dir: Optional[str],
|
|
152
|
+
threads: int) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Args:
|
|
155
|
+
queried_images_json_path: str, path to output of json_validator.py
|
|
156
|
+
detector_version: str, detector version string, e.g., '4.1',
|
|
157
|
+
see {batch_detection_api_url}/supported_model_versions,
|
|
158
|
+
determines the subfolder of detector_output_cache_base_dir in
|
|
159
|
+
which to find and save detector outputs
|
|
160
|
+
detector_output_cache_base_dir: str, path to local directory
|
|
161
|
+
where detector outputs are cached, 1 JSON file per dataset
|
|
162
|
+
cropped_images_dir: str, path to local directory for saving crops of
|
|
163
|
+
bounding boxes
|
|
164
|
+
run_detector: bool, whether to run Batch Detection API, or to skip
|
|
165
|
+
running the detector entirely
|
|
166
|
+
output_dir: str, path to directory to save outputs, see module docstring
|
|
167
|
+
save_full_images: bool, whether to save downloaded images to images_dir,
|
|
168
|
+
images_dir must be given if save_full_images=True
|
|
169
|
+
square_crops: bool, whether to crop bounding boxes as squares
|
|
170
|
+
check_crops_valid: bool, whether to load each crop to ensure the file is
|
|
171
|
+
valid (i.e., not truncated)
|
|
172
|
+
confidence_threshold: float, only crop bounding boxes above this value
|
|
173
|
+
images_dir: optional str, path to local directory where images are saved
|
|
174
|
+
threads: int, number of threads to use for downloading images
|
|
175
|
+
resume_file_path: optional str, path to save JSON file with list of info
|
|
176
|
+
dicts on running tasks, or to resume from running tasks, only used
|
|
177
|
+
if run_detector=True
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
# This dictionary will get written out at the end of this process; store
|
|
181
|
+
# diagnostic variables here
|
|
182
|
+
log: dict[str, Any] = {}
|
|
183
|
+
|
|
184
|
+
# error checking
|
|
185
|
+
assert 0 <= confidence_threshold <= 1
|
|
186
|
+
if save_full_images:
|
|
187
|
+
assert images_dir is not None
|
|
188
|
+
if not os.path.exists(images_dir):
|
|
189
|
+
os.makedirs(images_dir, exist_ok=True)
|
|
190
|
+
print(f'Created images_dir at {images_dir}')
|
|
191
|
+
|
|
192
|
+
with open(queried_images_json_path, 'r') as f:
|
|
193
|
+
js = json.load(f)
|
|
194
|
+
detector_output_cache_dir = os.path.join(
|
|
195
|
+
detector_output_cache_base_dir, f'v{detector_version}')
|
|
196
|
+
if not os.path.exists(detector_output_cache_dir):
|
|
197
|
+
os.makedirs(detector_output_cache_dir)
|
|
198
|
+
print(f'Created directory at {detector_output_cache_dir}')
|
|
199
|
+
images_without_ground_truth_bbox = [k for k in js if 'bbox' not in js[k]]
|
|
200
|
+
images_to_detect, detection_cache, categories = filter_detected_images(
|
|
201
|
+
potential_images_to_detect=images_without_ground_truth_bbox,
|
|
202
|
+
detector_output_cache_dir=detector_output_cache_dir)
|
|
203
|
+
print(f'{len(images_to_detect)} images not in detection cache')
|
|
204
|
+
|
|
205
|
+
if run_detector:
|
|
206
|
+
log['images_submitted_for_detection'] = images_to_detect
|
|
207
|
+
|
|
208
|
+
assert resume_file_path is not None
|
|
209
|
+
assert not os.path.isdir(resume_file_path)
|
|
210
|
+
batch_detection_api_url = os.environ['BATCH_DETECTION_API_URL']
|
|
211
|
+
|
|
212
|
+
if os.path.exists(resume_file_path):
|
|
213
|
+
tasks_by_dataset = resume_tasks(
|
|
214
|
+
resume_file_path,
|
|
215
|
+
batch_detection_api_url=batch_detection_api_url)
|
|
216
|
+
else:
|
|
217
|
+
task_lists_dir = os.path.join(output_dir, 'batchapi_tasklists')
|
|
218
|
+
tasks_by_dataset = submit_batch_detection_api(
|
|
219
|
+
images_to_detect=images_to_detect,
|
|
220
|
+
task_lists_dir=task_lists_dir,
|
|
221
|
+
detector_version=detector_version,
|
|
222
|
+
account=os.environ['CLASSIFICATION_BLOB_STORAGE_ACCOUNT'],
|
|
223
|
+
container=os.environ['CLASSIFICATION_BLOB_CONTAINER'],
|
|
224
|
+
sas_token=os.environ['CLASSIFICATION_BLOB_CONTAINER_WRITE_SAS'],
|
|
225
|
+
caller=os.environ['DETECTION_API_CALLER'],
|
|
226
|
+
batch_detection_api_url=batch_detection_api_url,
|
|
227
|
+
resume_file_path=resume_file_path)
|
|
228
|
+
|
|
229
|
+
wait_for_tasks(tasks_by_dataset, detector_output_cache_dir,
|
|
230
|
+
output_dir=output_dir)
|
|
231
|
+
|
|
232
|
+
# refresh detection cache
|
|
233
|
+
print('Refreshing detection cache...')
|
|
234
|
+
images_to_detect, detection_cache, categories = filter_detected_images(
|
|
235
|
+
potential_images_to_detect=images_without_ground_truth_bbox,
|
|
236
|
+
detector_output_cache_dir=detector_output_cache_dir)
|
|
237
|
+
print(f'{len(images_to_detect)} images not in detection cache')
|
|
238
|
+
|
|
239
|
+
log['images_missing_detections'] = images_to_detect
|
|
240
|
+
|
|
241
|
+
if cropped_images_dir is not None:
|
|
242
|
+
|
|
243
|
+
images_failed_dload_crop, num_downloads, num_crops = download_and_crop(
|
|
244
|
+
queried_images_json=js,
|
|
245
|
+
detection_cache=detection_cache,
|
|
246
|
+
detection_categories=categories,
|
|
247
|
+
detector_version=detector_version,
|
|
248
|
+
cropped_images_dir=cropped_images_dir,
|
|
249
|
+
confidence_threshold=confidence_threshold,
|
|
250
|
+
save_full_images=save_full_images,
|
|
251
|
+
square_crops=square_crops,
|
|
252
|
+
check_crops_valid=check_crops_valid,
|
|
253
|
+
images_dir=images_dir,
|
|
254
|
+
threads=threads,
|
|
255
|
+
images_missing_detections=images_to_detect)
|
|
256
|
+
log['images_failed_download_or_crop'] = images_failed_dload_crop
|
|
257
|
+
log['num_new_downloads'] = num_downloads
|
|
258
|
+
log['num_new_crops'] = num_crops
|
|
259
|
+
|
|
260
|
+
print(f'{len(images_to_detect)} images with missing detections.')
|
|
261
|
+
if cropped_images_dir is not None:
|
|
262
|
+
print(f'{len(images_failed_dload_crop)} images failed to download or '
|
|
263
|
+
'crop.')
|
|
264
|
+
|
|
265
|
+
# save log of bad images
|
|
266
|
+
date = datetime.now().strftime('%Y%m%d_%H%M%S') # e.g., '20200722_110816'
|
|
267
|
+
log_path = os.path.join(output_dir, f'detect_and_crop_log_{date}.json')
|
|
268
|
+
ct_utils.write_json(log_path, log)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
#%% Support functions
|
|
272
|
+
|
|
273
|
+
def load_detection_cache(detector_output_cache_dir: str,
|
|
274
|
+
datasets: Collection[str]) -> tuple[
|
|
275
|
+
dict[str, dict[str, dict[str, Any]]],
|
|
276
|
+
dict[str, str]
|
|
277
|
+
]:
|
|
278
|
+
"""
|
|
279
|
+
Loads detection cache for a given dataset. Returns empty dictionaries
|
|
280
|
+
if the cache does not exist.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
detector_output_cache_dir: str, path to local directory where detector
|
|
284
|
+
outputs are cached, 1 JSON file per dataset
|
|
285
|
+
datasets: list of str, names of datasets
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
detection_cache: dict, maps dataset name to dict, which maps
|
|
289
|
+
image file to corresponding entry in 'images' list from the
|
|
290
|
+
Batch Detection API output. detection_cache[ds] is an empty dict
|
|
291
|
+
if no cached detections were found for the given dataset ds.
|
|
292
|
+
detection_categories: dict, maps str category ID to str category name
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
# cache of Detector outputs: dataset name => {img_path => detection_dict}
|
|
296
|
+
detection_cache = {}
|
|
297
|
+
detection_categories: dict[str, str] = {}
|
|
298
|
+
|
|
299
|
+
pbar = tqdm(datasets)
|
|
300
|
+
for ds in pbar:
|
|
301
|
+
pbar.set_description(f'Loading dataset {ds} into detection cache')
|
|
302
|
+
cache_path = os.path.join(detector_output_cache_dir, f'{ds}.json')
|
|
303
|
+
if os.path.exists(cache_path):
|
|
304
|
+
with open(cache_path, 'r') as f:
|
|
305
|
+
js = json.load(f)
|
|
306
|
+
detection_cache[ds] = {img['file']: img for img in js['images']}
|
|
307
|
+
if len(detection_categories) == 0:
|
|
308
|
+
detection_categories = js['detection_categories']
|
|
309
|
+
assert detection_categories == js['detection_categories']
|
|
310
|
+
else:
|
|
311
|
+
tqdm.write(f'No detection cache found for dataset {ds}')
|
|
312
|
+
detection_cache[ds] = {}
|
|
313
|
+
return detection_cache, detection_categories
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def filter_detected_images(
|
|
317
|
+
potential_images_to_detect: Iterable[str],
|
|
318
|
+
detector_output_cache_dir: str
|
|
319
|
+
) -> tuple[list[str],
|
|
320
|
+
dict[str, dict[str, dict[str, Any]]],
|
|
321
|
+
dict[str, str]]:
|
|
322
|
+
"""
|
|
323
|
+
Checks image paths against cached Detector outputs, and prepares
|
|
324
|
+
the SAS URIs for each image not in the cache.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
potential_images_to_detect: list of str, paths to images that do not
|
|
328
|
+
have ground truth bounding boxes, each path has format
|
|
329
|
+
<dataset-name>/<img-filename>, where <img-filename> is the blob name
|
|
330
|
+
detector_output_cache_dir: str, path to local directory where detector
|
|
331
|
+
outputs are cached, 1 JSON file per dataset
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
images_to_detect: list of str, paths to images not in the detector
|
|
335
|
+
output cache, with the format <dataset-name>/<img-filename>
|
|
336
|
+
detection_cache: dict, maps str dataset name to dict,
|
|
337
|
+
detection_cache[dataset_name] is the 'detections' list from the
|
|
338
|
+
Batch Detection API output
|
|
339
|
+
detection_categories: dict, maps str category ID to str category name,
|
|
340
|
+
empty dict if no cached detections are found
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
datasets = set(img_path[:img_path.find('/')]
|
|
344
|
+
for img_path in potential_images_to_detect)
|
|
345
|
+
detection_cache, detection_categories = load_detection_cache(
|
|
346
|
+
detector_output_cache_dir, datasets)
|
|
347
|
+
|
|
348
|
+
images_to_detect = []
|
|
349
|
+
for img_path in potential_images_to_detect:
|
|
350
|
+
# img_path: <dataset-name>/<img-filename>
|
|
351
|
+
ds, img_file = img_path.split('/', maxsplit=1)
|
|
352
|
+
if img_file not in detection_cache[ds]:
|
|
353
|
+
images_to_detect.append(img_path)
|
|
354
|
+
|
|
355
|
+
return images_to_detect, detection_cache, detection_categories
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def split_images_list_by_dataset(images_to_detect: Iterable[str]
|
|
359
|
+
) -> dict[str, list[str]]:
|
|
360
|
+
"""
|
|
361
|
+
Args:
|
|
362
|
+
images_to_detect: list of str, image paths with the format
|
|
363
|
+
<dataset-name>/<image-filename>
|
|
364
|
+
|
|
365
|
+
Returns: dict, maps dataset name to a list of image paths
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
images_by_dataset: dict[str, list[str]] = {}
|
|
369
|
+
for img_path in images_to_detect:
|
|
370
|
+
dataset = img_path[:img_path.find('/')]
|
|
371
|
+
if dataset not in images_by_dataset:
|
|
372
|
+
images_by_dataset[dataset] = []
|
|
373
|
+
images_by_dataset[dataset].append(img_path)
|
|
374
|
+
return images_by_dataset
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def submit_batch_detection_api(images_to_detect: Iterable[str],
|
|
378
|
+
task_lists_dir: str,
|
|
379
|
+
detector_version: str,
|
|
380
|
+
account: str,
|
|
381
|
+
container: str,
|
|
382
|
+
sas_token: str,
|
|
383
|
+
caller: str,
|
|
384
|
+
batch_detection_api_url: str,
|
|
385
|
+
resume_file_path: str
|
|
386
|
+
) -> dict[str, list[Task]]:
|
|
387
|
+
"""
|
|
388
|
+
Args:
|
|
389
|
+
images_to_detect: list of str, list of str, image paths with the format
|
|
390
|
+
<dataset-name>/<image-filename>
|
|
391
|
+
task_lists_dir: str, path to local directory for saving JSON files
|
|
392
|
+
each containing a list of image URLs corresponding to an API task
|
|
393
|
+
detector_version: str, MegaDetector version string, e.g., '4.1',
|
|
394
|
+
see {batch_detection_api_url}/supported_model_versions
|
|
395
|
+
account: str, Azure Storage account name
|
|
396
|
+
container: str, Azure Blob Storage container name, where the task lists
|
|
397
|
+
will be uploaded
|
|
398
|
+
sas_token: str, SAS token with write permissions for the container
|
|
399
|
+
caller: str, allow-listed caller
|
|
400
|
+
batch_detection_api_url: str, URL to batch detection API
|
|
401
|
+
resume_file_path: str, path to save resume file
|
|
402
|
+
|
|
403
|
+
Returns: dict, maps str dataset name to list of Task objects
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
filtered_images_to_detect = [
|
|
407
|
+
x for x in images_to_detect if path_utils.is_image_file(x)]
|
|
408
|
+
not_images = set(images_to_detect) - set(filtered_images_to_detect)
|
|
409
|
+
if len(not_images) == 0:
|
|
410
|
+
print('Good! All image files have valid file extensions.')
|
|
411
|
+
else:
|
|
412
|
+
print(f'Skipping {len(not_images)} files with non-image extensions:')
|
|
413
|
+
pprint.pprint(sorted(not_images))
|
|
414
|
+
images_to_detect = filtered_images_to_detect
|
|
415
|
+
|
|
416
|
+
datasets_table = megadb_utils.MegadbUtils().get_datasets_table()
|
|
417
|
+
|
|
418
|
+
images_by_dataset = split_images_list_by_dataset(images_to_detect)
|
|
419
|
+
tasks_by_dataset = {}
|
|
420
|
+
for dataset, image_paths in images_by_dataset.items():
|
|
421
|
+
# get SAS URL for images container
|
|
422
|
+
images_sas_token = datasets_table[dataset]['container_sas_key']
|
|
423
|
+
if images_sas_token[0] == '?':
|
|
424
|
+
images_sas_token = images_sas_token[1:]
|
|
425
|
+
images_container_url = sas_blob_utils.build_azure_storage_uri(
|
|
426
|
+
account=datasets_table[dataset]['storage_account'],
|
|
427
|
+
container=datasets_table[dataset]['container'],
|
|
428
|
+
sas_token=images_sas_token)
|
|
429
|
+
|
|
430
|
+
# strip image paths of dataset name
|
|
431
|
+
image_blob_names = [path[path.find('/') + 1:] for path in image_paths]
|
|
432
|
+
|
|
433
|
+
tasks_by_dataset[dataset] = submit_batch_detection_api_by_dataset(
|
|
434
|
+
dataset=dataset,
|
|
435
|
+
image_blob_names=image_blob_names,
|
|
436
|
+
images_container_url=images_container_url,
|
|
437
|
+
task_lists_dir=task_lists_dir,
|
|
438
|
+
detector_version=detector_version,
|
|
439
|
+
account=account, container=container, sas_token=sas_token,
|
|
440
|
+
caller=caller, batch_detection_api_url=batch_detection_api_url)
|
|
441
|
+
|
|
442
|
+
# save list of dataset names and task IDs for resuming
|
|
443
|
+
resume_json = [
|
|
444
|
+
{
|
|
445
|
+
'dataset': dataset,
|
|
446
|
+
'task_name': task.name,
|
|
447
|
+
'task_id': task.id,
|
|
448
|
+
'local_images_list_path': task.local_images_list_path
|
|
449
|
+
}
|
|
450
|
+
for dataset in tasks_by_dataset
|
|
451
|
+
for task in tasks_by_dataset[dataset]
|
|
452
|
+
]
|
|
453
|
+
ct_utils.write_json(resume_file_path, resume_json)
|
|
454
|
+
return tasks_by_dataset
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def submit_batch_detection_api_by_dataset(
|
|
458
|
+
dataset: str,
|
|
459
|
+
image_blob_names: Sequence[str],
|
|
460
|
+
images_container_url: str,
|
|
461
|
+
task_lists_dir: str,
|
|
462
|
+
detector_version: str,
|
|
463
|
+
account: str,
|
|
464
|
+
container: str,
|
|
465
|
+
sas_token: str,
|
|
466
|
+
caller: str,
|
|
467
|
+
batch_detection_api_url: str
|
|
468
|
+
) -> list[Task]:
|
|
469
|
+
"""
|
|
470
|
+
Args:
|
|
471
|
+
dataset: str, name of dataset
|
|
472
|
+
image_blob_names: list of str, image blob names from the same dataset
|
|
473
|
+
images_container_url: str, URL to blob storage container where images
|
|
474
|
+
from this dataset are stored, including SAS token with read
|
|
475
|
+
permissions if container is not public
|
|
476
|
+
**see submit_batch_detection_api() for description of other args
|
|
477
|
+
|
|
478
|
+
Returns: list of Task objects
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
os.makedirs(task_lists_dir, exist_ok=True)
|
|
482
|
+
|
|
483
|
+
date = datetime.now().strftime('%Y%m%d_%H%M%S') # e.g., '20200722_110816'
|
|
484
|
+
task_list_base_filename = f'task_list_{dataset}_{date}.json'
|
|
485
|
+
|
|
486
|
+
task_list_paths, _ = divide_list_into_tasks(
|
|
487
|
+
file_list=image_blob_names,
|
|
488
|
+
save_path=os.path.join(task_lists_dir, task_list_base_filename))
|
|
489
|
+
|
|
490
|
+
# complete task name: 'detect_for_classifier_caltech_20200722_110816_task01'
|
|
491
|
+
task_name_template = 'detect_for_classifier_{dataset}_{date}_task{n:>02d}'
|
|
492
|
+
tasks: list[Task] = []
|
|
493
|
+
for i, task_list_path in enumerate(task_list_paths):
|
|
494
|
+
task = Task(
|
|
495
|
+
name=task_name_template.format(dataset=dataset, date=date, n=i),
|
|
496
|
+
images_list_path=task_list_path, api_url=batch_detection_api_url)
|
|
497
|
+
task.upload_images_list(
|
|
498
|
+
account=account, container=container, sas_token=sas_token)
|
|
499
|
+
task.generate_api_request(
|
|
500
|
+
caller=caller,
|
|
501
|
+
input_container_url=images_container_url,
|
|
502
|
+
model_version=detector_version)
|
|
503
|
+
print(f'Submitting task for: {task_list_path}')
|
|
504
|
+
task.submit()
|
|
505
|
+
print(f'- task ID: {task.id}')
|
|
506
|
+
tasks.append(task)
|
|
507
|
+
|
|
508
|
+
# HACK! Sleep for 10s between task submissions in the hopes that it
|
|
509
|
+
# decreases the chance of backend JSON "database" corruption
|
|
510
|
+
time.sleep(10)
|
|
511
|
+
return tasks
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def resume_tasks(resume_file_path: str, batch_detection_api_url: str
|
|
515
|
+
) -> dict[str, list[Task]]:
|
|
516
|
+
"""
|
|
517
|
+
Args:
|
|
518
|
+
resume_file_path: str, path to resume file with list of info dicts on
|
|
519
|
+
running tasks
|
|
520
|
+
batch_detection_api_url: str, URL to batch detection API
|
|
521
|
+
|
|
522
|
+
Returns: dict, maps str dataset name to list of Task objects
|
|
523
|
+
"""
|
|
524
|
+
|
|
525
|
+
with open(resume_file_path, 'r') as f:
|
|
526
|
+
resume_json = json.load(f)
|
|
527
|
+
|
|
528
|
+
tasks_by_dataset: dict[str, list[Task]] = {}
|
|
529
|
+
for info_dict in resume_json:
|
|
530
|
+
dataset = info_dict['dataset']
|
|
531
|
+
if dataset not in tasks_by_dataset:
|
|
532
|
+
tasks_by_dataset[dataset] = []
|
|
533
|
+
task = Task(name=info_dict['task_name'],
|
|
534
|
+
task_id=info_dict['task_id'],
|
|
535
|
+
images_list_path=info_dict['local_images_list_path'],
|
|
536
|
+
validate=False,
|
|
537
|
+
api_url=batch_detection_api_url)
|
|
538
|
+
tasks_by_dataset[dataset].append(task)
|
|
539
|
+
return tasks_by_dataset
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def wait_for_tasks(tasks_by_dataset: Mapping[str, Iterable[Task]],
|
|
543
|
+
detector_output_cache_dir: str,
|
|
544
|
+
output_dir: Optional[str] = None,
|
|
545
|
+
poll_interval: int = 120) -> None:
|
|
546
|
+
"""
|
|
547
|
+
Waits for the Batch Detection API tasks to finish running.
|
|
548
|
+
|
|
549
|
+
For jobs that finish successfully, merges the output with cached detector
|
|
550
|
+
outputs.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
tasks_by_dataset: dict, maps str dataset name to list of Task objects
|
|
554
|
+
detector_output_cache_dir: str, path to local directory where detector
|
|
555
|
+
outputs are cached, 1 JSON file per dataset, directory must
|
|
556
|
+
already exist
|
|
557
|
+
output_dir: optional str, task status responses for completed tasks are
|
|
558
|
+
saved to <output_dir>/batchapi_response/{task_id}.json
|
|
559
|
+
poll_interval: int, # of seconds between pinging the task status API
|
|
560
|
+
"""
|
|
561
|
+
|
|
562
|
+
remaining_tasks: list[tuple[str, Task]] = [
|
|
563
|
+
(dataset, task) for dataset, tasks in tasks_by_dataset.items()
|
|
564
|
+
for task in tasks]
|
|
565
|
+
|
|
566
|
+
progbar = tqdm(total=len(remaining_tasks))
|
|
567
|
+
while True:
|
|
568
|
+
new_remaining_tasks = []
|
|
569
|
+
for dataset, task in remaining_tasks:
|
|
570
|
+
try:
|
|
571
|
+
task.check_status()
|
|
572
|
+
except (BatchAPIResponseError, requests.HTTPError) as e:
|
|
573
|
+
exception_type = type(e).__name__
|
|
574
|
+
tqdm.write(f'Error in checking status of task {task.id}: '
|
|
575
|
+
f'({exception_type}) {e}')
|
|
576
|
+
tqdm.write(f'Skipping task {task.id}.')
|
|
577
|
+
continue
|
|
578
|
+
|
|
579
|
+
# task still running => continue
|
|
580
|
+
if task.status == TaskStatus.RUNNING:
|
|
581
|
+
new_remaining_tasks.append((dataset, task))
|
|
582
|
+
continue
|
|
583
|
+
|
|
584
|
+
progbar.update(1)
|
|
585
|
+
tqdm.write(f'Task {task.id} stopped with status {task.status}')
|
|
586
|
+
|
|
587
|
+
if task.status in [TaskStatus.PROBLEM, TaskStatus.FAILED]:
|
|
588
|
+
tqdm.write('API response:')
|
|
589
|
+
tqdm.write(str(task.response))
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
# task finished successfully, save response to disk
|
|
593
|
+
assert task.status == TaskStatus.COMPLETED
|
|
594
|
+
if output_dir is not None:
|
|
595
|
+
save_dir = os.path.join(output_dir, 'batchapi_response')
|
|
596
|
+
if not os.path.exists(save_dir):
|
|
597
|
+
tqdm.write(f'Creating API output dir: {save_dir}')
|
|
598
|
+
os.makedirs(save_dir)
|
|
599
|
+
ct_utils.write_json(os.path.join(save_dir, f'{task.id}.json'), task.response)
|
|
600
|
+
message = task.response['Status']['message']
|
|
601
|
+
num_failed_shards = message['num_failed_shards']
|
|
602
|
+
if num_failed_shards != 0:
|
|
603
|
+
tqdm.write(f'Task {task.id} completed with {num_failed_shards} '
|
|
604
|
+
'failed shards.')
|
|
605
|
+
|
|
606
|
+
detections_url = message['output_file_urls']['detections']
|
|
607
|
+
if task.id not in detections_url:
|
|
608
|
+
tqdm.write('Invalid detections URL in response. Skipping task.')
|
|
609
|
+
continue
|
|
610
|
+
|
|
611
|
+
detections = requests.get(detections_url).json()
|
|
612
|
+
msg = cache_detections(
|
|
613
|
+
detections=detections, dataset=dataset,
|
|
614
|
+
detector_output_cache_dir=detector_output_cache_dir)
|
|
615
|
+
tqdm.write(msg)
|
|
616
|
+
|
|
617
|
+
remaining_tasks = new_remaining_tasks
|
|
618
|
+
if len(remaining_tasks) == 0:
|
|
619
|
+
break
|
|
620
|
+
tqdm.write(f'Sleeping for {poll_interval} seconds...')
|
|
621
|
+
time.sleep(poll_interval)
|
|
622
|
+
|
|
623
|
+
progbar.close()
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def download_and_crop(
|
|
627
|
+
queried_images_json: Mapping[str, Mapping[str, Any]],
|
|
628
|
+
detection_cache: Mapping[str, Mapping[str, Mapping[str, Any]]],
|
|
629
|
+
detection_categories: Mapping[str, str],
|
|
630
|
+
detector_version: str,
|
|
631
|
+
cropped_images_dir: str,
|
|
632
|
+
confidence_threshold: float,
|
|
633
|
+
save_full_images: bool,
|
|
634
|
+
square_crops: bool,
|
|
635
|
+
check_crops_valid: bool,
|
|
636
|
+
images_dir: Optional[str] = None,
|
|
637
|
+
threads: int = 1,
|
|
638
|
+
images_missing_detections: Optional[Iterable[str]] = None
|
|
639
|
+
) -> tuple[list[str], int, int]:
|
|
640
|
+
"""
|
|
641
|
+
Saves crops to a file with the same name as the original image with an
|
|
642
|
+
additional suffix appended, starting with 3 underscores:
|
|
643
|
+
- if image has ground truth bboxes: "___cropXX.jpg", where "XX" indicates
|
|
644
|
+
the bounding box index
|
|
645
|
+
- if image has bboxes from MegaDetector: "___cropXX_mdvY.Y.jpg", where
|
|
646
|
+
"Y.Y" indicates the MegaDetector version
|
|
647
|
+
See module docstring for more info and examples.
|
|
648
|
+
|
|
649
|
+
Note: this function is very similar to the "download_and_crop()" function in
|
|
650
|
+
crop_detections.py. The main difference is that this function uses
|
|
651
|
+
MegaDB to look up Azure Storage container information for images based
|
|
652
|
+
on the dataset, whereas the crop_detections.py version has no concept
|
|
653
|
+
of a "dataset" and "ground-truth" bounding boxes from MegaDB.
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
queried_images_json: dict, represents JSON output of json_validator.py,
|
|
657
|
+
all images in queried_images_json are assumed to have either ground
|
|
658
|
+
truth or cached detected bounding boxes unless
|
|
659
|
+
images_missing_detections is given
|
|
660
|
+
detection_cache: dict, dataset_name => {img_path => detection_dict}
|
|
661
|
+
detector_version: str, detector version string, e.g., '4.1'
|
|
662
|
+
cropped_images_dir: str, path to folder where cropped images are saved
|
|
663
|
+
confidence_threshold: float, only crop bounding boxes above this value
|
|
664
|
+
save_full_images: bool, whether to save downloaded images to images_dir,
|
|
665
|
+
images_dir must be given and must exist if save_full_images=True
|
|
666
|
+
square_crops: bool, whether to crop bounding boxes as squares
|
|
667
|
+
check_crops_valid: bool, whether to load each crop to ensure the file is
|
|
668
|
+
valid (i.e., not truncated)
|
|
669
|
+
images_dir: optional str, path to folder where full images are saved
|
|
670
|
+
threads: int, number of threads to use for downloading images
|
|
671
|
+
images_missing_detections: optional list of str, image files to skip
|
|
672
|
+
because they have no ground truth or cached detected bounding boxes
|
|
673
|
+
|
|
674
|
+
Returns: list of str, images with bounding boxes that failed to download or
|
|
675
|
+
crop properly
|
|
676
|
+
"""
|
|
677
|
+
|
|
678
|
+
# error checking before we download and crop any images
|
|
679
|
+
valid_img_paths = set(queried_images_json.keys())
|
|
680
|
+
if images_missing_detections is not None:
|
|
681
|
+
valid_img_paths -= set(images_missing_detections)
|
|
682
|
+
for img_path in valid_img_paths:
|
|
683
|
+
info_dict = queried_images_json[img_path]
|
|
684
|
+
ds, img_file = img_path.split('/', maxsplit=1)
|
|
685
|
+
assert ds == info_dict['dataset']
|
|
686
|
+
|
|
687
|
+
if 'bbox' in info_dict: # ground-truth bounding boxes
|
|
688
|
+
pass
|
|
689
|
+
elif img_file in detection_cache[ds]: # detected bounding boxes
|
|
690
|
+
bbox_dicts = detection_cache[ds][img_file]['detections']
|
|
691
|
+
assert all('conf' in bbox_dict for bbox_dict in bbox_dicts)
|
|
692
|
+
# convert from category ID to category name
|
|
693
|
+
for d in bbox_dicts:
|
|
694
|
+
d['category'] = detection_categories[d['category']]
|
|
695
|
+
else:
|
|
696
|
+
raise ValueError(f'{img_path} has no ground truth bounding boxes '
|
|
697
|
+
'and was not found in the detection cache. Please '
|
|
698
|
+
'include it in images_missing_detections.')
|
|
699
|
+
|
|
700
|
+
# we need the datasets table for getting SAS keys
|
|
701
|
+
datasets_table = megadb_utils.MegadbUtils().get_datasets_table()
|
|
702
|
+
container_clients = {} # dataset name => ContainerClient
|
|
703
|
+
|
|
704
|
+
pool = futures.ThreadPoolExecutor(max_workers=threads)
|
|
705
|
+
future_to_img_path = {}
|
|
706
|
+
images_failed_download = []
|
|
707
|
+
|
|
708
|
+
print(f'Getting bbox info for {len(valid_img_paths)} images...')
|
|
709
|
+
for img_path in tqdm(sorted(valid_img_paths)):
|
|
710
|
+
# we already did all error checking above, so we don't do any here
|
|
711
|
+
info_dict = queried_images_json[img_path]
|
|
712
|
+
ds, img_file = img_path.split('/', maxsplit=1)
|
|
713
|
+
|
|
714
|
+
# get ContainerClient
|
|
715
|
+
if ds not in container_clients:
|
|
716
|
+
sas_token = datasets_table[ds]['container_sas_key']
|
|
717
|
+
if sas_token[0] == '?':
|
|
718
|
+
sas_token = sas_token[1:]
|
|
719
|
+
url = sas_blob_utils.build_azure_storage_uri(
|
|
720
|
+
account=datasets_table[ds]['storage_account'],
|
|
721
|
+
container=datasets_table[ds]['container'],
|
|
722
|
+
sas_token=sas_token)
|
|
723
|
+
container_clients[ds] = ContainerClient.from_container_url(url)
|
|
724
|
+
container_client = container_clients[ds]
|
|
725
|
+
|
|
726
|
+
# get bounding boxes
|
|
727
|
+
# we must include the dataset <ds> in <crop_path_template> because
|
|
728
|
+
# '{img_path}' actually gets populated with <img_file> in
|
|
729
|
+
# load_and_crop()
|
|
730
|
+
is_ground_truth = ('bbox' in info_dict)
|
|
731
|
+
if is_ground_truth: # ground-truth bounding boxes
|
|
732
|
+
bbox_dicts = info_dict['bbox']
|
|
733
|
+
crop_path_template = os.path.join(
|
|
734
|
+
cropped_images_dir, ds, '{img_path}___crop{n:>02d}.jpg')
|
|
735
|
+
else: # detected bounding boxes
|
|
736
|
+
bbox_dicts = detection_cache[ds][img_file]['detections']
|
|
737
|
+
crop_path_template = os.path.join(
|
|
738
|
+
cropped_images_dir, ds,
|
|
739
|
+
'{img_path}___crop{n:>02d}_' + f'mdv{detector_version}.jpg')
|
|
740
|
+
|
|
741
|
+
ds_dir = None if images_dir is None else os.path.join(images_dir, ds)
|
|
742
|
+
|
|
743
|
+
# get the image, either from disk or from Blob Storage
|
|
744
|
+
future = pool.submit(
|
|
745
|
+
load_and_crop, img_file, ds_dir, container_client, bbox_dicts,
|
|
746
|
+
confidence_threshold, crop_path_template, save_full_images,
|
|
747
|
+
square_crops, check_crops_valid)
|
|
748
|
+
future_to_img_path[future] = img_path
|
|
749
|
+
|
|
750
|
+
total = len(future_to_img_path)
|
|
751
|
+
total_downloads = 0
|
|
752
|
+
total_new_crops = 0
|
|
753
|
+
print(f'Reading/downloading {total} images and cropping...')
|
|
754
|
+
for future in tqdm(futures.as_completed(future_to_img_path), total=total):
|
|
755
|
+
img_path = future_to_img_path[future]
|
|
756
|
+
try:
|
|
757
|
+
did_download, num_new_crops = future.result()
|
|
758
|
+
total_downloads += did_download
|
|
759
|
+
total_new_crops += num_new_crops
|
|
760
|
+
except Exception as e: # pylint: disable=broad-except
|
|
761
|
+
exception_type = type(e).__name__
|
|
762
|
+
tqdm.write(f'{img_path} - generated {exception_type}: {e}')
|
|
763
|
+
images_failed_download.append(img_path)
|
|
764
|
+
|
|
765
|
+
pool.shutdown()
|
|
766
|
+
for container_client in container_clients.values():
|
|
767
|
+
# inelegant way to close the container_clients
|
|
768
|
+
with container_client:
|
|
769
|
+
pass
|
|
770
|
+
|
|
771
|
+
print(f'Downloaded {total_downloads} images.')
|
|
772
|
+
print(f'Made {total_new_crops} new crops.')
|
|
773
|
+
return images_failed_download, total_downloads, total_new_crops
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
#%% Command-line driver
|
|
777
|
+
|
|
778
|
+
def _parse_args() -> argparse.Namespace:
|
|
779
|
+
|
|
780
|
+
parser = argparse.ArgumentParser(
|
|
781
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
782
|
+
description='Detects and crops images.')
|
|
783
|
+
parser.add_argument(
|
|
784
|
+
'queried_images_json',
|
|
785
|
+
help='path to JSON file mapping image paths and classification info')
|
|
786
|
+
parser.add_argument(
|
|
787
|
+
'output_dir',
|
|
788
|
+
help='path to directory to save log file. If --run-detector, then '
|
|
789
|
+
'task lists and status responses are also saved here.')
|
|
790
|
+
parser.add_argument(
|
|
791
|
+
'-c', '--detector-output-cache-dir', required=True,
|
|
792
|
+
help='(required) path to directory where detector outputs are cached')
|
|
793
|
+
parser.add_argument(
|
|
794
|
+
'-v', '--detector-version', required=True,
|
|
795
|
+
help='(required) detector version string, e.g., "4.1"')
|
|
796
|
+
parser.add_argument(
|
|
797
|
+
'-d', '--run-detector', action='store_true',
|
|
798
|
+
help='Run the Batch Detection API. If not given, skips running the '
|
|
799
|
+
'detector (and only use ground truth and cached bounding boxes).')
|
|
800
|
+
parser.add_argument(
|
|
801
|
+
'-r', '--resume-file',
|
|
802
|
+
help='path to save JSON file with list of info dicts on running tasks, '
|
|
803
|
+
'or to resume from running tasks. Only used if --run-detector is '
|
|
804
|
+
'set. Each dict has keys '
|
|
805
|
+
'["dataset", "task_id", "task_name", "local_images_list_path", '
|
|
806
|
+
'"remote_images_list_url"]')
|
|
807
|
+
parser.add_argument(
|
|
808
|
+
'-p', '--cropped-images-dir',
|
|
809
|
+
help='path to local directory for saving crops of bounding boxes. No '
|
|
810
|
+
'images are downloaded or cropped if this argument is not given.')
|
|
811
|
+
parser.add_argument(
|
|
812
|
+
'--save-full-images', action='store_true',
|
|
813
|
+
help='if downloading an image, save the full image to --images-dir, '
|
|
814
|
+
'only used if <cropped_images_dir> is not None')
|
|
815
|
+
parser.add_argument(
|
|
816
|
+
'--square-crops', action='store_true',
|
|
817
|
+
help='crop bounding boxes as squares, '
|
|
818
|
+
'only used if <cropped_images_dir> is not None')
|
|
819
|
+
parser.add_argument(
|
|
820
|
+
'--check-crops-valid', action='store_true',
|
|
821
|
+
help='load each crop to ensure file is valid (i.e., not truncated), '
|
|
822
|
+
'only used if <cropped_images_dir> is not None')
|
|
823
|
+
parser.add_argument(
|
|
824
|
+
'-t', '--threshold', type=float, default=0.0,
|
|
825
|
+
help='confidence threshold above which to crop bounding boxes, '
|
|
826
|
+
'only used if <cropped_images_dir> is not None')
|
|
827
|
+
parser.add_argument(
|
|
828
|
+
'-i', '--images-dir',
|
|
829
|
+
help='path to local directory where images are saved, '
|
|
830
|
+
'only used if <cropped_images_dir> is not None')
|
|
831
|
+
parser.add_argument(
|
|
832
|
+
'-n', '--threads', type=int, default=1,
|
|
833
|
+
help='number of threads to use for downloading images, '
|
|
834
|
+
'only used if <cropped_images_dir> is not None')
|
|
835
|
+
return parser.parse_args()
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
if __name__ == '__main__':
|
|
839
|
+
|
|
840
|
+
args = _parse_args()
|
|
841
|
+
main(queried_images_json_path=args.queried_images_json,
|
|
842
|
+
output_dir=args.output_dir,
|
|
843
|
+
detector_version=args.detector_version,
|
|
844
|
+
detector_output_cache_base_dir=args.detector_output_cache_dir,
|
|
845
|
+
run_detector=args.run_detector,
|
|
846
|
+
resume_file_path=args.resume_file,
|
|
847
|
+
cropped_images_dir=args.cropped_images_dir,
|
|
848
|
+
save_full_images=args.save_full_images,
|
|
849
|
+
square_crops=args.square_crops,
|
|
850
|
+
check_crops_valid=args.check_crops_valid,
|
|
851
|
+
confidence_threshold=args.threshold,
|
|
852
|
+
images_dir=args.images_dir,
|
|
853
|
+
threads=args.threads)
|