megadetector 10.0.1__py3-none-any.whl → 10.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/detection/process_video.py +120 -913
- megadetector/detection/pytorch_detector.py +572 -263
- megadetector/detection/run_detector.py +13 -6
- megadetector/detection/run_detector_batch.py +525 -143
- megadetector/detection/run_md_and_speciesnet.py +1301 -0
- megadetector/detection/video_utils.py +240 -105
- megadetector/postprocessing/classification_postprocessing.py +12 -1
- megadetector/postprocessing/compare_batch_results.py +21 -2
- megadetector/postprocessing/merge_detections.py +16 -12
- megadetector/postprocessing/validate_batch_results.py +25 -2
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/ct_utils.py +16 -5
- megadetector/utils/extract_frames_from_video.py +303 -0
- megadetector/utils/md_tests.py +578 -520
- megadetector/utils/wi_utils.py +20 -4
- megadetector/visualization/visualize_db.py +8 -22
- megadetector/visualization/visualize_detector_output.py +1 -1
- megadetector/visualization/visualize_video_output.py +607 -0
- {megadetector-10.0.1.dist-info → megadetector-10.0.3.dist-info}/METADATA +134 -135
- {megadetector-10.0.1.dist-info → megadetector-10.0.3.dist-info}/RECORD +24 -19
- {megadetector-10.0.1.dist-info → megadetector-10.0.3.dist-info}/licenses/LICENSE +0 -0
- {megadetector-10.0.1.dist-info → megadetector-10.0.3.dist-info}/top_level.txt +0 -0
- {megadetector-10.0.1.dist-info → megadetector-10.0.3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,1301 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
run_md_and_speciesnet.py
|
|
4
|
+
|
|
5
|
+
Script to run MegaDetector and SpeciesNet on a folder of images and/or videos.
|
|
6
|
+
Runs MD first, then runs SpeciesNet on every above-threshold crop.
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
#%% Constants, imports, environment
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import json
|
|
14
|
+
import multiprocessing
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
import time
|
|
18
|
+
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from multiprocessing import JoinableQueue, Process, Queue
|
|
21
|
+
|
|
22
|
+
import humanfriendly
|
|
23
|
+
|
|
24
|
+
from megadetector.detection import run_detector_batch
|
|
25
|
+
from megadetector.detection.video_utils import find_videos, run_callback_on_frames, is_video_file
|
|
26
|
+
from megadetector.detection.run_detector_batch import load_and_run_detector_batch
|
|
27
|
+
from megadetector.detection.run_detector import DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
|
|
28
|
+
from megadetector.detection.run_detector import CONF_DIGITS
|
|
29
|
+
from megadetector.detection.run_detector_batch import write_results_to_file
|
|
30
|
+
from megadetector.utils.ct_utils import round_float
|
|
31
|
+
from megadetector.utils.ct_utils import write_json
|
|
32
|
+
from megadetector.utils.ct_utils import make_temp_folder
|
|
33
|
+
from megadetector.utils.ct_utils import is_list_sorted
|
|
34
|
+
from megadetector.utils import path_utils
|
|
35
|
+
from megadetector.visualization import visualization_utils as vis_utils
|
|
36
|
+
from megadetector.postprocessing.validate_batch_results import \
|
|
37
|
+
validate_batch_results, ValidateBatchResultsOptions
|
|
38
|
+
from megadetector.detection.process_video import \
|
|
39
|
+
process_videos, ProcessVideoOptions
|
|
40
|
+
from megadetector.postprocessing.combine_batch_outputs import combine_batch_output_files
|
|
41
|
+
|
|
42
|
+
# We aren't taking an explicit dependency on the speciesnet package yet,
|
|
43
|
+
# so we wrap this in a try/except so sphinx can still document this module.
|
|
44
|
+
try:
|
|
45
|
+
from speciesnet import SpeciesNetClassifier
|
|
46
|
+
from speciesnet.utils import BBox
|
|
47
|
+
from speciesnet.ensemble import SpeciesNetEnsemble
|
|
48
|
+
from speciesnet.geofence_utils import roll_up_labels_to_first_matching_level
|
|
49
|
+
from speciesnet.geofence_utils import geofence_animal_classification
|
|
50
|
+
except Exception:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
#%% Constants
|
|
55
|
+
|
|
56
|
+
DEFAULT_DETECTOR_MODEL = 'MDV5A'
|
|
57
|
+
DEFAULT_CLASSIFIER_MODEL = 'kaggle:google/speciesnet/pyTorch/v4.0.1a'
|
|
58
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION = 0.1
|
|
59
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
|
|
60
|
+
DEFAULT_DETECTOR_BATCH_SIZE = 1
|
|
61
|
+
DEFAULT_CLASSIFIER_BATCH_SIZE = 8
|
|
62
|
+
DEFAULT_LOADER_WORKERS = 4
|
|
63
|
+
MAX_QUEUE_SIZE_IMAGES_PER_WORKER = 10
|
|
64
|
+
DEAFULT_SECONDS_PER_VIDEO_FRAME = 1.0
|
|
65
|
+
|
|
66
|
+
# Max number of classification scores to include per detection
|
|
67
|
+
DEFAULT_TOP_N_SCORES = 2
|
|
68
|
+
|
|
69
|
+
# Unless --norollup is specified, roll up taxonomic levels until the
|
|
70
|
+
# cumulative confidence is above this value
|
|
71
|
+
ROLLUP_TARGET_CONFIDENCE = 0.5
|
|
72
|
+
|
|
73
|
+
verbose = False
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
#%% Support classes
|
|
77
|
+
|
|
78
|
+
class CropMetadata:
|
|
79
|
+
"""
|
|
80
|
+
Metadata for a crop extracted from an image detection.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self,
|
|
84
|
+
image_file: str,
|
|
85
|
+
detection_index: int,
|
|
86
|
+
bbox: list[float],
|
|
87
|
+
original_width: int,
|
|
88
|
+
original_height: int):
|
|
89
|
+
"""
|
|
90
|
+
Args:
|
|
91
|
+
image_file (str): path to the original image file
|
|
92
|
+
detection_index (int): index of this detection in the image
|
|
93
|
+
bbox (List[float]): normalized bounding box [x_min, y_min, width, height]
|
|
94
|
+
original_width (int): width of the original image
|
|
95
|
+
original_height (int): height of the original image
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
self.image_file = image_file
|
|
99
|
+
self.detection_index = detection_index
|
|
100
|
+
self.bbox = bbox
|
|
101
|
+
self.original_width = original_width
|
|
102
|
+
self.original_height = original_height
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class CropBatch:
|
|
106
|
+
"""
|
|
107
|
+
A batch of crops with their metadata for classification.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self):
|
|
111
|
+
# List of preprocessed images
|
|
112
|
+
self.crops = []
|
|
113
|
+
|
|
114
|
+
# List of CropMetadata objects
|
|
115
|
+
self.metadata = []
|
|
116
|
+
|
|
117
|
+
def add_crop(self, crop_data, metadata):
|
|
118
|
+
"""
|
|
119
|
+
Args:
|
|
120
|
+
crop_data (PreprocessedImage): preprocessed image data from
|
|
121
|
+
SpeciesNetClassifier.preprocess()
|
|
122
|
+
metadata (CropMetadata): metadata for this crop
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
self.crops.append(crop_data)
|
|
126
|
+
self.metadata.append(metadata)
|
|
127
|
+
|
|
128
|
+
def __len__(self):
|
|
129
|
+
return len(self.crops)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
#%% Support functions for classification
|
|
133
|
+
|
|
134
|
+
def _process_image_detections(file_path: str,
|
|
135
|
+
absolute_file_path: str,
|
|
136
|
+
detection_results: dict,
|
|
137
|
+
classifier: 'SpeciesNetClassifier',
|
|
138
|
+
detection_confidence_threshold: float,
|
|
139
|
+
batch_queue: Queue):
|
|
140
|
+
"""
|
|
141
|
+
Process detections from a single image.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
file_path (str): relative path to the image file
|
|
145
|
+
absolute_file_path (str): absolute path to the image file
|
|
146
|
+
detection_results (dict): detection results for this image
|
|
147
|
+
classifier (SpeciesNetClassifier): classifier instance for preprocessing
|
|
148
|
+
detection_confidence_threshold (float): classify detections above this threshold
|
|
149
|
+
batch_queue (Queue): queue to send crops to
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
detections = detection_results['detections']
|
|
153
|
+
|
|
154
|
+
# Load the image
|
|
155
|
+
try:
|
|
156
|
+
image = vis_utils.load_image(absolute_file_path)
|
|
157
|
+
original_width, original_height = image.size
|
|
158
|
+
except Exception as e:
|
|
159
|
+
print('Warning: failed to load image {}: {}'.format(file_path, str(e)))
|
|
160
|
+
|
|
161
|
+
# Send failure information to consumer
|
|
162
|
+
failure_metadata = CropMetadata(
|
|
163
|
+
image_file=file_path,
|
|
164
|
+
detection_index=-1, # -1 indicates whole-image failure
|
|
165
|
+
bbox=[],
|
|
166
|
+
original_width=0,
|
|
167
|
+
original_height=0
|
|
168
|
+
)
|
|
169
|
+
batch_queue.put(('failure',
|
|
170
|
+
'Failed to load image: {}'.format(str(e)),
|
|
171
|
+
failure_metadata))
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
# Process each detection above threshold
|
|
175
|
+
for detection_index, detection in enumerate(detections):
|
|
176
|
+
|
|
177
|
+
conf = detection['conf']
|
|
178
|
+
if conf < detection_confidence_threshold:
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
bbox = detection['bbox']
|
|
182
|
+
assert len(bbox) == 4
|
|
183
|
+
|
|
184
|
+
# Convert normalized bbox to BBox object for SpeciesNet
|
|
185
|
+
speciesnet_bbox = BBox(
|
|
186
|
+
xmin=bbox[0],
|
|
187
|
+
ymin=bbox[1],
|
|
188
|
+
width=bbox[2],
|
|
189
|
+
height=bbox[3]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Preprocess the crop
|
|
193
|
+
try:
|
|
194
|
+
preprocessed_crop = classifier.preprocess(
|
|
195
|
+
image,
|
|
196
|
+
bboxes=[speciesnet_bbox],
|
|
197
|
+
resize=True
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if preprocessed_crop is not None:
|
|
201
|
+
metadata = CropMetadata(
|
|
202
|
+
image_file=file_path,
|
|
203
|
+
detection_index=detection_index,
|
|
204
|
+
bbox=bbox,
|
|
205
|
+
original_width=original_width,
|
|
206
|
+
original_height=original_height
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Send individual crop immediately to consumer
|
|
210
|
+
batch_queue.put(('crop', preprocessed_crop, metadata))
|
|
211
|
+
|
|
212
|
+
except Exception as e:
|
|
213
|
+
print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
|
|
214
|
+
file_path, detection_index, str(e)))
|
|
215
|
+
|
|
216
|
+
# Send failure information to consumer
|
|
217
|
+
failure_metadata = CropMetadata(
|
|
218
|
+
image_file=file_path,
|
|
219
|
+
detection_index=detection_index,
|
|
220
|
+
bbox=bbox,
|
|
221
|
+
original_width=original_width,
|
|
222
|
+
original_height=original_height
|
|
223
|
+
)
|
|
224
|
+
batch_queue.put(('failure',
|
|
225
|
+
'Failed to preprocess crop: {}'.format(str(e)),
|
|
226
|
+
failure_metadata))
|
|
227
|
+
|
|
228
|
+
# ...for each detection in this image
|
|
229
|
+
|
|
230
|
+
# ...def _process_image_detections(...)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _process_video_detections(file_path: str,
|
|
234
|
+
absolute_file_path: str,
|
|
235
|
+
detection_results: dict,
|
|
236
|
+
classifier: 'SpeciesNetClassifier',
|
|
237
|
+
detection_confidence_threshold: float,
|
|
238
|
+
batch_queue: Queue):
|
|
239
|
+
"""
|
|
240
|
+
Process detections from a single video.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
file_path (str): relative path to the video file
|
|
244
|
+
absolute_file_path (str): absolute path to the video file
|
|
245
|
+
detection_results (dict): detection results for this video
|
|
246
|
+
classifier (SpeciesNetClassifier): classifier instance for preprocessing
|
|
247
|
+
detection_confidence_threshold (float): classify detections above this threshold
|
|
248
|
+
batch_queue (Queue): queue to send crops to
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
detections = detection_results['detections']
|
|
252
|
+
|
|
253
|
+
# Find frames with above-threshold detections
|
|
254
|
+
frames_with_detections = set()
|
|
255
|
+
frame_to_detections = {}
|
|
256
|
+
|
|
257
|
+
for detection_index, detection in enumerate(detections):
|
|
258
|
+
conf = detection['conf']
|
|
259
|
+
if conf < detection_confidence_threshold:
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
frame_number = detection['frame_number']
|
|
263
|
+
frames_with_detections.add(frame_number)
|
|
264
|
+
|
|
265
|
+
if frame_number not in frame_to_detections:
|
|
266
|
+
frame_to_detections[frame_number] = []
|
|
267
|
+
frame_to_detections[frame_number].append((detection_index, detection))
|
|
268
|
+
|
|
269
|
+
if len(frames_with_detections) == 0:
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
frames_to_process = sorted(list(frames_with_detections))
|
|
273
|
+
|
|
274
|
+
# Define callback for processing each frame
|
|
275
|
+
def frame_callback(frame_array, frame_id):
|
|
276
|
+
"""
|
|
277
|
+
Callback to process a single frame.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
frame_array (numpy.ndarray): frame data in PIL format
|
|
281
|
+
frame_id (str): frame identifier like "frame0006.jpg"
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
# Extract frame number from frame_id (e.g., "frame0006.jpg" -> 6)
|
|
285
|
+
import re
|
|
286
|
+
match = re.match(r'frame(\d+)\.jpg', frame_id)
|
|
287
|
+
if not match:
|
|
288
|
+
print('Warning: could not parse frame number from {}'.format(frame_id))
|
|
289
|
+
return
|
|
290
|
+
frame_number = int(match.group(1))
|
|
291
|
+
|
|
292
|
+
if frame_number not in frame_to_detections:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
# Convert numpy array to PIL Image
|
|
296
|
+
from PIL import Image
|
|
297
|
+
if frame_array.dtype != 'uint8':
|
|
298
|
+
frame_array = (frame_array * 255).astype('uint8')
|
|
299
|
+
frame_image = Image.fromarray(frame_array)
|
|
300
|
+
original_width, original_height = frame_image.size
|
|
301
|
+
|
|
302
|
+
# Process each detection in this frame
|
|
303
|
+
for detection_index, detection in frame_to_detections[frame_number]:
|
|
304
|
+
|
|
305
|
+
bbox = detection['bbox']
|
|
306
|
+
assert len(bbox) == 4
|
|
307
|
+
|
|
308
|
+
# Convert normalized bbox to BBox object for SpeciesNet
|
|
309
|
+
speciesnet_bbox = BBox(
|
|
310
|
+
xmin=bbox[0],
|
|
311
|
+
ymin=bbox[1],
|
|
312
|
+
width=bbox[2],
|
|
313
|
+
height=bbox[3]
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Preprocess the crop
|
|
317
|
+
try:
|
|
318
|
+
|
|
319
|
+
preprocessed_crop = classifier.preprocess(
|
|
320
|
+
frame_image,
|
|
321
|
+
bboxes=[speciesnet_bbox],
|
|
322
|
+
resize=True
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if preprocessed_crop is not None:
|
|
326
|
+
metadata = CropMetadata(
|
|
327
|
+
image_file=file_path,
|
|
328
|
+
detection_index=detection_index,
|
|
329
|
+
bbox=bbox,
|
|
330
|
+
original_width=original_width,
|
|
331
|
+
original_height=original_height
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Send individual crop immediately to consumer
|
|
335
|
+
batch_queue.put(('crop', preprocessed_crop, metadata))
|
|
336
|
+
|
|
337
|
+
except Exception as e:
|
|
338
|
+
|
|
339
|
+
print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
|
|
340
|
+
file_path, detection_index, str(e)))
|
|
341
|
+
|
|
342
|
+
# Send failure information to consumer
|
|
343
|
+
failure_metadata = CropMetadata(
|
|
344
|
+
image_file=file_path,
|
|
345
|
+
detection_index=detection_index,
|
|
346
|
+
bbox=bbox,
|
|
347
|
+
original_width=original_width,
|
|
348
|
+
original_height=original_height
|
|
349
|
+
)
|
|
350
|
+
batch_queue.put(('failure',
|
|
351
|
+
'Failed to preprocess crop: {}'.format(str(e)),
|
|
352
|
+
failure_metadata))
|
|
353
|
+
|
|
354
|
+
# ...try/except
|
|
355
|
+
|
|
356
|
+
# ...for each detection
|
|
357
|
+
|
|
358
|
+
# ...def frame_callback(...)
|
|
359
|
+
|
|
360
|
+
# Process the video frames
|
|
361
|
+
try:
|
|
362
|
+
run_callback_on_frames(
|
|
363
|
+
input_video_file=absolute_file_path,
|
|
364
|
+
frame_callback=frame_callback,
|
|
365
|
+
frames_to_process=frames_to_process,
|
|
366
|
+
verbose=verbose
|
|
367
|
+
)
|
|
368
|
+
except Exception as e:
|
|
369
|
+
print('Warning: failed to process video {}: {}'.format(file_path, str(e)))
|
|
370
|
+
|
|
371
|
+
# Send failure information to consumer for the whole video
|
|
372
|
+
failure_metadata = CropMetadata(
|
|
373
|
+
image_file=file_path,
|
|
374
|
+
detection_index=-1, # -1 indicates whole-file failure
|
|
375
|
+
bbox=[],
|
|
376
|
+
original_width=0,
|
|
377
|
+
original_height=0
|
|
378
|
+
)
|
|
379
|
+
batch_queue.put(('failure',
|
|
380
|
+
'Failed to process video: {}'.format(str(e)),
|
|
381
|
+
failure_metadata))
|
|
382
|
+
# ...try/except
|
|
383
|
+
|
|
384
|
+
# ...def _process_video_detections(...)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _crop_producer_func(image_queue: JoinableQueue,
|
|
388
|
+
batch_queue: Queue,
|
|
389
|
+
classifier_model: str,
|
|
390
|
+
detection_confidence_threshold: float,
|
|
391
|
+
source_folder: str,
|
|
392
|
+
producer_id: int = -1):
|
|
393
|
+
"""
|
|
394
|
+
Producer function for classification workers.
|
|
395
|
+
|
|
396
|
+
Reads images and videos from [image_queue], crops detections above a threshold,
|
|
397
|
+
preprocesses them, and sends individual crops to [batch_queue].
|
|
398
|
+
See the documentation of _crop_consumer_func to for the format of the
|
|
399
|
+
tuples placed on batch_queue.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
image_queue (JoinableQueue): queue containing detection_results dicts (for both images and videos)
|
|
403
|
+
batch_queue (Queue): queue to put individual crops into
|
|
404
|
+
classifier_model (str): classifier model identifier to load in this process
|
|
405
|
+
detection_confidence_threshold (float): classify detections above this threshold
|
|
406
|
+
source_folder (str): source folder to resolve relative paths
|
|
407
|
+
producer_id (int, optional): identifier for this producer worker
|
|
408
|
+
"""
|
|
409
|
+
|
|
410
|
+
if verbose:
|
|
411
|
+
print('Classification producer starting: ID {}'.format(producer_id))
|
|
412
|
+
|
|
413
|
+
# Load classifier; this is just being used as a preprocessor, so we force device=cpu.
|
|
414
|
+
#
|
|
415
|
+
# There are a number of reasons loading the model might fail; note to self: *don't*
|
|
416
|
+
# catch Exceptions here. This should be a catastrophic failure that stops the whole
|
|
417
|
+
# process.
|
|
418
|
+
classifier = SpeciesNetClassifier(classifier_model, device='cpu')
|
|
419
|
+
if verbose:
|
|
420
|
+
print('Classification producer {}: loaded classifier'.format(producer_id))
|
|
421
|
+
|
|
422
|
+
while True:
|
|
423
|
+
|
|
424
|
+
# Pull an image of detection results from the queue
|
|
425
|
+
detection_results = image_queue.get()
|
|
426
|
+
|
|
427
|
+
# Pulling None from the queue indicates that this producer is done
|
|
428
|
+
if detection_results is None:
|
|
429
|
+
image_queue.task_done()
|
|
430
|
+
break
|
|
431
|
+
|
|
432
|
+
file_path = detection_results['file']
|
|
433
|
+
|
|
434
|
+
# Skip files that failed at the detection stage
|
|
435
|
+
if 'failure' in detection_results:
|
|
436
|
+
image_queue.task_done()
|
|
437
|
+
continue
|
|
438
|
+
|
|
439
|
+
# Skip files with no detections
|
|
440
|
+
detections = detection_results['detections']
|
|
441
|
+
if len(detections) == 0:
|
|
442
|
+
image_queue.task_done()
|
|
443
|
+
continue
|
|
444
|
+
|
|
445
|
+
# Determine if this is an image or video
|
|
446
|
+
absolute_file_path = os.path.join(source_folder, file_path)
|
|
447
|
+
is_video = is_video_file(file_path)
|
|
448
|
+
|
|
449
|
+
if is_video:
|
|
450
|
+
# Process video
|
|
451
|
+
_process_video_detections(
|
|
452
|
+
file_path=file_path,
|
|
453
|
+
absolute_file_path=absolute_file_path,
|
|
454
|
+
detection_results=detection_results,
|
|
455
|
+
classifier=classifier,
|
|
456
|
+
detection_confidence_threshold=detection_confidence_threshold,
|
|
457
|
+
batch_queue=batch_queue
|
|
458
|
+
)
|
|
459
|
+
else:
|
|
460
|
+
# Process image
|
|
461
|
+
_process_image_detections(
|
|
462
|
+
file_path=file_path,
|
|
463
|
+
absolute_file_path=absolute_file_path,
|
|
464
|
+
detection_results=detection_results,
|
|
465
|
+
classifier=classifier,
|
|
466
|
+
detection_confidence_threshold=detection_confidence_threshold,
|
|
467
|
+
batch_queue=batch_queue
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
image_queue.task_done()
|
|
471
|
+
|
|
472
|
+
# ...while(we still have items to process)
|
|
473
|
+
|
|
474
|
+
# Send sentinel to indicate this producer is done
|
|
475
|
+
batch_queue.put(None)
|
|
476
|
+
|
|
477
|
+
if verbose:
|
|
478
|
+
print('Classification producer {} finished'.format(producer_id))
|
|
479
|
+
|
|
480
|
+
# ...def _crop_producer_func(...)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _crop_consumer_func(batch_queue: Queue,
|
|
484
|
+
results_queue: Queue,
|
|
485
|
+
classifier_model: str,
|
|
486
|
+
batch_size: int,
|
|
487
|
+
num_producers: int,
|
|
488
|
+
enable_rollup: bool,
|
|
489
|
+
country: str = None,
|
|
490
|
+
admin1_region: str = None):
|
|
491
|
+
"""
|
|
492
|
+
Consumer function for classification inference.
|
|
493
|
+
|
|
494
|
+
Pulls individual crops from batch_queue, assembles them into batches,
|
|
495
|
+
runs inference, and puts results into results_queue.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
batch_queue (Queue): queue containing individual crop tuples or failures.
|
|
499
|
+
Items on this queue are either None (to indicate that a producer finished)
|
|
500
|
+
or tuples formatted as (type,image,metadata). [type] is a string (either
|
|
501
|
+
"crop" or "failure"), [image] is a PreprocessedImage, and [metadata] is
|
|
502
|
+
a CropMetadata object.
|
|
503
|
+
results_queue (Queue): queue to put classification results into
|
|
504
|
+
classifier_model (str): classifier model identifier to load
|
|
505
|
+
batch_size (int): batch size for inference
|
|
506
|
+
num_producers (int): number of producer workers
|
|
507
|
+
enable_rollup (bool): whether to apply taxonomic rollup
|
|
508
|
+
country (str, optional): country code for geofencing
|
|
509
|
+
admin1_region (str, optional): admin1 region for geofencing
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
if verbose:
|
|
513
|
+
print('Classification consumer starting')
|
|
514
|
+
|
|
515
|
+
# Load classifier
|
|
516
|
+
try:
|
|
517
|
+
classifier = SpeciesNetClassifier(classifier_model)
|
|
518
|
+
if verbose:
|
|
519
|
+
print('Classification consumer: loaded classifier')
|
|
520
|
+
except Exception as e:
|
|
521
|
+
print('Classification consumer: failed to load classifier: {}'.format(str(e)))
|
|
522
|
+
results_queue.put({})
|
|
523
|
+
return
|
|
524
|
+
|
|
525
|
+
all_results = {} # image_file -> {detection_index -> classification_result}
|
|
526
|
+
current_batch = CropBatch()
|
|
527
|
+
producers_finished = 0
|
|
528
|
+
|
|
529
|
+
# Load ensemble metadata if rollup/geofencing is enabled
|
|
530
|
+
taxonomy_map = {}
|
|
531
|
+
geofence_map = {}
|
|
532
|
+
|
|
533
|
+
if (enable_rollup is not None) or (country is not None):
|
|
534
|
+
|
|
535
|
+
# Note to self: there are a number of reasons loading the ensemble
|
|
536
|
+
# could fail here; don't catch this exception, this should be a
|
|
537
|
+
# catatstrophic failure.
|
|
538
|
+
ensemble = SpeciesNetEnsemble(
|
|
539
|
+
classifier_model, geofence=(country is not None))
|
|
540
|
+
taxonomy_map = ensemble.taxonomy_map
|
|
541
|
+
geofence_map = ensemble.geofence_map
|
|
542
|
+
|
|
543
|
+
# ...if we need to load ensemble components
|
|
544
|
+
|
|
545
|
+
while True:
|
|
546
|
+
|
|
547
|
+
# Pull an item from the queue
|
|
548
|
+
item = batch_queue.get()
|
|
549
|
+
|
|
550
|
+
# This indicates that a producer worker finished
|
|
551
|
+
if item is None:
|
|
552
|
+
|
|
553
|
+
producers_finished += 1
|
|
554
|
+
if producers_finished == num_producers:
|
|
555
|
+
# Process any remaining images
|
|
556
|
+
if len(current_batch) > 0:
|
|
557
|
+
_process_classification_batch(
|
|
558
|
+
current_batch, classifier, all_results,
|
|
559
|
+
enable_rollup, taxonomy_map, geofence_map,
|
|
560
|
+
country, admin1_region
|
|
561
|
+
)
|
|
562
|
+
break
|
|
563
|
+
continue
|
|
564
|
+
|
|
565
|
+
# ...if a producer finished
|
|
566
|
+
|
|
567
|
+
# If we got here, we know we have a crop to process, or
|
|
568
|
+
# a failure to ignore.
|
|
569
|
+
assert isinstance(item, tuple) and len(item) == 3
|
|
570
|
+
item_type, data, metadata = item
|
|
571
|
+
|
|
572
|
+
if metadata.image_file not in all_results:
|
|
573
|
+
all_results[metadata.image_file] = {}
|
|
574
|
+
|
|
575
|
+
# We should never be processing the same detetion twice
|
|
576
|
+
assert metadata.detection_index not in all_results[metadata.image_file]
|
|
577
|
+
|
|
578
|
+
if item_type == 'failure':
|
|
579
|
+
|
|
580
|
+
all_results[metadata.image_file][metadata.detection_index] = {
|
|
581
|
+
'failure': 'Failure classification: {}'.format(data)
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
else:
|
|
585
|
+
|
|
586
|
+
assert item_type == 'crop'
|
|
587
|
+
current_batch.add_crop(data, metadata)
|
|
588
|
+
assert len(current_batch) <= batch_size
|
|
589
|
+
|
|
590
|
+
# Process batch if necessary
|
|
591
|
+
if len(current_batch) == batch_size:
|
|
592
|
+
_process_classification_batch(
|
|
593
|
+
current_batch, classifier, all_results,
|
|
594
|
+
enable_rollup, taxonomy_map, geofence_map,
|
|
595
|
+
country, admin1_region
|
|
596
|
+
)
|
|
597
|
+
current_batch = CropBatch()
|
|
598
|
+
|
|
599
|
+
# ...was this item a failure or a crop?
|
|
600
|
+
|
|
601
|
+
# ...while (we have items to process)
|
|
602
|
+
|
|
603
|
+
results_queue.put(all_results)
|
|
604
|
+
|
|
605
|
+
if verbose:
|
|
606
|
+
print('Classification consumer finished')
|
|
607
|
+
|
|
608
|
+
# ...def _crop_consumer_func(...)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def _process_classification_batch(batch: CropBatch,
|
|
612
|
+
classifier: 'SpeciesNetClassifier',
|
|
613
|
+
all_results: dict,
|
|
614
|
+
enable_rollup: bool,
|
|
615
|
+
taxonomy_map: dict,
|
|
616
|
+
geofence_map: dict,
|
|
617
|
+
country: str = None,
|
|
618
|
+
admin1_region: str = None):
|
|
619
|
+
"""
|
|
620
|
+
Run a batch of crops through the classifier.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
batch (CropBatch): batch of crops to process
|
|
624
|
+
classifier (SpeciesNetClassifier): classifier instance
|
|
625
|
+
all_results (dict): dictionary to store results in, modified in-place with format:
|
|
626
|
+
{image_file: {detection_index: {'predictions': [[class_name, score], ...]}
|
|
627
|
+
or {image_file: {detection_index: {'failure': error_message}}}.
|
|
628
|
+
enable_rollup (bool): whether to apply rollup
|
|
629
|
+
taxonomy_map (dict): taxonomy mapping for rollup
|
|
630
|
+
geofence_map (dict): geofence mapping
|
|
631
|
+
country (str, optional): country code for geofencing
|
|
632
|
+
admin1_region (str, optional): admin1 region for geofencing
|
|
633
|
+
"""
|
|
634
|
+
|
|
635
|
+
if len(batch) == 0:
|
|
636
|
+
print('Warning: _process_classification_batch received empty batch')
|
|
637
|
+
return
|
|
638
|
+
|
|
639
|
+
# Prepare batch for inference
|
|
640
|
+
filepaths = [f"{metadata.image_file}_{metadata.detection_index}"
|
|
641
|
+
for metadata in batch.metadata]
|
|
642
|
+
|
|
643
|
+
# Run batch inference
|
|
644
|
+
try:
|
|
645
|
+
batch_results = classifier.batch_predict(filepaths, batch.crops)
|
|
646
|
+
except Exception as e:
|
|
647
|
+
print('*** Batch classification failed: {} ***'.format(str(e)))
|
|
648
|
+
# Mark all crops in this batch as failed
|
|
649
|
+
for metadata in batch.metadata:
|
|
650
|
+
if metadata.image_file not in all_results:
|
|
651
|
+
all_results[metadata.image_file] = {}
|
|
652
|
+
all_results[metadata.image_file][metadata.detection_index] = {
|
|
653
|
+
'failure': 'Failure classification: {}'.format(str(e))
|
|
654
|
+
}
|
|
655
|
+
return
|
|
656
|
+
|
|
657
|
+
# Process results
|
|
658
|
+
assert len(batch_results) == len(batch.metadata)
|
|
659
|
+
assert len(batch_results) == len(filepaths)
|
|
660
|
+
|
|
661
|
+
for i_result in range(0, len(batch_results)):
|
|
662
|
+
|
|
663
|
+
result = batch_results[i_result]
|
|
664
|
+
metadata = batch.metadata[i_result]
|
|
665
|
+
|
|
666
|
+
assert metadata.image_file in all_results, \
|
|
667
|
+
'File {} not in results dict'.format(metadata.image_file)
|
|
668
|
+
|
|
669
|
+
detection_index = metadata.detection_index
|
|
670
|
+
|
|
671
|
+
# Handle classification failure
|
|
672
|
+
if 'failures' in result:
|
|
673
|
+
print('*** Classification failure for image: {} ***'.format(
|
|
674
|
+
filepaths[i_result]))
|
|
675
|
+
all_results[metadata.image_file][detection_index] = {
|
|
676
|
+
'failure': 'Failure classification: SpeciesNet classifier failed'
|
|
677
|
+
}
|
|
678
|
+
continue
|
|
679
|
+
|
|
680
|
+
# Extract classification results; this is a dict with keys "classes"
|
|
681
|
+
# and "scores", each of which points to a list.
|
|
682
|
+
classifications = result['classifications']
|
|
683
|
+
classes = classifications['classes']
|
|
684
|
+
scores = classifications['scores']
|
|
685
|
+
|
|
686
|
+
classification_was_geofenced = False
|
|
687
|
+
|
|
688
|
+
predicted_class = classes[0]
|
|
689
|
+
predicted_score = scores[0]
|
|
690
|
+
|
|
691
|
+
# Possibly apply geofencing
|
|
692
|
+
if country:
|
|
693
|
+
|
|
694
|
+
geofence_result = geofence_animal_classification(
|
|
695
|
+
labels=classes,
|
|
696
|
+
scores=scores,
|
|
697
|
+
country=country,
|
|
698
|
+
admin1_region=admin1_region,
|
|
699
|
+
taxonomy_map=taxonomy_map,
|
|
700
|
+
geofence_map=geofence_map,
|
|
701
|
+
enable_geofence=True
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
geofenced_class, geofenced_score, prediction_source = geofence_result
|
|
705
|
+
|
|
706
|
+
if prediction_source != 'classifier':
|
|
707
|
+
classification_was_geofenced = True
|
|
708
|
+
predicted_class = geofenced_class
|
|
709
|
+
predicted_score = geofenced_score
|
|
710
|
+
|
|
711
|
+
# ...if we might need to apply geofencing
|
|
712
|
+
|
|
713
|
+
# Possibly apply rollup; this was already done if geofencing was applied
|
|
714
|
+
if enable_rollup and (not classification_was_geofenced):
|
|
715
|
+
|
|
716
|
+
rollup_result = roll_up_labels_to_first_matching_level(
|
|
717
|
+
labels=classes,
|
|
718
|
+
scores=scores,
|
|
719
|
+
country=country,
|
|
720
|
+
admin1_region=admin1_region,
|
|
721
|
+
target_taxonomy_levels=['species','genus','family', 'order','class', 'kingdom'],
|
|
722
|
+
non_blank_threshold=ROLLUP_TARGET_CONFIDENCE,
|
|
723
|
+
taxonomy_map=taxonomy_map,
|
|
724
|
+
geofence_map=geofence_map,
|
|
725
|
+
enable_geofence=(country is not None)
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
if rollup_result is not None:
|
|
729
|
+
rolled_up_class, rolled_up_score, prediction_source = rollup_result
|
|
730
|
+
if rolled_up_class != predicted_class:
|
|
731
|
+
predicted_class = rolled_up_class
|
|
732
|
+
predicted_score = rolled_up_score
|
|
733
|
+
|
|
734
|
+
# ...if we might need to apply taxonomic rollup
|
|
735
|
+
|
|
736
|
+
# For now, we'll store category names as strings; these will be assigned to integer
|
|
737
|
+
# IDs before writing results to file later.
|
|
738
|
+
classification = [predicted_class,predicted_score]
|
|
739
|
+
|
|
740
|
+
# Also report raw model classifications
|
|
741
|
+
raw_classifications = []
|
|
742
|
+
for i_class in range(0,len(classes)):
|
|
743
|
+
raw_classifications.append([classes[i_class],scores[i_class]])
|
|
744
|
+
|
|
745
|
+
all_results[metadata.image_file][detection_index] = {
|
|
746
|
+
'classifications': [classification],
|
|
747
|
+
'raw_classifications': raw_classifications
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
# ...for each result in this batch
|
|
751
|
+
|
|
752
|
+
# ...def _process_classification_batch(...)
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
#%% Inference functions
|
|
756
|
+
|
|
757
|
+
def _run_detection_step(source_folder: str,
|
|
758
|
+
detector_output_file: str,
|
|
759
|
+
detector_model: str = DEFAULT_DETECTOR_MODEL,
|
|
760
|
+
detector_batch_size: int = DEFAULT_DETECTOR_BATCH_SIZE,
|
|
761
|
+
detection_confidence_threshold: float = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
|
|
762
|
+
detector_worker_threads: int = DEFAULT_LOADER_WORKERS,
|
|
763
|
+
skip_images: bool = False,
|
|
764
|
+
skip_video: bool = False,
|
|
765
|
+
frame_sample: int = None,
|
|
766
|
+
time_sample: float = None) -> str:
|
|
767
|
+
"""
|
|
768
|
+
Run MegaDetector on all images/videos in [source_folder].
|
|
769
|
+
|
|
770
|
+
Args:
|
|
771
|
+
source_folder (str): folder containing images/videos
|
|
772
|
+
detector_output_file (str): output .json file
|
|
773
|
+
detector_model (str, optional): detector model identifier
|
|
774
|
+
detector_batch_size (int, optional): batch size for detection
|
|
775
|
+
detection_confidence_threshold (float, optional): confidence threshold for detections
|
|
776
|
+
(to include in the output file)
|
|
777
|
+
detector_worker_threads (int, optional): number of workers to use for preprocessing
|
|
778
|
+
skip_images (bool, optional): ignore images, only process videos
|
|
779
|
+
skip_video (bool, optional): ignore videos, only process images
|
|
780
|
+
frame_sample (int, optional): sample every Nth frame from videos
|
|
781
|
+
time_sample (float, optional): sample frames every N seconds from videos
|
|
782
|
+
"""
|
|
783
|
+
|
|
784
|
+
print('Starting detection step...')
|
|
785
|
+
|
|
786
|
+
# Validate arguments
|
|
787
|
+
assert not (frame_sample is None and time_sample is None), \
|
|
788
|
+
'Must specify either frame_sample or time_sample'
|
|
789
|
+
|
|
790
|
+
# Find image and video files
|
|
791
|
+
if not skip_images:
|
|
792
|
+
image_files = path_utils.find_images(source_folder, recursive=True,
|
|
793
|
+
return_relative_paths=False)
|
|
794
|
+
else:
|
|
795
|
+
image_files = []
|
|
796
|
+
|
|
797
|
+
if not skip_video:
|
|
798
|
+
video_files = find_videos(source_folder, recursive=True,
|
|
799
|
+
return_relative_paths=False)
|
|
800
|
+
else:
|
|
801
|
+
video_files = []
|
|
802
|
+
|
|
803
|
+
if len(image_files) == 0 and len(video_files) == 0:
|
|
804
|
+
raise ValueError(
|
|
805
|
+
'No images or videos found in {}'.format(source_folder))
|
|
806
|
+
|
|
807
|
+
print('Found {} images and {} videos'.format(len(image_files), len(video_files)))
|
|
808
|
+
|
|
809
|
+
files_to_merge = []
|
|
810
|
+
|
|
811
|
+
# Process images if any
|
|
812
|
+
if len(image_files) > 0:
|
|
813
|
+
print('Running MegaDetector on {} images...'.format(len(image_files)))
|
|
814
|
+
|
|
815
|
+
image_results = load_and_run_detector_batch(
|
|
816
|
+
model_file=detector_model,
|
|
817
|
+
image_file_names=image_files,
|
|
818
|
+
checkpoint_path=None,
|
|
819
|
+
confidence_threshold=detection_confidence_threshold,
|
|
820
|
+
checkpoint_frequency=-1,
|
|
821
|
+
results=None,
|
|
822
|
+
n_cores=0,
|
|
823
|
+
use_image_queue=True,
|
|
824
|
+
quiet=True,
|
|
825
|
+
image_size=None,
|
|
826
|
+
batch_size=detector_batch_size,
|
|
827
|
+
include_image_size=False,
|
|
828
|
+
include_image_timestamp=False,
|
|
829
|
+
include_exif_data=False,
|
|
830
|
+
loader_workers=detector_worker_threads,
|
|
831
|
+
preprocess_on_image_queue=True
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
# Write image results to temporary file
|
|
835
|
+
image_output_file = detector_output_file.replace('.json', '_images.json')
|
|
836
|
+
write_results_to_file(image_results,
|
|
837
|
+
image_output_file,
|
|
838
|
+
relative_path_base=source_folder,
|
|
839
|
+
detector_file=detector_model)
|
|
840
|
+
|
|
841
|
+
print('Image detection results written to {}'.format(image_output_file))
|
|
842
|
+
files_to_merge.append(image_output_file)
|
|
843
|
+
|
|
844
|
+
# Process videos if any
|
|
845
|
+
if len(video_files) > 0:
|
|
846
|
+
print('Running MegaDetector on {} videos...'.format(len(video_files)))
|
|
847
|
+
|
|
848
|
+
# Set up video processing options
|
|
849
|
+
video_options = ProcessVideoOptions()
|
|
850
|
+
video_options.model_file = detector_model
|
|
851
|
+
video_options.input_video_file = source_folder
|
|
852
|
+
video_options.output_json_file = detector_output_file.replace('.json', '_videos.json')
|
|
853
|
+
video_options.json_confidence_threshold = detection_confidence_threshold
|
|
854
|
+
video_options.frame_sample = frame_sample
|
|
855
|
+
video_options.time_sample = time_sample
|
|
856
|
+
|
|
857
|
+
# Process videos
|
|
858
|
+
process_videos(video_options)
|
|
859
|
+
|
|
860
|
+
print('Video detection results written to {}'.format(video_options.output_json_file))
|
|
861
|
+
files_to_merge.append(video_options.output_json_file)
|
|
862
|
+
|
|
863
|
+
# Merge results if we have both images and videos
|
|
864
|
+
if len(files_to_merge) > 1:
|
|
865
|
+
print('Merging image and video detection results...')
|
|
866
|
+
combine_batch_output_files(files_to_merge, detector_output_file)
|
|
867
|
+
print('Merged detection results written to {}'.format(detector_output_file))
|
|
868
|
+
elif len(files_to_merge) == 1:
|
|
869
|
+
# Just rename the single file
|
|
870
|
+
if files_to_merge[0] != detector_output_file:
|
|
871
|
+
os.rename(files_to_merge[0], detector_output_file)
|
|
872
|
+
print('Detection results written to {}'.format(detector_output_file))
|
|
873
|
+
|
|
874
|
+
# ...def _run_detection_step(...)
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
def _run_classification_step(detector_results_file: str,
|
|
878
|
+
merged_results_file: str,
|
|
879
|
+
source_folder: str,
|
|
880
|
+
classifier_model: str = DEFAULT_CLASSIFIER_MODEL,
|
|
881
|
+
classifier_batch_size: int = DEFAULT_CLASSIFIER_BATCH_SIZE,
|
|
882
|
+
classifier_worker_threads: int = DEFAULT_LOADER_WORKERS,
|
|
883
|
+
detection_confidence_threshold: float = \
|
|
884
|
+
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
|
|
885
|
+
enable_rollup: bool = True,
|
|
886
|
+
country: str = None,
|
|
887
|
+
admin1_region: str = None,
|
|
888
|
+
top_n_scores: int = DEFAULT_TOP_N_SCORES):
|
|
889
|
+
"""
|
|
890
|
+
Run SpeciesNet classification on detections from MegaDetector results.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
detector_results_file (str): path to MegaDetector output .json file
|
|
894
|
+
merged_results_file (str): path to which we should write the merged results
|
|
895
|
+
source_folder (str): source folder for resolving relative paths
|
|
896
|
+
classifier_model (str, optional): classifier model identifier
|
|
897
|
+
classifier_batch_size (int, optional): batch size for classification
|
|
898
|
+
classifier_worker_threads (int, optional): number of worker threads
|
|
899
|
+
detection_confidence_threshold (float, optional): classify detections above this threshold
|
|
900
|
+
enable_rollup (bool, optional): whether to apply taxonomic rollup
|
|
901
|
+
country (str, optional): country code for geofencing
|
|
902
|
+
admin1_region (str, optional): admin1 region (typically a state code) for geofencing
|
|
903
|
+
top_n_scores (int, optional): maximum number of scores to include for each detection
|
|
904
|
+
"""
|
|
905
|
+
|
|
906
|
+
print('Starting SpeciesNet classification step...')
|
|
907
|
+
|
|
908
|
+
# Load MegaDetector results
|
|
909
|
+
with open(detector_results_file, 'r') as f:
|
|
910
|
+
detector_results = json.load(f)
|
|
911
|
+
|
|
912
|
+
print('Classification step loaded detection results for {} images'.format(
|
|
913
|
+
len(detector_results['images'])))
|
|
914
|
+
|
|
915
|
+
images = detector_results['images']
|
|
916
|
+
if len(images) == 0:
|
|
917
|
+
raise ValueError('No images found in detector results')
|
|
918
|
+
|
|
919
|
+
print('Using SpeciesNet classifier: {}'.format(classifier_model))
|
|
920
|
+
|
|
921
|
+
# Set multiprocessing start method to 'spawn' for CUDA compatibility
|
|
922
|
+
original_start_method = multiprocessing.get_start_method()
|
|
923
|
+
if original_start_method != 'spawn':
|
|
924
|
+
multiprocessing.set_start_method('spawn', force=True)
|
|
925
|
+
print('Set multiprocessing start method to spawn (was {})'.format(
|
|
926
|
+
original_start_method))
|
|
927
|
+
|
|
928
|
+
# Set up multiprocessing queues
|
|
929
|
+
max_queue_size = classifier_worker_threads * MAX_QUEUE_SIZE_IMAGES_PER_WORKER
|
|
930
|
+
image_queue = JoinableQueue(max_queue_size)
|
|
931
|
+
batch_queue = Queue()
|
|
932
|
+
results_queue = Queue()
|
|
933
|
+
|
|
934
|
+
# Start producer workers
|
|
935
|
+
producers = []
|
|
936
|
+
for i_worker in range(classifier_worker_threads):
|
|
937
|
+
p = Process(target=_crop_producer_func,
|
|
938
|
+
args=(image_queue, batch_queue, classifier_model,
|
|
939
|
+
detection_confidence_threshold, source_folder, i_worker))
|
|
940
|
+
p.start()
|
|
941
|
+
producers.append(p)
|
|
942
|
+
|
|
943
|
+
# Start consumer worker
|
|
944
|
+
consumer = Process(target=_crop_consumer_func,
|
|
945
|
+
args=(batch_queue, results_queue, classifier_model,
|
|
946
|
+
classifier_batch_size, classifier_worker_threads,
|
|
947
|
+
enable_rollup, country, admin1_region))
|
|
948
|
+
consumer.start()
|
|
949
|
+
|
|
950
|
+
# This will block every time the queue reaches its maximum depth, so for
|
|
951
|
+
# very small jobs, this will not be a useful progress bar.
|
|
952
|
+
with tqdm(total=len(images)) as pbar:
|
|
953
|
+
for image_data in images:
|
|
954
|
+
image_queue.put(image_data)
|
|
955
|
+
pbar.update()
|
|
956
|
+
|
|
957
|
+
# Send sentinel signals to producers
|
|
958
|
+
for _ in range(classifier_worker_threads):
|
|
959
|
+
image_queue.put(None)
|
|
960
|
+
|
|
961
|
+
# Wait for all work to complete
|
|
962
|
+
image_queue.join()
|
|
963
|
+
|
|
964
|
+
print('Finished waiting for input queue')
|
|
965
|
+
|
|
966
|
+
# Wait for results
|
|
967
|
+
classification_results = results_queue.get()
|
|
968
|
+
|
|
969
|
+
# Clean up processes
|
|
970
|
+
for p in producers:
|
|
971
|
+
p.join()
|
|
972
|
+
consumer.join()
|
|
973
|
+
|
|
974
|
+
print('Finished waiting for workers')
|
|
975
|
+
|
|
976
|
+
class CategoryState:
|
|
977
|
+
"""
|
|
978
|
+
Helper class to manage classification category IDs.
|
|
979
|
+
"""
|
|
980
|
+
|
|
981
|
+
def __init__(self):
|
|
982
|
+
|
|
983
|
+
self.next_category_id = 0
|
|
984
|
+
|
|
985
|
+
# Maps common name to string-int IDs
|
|
986
|
+
self.common_name_to_id = {}
|
|
987
|
+
|
|
988
|
+
# Maps string-ints to common names, as per format standard
|
|
989
|
+
self.classification_categories = {}
|
|
990
|
+
|
|
991
|
+
# Maps string-ints to latin taxonomy strings, as per format standard
|
|
992
|
+
self.classification_category_descriptions = {}
|
|
993
|
+
|
|
994
|
+
def _get_category_id(self, class_name):
|
|
995
|
+
"""
|
|
996
|
+
Get an integer-valued category ID for the 7-token string [class_name],
|
|
997
|
+
creating a new one if necessary.
|
|
998
|
+
"""
|
|
999
|
+
|
|
1000
|
+
# E.g.:
|
|
1001
|
+
#
|
|
1002
|
+
# "cb553c4e-42c9-4fe0-9bd0-da2d6ed5bfa1;mammalia;carnivora;canidae;urocyon;littoralis;island fox"
|
|
1003
|
+
tokens = class_name.split(';')
|
|
1004
|
+
assert len(tokens) == 7
|
|
1005
|
+
taxonomy_string = ';'.join(tokens[1:6])
|
|
1006
|
+
common_name = tokens[6]
|
|
1007
|
+
if len(common_name) == 0:
|
|
1008
|
+
common_name = taxonomy_string
|
|
1009
|
+
|
|
1010
|
+
if common_name not in self.common_name_to_id:
|
|
1011
|
+
self.common_name_to_id[common_name] = str(self.next_category_id)
|
|
1012
|
+
self.classification_categories[str(self.next_category_id)] = common_name
|
|
1013
|
+
self.classification_category_descriptions[str(self.next_category_id)] = taxonomy_string
|
|
1014
|
+
self.next_category_id += 1
|
|
1015
|
+
|
|
1016
|
+
category_id = self.common_name_to_id[common_name]
|
|
1017
|
+
|
|
1018
|
+
return category_id
|
|
1019
|
+
|
|
1020
|
+
# ...class CategoryState
|
|
1021
|
+
|
|
1022
|
+
category_state = CategoryState()
|
|
1023
|
+
|
|
1024
|
+
# Merge classification results back into detector results with proper category IDs
|
|
1025
|
+
for image_data in images:
|
|
1026
|
+
|
|
1027
|
+
image_file = image_data['file']
|
|
1028
|
+
|
|
1029
|
+
if ('detections' not in image_data) or (image_data['detections'] is None):
|
|
1030
|
+
continue
|
|
1031
|
+
|
|
1032
|
+
detections = image_data['detections']
|
|
1033
|
+
|
|
1034
|
+
if image_file not in classification_results:
|
|
1035
|
+
continue
|
|
1036
|
+
|
|
1037
|
+
image_classifications = classification_results[image_file]
|
|
1038
|
+
|
|
1039
|
+
for detection_index, detection in enumerate(detections):
|
|
1040
|
+
|
|
1041
|
+
if detection_index in image_classifications:
|
|
1042
|
+
|
|
1043
|
+
result = image_classifications[detection_index]
|
|
1044
|
+
|
|
1045
|
+
if 'failure' in result:
|
|
1046
|
+
# Add failure to the image, not the detection
|
|
1047
|
+
if 'failure' not in image_data:
|
|
1048
|
+
image_data['failure'] = result['failure']
|
|
1049
|
+
else:
|
|
1050
|
+
image_data['failure'] += ';' + result['failure']
|
|
1051
|
+
else:
|
|
1052
|
+
|
|
1053
|
+
# Convert class names to category IDs
|
|
1054
|
+
classification_pairs = []
|
|
1055
|
+
raw_classification_pairs = []
|
|
1056
|
+
|
|
1057
|
+
scores = [x[1] for x in result['classifications']]
|
|
1058
|
+
assert is_list_sorted(scores, reverse=True)
|
|
1059
|
+
|
|
1060
|
+
# Only report the requested number of scores per detection
|
|
1061
|
+
if len(result['classifications']) > top_n_scores:
|
|
1062
|
+
result['classifications'] = \
|
|
1063
|
+
result['classifications'][0:top_n_scores]
|
|
1064
|
+
|
|
1065
|
+
if len(result['raw_classifications']) > top_n_scores:
|
|
1066
|
+
result['raw_classifications'] = \
|
|
1067
|
+
result['raw_classifications'][0:top_n_scores]
|
|
1068
|
+
|
|
1069
|
+
for class_name, score in result['classifications']:
|
|
1070
|
+
|
|
1071
|
+
category_id = category_state._get_category_id(class_name)
|
|
1072
|
+
score = round_float(score, precision=CONF_DIGITS)
|
|
1073
|
+
classification_pairs.append([category_id, score])
|
|
1074
|
+
|
|
1075
|
+
for class_name, score in result['raw_classifications']:
|
|
1076
|
+
|
|
1077
|
+
category_id = category_state._get_category_id(class_name)
|
|
1078
|
+
score = round_float(score, precision=CONF_DIGITS)
|
|
1079
|
+
raw_classification_pairs.append([category_id, score])
|
|
1080
|
+
|
|
1081
|
+
# Add classifications to the detection
|
|
1082
|
+
detection['classifications'] = classification_pairs
|
|
1083
|
+
detection['raw_classifications'] = raw_classification_pairs
|
|
1084
|
+
|
|
1085
|
+
# ...if this classification contains a failure
|
|
1086
|
+
|
|
1087
|
+
# ...if this detection has classification information
|
|
1088
|
+
|
|
1089
|
+
# ...for each detection
|
|
1090
|
+
|
|
1091
|
+
# ...for each image
|
|
1092
|
+
|
|
1093
|
+
# Update metadata in the results
|
|
1094
|
+
if 'info' not in detector_results:
|
|
1095
|
+
detector_results['info'] = {}
|
|
1096
|
+
|
|
1097
|
+
detector_results['info']['classifier'] = classifier_model
|
|
1098
|
+
detector_results['info']['classification_completion_time'] = time.strftime(
|
|
1099
|
+
'%Y-%m-%d %H:%M:%S')
|
|
1100
|
+
|
|
1101
|
+
# Add classification category mapping
|
|
1102
|
+
detector_results['classification_categories'] = \
|
|
1103
|
+
category_state.classification_categories
|
|
1104
|
+
detector_results['classification_category_descriptions'] = \
|
|
1105
|
+
category_state.classification_category_descriptions
|
|
1106
|
+
|
|
1107
|
+
# Write results
|
|
1108
|
+
write_json(merged_results_file, detector_results)
|
|
1109
|
+
|
|
1110
|
+
if verbose:
|
|
1111
|
+
print('Classification results written to {}'.format(merged_results_file))
|
|
1112
|
+
|
|
1113
|
+
# ...def _run_classification_step(...)
|
|
1114
|
+
|
|
1115
|
+
|
|
1116
|
+
#%% Command-line driver
|
|
1117
|
+
|
|
1118
|
+
def main():
|
|
1119
|
+
"""
|
|
1120
|
+
Command-line driver for run_md_and_speciesnet.py
|
|
1121
|
+
"""
|
|
1122
|
+
|
|
1123
|
+
parser = argparse.ArgumentParser(
|
|
1124
|
+
description='Run MegaDetector and SpeciesNet on a folder of images/videos',
|
|
1125
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
1126
|
+
)
|
|
1127
|
+
|
|
1128
|
+
# Required arguments
|
|
1129
|
+
parser.add_argument('source',
|
|
1130
|
+
help='Folder containing images and/or videos to process')
|
|
1131
|
+
parser.add_argument('output_file',
|
|
1132
|
+
help='Output file for results (JSON format)')
|
|
1133
|
+
|
|
1134
|
+
# Optional arguments
|
|
1135
|
+
parser.add_argument('--detector_model',
|
|
1136
|
+
default=DEFAULT_DETECTOR_MODEL,
|
|
1137
|
+
help='MegaDetector model identifier')
|
|
1138
|
+
parser.add_argument('--classification_model',
|
|
1139
|
+
default=DEFAULT_CLASSIFIER_MODEL,
|
|
1140
|
+
help='SpeciesNet classifier model identifier')
|
|
1141
|
+
parser.add_argument('--detector_batch_size',
|
|
1142
|
+
type=int,
|
|
1143
|
+
default=DEFAULT_DETECTOR_BATCH_SIZE,
|
|
1144
|
+
help='Batch size for MegaDetector inference')
|
|
1145
|
+
parser.add_argument('--classifier_batch_size',
|
|
1146
|
+
type=int,
|
|
1147
|
+
default=DEFAULT_CLASSIFIER_BATCH_SIZE,
|
|
1148
|
+
help='Batch size for SpeciesNet classification')
|
|
1149
|
+
parser.add_argument('--loader_workers',
|
|
1150
|
+
type=int,
|
|
1151
|
+
default=DEFAULT_LOADER_WORKERS,
|
|
1152
|
+
help='Number of worker threads for preprocessing')
|
|
1153
|
+
parser.add_argument('--detection_confidence_threshold_for_classification',
|
|
1154
|
+
type=float,
|
|
1155
|
+
default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
|
|
1156
|
+
help='Classifiy detections above this threshold')
|
|
1157
|
+
parser.add_argument('--detection_confidence_threshold_for_output',
|
|
1158
|
+
type=float,
|
|
1159
|
+
default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT,
|
|
1160
|
+
help='Include detections above this threshold in the output')
|
|
1161
|
+
parser.add_argument('--intermediate_file_folder',
|
|
1162
|
+
default=None,
|
|
1163
|
+
help='Folder for intermediate files (default: system temp)')
|
|
1164
|
+
parser.add_argument('--keep_intermediate_files',
|
|
1165
|
+
action='store_true',
|
|
1166
|
+
help='Keep intermediate files for debugging')
|
|
1167
|
+
parser.add_argument('--norollup',
|
|
1168
|
+
action='store_true',
|
|
1169
|
+
help='Disable taxonomic rollup')
|
|
1170
|
+
parser.add_argument('--country',
|
|
1171
|
+
default=None,
|
|
1172
|
+
help='Country code (ISO 3166-1 alpha-3) for geofencing')
|
|
1173
|
+
parser.add_argument('--admin1_region', '--state',
|
|
1174
|
+
default=None,
|
|
1175
|
+
help='Admin1 region/state code for geofencing')
|
|
1176
|
+
parser.add_argument('--detections_file',
|
|
1177
|
+
default=None,
|
|
1178
|
+
help='Path to existing MegaDetector output file (skips detection step)')
|
|
1179
|
+
parser.add_argument('--skip_video',
|
|
1180
|
+
action='store_true',
|
|
1181
|
+
help='Ignore videos, only process images')
|
|
1182
|
+
parser.add_argument('--skip_images',
|
|
1183
|
+
action='store_true',
|
|
1184
|
+
help='Ignore images, only process videos')
|
|
1185
|
+
parser.add_argument('--frame_sample',
|
|
1186
|
+
type=int,
|
|
1187
|
+
default=None,
|
|
1188
|
+
help='Sample every Nth frame from videos (mutually exclusive with --time_sample)')
|
|
1189
|
+
parser.add_argument('--time_sample',
|
|
1190
|
+
type=float,
|
|
1191
|
+
default=None,
|
|
1192
|
+
help='Sample frames every N seconds from videos (default {})'.\
|
|
1193
|
+
format(DEAFULT_SECONDS_PER_VIDEO_FRAME) + \
|
|
1194
|
+
' (mutually exclusive with --frame_sample)')
|
|
1195
|
+
parser.add_argument('--verbose',
|
|
1196
|
+
action='store_true',
|
|
1197
|
+
help='Enable additional debug output')
|
|
1198
|
+
|
|
1199
|
+
if len(sys.argv[1:]) == 0:
|
|
1200
|
+
parser.print_help()
|
|
1201
|
+
parser.exit()
|
|
1202
|
+
|
|
1203
|
+
args = parser.parse_args()
|
|
1204
|
+
|
|
1205
|
+
# Set global verbose flag
|
|
1206
|
+
global verbose
|
|
1207
|
+
verbose = args.verbose
|
|
1208
|
+
|
|
1209
|
+
# Also set the run_detector_batch verbose flag
|
|
1210
|
+
run_detector_batch.verbose = verbose
|
|
1211
|
+
|
|
1212
|
+
# Validate arguments
|
|
1213
|
+
if not os.path.isdir(args.source):
|
|
1214
|
+
raise ValueError(
|
|
1215
|
+
'Source folder does not exist: {}'.format(args.source))
|
|
1216
|
+
|
|
1217
|
+
if args.admin1_region and not args.country:
|
|
1218
|
+
raise ValueError('--admin1_region requires --country to be specified')
|
|
1219
|
+
|
|
1220
|
+
if args.skip_images and args.skip_video:
|
|
1221
|
+
raise ValueError('Cannot skip both images and videos')
|
|
1222
|
+
|
|
1223
|
+
if (args.frame_sample is not None) and (args.time_sample is not None):
|
|
1224
|
+
raise ValueError('--frame_sample and --time_sample are mutually exclusive')
|
|
1225
|
+
if (args.frame_sample is None) and (args.time_sample is None):
|
|
1226
|
+
args.time_sample = DEAFULT_SECONDS_PER_VIDEO_FRAME
|
|
1227
|
+
|
|
1228
|
+
# Set up intermediate file folder
|
|
1229
|
+
if args.intermediate_file_folder:
|
|
1230
|
+
temp_folder = args.intermediate_file_folder
|
|
1231
|
+
os.makedirs(temp_folder, exist_ok=True)
|
|
1232
|
+
else:
|
|
1233
|
+
temp_folder = make_temp_folder(subfolder='run_md_and_speciesnet')
|
|
1234
|
+
|
|
1235
|
+
start_time = time.time()
|
|
1236
|
+
|
|
1237
|
+
print('Processing folder: {}'.format(args.source))
|
|
1238
|
+
print('Output file: {}'.format(args.output_file))
|
|
1239
|
+
print('Intermediate files: {}'.format(temp_folder))
|
|
1240
|
+
|
|
1241
|
+
# Determine detector output file path
|
|
1242
|
+
if args.detections_file:
|
|
1243
|
+
detector_output_file = args.detections_file
|
|
1244
|
+
print('Using existing detections file: {}'.format(detector_output_file))
|
|
1245
|
+
validation_options = ValidateBatchResultsOptions()
|
|
1246
|
+
validation_options.check_image_existence = True
|
|
1247
|
+
validation_options.relative_path_base = args.source
|
|
1248
|
+
validation_options.raise_errors = True
|
|
1249
|
+
validate_batch_results(detector_output_file,options=validation_options)
|
|
1250
|
+
print('Validated detections file')
|
|
1251
|
+
else:
|
|
1252
|
+
detector_output_file = os.path.join(temp_folder, 'detector_output.json')
|
|
1253
|
+
|
|
1254
|
+
# Run MegaDetector
|
|
1255
|
+
_run_detection_step(
|
|
1256
|
+
source_folder=args.source,
|
|
1257
|
+
detector_output_file=detector_output_file,
|
|
1258
|
+
detector_model=args.detector_model,
|
|
1259
|
+
detector_batch_size=args.detector_batch_size,
|
|
1260
|
+
detection_confidence_threshold=args.detection_confidence_threshold_for_output,
|
|
1261
|
+
detector_worker_threads=args.loader_workers,
|
|
1262
|
+
skip_images=args.skip_images,
|
|
1263
|
+
skip_video=args.skip_video,
|
|
1264
|
+
frame_sample=args.frame_sample,
|
|
1265
|
+
time_sample=args.time_sample
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
# Run SpeciesNet
|
|
1269
|
+
_run_classification_step(
|
|
1270
|
+
detector_results_file=detector_output_file,
|
|
1271
|
+
merged_results_file=args.output_file,
|
|
1272
|
+
source_folder=args.source,
|
|
1273
|
+
classifier_model=args.classification_model,
|
|
1274
|
+
classifier_batch_size=args.classifier_batch_size,
|
|
1275
|
+
classifier_worker_threads=args.loader_workers,
|
|
1276
|
+
detection_confidence_threshold=args.detection_confidence_threshold_for_classification,
|
|
1277
|
+
enable_rollup=(not args.norollup),
|
|
1278
|
+
country=args.country,
|
|
1279
|
+
admin1_region=args.admin1_region,
|
|
1280
|
+
)
|
|
1281
|
+
|
|
1282
|
+
elapsed_time = time.time() - start_time
|
|
1283
|
+
print(
|
|
1284
|
+
'Processing complete in {}'.format(humanfriendly.format_timespan(elapsed_time)))
|
|
1285
|
+
print('Results written to: {}'.format(args.output_file))
|
|
1286
|
+
|
|
1287
|
+
# Clean up intermediate files if requested
|
|
1288
|
+
if (not args.keep_intermediate_files) and \
|
|
1289
|
+
(not args.intermediate_file_folder) and \
|
|
1290
|
+
(not args.detections_file):
|
|
1291
|
+
try:
|
|
1292
|
+
os.remove(detector_output_file)
|
|
1293
|
+
except Exception as e:
|
|
1294
|
+
print('Warning: error removing temporary output file {}: {}'.format(
|
|
1295
|
+
detector_output_file, str(e)))
|
|
1296
|
+
|
|
1297
|
+
# ...def main(...)
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
if __name__ == '__main__':
|
|
1301
|
+
main()
|