megadetector 10.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +702 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +528 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +187 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +663 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +876 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2159 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1494 -0
- megadetector/detection/run_tiled_inference.py +1038 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1752 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2077 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +224 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2832 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1759 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1940 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +479 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.13.dist-info/METADATA +134 -0
- megadetector-10.0.13.dist-info/RECORD +147 -0
- megadetector-10.0.13.dist-info/WHEEL +5 -0
- megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1494 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
run_md_and_speciesnet.py
|
|
4
|
+
|
|
5
|
+
Script to run MegaDetector and SpeciesNet on a folder of images and/or videos.
|
|
6
|
+
Runs MD first, then runs SpeciesNet on every above-threshold crop.
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
#%% Constants, imports, environment
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import json
|
|
14
|
+
import multiprocessing
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
import time
|
|
18
|
+
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from multiprocessing import JoinableQueue, Process, Queue
|
|
21
|
+
|
|
22
|
+
import humanfriendly
|
|
23
|
+
|
|
24
|
+
from megadetector.detection import run_detector_batch
|
|
25
|
+
from megadetector.detection.video_utils import find_videos, run_callback_on_frames, is_video_file
|
|
26
|
+
from megadetector.detection.run_detector_batch import load_and_run_detector_batch
|
|
27
|
+
from megadetector.detection.run_detector import DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
|
|
28
|
+
from megadetector.detection.run_detector import CONF_DIGITS
|
|
29
|
+
from megadetector.detection.run_detector_batch import write_results_to_file
|
|
30
|
+
from megadetector.utils.ct_utils import round_float
|
|
31
|
+
from megadetector.utils.ct_utils import write_json
|
|
32
|
+
from megadetector.utils.ct_utils import make_temp_folder
|
|
33
|
+
from megadetector.utils.ct_utils import is_list_sorted
|
|
34
|
+
from megadetector.utils.ct_utils import is_sphinx_build
|
|
35
|
+
from megadetector.utils.ct_utils import args_to_object
|
|
36
|
+
from megadetector.utils import path_utils
|
|
37
|
+
from megadetector.visualization import visualization_utils as vis_utils
|
|
38
|
+
from megadetector.postprocessing.validate_batch_results import \
|
|
39
|
+
validate_batch_results, ValidateBatchResultsOptions
|
|
40
|
+
from megadetector.detection.process_video import \
|
|
41
|
+
process_videos, ProcessVideoOptions
|
|
42
|
+
from megadetector.postprocessing.combine_batch_outputs import combine_batch_output_files
|
|
43
|
+
|
|
44
|
+
# We aren't taking an explicit dependency on the speciesnet package yet,
|
|
45
|
+
# so we wrap this in a try/except so sphinx can still document this module.
|
|
46
|
+
try:
|
|
47
|
+
from speciesnet import SpeciesNetClassifier
|
|
48
|
+
from speciesnet.utils import BBox
|
|
49
|
+
from speciesnet.ensemble import SpeciesNetEnsemble
|
|
50
|
+
from speciesnet.geofence_utils import roll_up_labels_to_first_matching_level
|
|
51
|
+
from speciesnet.geofence_utils import geofence_animal_classification
|
|
52
|
+
except Exception:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
#%% Constants
|
|
57
|
+
|
|
58
|
+
DEFAULT_DETECTOR_MODEL = 'MDV5A'
|
|
59
|
+
DEFAULT_CLASSIFIER_MODEL = 'kaggle:google/speciesnet/pyTorch/v4.0.1a'
|
|
60
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION = 0.1
|
|
61
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
|
|
62
|
+
DEFAULT_DETECTOR_BATCH_SIZE = 1
|
|
63
|
+
DEFAULT_CLASSIFIER_BATCH_SIZE = 8
|
|
64
|
+
DEFAULT_LOADER_WORKERS = 4
|
|
65
|
+
|
|
66
|
+
# This determines the maximum number of image filenames that can be assigned to
|
|
67
|
+
# each of the producer workers before blocking. The actual size of the queue
|
|
68
|
+
# will be MAX_IMAGE_QUEUE_SIZE_PER_WORKER * n_workers. This is only used for
|
|
69
|
+
# the classification step.
|
|
70
|
+
MAX_IMAGE_QUEUE_SIZE_PER_WORKER = 30
|
|
71
|
+
|
|
72
|
+
# This determines the maximum number of crops that can accumulate in the queue
|
|
73
|
+
# used to communicate between the producers (which read and crop images) and the
|
|
74
|
+
# consumer (which runs the classifier). This is only used for the classification step.
|
|
75
|
+
MAX_BATCH_QUEUE_SIZE = 300
|
|
76
|
+
|
|
77
|
+
# Default interval between frames we should process when processing video.
|
|
78
|
+
# This is only used for the detection step.
|
|
79
|
+
DEFAULT_SECONDS_PER_VIDEO_FRAME = 1.0
|
|
80
|
+
|
|
81
|
+
# Max number of classification scores to include per detection
|
|
82
|
+
DEFAULT_TOP_N_SCORES = 2
|
|
83
|
+
|
|
84
|
+
# Unless --norollup is specified, roll up taxonomic levels until the
|
|
85
|
+
# cumulative confidence is above this value
|
|
86
|
+
ROLLUP_TARGET_CONFIDENCE = 0.5
|
|
87
|
+
|
|
88
|
+
# When the called supplies an existing MD results file, should we validate it before
|
|
89
|
+
# starting classification? This tends
|
|
90
|
+
VALIDATE_DETECTION_FILE = False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
verbose = False
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
#%% Support classes
|
|
97
|
+
|
|
98
|
+
class CropMetadata:
|
|
99
|
+
"""
|
|
100
|
+
Metadata for a crop extracted from an image detection.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self,
|
|
104
|
+
image_file: str,
|
|
105
|
+
detection_index: int,
|
|
106
|
+
bbox: list[float],
|
|
107
|
+
original_width: int,
|
|
108
|
+
original_height: int):
|
|
109
|
+
"""
|
|
110
|
+
Args:
|
|
111
|
+
image_file (str): path to the original image file
|
|
112
|
+
detection_index (int): index of this detection in the image
|
|
113
|
+
bbox (List[float]): normalized bounding box [x_min, y_min, width, height]
|
|
114
|
+
original_width (int): width of the original image
|
|
115
|
+
original_height (int): height of the original image
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
self.image_file = image_file
|
|
119
|
+
self.detection_index = detection_index
|
|
120
|
+
self.bbox = bbox
|
|
121
|
+
self.original_width = original_width
|
|
122
|
+
self.original_height = original_height
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class CropBatch:
|
|
126
|
+
"""
|
|
127
|
+
A batch of crops with their metadata for classification.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self):
|
|
131
|
+
#: List of preprocessed images
|
|
132
|
+
self.crops = []
|
|
133
|
+
|
|
134
|
+
#: List of CropMetadata objects
|
|
135
|
+
self.metadata = []
|
|
136
|
+
|
|
137
|
+
def add_crop(self, crop_data, metadata):
|
|
138
|
+
"""
|
|
139
|
+
Args:
|
|
140
|
+
crop_data (PreprocessedImage): preprocessed image data from
|
|
141
|
+
SpeciesNetClassifier.preprocess()
|
|
142
|
+
metadata (CropMetadata): metadata for this crop
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
self.crops.append(crop_data)
|
|
146
|
+
self.metadata.append(metadata)
|
|
147
|
+
|
|
148
|
+
def __len__(self):
|
|
149
|
+
return len(self.crops)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
#%% Support functions for classification
|
|
153
|
+
|
|
154
|
+
def _process_image_detections(file_path: str,
|
|
155
|
+
absolute_file_path: str,
|
|
156
|
+
detection_results: dict,
|
|
157
|
+
classifier: 'SpeciesNetClassifier',
|
|
158
|
+
detection_confidence_threshold: float,
|
|
159
|
+
batch_queue: Queue):
|
|
160
|
+
"""
|
|
161
|
+
Process detections from a single image.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
file_path (str): relative path to the image file
|
|
165
|
+
absolute_file_path (str): absolute path to the image file
|
|
166
|
+
detection_results (dict): detection results for this image
|
|
167
|
+
classifier (SpeciesNetClassifier): classifier instance for preprocessing
|
|
168
|
+
detection_confidence_threshold (float): classify detections above this threshold
|
|
169
|
+
batch_queue (Queue): queue to send crops to
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
detections = detection_results['detections']
|
|
173
|
+
|
|
174
|
+
# Don't bother loading images that have no above-threshold detections
|
|
175
|
+
detections_above_threshold = \
|
|
176
|
+
[d for d in detections if d['conf'] >= detection_confidence_threshold]
|
|
177
|
+
if len(detections_above_threshold) == 0:
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
# Load the image
|
|
181
|
+
try:
|
|
182
|
+
image = vis_utils.load_image(absolute_file_path)
|
|
183
|
+
original_width, original_height = image.size
|
|
184
|
+
except Exception as e:
|
|
185
|
+
print('Warning: failed to load image {}: {}'.format(file_path, str(e)))
|
|
186
|
+
|
|
187
|
+
# Send failure information to consumer
|
|
188
|
+
failure_metadata = CropMetadata(
|
|
189
|
+
image_file=file_path,
|
|
190
|
+
detection_index=-1, # -1 indicates whole-image failure
|
|
191
|
+
bbox=[],
|
|
192
|
+
original_width=0,
|
|
193
|
+
original_height=0
|
|
194
|
+
)
|
|
195
|
+
batch_queue.put(('failure',
|
|
196
|
+
'Failed to load image: {}'.format(str(e)),
|
|
197
|
+
failure_metadata))
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
# Process each detection above threshold
|
|
201
|
+
#
|
|
202
|
+
# detection_index needs to index into the original list of detections
|
|
203
|
+
# (this is how classification results will be associated with detections
|
|
204
|
+
# later), so iterate over "detections" here, rather than
|
|
205
|
+
# "detections_above_threshold".
|
|
206
|
+
for detection_index, detection in enumerate(detections):
|
|
207
|
+
|
|
208
|
+
conf = detection['conf']
|
|
209
|
+
if conf < detection_confidence_threshold:
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
bbox = detection['bbox']
|
|
213
|
+
assert len(bbox) == 4
|
|
214
|
+
|
|
215
|
+
# Convert normalized bbox to BBox object for SpeciesNet
|
|
216
|
+
speciesnet_bbox = BBox(
|
|
217
|
+
xmin=bbox[0],
|
|
218
|
+
ymin=bbox[1],
|
|
219
|
+
width=bbox[2],
|
|
220
|
+
height=bbox[3]
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Preprocess the crop
|
|
224
|
+
try:
|
|
225
|
+
|
|
226
|
+
preprocessed_crop = classifier.preprocess(
|
|
227
|
+
image,
|
|
228
|
+
bboxes=[speciesnet_bbox],
|
|
229
|
+
resize=True
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if preprocessed_crop is not None:
|
|
233
|
+
|
|
234
|
+
metadata = CropMetadata(
|
|
235
|
+
image_file=file_path,
|
|
236
|
+
detection_index=detection_index,
|
|
237
|
+
bbox=bbox,
|
|
238
|
+
original_width=original_width,
|
|
239
|
+
original_height=original_height
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Send individual crop to the consumer
|
|
243
|
+
batch_queue.put(('crop', preprocessed_crop, metadata))
|
|
244
|
+
|
|
245
|
+
except Exception as e:
|
|
246
|
+
|
|
247
|
+
print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
|
|
248
|
+
file_path, detection_index, str(e)))
|
|
249
|
+
|
|
250
|
+
# Send failure information to consumer
|
|
251
|
+
failure_metadata = CropMetadata(
|
|
252
|
+
image_file=file_path,
|
|
253
|
+
detection_index=detection_index,
|
|
254
|
+
bbox=bbox,
|
|
255
|
+
original_width=original_width,
|
|
256
|
+
original_height=original_height
|
|
257
|
+
)
|
|
258
|
+
batch_queue.put(('failure',
|
|
259
|
+
'Failed to preprocess crop: {}'.format(str(e)),
|
|
260
|
+
failure_metadata))
|
|
261
|
+
|
|
262
|
+
# ...try/except
|
|
263
|
+
|
|
264
|
+
# ...for each detection in this image
|
|
265
|
+
|
|
266
|
+
# ...def _process_image_detections(...)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _process_video_detections(file_path: str,
|
|
270
|
+
absolute_file_path: str,
|
|
271
|
+
detection_results: dict,
|
|
272
|
+
classifier: 'SpeciesNetClassifier',
|
|
273
|
+
detection_confidence_threshold: float,
|
|
274
|
+
batch_queue: Queue):
|
|
275
|
+
"""
|
|
276
|
+
Process detections from a single video.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
file_path (str): relative path to the video file
|
|
280
|
+
absolute_file_path (str): absolute path to the video file
|
|
281
|
+
detection_results (dict): detection results for this video
|
|
282
|
+
classifier (SpeciesNetClassifier): classifier instance for preprocessing
|
|
283
|
+
detection_confidence_threshold (float): classify detections above this threshold
|
|
284
|
+
batch_queue (Queue): queue to send crops to
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
detections = detection_results['detections']
|
|
288
|
+
|
|
289
|
+
# Find frames with above-threshold detections
|
|
290
|
+
frames_with_detections = set()
|
|
291
|
+
frame_to_detections = {}
|
|
292
|
+
|
|
293
|
+
for detection_index, detection in enumerate(detections):
|
|
294
|
+
|
|
295
|
+
conf = detection['conf']
|
|
296
|
+
if conf < detection_confidence_threshold:
|
|
297
|
+
continue
|
|
298
|
+
|
|
299
|
+
frame_number = detection['frame_number']
|
|
300
|
+
frames_with_detections.add(frame_number)
|
|
301
|
+
|
|
302
|
+
if frame_number not in frame_to_detections:
|
|
303
|
+
frame_to_detections[frame_number] = []
|
|
304
|
+
frame_to_detections[frame_number].append((detection_index, detection))
|
|
305
|
+
|
|
306
|
+
# ...for each detection in this video
|
|
307
|
+
|
|
308
|
+
if len(frames_with_detections) == 0:
|
|
309
|
+
return
|
|
310
|
+
|
|
311
|
+
frames_to_process = sorted(list(frames_with_detections))
|
|
312
|
+
|
|
313
|
+
# Define callback for processing each frame
|
|
314
|
+
def frame_callback(frame_array, frame_id):
|
|
315
|
+
"""
|
|
316
|
+
Callback to process a single frame.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
frame_array (numpy.ndarray): frame data in PIL format
|
|
320
|
+
frame_id (str): frame identifier like "frame0006.jpg"
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
# Extract frame number from frame_id (e.g., "frame0006.jpg" -> 6)
|
|
324
|
+
import re
|
|
325
|
+
match = re.match(r'frame(\d+)\.jpg', frame_id)
|
|
326
|
+
if not match:
|
|
327
|
+
print('Warning: could not parse frame number from {}'.format(frame_id))
|
|
328
|
+
return
|
|
329
|
+
frame_number = int(match.group(1))
|
|
330
|
+
|
|
331
|
+
# Only process frames for which we have detection results
|
|
332
|
+
if frame_number not in frame_to_detections:
|
|
333
|
+
return
|
|
334
|
+
|
|
335
|
+
# Convert numpy array to PIL Image
|
|
336
|
+
from PIL import Image
|
|
337
|
+
if frame_array.dtype != 'uint8':
|
|
338
|
+
frame_array = (frame_array * 255).astype('uint8')
|
|
339
|
+
frame_image = Image.fromarray(frame_array)
|
|
340
|
+
original_width, original_height = frame_image.size
|
|
341
|
+
|
|
342
|
+
# Process each detection in this frame
|
|
343
|
+
for detection_index, detection in frame_to_detections[frame_number]:
|
|
344
|
+
|
|
345
|
+
bbox = detection['bbox']
|
|
346
|
+
assert len(bbox) == 4
|
|
347
|
+
|
|
348
|
+
# Convert normalized bbox to BBox object for SpeciesNet
|
|
349
|
+
speciesnet_bbox = BBox(
|
|
350
|
+
xmin=bbox[0],
|
|
351
|
+
ymin=bbox[1],
|
|
352
|
+
width=bbox[2],
|
|
353
|
+
height=bbox[3]
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Preprocess the crop
|
|
357
|
+
try:
|
|
358
|
+
|
|
359
|
+
preprocessed_crop = classifier.preprocess(
|
|
360
|
+
frame_image,
|
|
361
|
+
bboxes=[speciesnet_bbox],
|
|
362
|
+
resize=True
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
if preprocessed_crop is not None:
|
|
366
|
+
metadata = CropMetadata(
|
|
367
|
+
image_file=file_path,
|
|
368
|
+
detection_index=detection_index,
|
|
369
|
+
bbox=bbox,
|
|
370
|
+
original_width=original_width,
|
|
371
|
+
original_height=original_height
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Send individual crop immediately to consumer
|
|
375
|
+
batch_queue.put(('crop', preprocessed_crop, metadata))
|
|
376
|
+
|
|
377
|
+
except Exception as e:
|
|
378
|
+
|
|
379
|
+
print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
|
|
380
|
+
file_path, detection_index, str(e)))
|
|
381
|
+
|
|
382
|
+
# Send failure information to consumer
|
|
383
|
+
failure_metadata = CropMetadata(
|
|
384
|
+
image_file=file_path,
|
|
385
|
+
detection_index=detection_index,
|
|
386
|
+
bbox=bbox,
|
|
387
|
+
original_width=original_width,
|
|
388
|
+
original_height=original_height
|
|
389
|
+
)
|
|
390
|
+
batch_queue.put(('failure',
|
|
391
|
+
'Failed to preprocess crop: {}'.format(str(e)),
|
|
392
|
+
failure_metadata))
|
|
393
|
+
|
|
394
|
+
# ...try/except
|
|
395
|
+
|
|
396
|
+
# ...for each detection
|
|
397
|
+
|
|
398
|
+
# ...def frame_callback(...)
|
|
399
|
+
|
|
400
|
+
# Process the video frames
|
|
401
|
+
try:
|
|
402
|
+
|
|
403
|
+
run_callback_on_frames(
|
|
404
|
+
input_video_file=absolute_file_path,
|
|
405
|
+
frame_callback=frame_callback,
|
|
406
|
+
frames_to_process=frames_to_process,
|
|
407
|
+
verbose=verbose
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
except Exception as e:
|
|
411
|
+
|
|
412
|
+
print('Warning: failed to process video {}: {}'.format(file_path, str(e)))
|
|
413
|
+
|
|
414
|
+
# Send failure information to consumer for the whole video
|
|
415
|
+
failure_metadata = CropMetadata(
|
|
416
|
+
image_file=file_path,
|
|
417
|
+
detection_index=-1, # -1 indicates whole-file failure
|
|
418
|
+
bbox=[],
|
|
419
|
+
original_width=0,
|
|
420
|
+
original_height=0
|
|
421
|
+
)
|
|
422
|
+
batch_queue.put(('failure',
|
|
423
|
+
'Failed to process video: {}'.format(str(e)),
|
|
424
|
+
failure_metadata))
|
|
425
|
+
# ...try/except
|
|
426
|
+
|
|
427
|
+
# ...def _process_video_detections(...)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _crop_producer_func(image_queue: JoinableQueue,
|
|
431
|
+
batch_queue: Queue,
|
|
432
|
+
classifier_model: str,
|
|
433
|
+
detection_confidence_threshold: float,
|
|
434
|
+
source_folder: str,
|
|
435
|
+
producer_id: int = -1):
|
|
436
|
+
"""
|
|
437
|
+
Producer function for classification workers.
|
|
438
|
+
|
|
439
|
+
Reads images and videos from [image_queue], crops detections above a threshold,
|
|
440
|
+
preprocesses them, and sends individual crops to [batch_queue].
|
|
441
|
+
See the documentation of _crop_consumer_func to for the format of the
|
|
442
|
+
tuples placed on batch_queue.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
image_queue (JoinableQueue): queue containing detection_results dicts (for both images and videos)
|
|
446
|
+
batch_queue (Queue): queue to put individual crops into
|
|
447
|
+
classifier_model (str): classifier model identifier to load in this process
|
|
448
|
+
detection_confidence_threshold (float): classify detections above this threshold
|
|
449
|
+
source_folder (str): source folder to resolve relative paths
|
|
450
|
+
producer_id (int, optional): identifier for this producer worker
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
if verbose:
|
|
454
|
+
print('Classification producer starting: ID {}'.format(producer_id))
|
|
455
|
+
|
|
456
|
+
# Load classifier; this is just being used as a preprocessor, so we force device=cpu.
|
|
457
|
+
#
|
|
458
|
+
# There are a number of reasons loading the model might fail; note to self: *don't*
|
|
459
|
+
# catch Exceptions here. This should be a catastrophic failure that stops the whole
|
|
460
|
+
# process.
|
|
461
|
+
classifier = SpeciesNetClassifier(classifier_model, device='cpu')
|
|
462
|
+
if verbose:
|
|
463
|
+
print('Classification producer {}: loaded classifier'.format(producer_id))
|
|
464
|
+
|
|
465
|
+
while True:
|
|
466
|
+
|
|
467
|
+
# Pull an image of detection results from the queue
|
|
468
|
+
detection_results = image_queue.get()
|
|
469
|
+
|
|
470
|
+
# Pulling None from the queue indicates that this producer is done
|
|
471
|
+
if detection_results is None:
|
|
472
|
+
image_queue.task_done()
|
|
473
|
+
break
|
|
474
|
+
|
|
475
|
+
file_path = detection_results['file']
|
|
476
|
+
|
|
477
|
+
# Skip files that failed at the detection stage
|
|
478
|
+
if 'failure' in detection_results:
|
|
479
|
+
image_queue.task_done()
|
|
480
|
+
continue
|
|
481
|
+
|
|
482
|
+
# Skip files with no detections
|
|
483
|
+
detections = detection_results['detections']
|
|
484
|
+
if len(detections) == 0:
|
|
485
|
+
image_queue.task_done()
|
|
486
|
+
continue
|
|
487
|
+
|
|
488
|
+
# Determine if this is an image or video
|
|
489
|
+
absolute_file_path = os.path.join(source_folder, file_path)
|
|
490
|
+
is_video = is_video_file(file_path)
|
|
491
|
+
|
|
492
|
+
if is_video:
|
|
493
|
+
|
|
494
|
+
# Process video
|
|
495
|
+
_process_video_detections(
|
|
496
|
+
file_path=file_path,
|
|
497
|
+
absolute_file_path=absolute_file_path,
|
|
498
|
+
detection_results=detection_results,
|
|
499
|
+
classifier=classifier,
|
|
500
|
+
detection_confidence_threshold=detection_confidence_threshold,
|
|
501
|
+
batch_queue=batch_queue
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
else:
|
|
505
|
+
|
|
506
|
+
# Process image
|
|
507
|
+
_process_image_detections(
|
|
508
|
+
file_path=file_path,
|
|
509
|
+
absolute_file_path=absolute_file_path,
|
|
510
|
+
detection_results=detection_results,
|
|
511
|
+
classifier=classifier,
|
|
512
|
+
detection_confidence_threshold=detection_confidence_threshold,
|
|
513
|
+
batch_queue=batch_queue
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
image_queue.task_done()
|
|
517
|
+
|
|
518
|
+
# ...while(we still have items to process)
|
|
519
|
+
|
|
520
|
+
# Send sentinel to indicate this producer is done
|
|
521
|
+
batch_queue.put(None)
|
|
522
|
+
|
|
523
|
+
if verbose:
|
|
524
|
+
print('Classification producer {} finished'.format(producer_id))
|
|
525
|
+
|
|
526
|
+
# ...def _crop_producer_func(...)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def _crop_consumer_func(batch_queue: Queue,
|
|
530
|
+
results_queue: Queue,
|
|
531
|
+
classifier_model: str,
|
|
532
|
+
batch_size: int,
|
|
533
|
+
num_producers: int,
|
|
534
|
+
enable_rollup: bool,
|
|
535
|
+
country: str = None,
|
|
536
|
+
admin1_region: str = None):
|
|
537
|
+
"""
|
|
538
|
+
Consumer function for classification inference.
|
|
539
|
+
|
|
540
|
+
Pulls individual crops from batch_queue, assembles them into batches,
|
|
541
|
+
runs inference, and puts results into results_queue.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
batch_queue (Queue): queue containing individual crop tuples or failures.
|
|
545
|
+
Items on this queue are either None (to indicate that a producer finished)
|
|
546
|
+
or tuples formatted as (type,image,metadata). [type] is a string (either
|
|
547
|
+
"crop" or "failure"), [image] is a PreprocessedImage, and [metadata] is
|
|
548
|
+
a CropMetadata object.
|
|
549
|
+
results_queue (Queue): queue to put classification results into
|
|
550
|
+
classifier_model (str): classifier model identifier to load
|
|
551
|
+
batch_size (int): batch size for inference
|
|
552
|
+
num_producers (int): number of producer workers
|
|
553
|
+
enable_rollup (bool): whether to apply taxonomic rollup
|
|
554
|
+
country (str, optional): country code for geofencing
|
|
555
|
+
admin1_region (str, optional): admin1 region for geofencing
|
|
556
|
+
"""
|
|
557
|
+
|
|
558
|
+
if verbose:
|
|
559
|
+
print('Classification consumer starting')
|
|
560
|
+
|
|
561
|
+
# Load classifier
|
|
562
|
+
try:
|
|
563
|
+
classifier = SpeciesNetClassifier(classifier_model)
|
|
564
|
+
if verbose:
|
|
565
|
+
print('Classification consumer: loaded classifier')
|
|
566
|
+
except Exception as e:
|
|
567
|
+
print('Classification consumer: failed to load classifier: {}'.format(str(e)))
|
|
568
|
+
results_queue.put({})
|
|
569
|
+
return
|
|
570
|
+
|
|
571
|
+
all_results = {} # image_file -> {detection_index -> classification_result}
|
|
572
|
+
current_batch = CropBatch()
|
|
573
|
+
producers_finished = 0
|
|
574
|
+
|
|
575
|
+
# Load ensemble metadata if rollup/geofencing is enabled
|
|
576
|
+
taxonomy_map = {}
|
|
577
|
+
geofence_map = {}
|
|
578
|
+
|
|
579
|
+
if (enable_rollup) or (country is not None):
|
|
580
|
+
|
|
581
|
+
# Note to self: there are a number of reasons loading the ensemble
|
|
582
|
+
# could fail here; don't catch this exception, this should be a
|
|
583
|
+
# catatstrophic failure.
|
|
584
|
+
ensemble = SpeciesNetEnsemble(
|
|
585
|
+
classifier_model, geofence=(country is not None))
|
|
586
|
+
taxonomy_map = ensemble.taxonomy_map
|
|
587
|
+
geofence_map = ensemble.geofence_map
|
|
588
|
+
|
|
589
|
+
# ...if we need to load ensemble components
|
|
590
|
+
|
|
591
|
+
while True:
|
|
592
|
+
|
|
593
|
+
# Pull an item from the queue
|
|
594
|
+
item = batch_queue.get()
|
|
595
|
+
|
|
596
|
+
# This indicates that a producer worker finished
|
|
597
|
+
if item is None:
|
|
598
|
+
|
|
599
|
+
producers_finished += 1
|
|
600
|
+
if producers_finished == num_producers:
|
|
601
|
+
# Process any remaining images
|
|
602
|
+
if len(current_batch) > 0:
|
|
603
|
+
_process_classification_batch(
|
|
604
|
+
current_batch, classifier, all_results,
|
|
605
|
+
enable_rollup, taxonomy_map, geofence_map,
|
|
606
|
+
country, admin1_region
|
|
607
|
+
)
|
|
608
|
+
break
|
|
609
|
+
continue
|
|
610
|
+
|
|
611
|
+
# ...if a producer finished
|
|
612
|
+
|
|
613
|
+
# If we got here, we know we have a crop to process, or
|
|
614
|
+
# a failure to ignore.
|
|
615
|
+
assert isinstance(item, tuple) and len(item) == 3
|
|
616
|
+
item_type, data, metadata = item
|
|
617
|
+
|
|
618
|
+
if metadata.image_file not in all_results:
|
|
619
|
+
all_results[metadata.image_file] = {}
|
|
620
|
+
|
|
621
|
+
# We should never be processing the same detection twice
|
|
622
|
+
assert metadata.detection_index not in all_results[metadata.image_file]
|
|
623
|
+
|
|
624
|
+
if item_type == 'failure':
|
|
625
|
+
|
|
626
|
+
all_results[metadata.image_file][metadata.detection_index] = {
|
|
627
|
+
'failure': 'Failure classification: {}'.format(data)
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
else:
|
|
631
|
+
|
|
632
|
+
assert item_type == 'crop'
|
|
633
|
+
current_batch.add_crop(data, metadata)
|
|
634
|
+
assert len(current_batch) <= batch_size
|
|
635
|
+
|
|
636
|
+
# Process batch if necessary
|
|
637
|
+
if len(current_batch) == batch_size:
|
|
638
|
+
_process_classification_batch(
|
|
639
|
+
current_batch, classifier, all_results,
|
|
640
|
+
enable_rollup, taxonomy_map, geofence_map,
|
|
641
|
+
country, admin1_region
|
|
642
|
+
)
|
|
643
|
+
current_batch = CropBatch()
|
|
644
|
+
|
|
645
|
+
# ...was this item a failure or a crop?
|
|
646
|
+
|
|
647
|
+
# ...while (we have items to process)
|
|
648
|
+
|
|
649
|
+
# Send all the results at once back to the main process
|
|
650
|
+
results_queue.put(all_results)
|
|
651
|
+
|
|
652
|
+
if verbose:
|
|
653
|
+
print('Classification consumer finished')
|
|
654
|
+
|
|
655
|
+
# ...def _crop_consumer_func(...)
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def _process_classification_batch(batch: CropBatch,
|
|
659
|
+
classifier: 'SpeciesNetClassifier',
|
|
660
|
+
all_results: dict,
|
|
661
|
+
enable_rollup: bool,
|
|
662
|
+
taxonomy_map: dict,
|
|
663
|
+
geofence_map: dict,
|
|
664
|
+
country: str = None,
|
|
665
|
+
admin1_region: str = None):
|
|
666
|
+
"""
|
|
667
|
+
Run a batch of crops through the classifier.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
batch (CropBatch): batch of crops to process
|
|
671
|
+
classifier (SpeciesNetClassifier): classifier instance
|
|
672
|
+
all_results (dict): dictionary to store results in, modified in-place with format:
|
|
673
|
+
{image_file: {detection_index: {'predictions': [[class_name, score], ...]}
|
|
674
|
+
or {image_file: {detection_index: {'failure': error_message}}}.
|
|
675
|
+
enable_rollup (bool): whether to apply rollup
|
|
676
|
+
taxonomy_map (dict): taxonomy mapping for rollup
|
|
677
|
+
geofence_map (dict): geofence mapping
|
|
678
|
+
country (str, optional): country code for geofencing
|
|
679
|
+
admin1_region (str, optional): admin1 region for geofencing
|
|
680
|
+
"""
|
|
681
|
+
|
|
682
|
+
if len(batch) == 0:
|
|
683
|
+
print('Warning: _process_classification_batch received empty batch')
|
|
684
|
+
return
|
|
685
|
+
|
|
686
|
+
# Prepare batch for inference
|
|
687
|
+
filepaths = [f"{metadata.image_file}_{metadata.detection_index}"
|
|
688
|
+
for metadata in batch.metadata]
|
|
689
|
+
|
|
690
|
+
# Run batch inference
|
|
691
|
+
try:
|
|
692
|
+
batch_results = classifier.batch_predict(filepaths, batch.crops)
|
|
693
|
+
except Exception as e:
|
|
694
|
+
print('*** Batch classification failed: {} ***'.format(str(e)))
|
|
695
|
+
# Mark all crops in this batch as failed
|
|
696
|
+
for metadata in batch.metadata:
|
|
697
|
+
if metadata.image_file not in all_results:
|
|
698
|
+
all_results[metadata.image_file] = {}
|
|
699
|
+
all_results[metadata.image_file][metadata.detection_index] = {
|
|
700
|
+
'failure': 'Failure classification: {}'.format(str(e))
|
|
701
|
+
}
|
|
702
|
+
return
|
|
703
|
+
|
|
704
|
+
# Process results
|
|
705
|
+
assert len(batch_results) == len(batch.metadata)
|
|
706
|
+
assert len(batch_results) == len(filepaths)
|
|
707
|
+
|
|
708
|
+
for i_result in range(0, len(batch_results)):
|
|
709
|
+
|
|
710
|
+
result = batch_results[i_result]
|
|
711
|
+
metadata = batch.metadata[i_result]
|
|
712
|
+
|
|
713
|
+
assert metadata.image_file in all_results, \
|
|
714
|
+
'File {} not in results dict'.format(metadata.image_file)
|
|
715
|
+
|
|
716
|
+
detection_index = metadata.detection_index
|
|
717
|
+
|
|
718
|
+
# Handle classification failure
|
|
719
|
+
if 'failures' in result:
|
|
720
|
+
print('*** Classification failure for image: {} ***'.format(
|
|
721
|
+
filepaths[i_result]))
|
|
722
|
+
all_results[metadata.image_file][detection_index] = {
|
|
723
|
+
'failure': 'Failure classification: SpeciesNet classifier failed'
|
|
724
|
+
}
|
|
725
|
+
continue
|
|
726
|
+
|
|
727
|
+
# Extract classification results; this is a dict with keys "classes"
|
|
728
|
+
# and "scores", each of which points to a list.
|
|
729
|
+
classifications = result['classifications']
|
|
730
|
+
classes = classifications['classes']
|
|
731
|
+
scores = classifications['scores']
|
|
732
|
+
|
|
733
|
+
classification_was_geofenced = False
|
|
734
|
+
|
|
735
|
+
predicted_class = classes[0]
|
|
736
|
+
predicted_score = scores[0]
|
|
737
|
+
|
|
738
|
+
# Possibly apply geofencing
|
|
739
|
+
if country:
|
|
740
|
+
|
|
741
|
+
geofence_result = geofence_animal_classification(
|
|
742
|
+
labels=classes,
|
|
743
|
+
scores=scores,
|
|
744
|
+
country=country,
|
|
745
|
+
admin1_region=admin1_region,
|
|
746
|
+
taxonomy_map=taxonomy_map,
|
|
747
|
+
geofence_map=geofence_map,
|
|
748
|
+
enable_geofence=True
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
geofenced_class, geofenced_score, prediction_source = geofence_result
|
|
752
|
+
|
|
753
|
+
if prediction_source != 'classifier':
|
|
754
|
+
classification_was_geofenced = True
|
|
755
|
+
predicted_class = geofenced_class
|
|
756
|
+
predicted_score = geofenced_score
|
|
757
|
+
|
|
758
|
+
# ...if we might need to apply geofencing
|
|
759
|
+
|
|
760
|
+
# Possibly apply rollup; this was already done if geofencing was applied
|
|
761
|
+
if enable_rollup and (not classification_was_geofenced):
|
|
762
|
+
|
|
763
|
+
rollup_result = roll_up_labels_to_first_matching_level(
|
|
764
|
+
labels=classes,
|
|
765
|
+
scores=scores,
|
|
766
|
+
country=country,
|
|
767
|
+
admin1_region=admin1_region,
|
|
768
|
+
target_taxonomy_levels=['species','genus','family', 'order','class', 'kingdom'],
|
|
769
|
+
non_blank_threshold=ROLLUP_TARGET_CONFIDENCE,
|
|
770
|
+
taxonomy_map=taxonomy_map,
|
|
771
|
+
geofence_map=geofence_map,
|
|
772
|
+
enable_geofence=(country is not None)
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
if rollup_result is not None:
|
|
776
|
+
rolled_up_class, rolled_up_score, prediction_source = rollup_result
|
|
777
|
+
if rolled_up_class != predicted_class:
|
|
778
|
+
predicted_class = rolled_up_class
|
|
779
|
+
predicted_score = rolled_up_score
|
|
780
|
+
|
|
781
|
+
# ...if we might need to apply taxonomic rollup
|
|
782
|
+
|
|
783
|
+
# For now, we'll store category names as strings; these will be assigned to integer
|
|
784
|
+
# IDs before writing results to file later.
|
|
785
|
+
classification = [predicted_class,predicted_score]
|
|
786
|
+
|
|
787
|
+
# Also report raw model classifications
|
|
788
|
+
raw_classifications = []
|
|
789
|
+
for i_class in range(0,len(classes)):
|
|
790
|
+
raw_classifications.append([classes[i_class],scores[i_class]])
|
|
791
|
+
|
|
792
|
+
all_results[metadata.image_file][detection_index] = {
|
|
793
|
+
'classifications': [classification],
|
|
794
|
+
'raw_classifications': raw_classifications
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
# ...for each result in this batch
|
|
798
|
+
|
|
799
|
+
# ...def _process_classification_batch(...)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
#%% Inference functions
|
|
803
|
+
|
|
804
|
+
def _run_detection_step(source_folder: str,
|
|
805
|
+
detector_output_file: str,
|
|
806
|
+
detector_model: str = DEFAULT_DETECTOR_MODEL,
|
|
807
|
+
detector_batch_size: int = DEFAULT_DETECTOR_BATCH_SIZE,
|
|
808
|
+
detection_confidence_threshold: float = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
|
|
809
|
+
detector_worker_threads: int = DEFAULT_LOADER_WORKERS,
|
|
810
|
+
skip_images: bool = False,
|
|
811
|
+
skip_video: bool = False,
|
|
812
|
+
frame_sample: int = None,
|
|
813
|
+
time_sample: float = None) -> str:
|
|
814
|
+
"""
|
|
815
|
+
Run MegaDetector on all images/videos in [source_folder].
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
source_folder (str): folder containing images/videos
|
|
819
|
+
detector_output_file (str): output .json file
|
|
820
|
+
detector_model (str, optional): detector model identifier
|
|
821
|
+
detector_batch_size (int, optional): batch size for detection
|
|
822
|
+
detection_confidence_threshold (float, optional): confidence threshold for detections
|
|
823
|
+
(to include in the output file)
|
|
824
|
+
detector_worker_threads (int, optional): number of workers to use for preprocessing
|
|
825
|
+
skip_images (bool, optional): ignore images, only process videos
|
|
826
|
+
skip_video (bool, optional): ignore videos, only process images
|
|
827
|
+
frame_sample (int, optional): sample every Nth frame from videos
|
|
828
|
+
time_sample (float, optional): sample frames every N seconds from videos
|
|
829
|
+
"""
|
|
830
|
+
|
|
831
|
+
print('Starting detection step...')
|
|
832
|
+
|
|
833
|
+
# Validate arguments
|
|
834
|
+
assert not (frame_sample is None and time_sample is None), \
|
|
835
|
+
'Must specify either frame_sample or time_sample'
|
|
836
|
+
|
|
837
|
+
# Find image and video files
|
|
838
|
+
if not skip_images:
|
|
839
|
+
image_files = path_utils.find_images(source_folder, recursive=True,
|
|
840
|
+
return_relative_paths=False)
|
|
841
|
+
else:
|
|
842
|
+
image_files = []
|
|
843
|
+
|
|
844
|
+
if not skip_video:
|
|
845
|
+
video_files = find_videos(source_folder, recursive=True,
|
|
846
|
+
return_relative_paths=False)
|
|
847
|
+
else:
|
|
848
|
+
video_files = []
|
|
849
|
+
|
|
850
|
+
if len(image_files) == 0 and len(video_files) == 0:
|
|
851
|
+
raise ValueError(
|
|
852
|
+
'No images or videos found in {}'.format(source_folder))
|
|
853
|
+
|
|
854
|
+
print('Found {} images and {} videos'.format(len(image_files), len(video_files)))
|
|
855
|
+
|
|
856
|
+
files_to_merge = []
|
|
857
|
+
|
|
858
|
+
# Process images if necessary
|
|
859
|
+
if len(image_files) > 0:
|
|
860
|
+
|
|
861
|
+
print('Running MegaDetector on {} images...'.format(len(image_files)))
|
|
862
|
+
|
|
863
|
+
image_results = load_and_run_detector_batch(
|
|
864
|
+
model_file=detector_model,
|
|
865
|
+
image_file_names=image_files,
|
|
866
|
+
checkpoint_path=None,
|
|
867
|
+
confidence_threshold=detection_confidence_threshold,
|
|
868
|
+
checkpoint_frequency=-1,
|
|
869
|
+
results=None,
|
|
870
|
+
n_cores=0,
|
|
871
|
+
use_image_queue=True,
|
|
872
|
+
quiet=True,
|
|
873
|
+
image_size=None,
|
|
874
|
+
batch_size=detector_batch_size,
|
|
875
|
+
include_image_size=False,
|
|
876
|
+
include_image_timestamp=False,
|
|
877
|
+
include_exif_tags=None,
|
|
878
|
+
loader_workers=detector_worker_threads,
|
|
879
|
+
preprocess_on_image_queue=True
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
# Write image results to temporary file
|
|
883
|
+
image_output_file = detector_output_file.replace('.json', '_images.json')
|
|
884
|
+
write_results_to_file(image_results,
|
|
885
|
+
image_output_file,
|
|
886
|
+
relative_path_base=source_folder,
|
|
887
|
+
detector_file=detector_model)
|
|
888
|
+
|
|
889
|
+
print('Image detection results written to {}'.format(image_output_file))
|
|
890
|
+
files_to_merge.append(image_output_file)
|
|
891
|
+
|
|
892
|
+
# ...if we had images to process
|
|
893
|
+
|
|
894
|
+
# Process videos if necessary
|
|
895
|
+
if len(video_files) > 0:
|
|
896
|
+
|
|
897
|
+
print('Running MegaDetector on {} videos...'.format(len(video_files)))
|
|
898
|
+
|
|
899
|
+
# Set up video processing options
|
|
900
|
+
video_options = ProcessVideoOptions()
|
|
901
|
+
video_options.model_file = detector_model
|
|
902
|
+
video_options.input_video_file = source_folder
|
|
903
|
+
video_options.output_json_file = detector_output_file.replace('.json', '_videos.json')
|
|
904
|
+
video_options.json_confidence_threshold = detection_confidence_threshold
|
|
905
|
+
video_options.frame_sample = frame_sample
|
|
906
|
+
video_options.time_sample = time_sample
|
|
907
|
+
video_options.recursive = True
|
|
908
|
+
|
|
909
|
+
# Process videos
|
|
910
|
+
process_videos(video_options)
|
|
911
|
+
|
|
912
|
+
print('Video detection results written to {}'.format(video_options.output_json_file))
|
|
913
|
+
files_to_merge.append(video_options.output_json_file)
|
|
914
|
+
|
|
915
|
+
# ...if we had videos to process
|
|
916
|
+
|
|
917
|
+
# Merge results if we have both images and videos
|
|
918
|
+
if len(files_to_merge) > 1:
|
|
919
|
+
print('Merging image and video detection results...')
|
|
920
|
+
combine_batch_output_files(files_to_merge, detector_output_file)
|
|
921
|
+
print('Merged detection results written to {}'.format(detector_output_file))
|
|
922
|
+
elif len(files_to_merge) == 1:
|
|
923
|
+
# Just rename the single file
|
|
924
|
+
if files_to_merge[0] != detector_output_file:
|
|
925
|
+
if os.path.isfile(detector_output_file):
|
|
926
|
+
print('Detector file {} exists, over-writing'.format(detector_output_file))
|
|
927
|
+
os.remove(detector_output_file)
|
|
928
|
+
os.rename(files_to_merge[0], detector_output_file)
|
|
929
|
+
print('Detection results written to {}'.format(detector_output_file))
|
|
930
|
+
|
|
931
|
+
# ...def _run_detection_step(...)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def _run_classification_step(detector_results_file: str,
|
|
935
|
+
merged_results_file: str,
|
|
936
|
+
source_folder: str,
|
|
937
|
+
classifier_model: str = DEFAULT_CLASSIFIER_MODEL,
|
|
938
|
+
classifier_batch_size: int = DEFAULT_CLASSIFIER_BATCH_SIZE,
|
|
939
|
+
classifier_worker_threads: int = DEFAULT_LOADER_WORKERS,
|
|
940
|
+
detection_confidence_threshold: float = \
|
|
941
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
|
|
942
|
+
enable_rollup: bool = True,
|
|
943
|
+
country: str = None,
|
|
944
|
+
admin1_region: str = None,
|
|
945
|
+
top_n_scores: int = DEFAULT_TOP_N_SCORES):
|
|
946
|
+
"""
|
|
947
|
+
Run SpeciesNet classification on detections from MegaDetector results.
|
|
948
|
+
|
|
949
|
+
Args:
|
|
950
|
+
detector_results_file (str): path to MegaDetector output .json file
|
|
951
|
+
merged_results_file (str): path to which we should write the merged results
|
|
952
|
+
source_folder (str): source folder for resolving relative paths
|
|
953
|
+
classifier_model (str, optional): classifier model identifier
|
|
954
|
+
classifier_batch_size (int, optional): batch size for classification
|
|
955
|
+
classifier_worker_threads (int, optional): number of worker threads
|
|
956
|
+
detection_confidence_threshold (float, optional): classify detections above this threshold
|
|
957
|
+
enable_rollup (bool, optional): whether to apply taxonomic rollup
|
|
958
|
+
country (str, optional): country code for geofencing
|
|
959
|
+
admin1_region (str, optional): admin1 region (typically a state code) for geofencing
|
|
960
|
+
top_n_scores (int, optional): maximum number of scores to include for each detection
|
|
961
|
+
"""
|
|
962
|
+
|
|
963
|
+
print('Starting classification step...')
|
|
964
|
+
|
|
965
|
+
# Load MegaDetector results
|
|
966
|
+
print('Reading detection results from {}'.format(detector_results_file))
|
|
967
|
+
|
|
968
|
+
with open(detector_results_file, 'r') as f:
|
|
969
|
+
detector_results = json.load(f)
|
|
970
|
+
|
|
971
|
+
print('Classification step loaded detection results for {} images'.format(
|
|
972
|
+
len(detector_results['images'])))
|
|
973
|
+
|
|
974
|
+
images = detector_results['images']
|
|
975
|
+
if len(images) == 0:
|
|
976
|
+
raise ValueError('No images found in detector results')
|
|
977
|
+
|
|
978
|
+
print('Using SpeciesNet classifier: {}'.format(classifier_model))
|
|
979
|
+
|
|
980
|
+
# Set multiprocessing start method to 'spawn' for CUDA compatibility
|
|
981
|
+
original_start_method = multiprocessing.get_start_method()
|
|
982
|
+
if original_start_method != 'spawn':
|
|
983
|
+
multiprocessing.set_start_method('spawn', force=True)
|
|
984
|
+
print('Set multiprocessing start method to spawn (was {})'.format(
|
|
985
|
+
original_start_method))
|
|
986
|
+
|
|
987
|
+
## Set up multiprocessing queues
|
|
988
|
+
|
|
989
|
+
# This queue receives lists of image filenames (and associated detection results)
|
|
990
|
+
# from the "main" thread (the one you're reading right now). Items are pulled off
|
|
991
|
+
# of this queue by producer workers (on _crop_producer_func), where the corresponding
|
|
992
|
+
# images are loaded from disk and preprocessed into crops.
|
|
993
|
+
image_queue = JoinableQueue(maxsize= \
|
|
994
|
+
classifier_worker_threads * MAX_IMAGE_QUEUE_SIZE_PER_WORKER)
|
|
995
|
+
|
|
996
|
+
# This queue receives cropped images from producers (on _crop_producer_func); those
|
|
997
|
+
# crops are pulled off of this queue by the consumer (on _crop_consumer_func).
|
|
998
|
+
batch_queue = Queue(maxsize=MAX_BATCH_QUEUE_SIZE)
|
|
999
|
+
|
|
1000
|
+
# This is not really used as a queue, rather it's just used to send all the results
|
|
1001
|
+
# at once from the consumer process to the main process (the one you're reading right
|
|
1002
|
+
# now).
|
|
1003
|
+
results_queue = Queue()
|
|
1004
|
+
|
|
1005
|
+
# Start producer workers
|
|
1006
|
+
producers = []
|
|
1007
|
+
for i_worker in range(classifier_worker_threads):
|
|
1008
|
+
p = Process(target=_crop_producer_func,
|
|
1009
|
+
args=(image_queue, batch_queue, classifier_model,
|
|
1010
|
+
detection_confidence_threshold, source_folder, i_worker))
|
|
1011
|
+
p.start()
|
|
1012
|
+
producers.append(p)
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
## Start consumer worker
|
|
1016
|
+
|
|
1017
|
+
consumer = Process(target=_crop_consumer_func,
|
|
1018
|
+
args=(batch_queue, results_queue, classifier_model,
|
|
1019
|
+
classifier_batch_size, classifier_worker_threads,
|
|
1020
|
+
enable_rollup, country, admin1_region))
|
|
1021
|
+
consumer.start()
|
|
1022
|
+
|
|
1023
|
+
# This will block every time the queue reaches its maximum depth, so for
|
|
1024
|
+
# very small jobs, this will not be a useful progress bar.
|
|
1025
|
+
with tqdm(total=len(images),desc='Classification') as pbar:
|
|
1026
|
+
for image_data in images:
|
|
1027
|
+
image_queue.put(image_data)
|
|
1028
|
+
pbar.update()
|
|
1029
|
+
|
|
1030
|
+
# Send sentinel signals to producers
|
|
1031
|
+
for _ in range(classifier_worker_threads):
|
|
1032
|
+
image_queue.put(None)
|
|
1033
|
+
|
|
1034
|
+
# Wait for all work to complete
|
|
1035
|
+
image_queue.join()
|
|
1036
|
+
|
|
1037
|
+
print('Finished waiting for input queue')
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
## Wait for results
|
|
1041
|
+
|
|
1042
|
+
classification_results = results_queue.get()
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
## Clean up processes
|
|
1046
|
+
|
|
1047
|
+
for p in producers:
|
|
1048
|
+
p.join()
|
|
1049
|
+
consumer.join()
|
|
1050
|
+
|
|
1051
|
+
print('Finished waiting for workers')
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
## Format results and write output
|
|
1055
|
+
|
|
1056
|
+
class CategoryState:
|
|
1057
|
+
"""
|
|
1058
|
+
Helper class to manage classification category IDs.
|
|
1059
|
+
"""
|
|
1060
|
+
|
|
1061
|
+
def __init__(self):
|
|
1062
|
+
|
|
1063
|
+
self.next_category_id = 0
|
|
1064
|
+
|
|
1065
|
+
# Maps common name to string-int IDs
|
|
1066
|
+
self.common_name_to_id = {}
|
|
1067
|
+
|
|
1068
|
+
# Maps string-ints to common names, as per format standard
|
|
1069
|
+
self.classification_categories = {}
|
|
1070
|
+
|
|
1071
|
+
# Maps string-ints to latin taxonomy strings, as per format standard
|
|
1072
|
+
self.classification_category_descriptions = {}
|
|
1073
|
+
|
|
1074
|
+
def _get_category_id(self, class_name):
|
|
1075
|
+
"""
|
|
1076
|
+
Get an integer-valued category ID for the 7-token string [class_name],
|
|
1077
|
+
creating a new one if necessary.
|
|
1078
|
+
"""
|
|
1079
|
+
|
|
1080
|
+
# E.g.:
|
|
1081
|
+
#
|
|
1082
|
+
# "cb553c4e-42c9-4fe0-9bd0-da2d6ed5bfa1;mammalia;carnivora;canidae;urocyon;littoralis;island fox"
|
|
1083
|
+
tokens = class_name.split(';')
|
|
1084
|
+
assert len(tokens) == 7
|
|
1085
|
+
taxonomy_string = ';'.join(tokens[1:6])
|
|
1086
|
+
common_name = tokens[6]
|
|
1087
|
+
if len(common_name) == 0:
|
|
1088
|
+
common_name = taxonomy_string
|
|
1089
|
+
|
|
1090
|
+
if common_name not in self.common_name_to_id:
|
|
1091
|
+
self.common_name_to_id[common_name] = str(self.next_category_id)
|
|
1092
|
+
self.classification_categories[str(self.next_category_id)] = common_name
|
|
1093
|
+
# Store the full seven-token string, rather than the shortened five-token string, for
|
|
1094
|
+
# compatibility with what is expected by the classification_postprocessing module.
|
|
1095
|
+
# self.classification_category_descriptions[str(self.next_category_id)] = taxonomy_string
|
|
1096
|
+
self.classification_category_descriptions[str(self.next_category_id)] = class_name
|
|
1097
|
+
self.next_category_id += 1
|
|
1098
|
+
|
|
1099
|
+
category_id = self.common_name_to_id[common_name]
|
|
1100
|
+
|
|
1101
|
+
return category_id
|
|
1102
|
+
|
|
1103
|
+
# ...class CategoryState
|
|
1104
|
+
|
|
1105
|
+
category_state = CategoryState()
|
|
1106
|
+
|
|
1107
|
+
# Merge classification results back into detector results with proper category IDs
|
|
1108
|
+
for image_data in images:
|
|
1109
|
+
|
|
1110
|
+
image_file = image_data['file']
|
|
1111
|
+
|
|
1112
|
+
if ('detections' not in image_data) or (image_data['detections'] is None):
|
|
1113
|
+
continue
|
|
1114
|
+
|
|
1115
|
+
detections = image_data['detections']
|
|
1116
|
+
|
|
1117
|
+
if image_file not in classification_results:
|
|
1118
|
+
continue
|
|
1119
|
+
|
|
1120
|
+
image_classifications = classification_results[image_file]
|
|
1121
|
+
|
|
1122
|
+
for detection_index, detection in enumerate(detections):
|
|
1123
|
+
|
|
1124
|
+
if detection_index in image_classifications:
|
|
1125
|
+
|
|
1126
|
+
result = image_classifications[detection_index]
|
|
1127
|
+
|
|
1128
|
+
if 'failure' in result:
|
|
1129
|
+
# Add failure to the image, not the detection
|
|
1130
|
+
if 'failure' not in image_data:
|
|
1131
|
+
image_data['failure'] = result['failure']
|
|
1132
|
+
else:
|
|
1133
|
+
image_data['failure'] += ';' + result['failure']
|
|
1134
|
+
else:
|
|
1135
|
+
|
|
1136
|
+
# Convert class names to category IDs
|
|
1137
|
+
classification_pairs = []
|
|
1138
|
+
raw_classification_pairs = []
|
|
1139
|
+
|
|
1140
|
+
scores = [x[1] for x in result['classifications']]
|
|
1141
|
+
assert is_list_sorted(scores, reverse=True)
|
|
1142
|
+
|
|
1143
|
+
# Only report the requested number of scores per detection
|
|
1144
|
+
if len(result['classifications']) > top_n_scores:
|
|
1145
|
+
result['classifications'] = \
|
|
1146
|
+
result['classifications'][0:top_n_scores]
|
|
1147
|
+
|
|
1148
|
+
if len(result['raw_classifications']) > top_n_scores:
|
|
1149
|
+
result['raw_classifications'] = \
|
|
1150
|
+
result['raw_classifications'][0:top_n_scores]
|
|
1151
|
+
|
|
1152
|
+
for class_name, score in result['classifications']:
|
|
1153
|
+
|
|
1154
|
+
category_id = category_state._get_category_id(class_name)
|
|
1155
|
+
score = round_float(score, precision=CONF_DIGITS)
|
|
1156
|
+
classification_pairs.append([category_id, score])
|
|
1157
|
+
|
|
1158
|
+
for class_name, score in result['raw_classifications']:
|
|
1159
|
+
|
|
1160
|
+
category_id = category_state._get_category_id(class_name)
|
|
1161
|
+
score = round_float(score, precision=CONF_DIGITS)
|
|
1162
|
+
raw_classification_pairs.append([category_id, score])
|
|
1163
|
+
|
|
1164
|
+
# Add classifications to the detection
|
|
1165
|
+
detection['classifications'] = classification_pairs
|
|
1166
|
+
# detection['raw_classifications'] = raw_classification_pairs
|
|
1167
|
+
|
|
1168
|
+
# ...if this classification contains a failure
|
|
1169
|
+
|
|
1170
|
+
# ...if this detection has classification information
|
|
1171
|
+
|
|
1172
|
+
# ...for each detection
|
|
1173
|
+
|
|
1174
|
+
# ...for each image
|
|
1175
|
+
|
|
1176
|
+
# Update metadata in the results
|
|
1177
|
+
if 'info' not in detector_results:
|
|
1178
|
+
detector_results['info'] = {}
|
|
1179
|
+
|
|
1180
|
+
detector_results['info']['classifier'] = classifier_model
|
|
1181
|
+
detector_results['info']['classification_completion_time'] = time.strftime(
|
|
1182
|
+
'%Y-%m-%d %H:%M:%S')
|
|
1183
|
+
|
|
1184
|
+
# Add classification category mapping
|
|
1185
|
+
detector_results['classification_categories'] = \
|
|
1186
|
+
category_state.classification_categories
|
|
1187
|
+
detector_results['classification_category_descriptions'] = \
|
|
1188
|
+
category_state.classification_category_descriptions
|
|
1189
|
+
|
|
1190
|
+
print('Writing output file')
|
|
1191
|
+
|
|
1192
|
+
# Write results
|
|
1193
|
+
write_json(merged_results_file, detector_results)
|
|
1194
|
+
|
|
1195
|
+
if verbose:
|
|
1196
|
+
print('Classification results written to {}'.format(merged_results_file))
|
|
1197
|
+
|
|
1198
|
+
# ...def _run_classification_step(...)
|
|
1199
|
+
|
|
1200
|
+
|
|
1201
|
+
#%% Options class
|
|
1202
|
+
|
|
1203
|
+
class RunMDSpeciesNetOptions:
|
|
1204
|
+
"""
|
|
1205
|
+
Class controlling the behavior of run_md_and_speciesnet()
|
|
1206
|
+
"""
|
|
1207
|
+
|
|
1208
|
+
def __init__(self):
|
|
1209
|
+
|
|
1210
|
+
#: Folder containing images and/or videos to process
|
|
1211
|
+
self.source = None
|
|
1212
|
+
|
|
1213
|
+
#: Output file for results (JSON format)
|
|
1214
|
+
self.output_file = None
|
|
1215
|
+
|
|
1216
|
+
#: MegaDetector model identifier (MDv5a, MDv5b, MDv1000-redwood, etc.)
|
|
1217
|
+
self.detector_model = DEFAULT_DETECTOR_MODEL
|
|
1218
|
+
|
|
1219
|
+
#: SpeciesNet classifier model identifier (e.g. kaggle:google/speciesnet/pyTorch/v4.0.1a)
|
|
1220
|
+
self.classification_model = DEFAULT_CLASSIFIER_MODEL
|
|
1221
|
+
|
|
1222
|
+
#: Batch size for MegaDetector inference
|
|
1223
|
+
self.detector_batch_size = DEFAULT_DETECTOR_BATCH_SIZE
|
|
1224
|
+
|
|
1225
|
+
#: Batch size for SpeciesNet classification
|
|
1226
|
+
self.classifier_batch_size = DEFAULT_CLASSIFIER_BATCH_SIZE
|
|
1227
|
+
|
|
1228
|
+
#: Number of worker threads for preprocessing
|
|
1229
|
+
self.loader_workers = DEFAULT_LOADER_WORKERS
|
|
1230
|
+
|
|
1231
|
+
#: Classify detections above this threshold
|
|
1232
|
+
self.detection_confidence_threshold_for_classification = \
|
|
1233
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION
|
|
1234
|
+
|
|
1235
|
+
#: Include detections above this threshold in the output
|
|
1236
|
+
self.detection_confidence_threshold_for_output = \
|
|
1237
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT
|
|
1238
|
+
|
|
1239
|
+
#: Folder for intermediate files (default: system temp)
|
|
1240
|
+
self.intermediate_file_folder = None
|
|
1241
|
+
|
|
1242
|
+
#: Keep intermediate files (e.g. detection-only results file)
|
|
1243
|
+
self.keep_intermediate_files = False
|
|
1244
|
+
|
|
1245
|
+
#: Disable taxonomic rollup
|
|
1246
|
+
self.norollup = False
|
|
1247
|
+
|
|
1248
|
+
#: Country code (ISO 3166-1 alpha-3) for geofencing (default None, no geoferencing)
|
|
1249
|
+
self.country = None
|
|
1250
|
+
|
|
1251
|
+
#: Admin1 region/state code for geofencing
|
|
1252
|
+
self.admin1_region = None
|
|
1253
|
+
|
|
1254
|
+
#: Path to existing MegaDetector output file (skips detection step)
|
|
1255
|
+
self.detections_file = None
|
|
1256
|
+
|
|
1257
|
+
#: Ignore videos, only process images
|
|
1258
|
+
self.skip_video = False
|
|
1259
|
+
|
|
1260
|
+
#: Ignore images, only process videos
|
|
1261
|
+
self.skip_images = False
|
|
1262
|
+
|
|
1263
|
+
#: Sample every Nth frame from videos
|
|
1264
|
+
#:
|
|
1265
|
+
#: Mutually exclusive with time_sample
|
|
1266
|
+
self.frame_sample = None
|
|
1267
|
+
|
|
1268
|
+
#: Sample frames every N seconds from videos
|
|
1269
|
+
#:
|
|
1270
|
+
#: Mutually exclusive with frame_sample
|
|
1271
|
+
self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
|
|
1272
|
+
|
|
1273
|
+
#: Enable additional debug output
|
|
1274
|
+
self.verbose = False
|
|
1275
|
+
|
|
1276
|
+
if self.time_sample is None and self.frame_sample is None:
|
|
1277
|
+
self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
|
|
1278
|
+
|
|
1279
|
+
# ...class RunMDSpeciesNetOptions
|
|
1280
|
+
|
|
1281
|
+
|
|
1282
|
+
#%% Main function
|
|
1283
|
+
|
|
1284
|
+
def run_md_and_speciesnet(options):
|
|
1285
|
+
"""
|
|
1286
|
+
Main entry point, runs MegaDetector and SpeciesNet on a folder. See
|
|
1287
|
+
RunMDSpeciesNetOptions for available arguments.
|
|
1288
|
+
|
|
1289
|
+
Args:
|
|
1290
|
+
options (RunMDSpeciesNetOptions): options controlling MD and SN inference
|
|
1291
|
+
"""
|
|
1292
|
+
|
|
1293
|
+
# Set global verbose flag
|
|
1294
|
+
global verbose
|
|
1295
|
+
verbose = options.verbose
|
|
1296
|
+
|
|
1297
|
+
# Also set the run_detector_batch verbose flag
|
|
1298
|
+
run_detector_batch.verbose = verbose
|
|
1299
|
+
|
|
1300
|
+
# Validate arguments
|
|
1301
|
+
if not os.path.isdir(options.source):
|
|
1302
|
+
raise ValueError(
|
|
1303
|
+
'Source folder does not exist: {}'.format(options.source))
|
|
1304
|
+
|
|
1305
|
+
if (options.admin1_region is not None) and (options.country is None):
|
|
1306
|
+
raise ValueError('--admin1_region requires --country to be specified')
|
|
1307
|
+
|
|
1308
|
+
if options.skip_images and options.skip_video:
|
|
1309
|
+
raise ValueError('Cannot skip both images and videos')
|
|
1310
|
+
|
|
1311
|
+
if (options.frame_sample is not None) and (options.time_sample is not None):
|
|
1312
|
+
raise ValueError('--frame_sample and --time_sample are mutually exclusive')
|
|
1313
|
+
if (options.frame_sample is None) and (options.time_sample is None):
|
|
1314
|
+
options.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
|
|
1315
|
+
|
|
1316
|
+
# Set up intermediate file folder
|
|
1317
|
+
if options.intermediate_file_folder:
|
|
1318
|
+
temp_folder = options.intermediate_file_folder
|
|
1319
|
+
os.makedirs(temp_folder, exist_ok=True)
|
|
1320
|
+
else:
|
|
1321
|
+
temp_folder = make_temp_folder(subfolder='run_md_and_speciesnet')
|
|
1322
|
+
|
|
1323
|
+
start_time = time.time()
|
|
1324
|
+
|
|
1325
|
+
print('Processing folder: {}'.format(options.source))
|
|
1326
|
+
print('Output file: {}'.format(options.output_file))
|
|
1327
|
+
print('Intermediate files: {}'.format(temp_folder))
|
|
1328
|
+
|
|
1329
|
+
# Determine detector output file path
|
|
1330
|
+
if options.detections_file is not None:
|
|
1331
|
+
detector_output_file = options.detections_file
|
|
1332
|
+
if VALIDATE_DETECTION_FILE:
|
|
1333
|
+
print('Using existing detections file: {}'.format(detector_output_file))
|
|
1334
|
+
validation_options = ValidateBatchResultsOptions()
|
|
1335
|
+
validation_options.check_image_existence = True
|
|
1336
|
+
validation_options.relative_path_base = options.source
|
|
1337
|
+
validation_options.raise_errors = True
|
|
1338
|
+
validate_batch_results(detector_output_file,options=validation_options)
|
|
1339
|
+
print('Validated detections file')
|
|
1340
|
+
else:
|
|
1341
|
+
print('Bypassing validation of {}'.format(options.detections_file))
|
|
1342
|
+
else:
|
|
1343
|
+
detector_output_file = os.path.join(temp_folder, 'detector_output.json')
|
|
1344
|
+
|
|
1345
|
+
# Run MegaDetector
|
|
1346
|
+
_run_detection_step(
|
|
1347
|
+
source_folder=options.source,
|
|
1348
|
+
detector_output_file=detector_output_file,
|
|
1349
|
+
detector_model=options.detector_model,
|
|
1350
|
+
detector_batch_size=options.detector_batch_size,
|
|
1351
|
+
detection_confidence_threshold=options.detection_confidence_threshold_for_output,
|
|
1352
|
+
detector_worker_threads=options.loader_workers,
|
|
1353
|
+
skip_images=options.skip_images,
|
|
1354
|
+
skip_video=options.skip_video,
|
|
1355
|
+
frame_sample=options.frame_sample,
|
|
1356
|
+
time_sample=options.time_sample
|
|
1357
|
+
)
|
|
1358
|
+
|
|
1359
|
+
# Run SpeciesNet
|
|
1360
|
+
_run_classification_step(
|
|
1361
|
+
detector_results_file=detector_output_file,
|
|
1362
|
+
merged_results_file=options.output_file,
|
|
1363
|
+
source_folder=options.source,
|
|
1364
|
+
classifier_model=options.classification_model,
|
|
1365
|
+
classifier_batch_size=options.classifier_batch_size,
|
|
1366
|
+
classifier_worker_threads=options.loader_workers,
|
|
1367
|
+
detection_confidence_threshold=options.detection_confidence_threshold_for_classification,
|
|
1368
|
+
enable_rollup=(not options.norollup),
|
|
1369
|
+
country=options.country,
|
|
1370
|
+
admin1_region=options.admin1_region,
|
|
1371
|
+
)
|
|
1372
|
+
|
|
1373
|
+
elapsed_time = time.time() - start_time
|
|
1374
|
+
print(
|
|
1375
|
+
'Processing complete in {}'.format(humanfriendly.format_timespan(elapsed_time)))
|
|
1376
|
+
print('Results written to: {}'.format(options.output_file))
|
|
1377
|
+
|
|
1378
|
+
# Clean up intermediate files if requested
|
|
1379
|
+
if (not options.keep_intermediate_files) and \
|
|
1380
|
+
(not options.intermediate_file_folder) and \
|
|
1381
|
+
(not options.detections_file):
|
|
1382
|
+
try:
|
|
1383
|
+
os.remove(detector_output_file)
|
|
1384
|
+
except Exception as e:
|
|
1385
|
+
print('Warning: error removing temporary output file {}: {}'.format(
|
|
1386
|
+
detector_output_file, str(e)))
|
|
1387
|
+
|
|
1388
|
+
# ...def run_md_and_speciesnet(...)
|
|
1389
|
+
|
|
1390
|
+
|
|
1391
|
+
#%% Command-line driver
|
|
1392
|
+
|
|
1393
|
+
def main():
|
|
1394
|
+
"""
|
|
1395
|
+
Command-line driver for run_md_and_speciesnet.py
|
|
1396
|
+
"""
|
|
1397
|
+
|
|
1398
|
+
if 'speciesnet' not in sys.modules:
|
|
1399
|
+
print('It looks like the speciesnet package is not available, try "pip install speciesnet"')
|
|
1400
|
+
if not is_sphinx_build():
|
|
1401
|
+
sys.exit(-1)
|
|
1402
|
+
|
|
1403
|
+
parser = argparse.ArgumentParser(
|
|
1404
|
+
description='Run MegaDetector and SpeciesNet on a folder of images/videos',
|
|
1405
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
# Required arguments
|
|
1409
|
+
parser.add_argument('source',
|
|
1410
|
+
help='Folder containing images and/or videos to process')
|
|
1411
|
+
parser.add_argument('output_file',
|
|
1412
|
+
help='Output file for results (JSON format)')
|
|
1413
|
+
|
|
1414
|
+
# Optional arguments
|
|
1415
|
+
parser.add_argument('--detector_model',
|
|
1416
|
+
default=DEFAULT_DETECTOR_MODEL,
|
|
1417
|
+
help='MegaDetector model identifier')
|
|
1418
|
+
parser.add_argument('--classification_model',
|
|
1419
|
+
default=DEFAULT_CLASSIFIER_MODEL,
|
|
1420
|
+
help='SpeciesNet classifier model identifier')
|
|
1421
|
+
parser.add_argument('--detector_batch_size',
|
|
1422
|
+
type=int,
|
|
1423
|
+
default=DEFAULT_DETECTOR_BATCH_SIZE,
|
|
1424
|
+
help='Batch size for MegaDetector inference')
|
|
1425
|
+
parser.add_argument('--classifier_batch_size',
|
|
1426
|
+
type=int,
|
|
1427
|
+
default=DEFAULT_CLASSIFIER_BATCH_SIZE,
|
|
1428
|
+
help='Batch size for SpeciesNet classification')
|
|
1429
|
+
parser.add_argument('--loader_workers',
|
|
1430
|
+
type=int,
|
|
1431
|
+
default=DEFAULT_LOADER_WORKERS,
|
|
1432
|
+
help='Number of worker threads for preprocessing')
|
|
1433
|
+
parser.add_argument('--detection_confidence_threshold_for_classification',
|
|
1434
|
+
type=float,
|
|
1435
|
+
default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
|
|
1436
|
+
help='Classify detections above this threshold')
|
|
1437
|
+
parser.add_argument('--detection_confidence_threshold_for_output',
|
|
1438
|
+
type=float,
|
|
1439
|
+
default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT,
|
|
1440
|
+
help='Include detections above this threshold in the output')
|
|
1441
|
+
parser.add_argument('--intermediate_file_folder',
|
|
1442
|
+
default=None,
|
|
1443
|
+
help='Folder for intermediate files (default: system temp)')
|
|
1444
|
+
parser.add_argument('--keep_intermediate_files',
|
|
1445
|
+
action='store_true',
|
|
1446
|
+
help='Keep intermediate files (e.g. detection-only results file)')
|
|
1447
|
+
parser.add_argument('--norollup',
|
|
1448
|
+
action='store_true',
|
|
1449
|
+
help='Disable taxonomic rollup')
|
|
1450
|
+
parser.add_argument('--country',
|
|
1451
|
+
default=None,
|
|
1452
|
+
help='Country code (ISO 3166-1 alpha-3) for geofencing')
|
|
1453
|
+
parser.add_argument('--admin1_region', '--state',
|
|
1454
|
+
default=None,
|
|
1455
|
+
help='Admin1 region/state code for geofencing')
|
|
1456
|
+
parser.add_argument('--detections_file',
|
|
1457
|
+
default=None,
|
|
1458
|
+
help='Path to existing MegaDetector output file (skips detection step)')
|
|
1459
|
+
parser.add_argument('--skip_video',
|
|
1460
|
+
action='store_true',
|
|
1461
|
+
help='Ignore videos, only process images')
|
|
1462
|
+
parser.add_argument('--skip_images',
|
|
1463
|
+
action='store_true',
|
|
1464
|
+
help='Ignore images, only process videos')
|
|
1465
|
+
parser.add_argument('--frame_sample',
|
|
1466
|
+
type=int,
|
|
1467
|
+
default=None,
|
|
1468
|
+
help='Sample every Nth frame from videos (mutually exclusive with --time_sample)')
|
|
1469
|
+
parser.add_argument('--time_sample',
|
|
1470
|
+
type=float,
|
|
1471
|
+
default=None,
|
|
1472
|
+
help='Sample frames every N seconds from videos (default {})'.\
|
|
1473
|
+
format(DEFAULT_SECONDS_PER_VIDEO_FRAME) + \
|
|
1474
|
+
' (mutually exclusive with --frame_sample)')
|
|
1475
|
+
parser.add_argument('--verbose',
|
|
1476
|
+
action='store_true',
|
|
1477
|
+
help='Enable additional debug output')
|
|
1478
|
+
|
|
1479
|
+
if len(sys.argv[1:]) == 0:
|
|
1480
|
+
parser.print_help()
|
|
1481
|
+
parser.exit()
|
|
1482
|
+
|
|
1483
|
+
args = parser.parse_args()
|
|
1484
|
+
|
|
1485
|
+
options = RunMDSpeciesNetOptions()
|
|
1486
|
+
args_to_object(args,options)
|
|
1487
|
+
|
|
1488
|
+
run_md_and_speciesnet(options)
|
|
1489
|
+
|
|
1490
|
+
# ...def main(...)
|
|
1491
|
+
|
|
1492
|
+
|
|
1493
|
+
if __name__ == '__main__':
|
|
1494
|
+
main()
|