megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1494 @@
1
+ """
2
+
3
+ run_md_and_speciesnet.py
4
+
5
+ Script to run MegaDetector and SpeciesNet on a folder of images and/or videos.
6
+ Runs MD first, then runs SpeciesNet on every above-threshold crop.
7
+
8
+ """
9
+
10
+ #%% Constants, imports, environment
11
+
12
+ import argparse
13
+ import json
14
+ import multiprocessing
15
+ import os
16
+ import sys
17
+ import time
18
+
19
+ from tqdm import tqdm
20
+ from multiprocessing import JoinableQueue, Process, Queue
21
+
22
+ import humanfriendly
23
+
24
+ from megadetector.detection import run_detector_batch
25
+ from megadetector.detection.video_utils import find_videos, run_callback_on_frames, is_video_file
26
+ from megadetector.detection.run_detector_batch import load_and_run_detector_batch
27
+ from megadetector.detection.run_detector import DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
28
+ from megadetector.detection.run_detector import CONF_DIGITS
29
+ from megadetector.detection.run_detector_batch import write_results_to_file
30
+ from megadetector.utils.ct_utils import round_float
31
+ from megadetector.utils.ct_utils import write_json
32
+ from megadetector.utils.ct_utils import make_temp_folder
33
+ from megadetector.utils.ct_utils import is_list_sorted
34
+ from megadetector.utils.ct_utils import is_sphinx_build
35
+ from megadetector.utils.ct_utils import args_to_object
36
+ from megadetector.utils import path_utils
37
+ from megadetector.visualization import visualization_utils as vis_utils
38
+ from megadetector.postprocessing.validate_batch_results import \
39
+ validate_batch_results, ValidateBatchResultsOptions
40
+ from megadetector.detection.process_video import \
41
+ process_videos, ProcessVideoOptions
42
+ from megadetector.postprocessing.combine_batch_outputs import combine_batch_output_files
43
+
44
+ # We aren't taking an explicit dependency on the speciesnet package yet,
45
+ # so we wrap this in a try/except so sphinx can still document this module.
46
+ try:
47
+ from speciesnet import SpeciesNetClassifier
48
+ from speciesnet.utils import BBox
49
+ from speciesnet.ensemble import SpeciesNetEnsemble
50
+ from speciesnet.geofence_utils import roll_up_labels_to_first_matching_level
51
+ from speciesnet.geofence_utils import geofence_animal_classification
52
+ except Exception:
53
+ pass
54
+
55
+
56
+ #%% Constants
57
+
58
+ DEFAULT_DETECTOR_MODEL = 'MDV5A'
59
+ DEFAULT_CLASSIFIER_MODEL = 'kaggle:google/speciesnet/pyTorch/v4.0.1a'
60
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION = 0.1
61
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
62
+ DEFAULT_DETECTOR_BATCH_SIZE = 1
63
+ DEFAULT_CLASSIFIER_BATCH_SIZE = 8
64
+ DEFAULT_LOADER_WORKERS = 4
65
+
66
+ # This determines the maximum number of image filenames that can be assigned to
67
+ # each of the producer workers before blocking. The actual size of the queue
68
+ # will be MAX_IMAGE_QUEUE_SIZE_PER_WORKER * n_workers. This is only used for
69
+ # the classification step.
70
+ MAX_IMAGE_QUEUE_SIZE_PER_WORKER = 30
71
+
72
+ # This determines the maximum number of crops that can accumulate in the queue
73
+ # used to communicate between the producers (which read and crop images) and the
74
+ # consumer (which runs the classifier). This is only used for the classification step.
75
+ MAX_BATCH_QUEUE_SIZE = 300
76
+
77
+ # Default interval between frames we should process when processing video.
78
+ # This is only used for the detection step.
79
+ DEFAULT_SECONDS_PER_VIDEO_FRAME = 1.0
80
+
81
+ # Max number of classification scores to include per detection
82
+ DEFAULT_TOP_N_SCORES = 2
83
+
84
+ # Unless --norollup is specified, roll up taxonomic levels until the
85
+ # cumulative confidence is above this value
86
+ ROLLUP_TARGET_CONFIDENCE = 0.5
87
+
88
+ # When the called supplies an existing MD results file, should we validate it before
89
+ # starting classification? This tends
90
+ VALIDATE_DETECTION_FILE = False
91
+
92
+
93
+ verbose = False
94
+
95
+
96
+ #%% Support classes
97
+
98
+ class CropMetadata:
99
+ """
100
+ Metadata for a crop extracted from an image detection.
101
+ """
102
+
103
+ def __init__(self,
104
+ image_file: str,
105
+ detection_index: int,
106
+ bbox: list[float],
107
+ original_width: int,
108
+ original_height: int):
109
+ """
110
+ Args:
111
+ image_file (str): path to the original image file
112
+ detection_index (int): index of this detection in the image
113
+ bbox (List[float]): normalized bounding box [x_min, y_min, width, height]
114
+ original_width (int): width of the original image
115
+ original_height (int): height of the original image
116
+ """
117
+
118
+ self.image_file = image_file
119
+ self.detection_index = detection_index
120
+ self.bbox = bbox
121
+ self.original_width = original_width
122
+ self.original_height = original_height
123
+
124
+
125
+ class CropBatch:
126
+ """
127
+ A batch of crops with their metadata for classification.
128
+ """
129
+
130
+ def __init__(self):
131
+ #: List of preprocessed images
132
+ self.crops = []
133
+
134
+ #: List of CropMetadata objects
135
+ self.metadata = []
136
+
137
+ def add_crop(self, crop_data, metadata):
138
+ """
139
+ Args:
140
+ crop_data (PreprocessedImage): preprocessed image data from
141
+ SpeciesNetClassifier.preprocess()
142
+ metadata (CropMetadata): metadata for this crop
143
+ """
144
+
145
+ self.crops.append(crop_data)
146
+ self.metadata.append(metadata)
147
+
148
+ def __len__(self):
149
+ return len(self.crops)
150
+
151
+
152
+ #%% Support functions for classification
153
+
154
+ def _process_image_detections(file_path: str,
155
+ absolute_file_path: str,
156
+ detection_results: dict,
157
+ classifier: 'SpeciesNetClassifier',
158
+ detection_confidence_threshold: float,
159
+ batch_queue: Queue):
160
+ """
161
+ Process detections from a single image.
162
+
163
+ Args:
164
+ file_path (str): relative path to the image file
165
+ absolute_file_path (str): absolute path to the image file
166
+ detection_results (dict): detection results for this image
167
+ classifier (SpeciesNetClassifier): classifier instance for preprocessing
168
+ detection_confidence_threshold (float): classify detections above this threshold
169
+ batch_queue (Queue): queue to send crops to
170
+ """
171
+
172
+ detections = detection_results['detections']
173
+
174
+ # Don't bother loading images that have no above-threshold detections
175
+ detections_above_threshold = \
176
+ [d for d in detections if d['conf'] >= detection_confidence_threshold]
177
+ if len(detections_above_threshold) == 0:
178
+ return
179
+
180
+ # Load the image
181
+ try:
182
+ image = vis_utils.load_image(absolute_file_path)
183
+ original_width, original_height = image.size
184
+ except Exception as e:
185
+ print('Warning: failed to load image {}: {}'.format(file_path, str(e)))
186
+
187
+ # Send failure information to consumer
188
+ failure_metadata = CropMetadata(
189
+ image_file=file_path,
190
+ detection_index=-1, # -1 indicates whole-image failure
191
+ bbox=[],
192
+ original_width=0,
193
+ original_height=0
194
+ )
195
+ batch_queue.put(('failure',
196
+ 'Failed to load image: {}'.format(str(e)),
197
+ failure_metadata))
198
+ return
199
+
200
+ # Process each detection above threshold
201
+ #
202
+ # detection_index needs to index into the original list of detections
203
+ # (this is how classification results will be associated with detections
204
+ # later), so iterate over "detections" here, rather than
205
+ # "detections_above_threshold".
206
+ for detection_index, detection in enumerate(detections):
207
+
208
+ conf = detection['conf']
209
+ if conf < detection_confidence_threshold:
210
+ continue
211
+
212
+ bbox = detection['bbox']
213
+ assert len(bbox) == 4
214
+
215
+ # Convert normalized bbox to BBox object for SpeciesNet
216
+ speciesnet_bbox = BBox(
217
+ xmin=bbox[0],
218
+ ymin=bbox[1],
219
+ width=bbox[2],
220
+ height=bbox[3]
221
+ )
222
+
223
+ # Preprocess the crop
224
+ try:
225
+
226
+ preprocessed_crop = classifier.preprocess(
227
+ image,
228
+ bboxes=[speciesnet_bbox],
229
+ resize=True
230
+ )
231
+
232
+ if preprocessed_crop is not None:
233
+
234
+ metadata = CropMetadata(
235
+ image_file=file_path,
236
+ detection_index=detection_index,
237
+ bbox=bbox,
238
+ original_width=original_width,
239
+ original_height=original_height
240
+ )
241
+
242
+ # Send individual crop to the consumer
243
+ batch_queue.put(('crop', preprocessed_crop, metadata))
244
+
245
+ except Exception as e:
246
+
247
+ print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
248
+ file_path, detection_index, str(e)))
249
+
250
+ # Send failure information to consumer
251
+ failure_metadata = CropMetadata(
252
+ image_file=file_path,
253
+ detection_index=detection_index,
254
+ bbox=bbox,
255
+ original_width=original_width,
256
+ original_height=original_height
257
+ )
258
+ batch_queue.put(('failure',
259
+ 'Failed to preprocess crop: {}'.format(str(e)),
260
+ failure_metadata))
261
+
262
+ # ...try/except
263
+
264
+ # ...for each detection in this image
265
+
266
+ # ...def _process_image_detections(...)
267
+
268
+
269
+ def _process_video_detections(file_path: str,
270
+ absolute_file_path: str,
271
+ detection_results: dict,
272
+ classifier: 'SpeciesNetClassifier',
273
+ detection_confidence_threshold: float,
274
+ batch_queue: Queue):
275
+ """
276
+ Process detections from a single video.
277
+
278
+ Args:
279
+ file_path (str): relative path to the video file
280
+ absolute_file_path (str): absolute path to the video file
281
+ detection_results (dict): detection results for this video
282
+ classifier (SpeciesNetClassifier): classifier instance for preprocessing
283
+ detection_confidence_threshold (float): classify detections above this threshold
284
+ batch_queue (Queue): queue to send crops to
285
+ """
286
+
287
+ detections = detection_results['detections']
288
+
289
+ # Find frames with above-threshold detections
290
+ frames_with_detections = set()
291
+ frame_to_detections = {}
292
+
293
+ for detection_index, detection in enumerate(detections):
294
+
295
+ conf = detection['conf']
296
+ if conf < detection_confidence_threshold:
297
+ continue
298
+
299
+ frame_number = detection['frame_number']
300
+ frames_with_detections.add(frame_number)
301
+
302
+ if frame_number not in frame_to_detections:
303
+ frame_to_detections[frame_number] = []
304
+ frame_to_detections[frame_number].append((detection_index, detection))
305
+
306
+ # ...for each detection in this video
307
+
308
+ if len(frames_with_detections) == 0:
309
+ return
310
+
311
+ frames_to_process = sorted(list(frames_with_detections))
312
+
313
+ # Define callback for processing each frame
314
+ def frame_callback(frame_array, frame_id):
315
+ """
316
+ Callback to process a single frame.
317
+
318
+ Args:
319
+ frame_array (numpy.ndarray): frame data in PIL format
320
+ frame_id (str): frame identifier like "frame0006.jpg"
321
+ """
322
+
323
+ # Extract frame number from frame_id (e.g., "frame0006.jpg" -> 6)
324
+ import re
325
+ match = re.match(r'frame(\d+)\.jpg', frame_id)
326
+ if not match:
327
+ print('Warning: could not parse frame number from {}'.format(frame_id))
328
+ return
329
+ frame_number = int(match.group(1))
330
+
331
+ # Only process frames for which we have detection results
332
+ if frame_number not in frame_to_detections:
333
+ return
334
+
335
+ # Convert numpy array to PIL Image
336
+ from PIL import Image
337
+ if frame_array.dtype != 'uint8':
338
+ frame_array = (frame_array * 255).astype('uint8')
339
+ frame_image = Image.fromarray(frame_array)
340
+ original_width, original_height = frame_image.size
341
+
342
+ # Process each detection in this frame
343
+ for detection_index, detection in frame_to_detections[frame_number]:
344
+
345
+ bbox = detection['bbox']
346
+ assert len(bbox) == 4
347
+
348
+ # Convert normalized bbox to BBox object for SpeciesNet
349
+ speciesnet_bbox = BBox(
350
+ xmin=bbox[0],
351
+ ymin=bbox[1],
352
+ width=bbox[2],
353
+ height=bbox[3]
354
+ )
355
+
356
+ # Preprocess the crop
357
+ try:
358
+
359
+ preprocessed_crop = classifier.preprocess(
360
+ frame_image,
361
+ bboxes=[speciesnet_bbox],
362
+ resize=True
363
+ )
364
+
365
+ if preprocessed_crop is not None:
366
+ metadata = CropMetadata(
367
+ image_file=file_path,
368
+ detection_index=detection_index,
369
+ bbox=bbox,
370
+ original_width=original_width,
371
+ original_height=original_height
372
+ )
373
+
374
+ # Send individual crop immediately to consumer
375
+ batch_queue.put(('crop', preprocessed_crop, metadata))
376
+
377
+ except Exception as e:
378
+
379
+ print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
380
+ file_path, detection_index, str(e)))
381
+
382
+ # Send failure information to consumer
383
+ failure_metadata = CropMetadata(
384
+ image_file=file_path,
385
+ detection_index=detection_index,
386
+ bbox=bbox,
387
+ original_width=original_width,
388
+ original_height=original_height
389
+ )
390
+ batch_queue.put(('failure',
391
+ 'Failed to preprocess crop: {}'.format(str(e)),
392
+ failure_metadata))
393
+
394
+ # ...try/except
395
+
396
+ # ...for each detection
397
+
398
+ # ...def frame_callback(...)
399
+
400
+ # Process the video frames
401
+ try:
402
+
403
+ run_callback_on_frames(
404
+ input_video_file=absolute_file_path,
405
+ frame_callback=frame_callback,
406
+ frames_to_process=frames_to_process,
407
+ verbose=verbose
408
+ )
409
+
410
+ except Exception as e:
411
+
412
+ print('Warning: failed to process video {}: {}'.format(file_path, str(e)))
413
+
414
+ # Send failure information to consumer for the whole video
415
+ failure_metadata = CropMetadata(
416
+ image_file=file_path,
417
+ detection_index=-1, # -1 indicates whole-file failure
418
+ bbox=[],
419
+ original_width=0,
420
+ original_height=0
421
+ )
422
+ batch_queue.put(('failure',
423
+ 'Failed to process video: {}'.format(str(e)),
424
+ failure_metadata))
425
+ # ...try/except
426
+
427
+ # ...def _process_video_detections(...)
428
+
429
+
430
+ def _crop_producer_func(image_queue: JoinableQueue,
431
+ batch_queue: Queue,
432
+ classifier_model: str,
433
+ detection_confidence_threshold: float,
434
+ source_folder: str,
435
+ producer_id: int = -1):
436
+ """
437
+ Producer function for classification workers.
438
+
439
+ Reads images and videos from [image_queue], crops detections above a threshold,
440
+ preprocesses them, and sends individual crops to [batch_queue].
441
+ See the documentation of _crop_consumer_func to for the format of the
442
+ tuples placed on batch_queue.
443
+
444
+ Args:
445
+ image_queue (JoinableQueue): queue containing detection_results dicts (for both images and videos)
446
+ batch_queue (Queue): queue to put individual crops into
447
+ classifier_model (str): classifier model identifier to load in this process
448
+ detection_confidence_threshold (float): classify detections above this threshold
449
+ source_folder (str): source folder to resolve relative paths
450
+ producer_id (int, optional): identifier for this producer worker
451
+ """
452
+
453
+ if verbose:
454
+ print('Classification producer starting: ID {}'.format(producer_id))
455
+
456
+ # Load classifier; this is just being used as a preprocessor, so we force device=cpu.
457
+ #
458
+ # There are a number of reasons loading the model might fail; note to self: *don't*
459
+ # catch Exceptions here. This should be a catastrophic failure that stops the whole
460
+ # process.
461
+ classifier = SpeciesNetClassifier(classifier_model, device='cpu')
462
+ if verbose:
463
+ print('Classification producer {}: loaded classifier'.format(producer_id))
464
+
465
+ while True:
466
+
467
+ # Pull an image of detection results from the queue
468
+ detection_results = image_queue.get()
469
+
470
+ # Pulling None from the queue indicates that this producer is done
471
+ if detection_results is None:
472
+ image_queue.task_done()
473
+ break
474
+
475
+ file_path = detection_results['file']
476
+
477
+ # Skip files that failed at the detection stage
478
+ if 'failure' in detection_results:
479
+ image_queue.task_done()
480
+ continue
481
+
482
+ # Skip files with no detections
483
+ detections = detection_results['detections']
484
+ if len(detections) == 0:
485
+ image_queue.task_done()
486
+ continue
487
+
488
+ # Determine if this is an image or video
489
+ absolute_file_path = os.path.join(source_folder, file_path)
490
+ is_video = is_video_file(file_path)
491
+
492
+ if is_video:
493
+
494
+ # Process video
495
+ _process_video_detections(
496
+ file_path=file_path,
497
+ absolute_file_path=absolute_file_path,
498
+ detection_results=detection_results,
499
+ classifier=classifier,
500
+ detection_confidence_threshold=detection_confidence_threshold,
501
+ batch_queue=batch_queue
502
+ )
503
+
504
+ else:
505
+
506
+ # Process image
507
+ _process_image_detections(
508
+ file_path=file_path,
509
+ absolute_file_path=absolute_file_path,
510
+ detection_results=detection_results,
511
+ classifier=classifier,
512
+ detection_confidence_threshold=detection_confidence_threshold,
513
+ batch_queue=batch_queue
514
+ )
515
+
516
+ image_queue.task_done()
517
+
518
+ # ...while(we still have items to process)
519
+
520
+ # Send sentinel to indicate this producer is done
521
+ batch_queue.put(None)
522
+
523
+ if verbose:
524
+ print('Classification producer {} finished'.format(producer_id))
525
+
526
+ # ...def _crop_producer_func(...)
527
+
528
+
529
+ def _crop_consumer_func(batch_queue: Queue,
530
+ results_queue: Queue,
531
+ classifier_model: str,
532
+ batch_size: int,
533
+ num_producers: int,
534
+ enable_rollup: bool,
535
+ country: str = None,
536
+ admin1_region: str = None):
537
+ """
538
+ Consumer function for classification inference.
539
+
540
+ Pulls individual crops from batch_queue, assembles them into batches,
541
+ runs inference, and puts results into results_queue.
542
+
543
+ Args:
544
+ batch_queue (Queue): queue containing individual crop tuples or failures.
545
+ Items on this queue are either None (to indicate that a producer finished)
546
+ or tuples formatted as (type,image,metadata). [type] is a string (either
547
+ "crop" or "failure"), [image] is a PreprocessedImage, and [metadata] is
548
+ a CropMetadata object.
549
+ results_queue (Queue): queue to put classification results into
550
+ classifier_model (str): classifier model identifier to load
551
+ batch_size (int): batch size for inference
552
+ num_producers (int): number of producer workers
553
+ enable_rollup (bool): whether to apply taxonomic rollup
554
+ country (str, optional): country code for geofencing
555
+ admin1_region (str, optional): admin1 region for geofencing
556
+ """
557
+
558
+ if verbose:
559
+ print('Classification consumer starting')
560
+
561
+ # Load classifier
562
+ try:
563
+ classifier = SpeciesNetClassifier(classifier_model)
564
+ if verbose:
565
+ print('Classification consumer: loaded classifier')
566
+ except Exception as e:
567
+ print('Classification consumer: failed to load classifier: {}'.format(str(e)))
568
+ results_queue.put({})
569
+ return
570
+
571
+ all_results = {} # image_file -> {detection_index -> classification_result}
572
+ current_batch = CropBatch()
573
+ producers_finished = 0
574
+
575
+ # Load ensemble metadata if rollup/geofencing is enabled
576
+ taxonomy_map = {}
577
+ geofence_map = {}
578
+
579
+ if (enable_rollup) or (country is not None):
580
+
581
+ # Note to self: there are a number of reasons loading the ensemble
582
+ # could fail here; don't catch this exception, this should be a
583
+ # catatstrophic failure.
584
+ ensemble = SpeciesNetEnsemble(
585
+ classifier_model, geofence=(country is not None))
586
+ taxonomy_map = ensemble.taxonomy_map
587
+ geofence_map = ensemble.geofence_map
588
+
589
+ # ...if we need to load ensemble components
590
+
591
+ while True:
592
+
593
+ # Pull an item from the queue
594
+ item = batch_queue.get()
595
+
596
+ # This indicates that a producer worker finished
597
+ if item is None:
598
+
599
+ producers_finished += 1
600
+ if producers_finished == num_producers:
601
+ # Process any remaining images
602
+ if len(current_batch) > 0:
603
+ _process_classification_batch(
604
+ current_batch, classifier, all_results,
605
+ enable_rollup, taxonomy_map, geofence_map,
606
+ country, admin1_region
607
+ )
608
+ break
609
+ continue
610
+
611
+ # ...if a producer finished
612
+
613
+ # If we got here, we know we have a crop to process, or
614
+ # a failure to ignore.
615
+ assert isinstance(item, tuple) and len(item) == 3
616
+ item_type, data, metadata = item
617
+
618
+ if metadata.image_file not in all_results:
619
+ all_results[metadata.image_file] = {}
620
+
621
+ # We should never be processing the same detection twice
622
+ assert metadata.detection_index not in all_results[metadata.image_file]
623
+
624
+ if item_type == 'failure':
625
+
626
+ all_results[metadata.image_file][metadata.detection_index] = {
627
+ 'failure': 'Failure classification: {}'.format(data)
628
+ }
629
+
630
+ else:
631
+
632
+ assert item_type == 'crop'
633
+ current_batch.add_crop(data, metadata)
634
+ assert len(current_batch) <= batch_size
635
+
636
+ # Process batch if necessary
637
+ if len(current_batch) == batch_size:
638
+ _process_classification_batch(
639
+ current_batch, classifier, all_results,
640
+ enable_rollup, taxonomy_map, geofence_map,
641
+ country, admin1_region
642
+ )
643
+ current_batch = CropBatch()
644
+
645
+ # ...was this item a failure or a crop?
646
+
647
+ # ...while (we have items to process)
648
+
649
+ # Send all the results at once back to the main process
650
+ results_queue.put(all_results)
651
+
652
+ if verbose:
653
+ print('Classification consumer finished')
654
+
655
+ # ...def _crop_consumer_func(...)
656
+
657
+
658
+ def _process_classification_batch(batch: CropBatch,
659
+ classifier: 'SpeciesNetClassifier',
660
+ all_results: dict,
661
+ enable_rollup: bool,
662
+ taxonomy_map: dict,
663
+ geofence_map: dict,
664
+ country: str = None,
665
+ admin1_region: str = None):
666
+ """
667
+ Run a batch of crops through the classifier.
668
+
669
+ Args:
670
+ batch (CropBatch): batch of crops to process
671
+ classifier (SpeciesNetClassifier): classifier instance
672
+ all_results (dict): dictionary to store results in, modified in-place with format:
673
+ {image_file: {detection_index: {'predictions': [[class_name, score], ...]}
674
+ or {image_file: {detection_index: {'failure': error_message}}}.
675
+ enable_rollup (bool): whether to apply rollup
676
+ taxonomy_map (dict): taxonomy mapping for rollup
677
+ geofence_map (dict): geofence mapping
678
+ country (str, optional): country code for geofencing
679
+ admin1_region (str, optional): admin1 region for geofencing
680
+ """
681
+
682
+ if len(batch) == 0:
683
+ print('Warning: _process_classification_batch received empty batch')
684
+ return
685
+
686
+ # Prepare batch for inference
687
+ filepaths = [f"{metadata.image_file}_{metadata.detection_index}"
688
+ for metadata in batch.metadata]
689
+
690
+ # Run batch inference
691
+ try:
692
+ batch_results = classifier.batch_predict(filepaths, batch.crops)
693
+ except Exception as e:
694
+ print('*** Batch classification failed: {} ***'.format(str(e)))
695
+ # Mark all crops in this batch as failed
696
+ for metadata in batch.metadata:
697
+ if metadata.image_file not in all_results:
698
+ all_results[metadata.image_file] = {}
699
+ all_results[metadata.image_file][metadata.detection_index] = {
700
+ 'failure': 'Failure classification: {}'.format(str(e))
701
+ }
702
+ return
703
+
704
+ # Process results
705
+ assert len(batch_results) == len(batch.metadata)
706
+ assert len(batch_results) == len(filepaths)
707
+
708
+ for i_result in range(0, len(batch_results)):
709
+
710
+ result = batch_results[i_result]
711
+ metadata = batch.metadata[i_result]
712
+
713
+ assert metadata.image_file in all_results, \
714
+ 'File {} not in results dict'.format(metadata.image_file)
715
+
716
+ detection_index = metadata.detection_index
717
+
718
+ # Handle classification failure
719
+ if 'failures' in result:
720
+ print('*** Classification failure for image: {} ***'.format(
721
+ filepaths[i_result]))
722
+ all_results[metadata.image_file][detection_index] = {
723
+ 'failure': 'Failure classification: SpeciesNet classifier failed'
724
+ }
725
+ continue
726
+
727
+ # Extract classification results; this is a dict with keys "classes"
728
+ # and "scores", each of which points to a list.
729
+ classifications = result['classifications']
730
+ classes = classifications['classes']
731
+ scores = classifications['scores']
732
+
733
+ classification_was_geofenced = False
734
+
735
+ predicted_class = classes[0]
736
+ predicted_score = scores[0]
737
+
738
+ # Possibly apply geofencing
739
+ if country:
740
+
741
+ geofence_result = geofence_animal_classification(
742
+ labels=classes,
743
+ scores=scores,
744
+ country=country,
745
+ admin1_region=admin1_region,
746
+ taxonomy_map=taxonomy_map,
747
+ geofence_map=geofence_map,
748
+ enable_geofence=True
749
+ )
750
+
751
+ geofenced_class, geofenced_score, prediction_source = geofence_result
752
+
753
+ if prediction_source != 'classifier':
754
+ classification_was_geofenced = True
755
+ predicted_class = geofenced_class
756
+ predicted_score = geofenced_score
757
+
758
+ # ...if we might need to apply geofencing
759
+
760
+ # Possibly apply rollup; this was already done if geofencing was applied
761
+ if enable_rollup and (not classification_was_geofenced):
762
+
763
+ rollup_result = roll_up_labels_to_first_matching_level(
764
+ labels=classes,
765
+ scores=scores,
766
+ country=country,
767
+ admin1_region=admin1_region,
768
+ target_taxonomy_levels=['species','genus','family', 'order','class', 'kingdom'],
769
+ non_blank_threshold=ROLLUP_TARGET_CONFIDENCE,
770
+ taxonomy_map=taxonomy_map,
771
+ geofence_map=geofence_map,
772
+ enable_geofence=(country is not None)
773
+ )
774
+
775
+ if rollup_result is not None:
776
+ rolled_up_class, rolled_up_score, prediction_source = rollup_result
777
+ if rolled_up_class != predicted_class:
778
+ predicted_class = rolled_up_class
779
+ predicted_score = rolled_up_score
780
+
781
+ # ...if we might need to apply taxonomic rollup
782
+
783
+ # For now, we'll store category names as strings; these will be assigned to integer
784
+ # IDs before writing results to file later.
785
+ classification = [predicted_class,predicted_score]
786
+
787
+ # Also report raw model classifications
788
+ raw_classifications = []
789
+ for i_class in range(0,len(classes)):
790
+ raw_classifications.append([classes[i_class],scores[i_class]])
791
+
792
+ all_results[metadata.image_file][detection_index] = {
793
+ 'classifications': [classification],
794
+ 'raw_classifications': raw_classifications
795
+ }
796
+
797
+ # ...for each result in this batch
798
+
799
+ # ...def _process_classification_batch(...)
800
+
801
+
802
+ #%% Inference functions
803
+
804
+ def _run_detection_step(source_folder: str,
805
+ detector_output_file: str,
806
+ detector_model: str = DEFAULT_DETECTOR_MODEL,
807
+ detector_batch_size: int = DEFAULT_DETECTOR_BATCH_SIZE,
808
+ detection_confidence_threshold: float = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
809
+ detector_worker_threads: int = DEFAULT_LOADER_WORKERS,
810
+ skip_images: bool = False,
811
+ skip_video: bool = False,
812
+ frame_sample: int = None,
813
+ time_sample: float = None) -> str:
814
+ """
815
+ Run MegaDetector on all images/videos in [source_folder].
816
+
817
+ Args:
818
+ source_folder (str): folder containing images/videos
819
+ detector_output_file (str): output .json file
820
+ detector_model (str, optional): detector model identifier
821
+ detector_batch_size (int, optional): batch size for detection
822
+ detection_confidence_threshold (float, optional): confidence threshold for detections
823
+ (to include in the output file)
824
+ detector_worker_threads (int, optional): number of workers to use for preprocessing
825
+ skip_images (bool, optional): ignore images, only process videos
826
+ skip_video (bool, optional): ignore videos, only process images
827
+ frame_sample (int, optional): sample every Nth frame from videos
828
+ time_sample (float, optional): sample frames every N seconds from videos
829
+ """
830
+
831
+ print('Starting detection step...')
832
+
833
+ # Validate arguments
834
+ assert not (frame_sample is None and time_sample is None), \
835
+ 'Must specify either frame_sample or time_sample'
836
+
837
+ # Find image and video files
838
+ if not skip_images:
839
+ image_files = path_utils.find_images(source_folder, recursive=True,
840
+ return_relative_paths=False)
841
+ else:
842
+ image_files = []
843
+
844
+ if not skip_video:
845
+ video_files = find_videos(source_folder, recursive=True,
846
+ return_relative_paths=False)
847
+ else:
848
+ video_files = []
849
+
850
+ if len(image_files) == 0 and len(video_files) == 0:
851
+ raise ValueError(
852
+ 'No images or videos found in {}'.format(source_folder))
853
+
854
+ print('Found {} images and {} videos'.format(len(image_files), len(video_files)))
855
+
856
+ files_to_merge = []
857
+
858
+ # Process images if necessary
859
+ if len(image_files) > 0:
860
+
861
+ print('Running MegaDetector on {} images...'.format(len(image_files)))
862
+
863
+ image_results = load_and_run_detector_batch(
864
+ model_file=detector_model,
865
+ image_file_names=image_files,
866
+ checkpoint_path=None,
867
+ confidence_threshold=detection_confidence_threshold,
868
+ checkpoint_frequency=-1,
869
+ results=None,
870
+ n_cores=0,
871
+ use_image_queue=True,
872
+ quiet=True,
873
+ image_size=None,
874
+ batch_size=detector_batch_size,
875
+ include_image_size=False,
876
+ include_image_timestamp=False,
877
+ include_exif_tags=None,
878
+ loader_workers=detector_worker_threads,
879
+ preprocess_on_image_queue=True
880
+ )
881
+
882
+ # Write image results to temporary file
883
+ image_output_file = detector_output_file.replace('.json', '_images.json')
884
+ write_results_to_file(image_results,
885
+ image_output_file,
886
+ relative_path_base=source_folder,
887
+ detector_file=detector_model)
888
+
889
+ print('Image detection results written to {}'.format(image_output_file))
890
+ files_to_merge.append(image_output_file)
891
+
892
+ # ...if we had images to process
893
+
894
+ # Process videos if necessary
895
+ if len(video_files) > 0:
896
+
897
+ print('Running MegaDetector on {} videos...'.format(len(video_files)))
898
+
899
+ # Set up video processing options
900
+ video_options = ProcessVideoOptions()
901
+ video_options.model_file = detector_model
902
+ video_options.input_video_file = source_folder
903
+ video_options.output_json_file = detector_output_file.replace('.json', '_videos.json')
904
+ video_options.json_confidence_threshold = detection_confidence_threshold
905
+ video_options.frame_sample = frame_sample
906
+ video_options.time_sample = time_sample
907
+ video_options.recursive = True
908
+
909
+ # Process videos
910
+ process_videos(video_options)
911
+
912
+ print('Video detection results written to {}'.format(video_options.output_json_file))
913
+ files_to_merge.append(video_options.output_json_file)
914
+
915
+ # ...if we had videos to process
916
+
917
+ # Merge results if we have both images and videos
918
+ if len(files_to_merge) > 1:
919
+ print('Merging image and video detection results...')
920
+ combine_batch_output_files(files_to_merge, detector_output_file)
921
+ print('Merged detection results written to {}'.format(detector_output_file))
922
+ elif len(files_to_merge) == 1:
923
+ # Just rename the single file
924
+ if files_to_merge[0] != detector_output_file:
925
+ if os.path.isfile(detector_output_file):
926
+ print('Detector file {} exists, over-writing'.format(detector_output_file))
927
+ os.remove(detector_output_file)
928
+ os.rename(files_to_merge[0], detector_output_file)
929
+ print('Detection results written to {}'.format(detector_output_file))
930
+
931
+ # ...def _run_detection_step(...)
932
+
933
+
934
+ def _run_classification_step(detector_results_file: str,
935
+ merged_results_file: str,
936
+ source_folder: str,
937
+ classifier_model: str = DEFAULT_CLASSIFIER_MODEL,
938
+ classifier_batch_size: int = DEFAULT_CLASSIFIER_BATCH_SIZE,
939
+ classifier_worker_threads: int = DEFAULT_LOADER_WORKERS,
940
+ detection_confidence_threshold: float = \
941
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
942
+ enable_rollup: bool = True,
943
+ country: str = None,
944
+ admin1_region: str = None,
945
+ top_n_scores: int = DEFAULT_TOP_N_SCORES):
946
+ """
947
+ Run SpeciesNet classification on detections from MegaDetector results.
948
+
949
+ Args:
950
+ detector_results_file (str): path to MegaDetector output .json file
951
+ merged_results_file (str): path to which we should write the merged results
952
+ source_folder (str): source folder for resolving relative paths
953
+ classifier_model (str, optional): classifier model identifier
954
+ classifier_batch_size (int, optional): batch size for classification
955
+ classifier_worker_threads (int, optional): number of worker threads
956
+ detection_confidence_threshold (float, optional): classify detections above this threshold
957
+ enable_rollup (bool, optional): whether to apply taxonomic rollup
958
+ country (str, optional): country code for geofencing
959
+ admin1_region (str, optional): admin1 region (typically a state code) for geofencing
960
+ top_n_scores (int, optional): maximum number of scores to include for each detection
961
+ """
962
+
963
+ print('Starting classification step...')
964
+
965
+ # Load MegaDetector results
966
+ print('Reading detection results from {}'.format(detector_results_file))
967
+
968
+ with open(detector_results_file, 'r') as f:
969
+ detector_results = json.load(f)
970
+
971
+ print('Classification step loaded detection results for {} images'.format(
972
+ len(detector_results['images'])))
973
+
974
+ images = detector_results['images']
975
+ if len(images) == 0:
976
+ raise ValueError('No images found in detector results')
977
+
978
+ print('Using SpeciesNet classifier: {}'.format(classifier_model))
979
+
980
+ # Set multiprocessing start method to 'spawn' for CUDA compatibility
981
+ original_start_method = multiprocessing.get_start_method()
982
+ if original_start_method != 'spawn':
983
+ multiprocessing.set_start_method('spawn', force=True)
984
+ print('Set multiprocessing start method to spawn (was {})'.format(
985
+ original_start_method))
986
+
987
+ ## Set up multiprocessing queues
988
+
989
+ # This queue receives lists of image filenames (and associated detection results)
990
+ # from the "main" thread (the one you're reading right now). Items are pulled off
991
+ # of this queue by producer workers (on _crop_producer_func), where the corresponding
992
+ # images are loaded from disk and preprocessed into crops.
993
+ image_queue = JoinableQueue(maxsize= \
994
+ classifier_worker_threads * MAX_IMAGE_QUEUE_SIZE_PER_WORKER)
995
+
996
+ # This queue receives cropped images from producers (on _crop_producer_func); those
997
+ # crops are pulled off of this queue by the consumer (on _crop_consumer_func).
998
+ batch_queue = Queue(maxsize=MAX_BATCH_QUEUE_SIZE)
999
+
1000
+ # This is not really used as a queue, rather it's just used to send all the results
1001
+ # at once from the consumer process to the main process (the one you're reading right
1002
+ # now).
1003
+ results_queue = Queue()
1004
+
1005
+ # Start producer workers
1006
+ producers = []
1007
+ for i_worker in range(classifier_worker_threads):
1008
+ p = Process(target=_crop_producer_func,
1009
+ args=(image_queue, batch_queue, classifier_model,
1010
+ detection_confidence_threshold, source_folder, i_worker))
1011
+ p.start()
1012
+ producers.append(p)
1013
+
1014
+
1015
+ ## Start consumer worker
1016
+
1017
+ consumer = Process(target=_crop_consumer_func,
1018
+ args=(batch_queue, results_queue, classifier_model,
1019
+ classifier_batch_size, classifier_worker_threads,
1020
+ enable_rollup, country, admin1_region))
1021
+ consumer.start()
1022
+
1023
+ # This will block every time the queue reaches its maximum depth, so for
1024
+ # very small jobs, this will not be a useful progress bar.
1025
+ with tqdm(total=len(images),desc='Classification') as pbar:
1026
+ for image_data in images:
1027
+ image_queue.put(image_data)
1028
+ pbar.update()
1029
+
1030
+ # Send sentinel signals to producers
1031
+ for _ in range(classifier_worker_threads):
1032
+ image_queue.put(None)
1033
+
1034
+ # Wait for all work to complete
1035
+ image_queue.join()
1036
+
1037
+ print('Finished waiting for input queue')
1038
+
1039
+
1040
+ ## Wait for results
1041
+
1042
+ classification_results = results_queue.get()
1043
+
1044
+
1045
+ ## Clean up processes
1046
+
1047
+ for p in producers:
1048
+ p.join()
1049
+ consumer.join()
1050
+
1051
+ print('Finished waiting for workers')
1052
+
1053
+
1054
+ ## Format results and write output
1055
+
1056
+ class CategoryState:
1057
+ """
1058
+ Helper class to manage classification category IDs.
1059
+ """
1060
+
1061
+ def __init__(self):
1062
+
1063
+ self.next_category_id = 0
1064
+
1065
+ # Maps common name to string-int IDs
1066
+ self.common_name_to_id = {}
1067
+
1068
+ # Maps string-ints to common names, as per format standard
1069
+ self.classification_categories = {}
1070
+
1071
+ # Maps string-ints to latin taxonomy strings, as per format standard
1072
+ self.classification_category_descriptions = {}
1073
+
1074
+ def _get_category_id(self, class_name):
1075
+ """
1076
+ Get an integer-valued category ID for the 7-token string [class_name],
1077
+ creating a new one if necessary.
1078
+ """
1079
+
1080
+ # E.g.:
1081
+ #
1082
+ # "cb553c4e-42c9-4fe0-9bd0-da2d6ed5bfa1;mammalia;carnivora;canidae;urocyon;littoralis;island fox"
1083
+ tokens = class_name.split(';')
1084
+ assert len(tokens) == 7
1085
+ taxonomy_string = ';'.join(tokens[1:6])
1086
+ common_name = tokens[6]
1087
+ if len(common_name) == 0:
1088
+ common_name = taxonomy_string
1089
+
1090
+ if common_name not in self.common_name_to_id:
1091
+ self.common_name_to_id[common_name] = str(self.next_category_id)
1092
+ self.classification_categories[str(self.next_category_id)] = common_name
1093
+ # Store the full seven-token string, rather than the shortened five-token string, for
1094
+ # compatibility with what is expected by the classification_postprocessing module.
1095
+ # self.classification_category_descriptions[str(self.next_category_id)] = taxonomy_string
1096
+ self.classification_category_descriptions[str(self.next_category_id)] = class_name
1097
+ self.next_category_id += 1
1098
+
1099
+ category_id = self.common_name_to_id[common_name]
1100
+
1101
+ return category_id
1102
+
1103
+ # ...class CategoryState
1104
+
1105
+ category_state = CategoryState()
1106
+
1107
+ # Merge classification results back into detector results with proper category IDs
1108
+ for image_data in images:
1109
+
1110
+ image_file = image_data['file']
1111
+
1112
+ if ('detections' not in image_data) or (image_data['detections'] is None):
1113
+ continue
1114
+
1115
+ detections = image_data['detections']
1116
+
1117
+ if image_file not in classification_results:
1118
+ continue
1119
+
1120
+ image_classifications = classification_results[image_file]
1121
+
1122
+ for detection_index, detection in enumerate(detections):
1123
+
1124
+ if detection_index in image_classifications:
1125
+
1126
+ result = image_classifications[detection_index]
1127
+
1128
+ if 'failure' in result:
1129
+ # Add failure to the image, not the detection
1130
+ if 'failure' not in image_data:
1131
+ image_data['failure'] = result['failure']
1132
+ else:
1133
+ image_data['failure'] += ';' + result['failure']
1134
+ else:
1135
+
1136
+ # Convert class names to category IDs
1137
+ classification_pairs = []
1138
+ raw_classification_pairs = []
1139
+
1140
+ scores = [x[1] for x in result['classifications']]
1141
+ assert is_list_sorted(scores, reverse=True)
1142
+
1143
+ # Only report the requested number of scores per detection
1144
+ if len(result['classifications']) > top_n_scores:
1145
+ result['classifications'] = \
1146
+ result['classifications'][0:top_n_scores]
1147
+
1148
+ if len(result['raw_classifications']) > top_n_scores:
1149
+ result['raw_classifications'] = \
1150
+ result['raw_classifications'][0:top_n_scores]
1151
+
1152
+ for class_name, score in result['classifications']:
1153
+
1154
+ category_id = category_state._get_category_id(class_name)
1155
+ score = round_float(score, precision=CONF_DIGITS)
1156
+ classification_pairs.append([category_id, score])
1157
+
1158
+ for class_name, score in result['raw_classifications']:
1159
+
1160
+ category_id = category_state._get_category_id(class_name)
1161
+ score = round_float(score, precision=CONF_DIGITS)
1162
+ raw_classification_pairs.append([category_id, score])
1163
+
1164
+ # Add classifications to the detection
1165
+ detection['classifications'] = classification_pairs
1166
+ # detection['raw_classifications'] = raw_classification_pairs
1167
+
1168
+ # ...if this classification contains a failure
1169
+
1170
+ # ...if this detection has classification information
1171
+
1172
+ # ...for each detection
1173
+
1174
+ # ...for each image
1175
+
1176
+ # Update metadata in the results
1177
+ if 'info' not in detector_results:
1178
+ detector_results['info'] = {}
1179
+
1180
+ detector_results['info']['classifier'] = classifier_model
1181
+ detector_results['info']['classification_completion_time'] = time.strftime(
1182
+ '%Y-%m-%d %H:%M:%S')
1183
+
1184
+ # Add classification category mapping
1185
+ detector_results['classification_categories'] = \
1186
+ category_state.classification_categories
1187
+ detector_results['classification_category_descriptions'] = \
1188
+ category_state.classification_category_descriptions
1189
+
1190
+ print('Writing output file')
1191
+
1192
+ # Write results
1193
+ write_json(merged_results_file, detector_results)
1194
+
1195
+ if verbose:
1196
+ print('Classification results written to {}'.format(merged_results_file))
1197
+
1198
+ # ...def _run_classification_step(...)
1199
+
1200
+
1201
+ #%% Options class
1202
+
1203
+ class RunMDSpeciesNetOptions:
1204
+ """
1205
+ Class controlling the behavior of run_md_and_speciesnet()
1206
+ """
1207
+
1208
+ def __init__(self):
1209
+
1210
+ #: Folder containing images and/or videos to process
1211
+ self.source = None
1212
+
1213
+ #: Output file for results (JSON format)
1214
+ self.output_file = None
1215
+
1216
+ #: MegaDetector model identifier (MDv5a, MDv5b, MDv1000-redwood, etc.)
1217
+ self.detector_model = DEFAULT_DETECTOR_MODEL
1218
+
1219
+ #: SpeciesNet classifier model identifier (e.g. kaggle:google/speciesnet/pyTorch/v4.0.1a)
1220
+ self.classification_model = DEFAULT_CLASSIFIER_MODEL
1221
+
1222
+ #: Batch size for MegaDetector inference
1223
+ self.detector_batch_size = DEFAULT_DETECTOR_BATCH_SIZE
1224
+
1225
+ #: Batch size for SpeciesNet classification
1226
+ self.classifier_batch_size = DEFAULT_CLASSIFIER_BATCH_SIZE
1227
+
1228
+ #: Number of worker threads for preprocessing
1229
+ self.loader_workers = DEFAULT_LOADER_WORKERS
1230
+
1231
+ #: Classify detections above this threshold
1232
+ self.detection_confidence_threshold_for_classification = \
1233
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION
1234
+
1235
+ #: Include detections above this threshold in the output
1236
+ self.detection_confidence_threshold_for_output = \
1237
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT
1238
+
1239
+ #: Folder for intermediate files (default: system temp)
1240
+ self.intermediate_file_folder = None
1241
+
1242
+ #: Keep intermediate files (e.g. detection-only results file)
1243
+ self.keep_intermediate_files = False
1244
+
1245
+ #: Disable taxonomic rollup
1246
+ self.norollup = False
1247
+
1248
+ #: Country code (ISO 3166-1 alpha-3) for geofencing (default None, no geoferencing)
1249
+ self.country = None
1250
+
1251
+ #: Admin1 region/state code for geofencing
1252
+ self.admin1_region = None
1253
+
1254
+ #: Path to existing MegaDetector output file (skips detection step)
1255
+ self.detections_file = None
1256
+
1257
+ #: Ignore videos, only process images
1258
+ self.skip_video = False
1259
+
1260
+ #: Ignore images, only process videos
1261
+ self.skip_images = False
1262
+
1263
+ #: Sample every Nth frame from videos
1264
+ #:
1265
+ #: Mutually exclusive with time_sample
1266
+ self.frame_sample = None
1267
+
1268
+ #: Sample frames every N seconds from videos
1269
+ #:
1270
+ #: Mutually exclusive with frame_sample
1271
+ self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
1272
+
1273
+ #: Enable additional debug output
1274
+ self.verbose = False
1275
+
1276
+ if self.time_sample is None and self.frame_sample is None:
1277
+ self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
1278
+
1279
+ # ...class RunMDSpeciesNetOptions
1280
+
1281
+
1282
+ #%% Main function
1283
+
1284
+ def run_md_and_speciesnet(options):
1285
+ """
1286
+ Main entry point, runs MegaDetector and SpeciesNet on a folder. See
1287
+ RunMDSpeciesNetOptions for available arguments.
1288
+
1289
+ Args:
1290
+ options (RunMDSpeciesNetOptions): options controlling MD and SN inference
1291
+ """
1292
+
1293
+ # Set global verbose flag
1294
+ global verbose
1295
+ verbose = options.verbose
1296
+
1297
+ # Also set the run_detector_batch verbose flag
1298
+ run_detector_batch.verbose = verbose
1299
+
1300
+ # Validate arguments
1301
+ if not os.path.isdir(options.source):
1302
+ raise ValueError(
1303
+ 'Source folder does not exist: {}'.format(options.source))
1304
+
1305
+ if (options.admin1_region is not None) and (options.country is None):
1306
+ raise ValueError('--admin1_region requires --country to be specified')
1307
+
1308
+ if options.skip_images and options.skip_video:
1309
+ raise ValueError('Cannot skip both images and videos')
1310
+
1311
+ if (options.frame_sample is not None) and (options.time_sample is not None):
1312
+ raise ValueError('--frame_sample and --time_sample are mutually exclusive')
1313
+ if (options.frame_sample is None) and (options.time_sample is None):
1314
+ options.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
1315
+
1316
+ # Set up intermediate file folder
1317
+ if options.intermediate_file_folder:
1318
+ temp_folder = options.intermediate_file_folder
1319
+ os.makedirs(temp_folder, exist_ok=True)
1320
+ else:
1321
+ temp_folder = make_temp_folder(subfolder='run_md_and_speciesnet')
1322
+
1323
+ start_time = time.time()
1324
+
1325
+ print('Processing folder: {}'.format(options.source))
1326
+ print('Output file: {}'.format(options.output_file))
1327
+ print('Intermediate files: {}'.format(temp_folder))
1328
+
1329
+ # Determine detector output file path
1330
+ if options.detections_file is not None:
1331
+ detector_output_file = options.detections_file
1332
+ if VALIDATE_DETECTION_FILE:
1333
+ print('Using existing detections file: {}'.format(detector_output_file))
1334
+ validation_options = ValidateBatchResultsOptions()
1335
+ validation_options.check_image_existence = True
1336
+ validation_options.relative_path_base = options.source
1337
+ validation_options.raise_errors = True
1338
+ validate_batch_results(detector_output_file,options=validation_options)
1339
+ print('Validated detections file')
1340
+ else:
1341
+ print('Bypassing validation of {}'.format(options.detections_file))
1342
+ else:
1343
+ detector_output_file = os.path.join(temp_folder, 'detector_output.json')
1344
+
1345
+ # Run MegaDetector
1346
+ _run_detection_step(
1347
+ source_folder=options.source,
1348
+ detector_output_file=detector_output_file,
1349
+ detector_model=options.detector_model,
1350
+ detector_batch_size=options.detector_batch_size,
1351
+ detection_confidence_threshold=options.detection_confidence_threshold_for_output,
1352
+ detector_worker_threads=options.loader_workers,
1353
+ skip_images=options.skip_images,
1354
+ skip_video=options.skip_video,
1355
+ frame_sample=options.frame_sample,
1356
+ time_sample=options.time_sample
1357
+ )
1358
+
1359
+ # Run SpeciesNet
1360
+ _run_classification_step(
1361
+ detector_results_file=detector_output_file,
1362
+ merged_results_file=options.output_file,
1363
+ source_folder=options.source,
1364
+ classifier_model=options.classification_model,
1365
+ classifier_batch_size=options.classifier_batch_size,
1366
+ classifier_worker_threads=options.loader_workers,
1367
+ detection_confidence_threshold=options.detection_confidence_threshold_for_classification,
1368
+ enable_rollup=(not options.norollup),
1369
+ country=options.country,
1370
+ admin1_region=options.admin1_region,
1371
+ )
1372
+
1373
+ elapsed_time = time.time() - start_time
1374
+ print(
1375
+ 'Processing complete in {}'.format(humanfriendly.format_timespan(elapsed_time)))
1376
+ print('Results written to: {}'.format(options.output_file))
1377
+
1378
+ # Clean up intermediate files if requested
1379
+ if (not options.keep_intermediate_files) and \
1380
+ (not options.intermediate_file_folder) and \
1381
+ (not options.detections_file):
1382
+ try:
1383
+ os.remove(detector_output_file)
1384
+ except Exception as e:
1385
+ print('Warning: error removing temporary output file {}: {}'.format(
1386
+ detector_output_file, str(e)))
1387
+
1388
+ # ...def run_md_and_speciesnet(...)
1389
+
1390
+
1391
+ #%% Command-line driver
1392
+
1393
+ def main():
1394
+ """
1395
+ Command-line driver for run_md_and_speciesnet.py
1396
+ """
1397
+
1398
+ if 'speciesnet' not in sys.modules:
1399
+ print('It looks like the speciesnet package is not available, try "pip install speciesnet"')
1400
+ if not is_sphinx_build():
1401
+ sys.exit(-1)
1402
+
1403
+ parser = argparse.ArgumentParser(
1404
+ description='Run MegaDetector and SpeciesNet on a folder of images/videos',
1405
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
1406
+ )
1407
+
1408
+ # Required arguments
1409
+ parser.add_argument('source',
1410
+ help='Folder containing images and/or videos to process')
1411
+ parser.add_argument('output_file',
1412
+ help='Output file for results (JSON format)')
1413
+
1414
+ # Optional arguments
1415
+ parser.add_argument('--detector_model',
1416
+ default=DEFAULT_DETECTOR_MODEL,
1417
+ help='MegaDetector model identifier')
1418
+ parser.add_argument('--classification_model',
1419
+ default=DEFAULT_CLASSIFIER_MODEL,
1420
+ help='SpeciesNet classifier model identifier')
1421
+ parser.add_argument('--detector_batch_size',
1422
+ type=int,
1423
+ default=DEFAULT_DETECTOR_BATCH_SIZE,
1424
+ help='Batch size for MegaDetector inference')
1425
+ parser.add_argument('--classifier_batch_size',
1426
+ type=int,
1427
+ default=DEFAULT_CLASSIFIER_BATCH_SIZE,
1428
+ help='Batch size for SpeciesNet classification')
1429
+ parser.add_argument('--loader_workers',
1430
+ type=int,
1431
+ default=DEFAULT_LOADER_WORKERS,
1432
+ help='Number of worker threads for preprocessing')
1433
+ parser.add_argument('--detection_confidence_threshold_for_classification',
1434
+ type=float,
1435
+ default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
1436
+ help='Classify detections above this threshold')
1437
+ parser.add_argument('--detection_confidence_threshold_for_output',
1438
+ type=float,
1439
+ default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT,
1440
+ help='Include detections above this threshold in the output')
1441
+ parser.add_argument('--intermediate_file_folder',
1442
+ default=None,
1443
+ help='Folder for intermediate files (default: system temp)')
1444
+ parser.add_argument('--keep_intermediate_files',
1445
+ action='store_true',
1446
+ help='Keep intermediate files (e.g. detection-only results file)')
1447
+ parser.add_argument('--norollup',
1448
+ action='store_true',
1449
+ help='Disable taxonomic rollup')
1450
+ parser.add_argument('--country',
1451
+ default=None,
1452
+ help='Country code (ISO 3166-1 alpha-3) for geofencing')
1453
+ parser.add_argument('--admin1_region', '--state',
1454
+ default=None,
1455
+ help='Admin1 region/state code for geofencing')
1456
+ parser.add_argument('--detections_file',
1457
+ default=None,
1458
+ help='Path to existing MegaDetector output file (skips detection step)')
1459
+ parser.add_argument('--skip_video',
1460
+ action='store_true',
1461
+ help='Ignore videos, only process images')
1462
+ parser.add_argument('--skip_images',
1463
+ action='store_true',
1464
+ help='Ignore images, only process videos')
1465
+ parser.add_argument('--frame_sample',
1466
+ type=int,
1467
+ default=None,
1468
+ help='Sample every Nth frame from videos (mutually exclusive with --time_sample)')
1469
+ parser.add_argument('--time_sample',
1470
+ type=float,
1471
+ default=None,
1472
+ help='Sample frames every N seconds from videos (default {})'.\
1473
+ format(DEFAULT_SECONDS_PER_VIDEO_FRAME) + \
1474
+ ' (mutually exclusive with --frame_sample)')
1475
+ parser.add_argument('--verbose',
1476
+ action='store_true',
1477
+ help='Enable additional debug output')
1478
+
1479
+ if len(sys.argv[1:]) == 0:
1480
+ parser.print_help()
1481
+ parser.exit()
1482
+
1483
+ args = parser.parse_args()
1484
+
1485
+ options = RunMDSpeciesNetOptions()
1486
+ args_to_object(args,options)
1487
+
1488
+ run_md_and_speciesnet(options)
1489
+
1490
+ # ...def main(...)
1491
+
1492
+
1493
+ if __name__ == '__main__':
1494
+ main()