megadetector 10.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +701 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +563 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +192 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +665 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +984 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2172 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1604 -0
  81. megadetector/detection/run_tiled_inference.py +1044 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1943 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2140 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +231 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2872 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1766 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1973 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +498 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.15.dist-info/METADATA +115 -0
  144. megadetector-10.0.15.dist-info/RECORD +147 -0
  145. megadetector-10.0.15.dist-info/WHEEL +5 -0
  146. megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1604 @@
1
+ """
2
+
3
+ run_md_and_speciesnet.py
4
+
5
+ Script to run MegaDetector and SpeciesNet on a folder of images and/or videos.
6
+ Runs MD first, then runs SpeciesNet on every above-threshold crop.
7
+
8
+ """
9
+
10
+ #%% Constants, imports, environment
11
+
12
+ import argparse
13
+ import json
14
+ import multiprocessing
15
+ import os
16
+ import sys
17
+ import time
18
+
19
+ from tqdm import tqdm
20
+ from multiprocessing import JoinableQueue, Process, Queue
21
+ from threading import Thread
22
+
23
+ import humanfriendly
24
+
25
+ from megadetector.detection import run_detector_batch
26
+ from megadetector.detection.video_utils import find_videos, run_callback_on_frames, is_video_file
27
+ from megadetector.detection.run_detector_batch import load_and_run_detector_batch
28
+ from megadetector.detection.run_detector import DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
29
+ from megadetector.detection.run_detector import CONF_DIGITS
30
+ from megadetector.detection.run_detector_batch import write_results_to_file
31
+ from megadetector.utils.ct_utils import round_float
32
+ from megadetector.utils.ct_utils import write_json
33
+ from megadetector.utils.ct_utils import make_temp_folder
34
+ from megadetector.utils.ct_utils import is_list_sorted
35
+ from megadetector.utils.ct_utils import is_sphinx_build
36
+ from megadetector.utils.ct_utils import args_to_object
37
+ from megadetector.utils.path_utils import find_images
38
+ from megadetector.utils.path_utils import test_file_write
39
+ from megadetector.visualization import visualization_utils as vis_utils
40
+ from megadetector.postprocessing.validate_batch_results import \
41
+ validate_batch_results, ValidateBatchResultsOptions
42
+ from megadetector.detection.process_video import \
43
+ process_videos, ProcessVideoOptions
44
+ from megadetector.postprocessing.combine_batch_outputs import combine_batch_output_files
45
+
46
+ # We aren't taking an explicit dependency on the speciesnet package yet,
47
+ # so we wrap this in a try/except so sphinx can still document this module.
48
+ try:
49
+ from speciesnet import SpeciesNetClassifier
50
+ from speciesnet.utils import BBox
51
+ from speciesnet.ensemble import SpeciesNetEnsemble
52
+ from speciesnet.geofence_utils import roll_up_labels_to_first_matching_level
53
+ from speciesnet.geofence_utils import geofence_animal_classification
54
+ except Exception:
55
+ pass
56
+
57
+
58
+ #%% Constants
59
+
60
+ DEFAULT_DETECTOR_MODEL = 'MDV5A'
61
+ DEFAULT_CLASSIFIER_MODEL = 'kaggle:google/speciesnet/pyTorch/v4.0.1a'
62
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION = 0.1
63
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
64
+ DEFAULT_DETECTOR_BATCH_SIZE = 1
65
+ DEFAULT_CLASSIFIER_BATCH_SIZE = 8
66
+ DEFAULT_LOADER_WORKERS = 4
67
+ DEFAULT_WORKER_TYPE = 'thread'
68
+
69
+ # This determines the maximum number of image filenames that can be assigned to
70
+ # each of the producer workers before blocking. The actual size of the queue
71
+ # will be MAX_IMAGE_QUEUE_SIZE_PER_WORKER * n_workers. This is only used for
72
+ # the classification step.
73
+ MAX_IMAGE_QUEUE_SIZE_PER_WORKER = 30
74
+
75
+ # This determines the maximum number of crops that can accumulate in the queue
76
+ # used to communicate between the producers (which read and crop images) and the
77
+ # consumer (which runs the classifier). This is only used for the classification step.
78
+ MAX_BATCH_QUEUE_SIZE = 300
79
+
80
+ # Default interval between frames we should process when processing video.
81
+ # This is only used for the detection step.
82
+ DEFAULT_SECONDS_PER_VIDEO_FRAME = 1.0
83
+
84
+ # Max number of classification scores to include per detection
85
+ DEFAULT_TOP_N_SCORES = 2
86
+
87
+ # Unless --norollup is specified, roll up taxonomic levels until the
88
+ # cumulative confidence is above this value. Only relevant when
89
+ # geofencing is disabled, otherwise the default speciesnet library
90
+ # constants are used.
91
+ DEFAULT_ROLLUP_TARGET_CONFIDENCE = 0.65
92
+
93
+ # When the called supplies an existing MD results file, should we validate it before
94
+ # starting classification?
95
+ VALIDATE_DETECTION_FILE = False
96
+
97
+ verbose = False
98
+
99
+
100
+ #%% Main options class
101
+
102
+ class RunMDSpeciesNetOptions:
103
+ """
104
+ Class controlling the behavior of run_md_and_speciesnet()
105
+ """
106
+
107
+ def __init__(self):
108
+
109
+ #: Folder containing images and/or videos to process
110
+ self.source = None
111
+
112
+ #: Output file for results (JSON format)
113
+ self.output_file = None
114
+
115
+ #: What to do if the output file exists ('overwrite', 'error', 'skip')
116
+ self.overwrite_handling = 'overwrite'
117
+
118
+ #: MegaDetector model identifier (MDv5a, MDv5b, MDv1000-redwood, etc.)
119
+ self.detector_model = DEFAULT_DETECTOR_MODEL
120
+
121
+ #: SpeciesNet classifier model identifier (e.g. kaggle:google/speciesnet/pyTorch/v4.0.1a)
122
+ self.classification_model = DEFAULT_CLASSIFIER_MODEL
123
+
124
+ #: Batch size for MegaDetector inference
125
+ self.detector_batch_size = DEFAULT_DETECTOR_BATCH_SIZE
126
+
127
+ #: Batch size for SpeciesNet classification
128
+ self.classifier_batch_size = DEFAULT_CLASSIFIER_BATCH_SIZE
129
+
130
+ #: Number of worker threads for preprocessing
131
+ self.loader_workers = DEFAULT_LOADER_WORKERS
132
+
133
+ #: Classify detections above this threshold
134
+ self.detection_confidence_threshold_for_classification = \
135
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION
136
+
137
+ #: Include detections above this threshold in the output
138
+ self.detection_confidence_threshold_for_output = \
139
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT
140
+
141
+ #: Folder for intermediate files (default: system temp)
142
+ self.intermediate_file_folder = None
143
+
144
+ #: Keep intermediate files (e.g. detection-only results file)
145
+ self.keep_intermediate_files = False
146
+
147
+ #: Disable taxonomic rollup
148
+ self.norollup = False
149
+
150
+ #: Target confidence threshold for taxonomic rollup
151
+ self.rollup_target_confidence = DEFAULT_ROLLUP_TARGET_CONFIDENCE
152
+
153
+ #: Country code (ISO 3166-1 alpha-3) for geofencing (default None, no geoferencing)
154
+ self.country = None
155
+
156
+ #: Admin1 region/state code for geofencing
157
+ self.admin1_region = None
158
+
159
+ #: Path to existing MegaDetector output file (skips detection step)
160
+ self.detections_file = None
161
+
162
+ #: Ignore videos, only process images
163
+ self.skip_video = False
164
+
165
+ #: Ignore images, only process videos
166
+ self.skip_images = False
167
+
168
+ #: Sample every Nth frame from videos
169
+ #:
170
+ #: Mutually exclusive with time_sample
171
+ self.frame_sample = None
172
+
173
+ #: Sample frames every N seconds from videos
174
+ #:
175
+ #: Mutually exclusive with frame_sample
176
+ self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
177
+
178
+ #: Enable additional debug output
179
+ self.verbose = False
180
+
181
+ #: Worker type for parallelization; should be "thread" or "process"
182
+ self.worker_type = DEFAULT_WORKER_TYPE
183
+
184
+ #: Include raw (pre-rollup/geofence) classification scores in output
185
+ self.include_raw_classifications = False
186
+
187
+ if self.time_sample is None and self.frame_sample is None:
188
+ self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
189
+
190
+ # ...class RunMDSpeciesNetOptions
191
+
192
+
193
+ #%% Support classes
194
+
195
+ class CropMetadata:
196
+ """
197
+ Metadata for a crop extracted from an image detection.
198
+ """
199
+
200
+ def __init__(self,
201
+ image_file: str,
202
+ detection_index: int,
203
+ bbox: list[float],
204
+ original_width: int,
205
+ original_height: int):
206
+ """
207
+ Args:
208
+ image_file (str): path to the original image file
209
+ detection_index (int): index of this detection in the image
210
+ bbox (List[float]): normalized bounding box [x_min, y_min, width, height]
211
+ original_width (int): width of the original image
212
+ original_height (int): height of the original image
213
+ """
214
+
215
+ self.image_file = image_file
216
+ self.detection_index = detection_index
217
+ self.bbox = bbox
218
+ self.original_width = original_width
219
+ self.original_height = original_height
220
+
221
+
222
+ class CropBatch:
223
+ """
224
+ A batch of crops with their metadata for classification.
225
+ """
226
+
227
+ def __init__(self):
228
+ #: List of preprocessed images
229
+ self.crops = []
230
+
231
+ #: List of CropMetadata objects
232
+ self.metadata = []
233
+
234
+ def add_crop(self, crop_data, metadata):
235
+ """
236
+ Args:
237
+ crop_data (PreprocessedImage): preprocessed image data from
238
+ SpeciesNetClassifier.preprocess()
239
+ metadata (CropMetadata): metadata for this crop
240
+ """
241
+
242
+ self.crops.append(crop_data)
243
+ self.metadata.append(metadata)
244
+
245
+ def __len__(self):
246
+ return len(self.crops)
247
+
248
+
249
+ #%% Support functions for classification
250
+
251
+ def _process_image_detections(file_path: str,
252
+ absolute_file_path: str,
253
+ detection_results: dict,
254
+ classifier: 'SpeciesNetClassifier',
255
+ detection_confidence_threshold: float,
256
+ batch_queue: Queue):
257
+ """
258
+ Process detections from a single image.
259
+
260
+ Args:
261
+ file_path (str): relative path to the image file
262
+ absolute_file_path (str): absolute path to the image file
263
+ detection_results (dict): detection results for this image
264
+ classifier (SpeciesNetClassifier): classifier instance for preprocessing
265
+ detection_confidence_threshold (float): classify detections above this threshold
266
+ batch_queue (Queue): queue to send crops to
267
+ """
268
+
269
+ detections = detection_results['detections']
270
+
271
+ # Don't bother loading images that have no above-threshold detections
272
+ detections_above_threshold = \
273
+ [d for d in detections if d['conf'] >= detection_confidence_threshold]
274
+ if len(detections_above_threshold) == 0:
275
+ return
276
+
277
+ # Load the image
278
+ try:
279
+ image = vis_utils.load_image(absolute_file_path)
280
+ original_width, original_height = image.size
281
+ except Exception as e:
282
+ print('Warning: failed to load image {}: {}'.format(file_path, str(e)))
283
+
284
+ # Send failure information to consumer
285
+ failure_metadata = CropMetadata(
286
+ image_file=file_path,
287
+ detection_index=-1, # -1 indicates whole-image failure
288
+ bbox=[],
289
+ original_width=0,
290
+ original_height=0
291
+ )
292
+ batch_queue.put(('failure',
293
+ 'Failed to load image: {}'.format(str(e)),
294
+ failure_metadata))
295
+ return
296
+
297
+ # Process each detection above threshold
298
+ #
299
+ # detection_index needs to index into the original list of detections
300
+ # (this is how classification results will be associated with detections
301
+ # later), so iterate over "detections" here, rather than
302
+ # "detections_above_threshold".
303
+ for detection_index, detection in enumerate(detections):
304
+
305
+ conf = detection['conf']
306
+ if conf < detection_confidence_threshold:
307
+ continue
308
+
309
+ bbox = detection['bbox']
310
+ assert len(bbox) == 4
311
+
312
+ # Convert normalized bbox to BBox object for SpeciesNet
313
+ speciesnet_bbox = BBox(
314
+ xmin=bbox[0],
315
+ ymin=bbox[1],
316
+ width=bbox[2],
317
+ height=bbox[3]
318
+ )
319
+
320
+ # Preprocess the crop
321
+ try:
322
+
323
+ preprocessed_crop = classifier.preprocess(
324
+ image,
325
+ bboxes=[speciesnet_bbox],
326
+ resize=True
327
+ )
328
+
329
+ if preprocessed_crop is not None:
330
+
331
+ metadata = CropMetadata(
332
+ image_file=file_path,
333
+ detection_index=detection_index,
334
+ bbox=bbox,
335
+ original_width=original_width,
336
+ original_height=original_height
337
+ )
338
+
339
+ # Send individual crop to the consumer
340
+ batch_queue.put(('crop', preprocessed_crop, metadata))
341
+
342
+ except Exception as e:
343
+
344
+ print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
345
+ file_path, detection_index, str(e)))
346
+
347
+ # Send failure information to consumer
348
+ failure_metadata = CropMetadata(
349
+ image_file=file_path,
350
+ detection_index=detection_index,
351
+ bbox=bbox,
352
+ original_width=original_width,
353
+ original_height=original_height
354
+ )
355
+ batch_queue.put(('failure',
356
+ 'Failed to preprocess crop: {}'.format(str(e)),
357
+ failure_metadata))
358
+
359
+ # ...try/except
360
+
361
+ # ...for each detection in this image
362
+
363
+ # ...def _process_image_detections(...)
364
+
365
+
366
+ def _process_video_detections(file_path: str,
367
+ absolute_file_path: str,
368
+ detection_results: dict,
369
+ classifier: 'SpeciesNetClassifier',
370
+ detection_confidence_threshold: float,
371
+ batch_queue: Queue):
372
+ """
373
+ Process detections from a single video.
374
+
375
+ Args:
376
+ file_path (str): relative path to the video file
377
+ absolute_file_path (str): absolute path to the video file
378
+ detection_results (dict): detection results for this video
379
+ classifier (SpeciesNetClassifier): classifier instance for preprocessing
380
+ detection_confidence_threshold (float): classify detections above this threshold
381
+ batch_queue (Queue): queue to send crops to
382
+ """
383
+
384
+ detections = detection_results['detections']
385
+
386
+ # Find frames with above-threshold detections
387
+ frames_with_detections = set()
388
+ frame_to_detections = {}
389
+
390
+ for detection_index, detection in enumerate(detections):
391
+
392
+ conf = detection['conf']
393
+ if conf < detection_confidence_threshold:
394
+ continue
395
+
396
+ frame_number = detection['frame_number']
397
+ frames_with_detections.add(frame_number)
398
+
399
+ if frame_number not in frame_to_detections:
400
+ frame_to_detections[frame_number] = []
401
+ frame_to_detections[frame_number].append((detection_index, detection))
402
+
403
+ # ...for each detection in this video
404
+
405
+ if len(frames_with_detections) == 0:
406
+ return
407
+
408
+ frames_to_process = sorted(list(frames_with_detections))
409
+
410
+ # Define callback for processing each frame
411
+ def frame_callback(frame_array, frame_id):
412
+ """
413
+ Callback to process a single frame.
414
+
415
+ Args:
416
+ frame_array (numpy.ndarray): frame data in PIL format
417
+ frame_id (str): frame identifier like "frame0006.jpg"
418
+ """
419
+
420
+ # Extract frame number from frame_id (e.g., "frame0006.jpg" -> 6)
421
+ import re
422
+ match = re.match(r'frame(\d+)\.jpg', frame_id)
423
+ if not match:
424
+ print('Warning: could not parse frame number from {}'.format(frame_id))
425
+ return
426
+ frame_number = int(match.group(1))
427
+
428
+ # Only process frames for which we have detection results
429
+ if frame_number not in frame_to_detections:
430
+ return
431
+
432
+ # Convert numpy array to PIL Image
433
+ from PIL import Image
434
+ if frame_array.dtype != 'uint8':
435
+ frame_array = (frame_array * 255).astype('uint8')
436
+ frame_image = Image.fromarray(frame_array)
437
+ original_width, original_height = frame_image.size
438
+
439
+ # Process each detection in this frame
440
+ for detection_index, detection in frame_to_detections[frame_number]:
441
+
442
+ bbox = detection['bbox']
443
+ assert len(bbox) == 4
444
+
445
+ # Convert normalized bbox to BBox object for SpeciesNet
446
+ speciesnet_bbox = BBox(
447
+ xmin=bbox[0],
448
+ ymin=bbox[1],
449
+ width=bbox[2],
450
+ height=bbox[3]
451
+ )
452
+
453
+ # Preprocess the crop
454
+ try:
455
+
456
+ preprocessed_crop = classifier.preprocess(
457
+ frame_image,
458
+ bboxes=[speciesnet_bbox],
459
+ resize=True
460
+ )
461
+
462
+ if preprocessed_crop is not None:
463
+ metadata = CropMetadata(
464
+ image_file=file_path,
465
+ detection_index=detection_index,
466
+ bbox=bbox,
467
+ original_width=original_width,
468
+ original_height=original_height
469
+ )
470
+
471
+ # Send individual crop immediately to consumer
472
+ batch_queue.put(('crop', preprocessed_crop, metadata))
473
+
474
+ except Exception as e:
475
+
476
+ print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
477
+ file_path, detection_index, str(e)))
478
+
479
+ # Send failure information to consumer
480
+ failure_metadata = CropMetadata(
481
+ image_file=file_path,
482
+ detection_index=detection_index,
483
+ bbox=bbox,
484
+ original_width=original_width,
485
+ original_height=original_height
486
+ )
487
+ batch_queue.put(('failure',
488
+ 'Failed to preprocess crop: {}'.format(str(e)),
489
+ failure_metadata))
490
+
491
+ # ...try/except
492
+
493
+ # ...for each detection
494
+
495
+ # ...def frame_callback(...)
496
+
497
+ # Process the video frames
498
+ try:
499
+
500
+ run_callback_on_frames(
501
+ input_video_file=absolute_file_path,
502
+ frame_callback=frame_callback,
503
+ frames_to_process=frames_to_process,
504
+ verbose=verbose
505
+ )
506
+
507
+ except Exception as e:
508
+
509
+ print('Warning: failed to process video {}: {}'.format(file_path, str(e)))
510
+
511
+ # Send failure information to consumer for the whole video
512
+ failure_metadata = CropMetadata(
513
+ image_file=file_path,
514
+ detection_index=-1, # -1 indicates whole-file failure
515
+ bbox=[],
516
+ original_width=0,
517
+ original_height=0
518
+ )
519
+ batch_queue.put(('failure',
520
+ 'Failed to process video: {}'.format(str(e)),
521
+ failure_metadata))
522
+ # ...try/except
523
+
524
+ # ...def _process_video_detections(...)
525
+
526
+
527
+ def _crop_producer_func(image_queue: JoinableQueue,
528
+ batch_queue: Queue,
529
+ classifier_model: str,
530
+ detection_confidence_threshold: float,
531
+ source_folder: str,
532
+ producer_id: int = -1,
533
+ preloaded_classifier: 'SpeciesNetClassifier' = None):
534
+ """
535
+ Producer function for classification workers.
536
+
537
+ Reads images and videos from [image_queue], crops detections above a threshold,
538
+ preprocesses them, and sends individual crops to [batch_queue].
539
+ See the documentation of _crop_consumer_func to for the format of the
540
+ tuples placed on batch_queue.
541
+
542
+ Args:
543
+ image_queue (JoinableQueue): queue containing detection_results dicts (for both images and videos)
544
+ batch_queue (Queue): queue to put individual crops into
545
+ classifier_model (str): classifier model identifier to load in this process
546
+ detection_confidence_threshold (float): classify detections above this threshold
547
+ source_folder (str): source folder to resolve relative paths
548
+ producer_id (int, optional): identifier for this producer worker
549
+ preloaded_classifier (SpeciesNetClassifier, optional): pre-loaded classifier instance
550
+ (for thread-based workers, to avoid loading models in threads)
551
+ """
552
+
553
+ if verbose:
554
+ print('Classification producer starting: ID {}'.format(producer_id))
555
+
556
+ # Load classifier; this is just being used as a preprocessor, so we force device=cpu.
557
+ #
558
+ # When using threads, we pre-load the classifier in the main thread to avoid PyTorch FX
559
+ # issues with loading models in worker threads.
560
+ if preloaded_classifier is not None:
561
+ classifier = preloaded_classifier
562
+ if verbose:
563
+ print('Classification producer {}: using pre-loaded classifier'.format(producer_id))
564
+ else:
565
+ # There are a number of reasons loading the model might fail; note to self: *don't*
566
+ # catch Exceptions here. This should be a catastrophic failure that stops the whole
567
+ # process.
568
+ classifier = SpeciesNetClassifier(classifier_model, device='cpu')
569
+ if verbose:
570
+ print('Classification producer {}: loaded classifier'.format(producer_id))
571
+
572
+ while True:
573
+
574
+ # Pull an image of detection results from the queue
575
+ detection_results = image_queue.get()
576
+
577
+ # Pulling None from the queue indicates that this producer is done
578
+ if detection_results is None:
579
+ image_queue.task_done()
580
+ break
581
+
582
+ file_path = detection_results['file']
583
+
584
+ # Skip files that failed at the detection stage
585
+ if 'failure' in detection_results:
586
+ image_queue.task_done()
587
+ continue
588
+
589
+ # Skip files with no detections
590
+ detections = detection_results['detections']
591
+ if len(detections) == 0:
592
+ image_queue.task_done()
593
+ continue
594
+
595
+ # Determine if this is an image or video
596
+ absolute_file_path = os.path.join(source_folder, file_path)
597
+ is_video = is_video_file(file_path)
598
+
599
+ if is_video:
600
+
601
+ # Process video
602
+ _process_video_detections(
603
+ file_path=file_path,
604
+ absolute_file_path=absolute_file_path,
605
+ detection_results=detection_results,
606
+ classifier=classifier,
607
+ detection_confidence_threshold=detection_confidence_threshold,
608
+ batch_queue=batch_queue
609
+ )
610
+
611
+ else:
612
+
613
+ # Process image
614
+ _process_image_detections(
615
+ file_path=file_path,
616
+ absolute_file_path=absolute_file_path,
617
+ detection_results=detection_results,
618
+ classifier=classifier,
619
+ detection_confidence_threshold=detection_confidence_threshold,
620
+ batch_queue=batch_queue
621
+ )
622
+
623
+ image_queue.task_done()
624
+
625
+ # ...while(we still have items to process)
626
+
627
+ # Send sentinel to indicate this producer is done
628
+ batch_queue.put(None)
629
+
630
+ if verbose:
631
+ print('Classification producer {} finished'.format(producer_id))
632
+
633
+ # ...def _crop_producer_func(...)
634
+
635
+
636
+ def _crop_consumer_func(batch_queue: Queue,
637
+ results_queue: Queue,
638
+ classifier_model: str,
639
+ batch_size: int,
640
+ num_producers: int,
641
+ enable_rollup: bool,
642
+ country: str = None,
643
+ admin1_region: str = None,
644
+ preloaded_classifier: 'SpeciesNetClassifier' = None,
645
+ rollup_target_confidence: float = DEFAULT_ROLLUP_TARGET_CONFIDENCE):
646
+ """
647
+ Consumer function for classification inference.
648
+
649
+ Pulls individual crops from batch_queue, assembles them into batches,
650
+ runs inference, and puts results into results_queue.
651
+
652
+ Args:
653
+ batch_queue (Queue): queue containing individual crop tuples or failures.
654
+ Items on this queue are either None (to indicate that a producer finished)
655
+ or tuples formatted as (type,image,metadata). [type] is a string (either
656
+ "crop" or "failure"), [image] is a PreprocessedImage, and [metadata] is
657
+ a CropMetadata object.
658
+ results_queue (Queue): queue to put classification results into
659
+ classifier_model (str): classifier model identifier to load
660
+ batch_size (int): batch size for inference
661
+ num_producers (int): number of producer workers
662
+ enable_rollup (bool): whether to apply taxonomic rollup
663
+ country (str, optional): country code for geofencing
664
+ admin1_region (str, optional): admin1 region for geofencing
665
+ preloaded_classifier (SpeciesNetClassifier, optional): pre-loaded classifier instance
666
+ (for thread-based workers, to avoid loading models in threads)
667
+ rollup_target_confidence (float, optional): target confidence threshold for taxonomic
668
+ rollup. Ignored if enable_rollup is False.
669
+ """
670
+
671
+ if verbose:
672
+ print('Classification consumer starting')
673
+
674
+ # Load classifier
675
+ # When using threads, we pre-load the classifier in the main thread to avoid PyTorch FX
676
+ # issues with loading models in worker threads.
677
+ if preloaded_classifier is not None:
678
+ classifier = preloaded_classifier
679
+ if verbose:
680
+ print('Classification consumer: using pre-loaded classifier')
681
+ else:
682
+ try:
683
+ classifier = SpeciesNetClassifier(classifier_model)
684
+ if verbose:
685
+ print('Classification consumer: loaded classifier')
686
+ except Exception as e:
687
+ print('Classification consumer: failed to load classifier: {}'.format(str(e)))
688
+ results_queue.put({})
689
+ return
690
+
691
+ all_results = {} # image_file -> {detection_index -> classification_result}
692
+ current_batch = CropBatch()
693
+ producers_finished = 0
694
+
695
+ # Load ensemble metadata if rollup/geofencing is enabled
696
+ taxonomy_map = {}
697
+ geofence_map = {}
698
+
699
+ if (enable_rollup) or (country is not None):
700
+
701
+ # Note to self: there are a number of reasons loading the ensemble
702
+ # could fail here; don't catch this exception, this should be a
703
+ # catatstrophic failure.
704
+ ensemble = SpeciesNetEnsemble(
705
+ classifier_model, geofence=(country is not None))
706
+ taxonomy_map = ensemble.taxonomy_map
707
+ geofence_map = ensemble.geofence_map
708
+
709
+ # ...if we need to load ensemble components
710
+
711
+ while True:
712
+
713
+ # Pull an item from the queue
714
+ item = batch_queue.get()
715
+
716
+ # This indicates that a producer worker finished
717
+ if item is None:
718
+
719
+ producers_finished += 1
720
+ if producers_finished == num_producers:
721
+ # Process any remaining images
722
+ if len(current_batch) > 0:
723
+ _process_classification_batch(
724
+ current_batch, classifier, all_results,
725
+ enable_rollup, taxonomy_map, geofence_map,
726
+ country, admin1_region, rollup_target_confidence
727
+ )
728
+ break
729
+ continue
730
+
731
+ # ...if a producer finished
732
+
733
+ # If we got here, we know we have a crop to process, or
734
+ # a failure to ignore.
735
+ assert isinstance(item, tuple) and len(item) == 3
736
+ item_type, data, metadata = item
737
+
738
+ if metadata.image_file not in all_results:
739
+ all_results[metadata.image_file] = {}
740
+
741
+ # We should never be processing the same detection twice
742
+ assert metadata.detection_index not in all_results[metadata.image_file]
743
+
744
+ if item_type == 'failure':
745
+
746
+ all_results[metadata.image_file][metadata.detection_index] = {
747
+ 'failure': 'Failure classification: {}'.format(data)
748
+ }
749
+
750
+ else:
751
+
752
+ assert item_type == 'crop'
753
+ current_batch.add_crop(data, metadata)
754
+ assert len(current_batch) <= batch_size
755
+
756
+ # Process batch if necessary
757
+ if len(current_batch) == batch_size:
758
+ _process_classification_batch(
759
+ current_batch, classifier, all_results,
760
+ enable_rollup, taxonomy_map, geofence_map,
761
+ country, admin1_region, rollup_target_confidence
762
+ )
763
+ current_batch = CropBatch()
764
+
765
+ # ...was this item a failure or a crop?
766
+
767
+ # ...while (we have items to process)
768
+
769
+ # Send all the results at once back to the main process
770
+ results_queue.put(all_results)
771
+
772
+ if verbose:
773
+ print('Classification consumer finished')
774
+
775
+ # ...def _crop_consumer_func(...)
776
+
777
+
778
+ def _process_classification_batch(batch: CropBatch,
779
+ classifier: 'SpeciesNetClassifier',
780
+ all_results: dict,
781
+ enable_rollup: bool,
782
+ taxonomy_map: dict,
783
+ geofence_map: dict,
784
+ country: str = None,
785
+ admin1_region: str = None,
786
+ rollup_target_confidence: float =
787
+ DEFAULT_ROLLUP_TARGET_CONFIDENCE):
788
+ """
789
+ Run a batch of crops through the classifier.
790
+
791
+ Args:
792
+ batch (CropBatch): batch of crops to process
793
+ classifier (SpeciesNetClassifier): classifier instance
794
+ all_results (dict): dictionary to store results in, modified in-place with format:
795
+ {image_file: {detection_index: {'predictions': [[class_name, score], ...]}
796
+ or {image_file: {detection_index: {'failure': error_message}}}.
797
+ enable_rollup (bool): whether to apply rollup
798
+ taxonomy_map (dict): taxonomy mapping for rollup
799
+ geofence_map (dict): geofence mapping
800
+ country (str, optional): country code for geofencing
801
+ admin1_region (str, optional): admin1 region for geofencing
802
+ rollup_target_confidence (float, optional): target confidence threshold for
803
+ taxonomic rollup, ignored if enable_rollup is False
804
+ """
805
+
806
+ if len(batch) == 0:
807
+ print('Warning: _process_classification_batch received empty batch')
808
+ return
809
+
810
+ # Prepare batch for inference
811
+ filepaths = [f"{metadata.image_file}_{metadata.detection_index}"
812
+ for metadata in batch.metadata]
813
+
814
+ # Run batch inference
815
+ try:
816
+ batch_results = classifier.batch_predict(filepaths, batch.crops)
817
+ except Exception as e:
818
+ print('*** Batch classification failed: {} ***'.format(str(e)))
819
+ # Mark all crops in this batch as failed
820
+ for metadata in batch.metadata:
821
+ if metadata.image_file not in all_results:
822
+ all_results[metadata.image_file] = {}
823
+ all_results[metadata.image_file][metadata.detection_index] = {
824
+ 'failure': 'Failure classification: {}'.format(str(e))
825
+ }
826
+ return
827
+
828
+ # Process results
829
+ assert len(batch_results) == len(batch.metadata)
830
+ assert len(batch_results) == len(filepaths)
831
+
832
+ for i_result in range(0, len(batch_results)):
833
+
834
+ filepath = filepaths[i_result]
835
+ result = batch_results[i_result]
836
+ metadata = batch.metadata[i_result]
837
+
838
+ assert metadata.image_file in all_results, \
839
+ 'File {} not in results dict'.format(metadata.image_file)
840
+
841
+ detection_index = metadata.detection_index
842
+
843
+ # Handle classification failure
844
+ if 'failures' in result:
845
+ print('*** Classification failure for image: {} ***'.format(
846
+ filepath))
847
+ all_results[metadata.image_file][detection_index] = {
848
+ 'failure': 'Failure classification: SpeciesNet classifier failed'
849
+ }
850
+ continue
851
+
852
+ # Extract classification results; this is a dict with keys "classes"
853
+ # and "scores", each of which points to a list.
854
+ classifications = result['classifications']
855
+ classes = classifications['classes']
856
+ scores = classifications['scores']
857
+
858
+ classification_was_geofenced = False
859
+
860
+ predicted_class = classes[0]
861
+ predicted_score = scores[0]
862
+
863
+ # Possibly apply geofencing
864
+ if country:
865
+
866
+ geofence_result = geofence_animal_classification(
867
+ labels=classes,
868
+ scores=scores,
869
+ country=country,
870
+ admin1_region=admin1_region,
871
+ taxonomy_map=taxonomy_map,
872
+ geofence_map=geofence_map,
873
+ enable_geofence=True
874
+ )
875
+
876
+ geofenced_class, geofenced_score, prediction_source = geofence_result
877
+
878
+ if prediction_source != 'classifier':
879
+ classification_was_geofenced = True
880
+ predicted_class = geofenced_class
881
+ predicted_score = geofenced_score
882
+
883
+ # ...if we might need to apply geofencing
884
+
885
+ # Possibly apply rollup; this was already done if geofencing was applied
886
+ if enable_rollup and (not classification_was_geofenced):
887
+
888
+ rollup_result = roll_up_labels_to_first_matching_level(
889
+ labels=classes,
890
+ scores=scores,
891
+ country=country,
892
+ admin1_region=admin1_region,
893
+ target_taxonomy_levels=['species','genus','family','order','class','kingdom'],
894
+ non_blank_threshold=rollup_target_confidence,
895
+ taxonomy_map=taxonomy_map,
896
+ geofence_map=geofence_map,
897
+ enable_geofence=(country is not None)
898
+ )
899
+
900
+ if rollup_result is not None:
901
+ rolled_up_class, rolled_up_score, prediction_source = rollup_result
902
+ predicted_class = rolled_up_class
903
+ predicted_score = rolled_up_score
904
+
905
+ # ...if we might need to apply taxonomic rollup
906
+
907
+ # For now, we'll store category names as strings; these will be assigned to integer
908
+ # IDs before writing results to file later.
909
+ classification = [predicted_class,predicted_score]
910
+
911
+ # Also report raw model classifications
912
+ raw_classifications = []
913
+ for i_class in range(0,len(classes)):
914
+ raw_classifications.append([classes[i_class],scores[i_class]])
915
+
916
+ all_results[metadata.image_file][detection_index] = {
917
+ 'classifications': [classification],
918
+ 'raw_classifications': raw_classifications
919
+ }
920
+
921
+ # ...for each result in this batch
922
+
923
+ # ...def _process_classification_batch(...)
924
+
925
+
926
+ #%% Inference functions
927
+
928
+ def _run_detection_step(source_folder: str,
929
+ detector_output_file: str,
930
+ detector_model: str = DEFAULT_DETECTOR_MODEL,
931
+ detector_batch_size: int = DEFAULT_DETECTOR_BATCH_SIZE,
932
+ detection_confidence_threshold: float = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
933
+ detector_worker_threads: int = DEFAULT_LOADER_WORKERS,
934
+ worker_type: str = DEFAULT_WORKER_TYPE,
935
+ skip_images: bool = False,
936
+ skip_video: bool = False,
937
+ frame_sample: int = None,
938
+ time_sample: float = None) -> str:
939
+ """
940
+ Run MegaDetector on all images/videos in [source_folder].
941
+
942
+ Args:
943
+ source_folder (str): folder containing images/videos
944
+ detector_output_file (str): output .json file
945
+ detector_model (str, optional): detector model identifier
946
+ detector_batch_size (int, optional): batch size for detection
947
+ detection_confidence_threshold (float, optional): confidence threshold for detections
948
+ (to include in the output file)
949
+ detector_worker_threads (int, optional): number of workers to use for preprocessing
950
+ worker_type (str, optional): type of worker parallelization ("thread" or "process")
951
+ skip_images (bool, optional): ignore images, only process videos
952
+ skip_video (bool, optional): ignore videos, only process images
953
+ frame_sample (int, optional): sample every Nth frame from videos
954
+ time_sample (float, optional): sample frames every N seconds from videos
955
+ """
956
+
957
+ print('Starting detection step...')
958
+
959
+ # Validate arguments
960
+ assert not (frame_sample is None and time_sample is None), \
961
+ 'Must specify either frame_sample or time_sample'
962
+
963
+ # Find image and video files
964
+ if not skip_images:
965
+ image_files = find_images(source_folder, recursive=True,
966
+ return_relative_paths=False)
967
+ else:
968
+ image_files = []
969
+
970
+ if not skip_video:
971
+ video_files = find_videos(source_folder, recursive=True,
972
+ return_relative_paths=False)
973
+ else:
974
+ video_files = []
975
+
976
+ if len(image_files) == 0 and len(video_files) == 0:
977
+ raise ValueError(
978
+ 'No images or videos found in {}'.format(source_folder))
979
+
980
+ print('Found {} images and {} videos'.format(len(image_files), len(video_files)))
981
+
982
+ files_to_merge = []
983
+
984
+ # Process images if necessary
985
+ if len(image_files) > 0:
986
+
987
+ print('Running MegaDetector on {} images...'.format(len(image_files)))
988
+
989
+ use_threads_for_queue = (worker_type == 'thread')
990
+
991
+ image_results = load_and_run_detector_batch(
992
+ model_file=detector_model,
993
+ image_file_names=image_files,
994
+ checkpoint_path=None,
995
+ confidence_threshold=detection_confidence_threshold,
996
+ checkpoint_frequency=-1,
997
+ results=None,
998
+ n_cores=0,
999
+ use_image_queue=True,
1000
+ quiet=True,
1001
+ image_size=None,
1002
+ batch_size=detector_batch_size,
1003
+ include_image_size=False,
1004
+ include_image_timestamp=False,
1005
+ include_exif_tags=None,
1006
+ loader_workers=detector_worker_threads,
1007
+ preprocess_on_image_queue=True,
1008
+ use_threads_for_queue=use_threads_for_queue
1009
+ )
1010
+
1011
+ # Write image results to temporary file
1012
+ image_output_file = detector_output_file.replace('.json', '_images.json')
1013
+ write_results_to_file(image_results,
1014
+ image_output_file,
1015
+ relative_path_base=source_folder,
1016
+ detector_file=detector_model)
1017
+
1018
+ print('Image detection results written to {}'.format(image_output_file))
1019
+ files_to_merge.append(image_output_file)
1020
+
1021
+ # ...if we had images to process
1022
+
1023
+ # Process videos if necessary
1024
+ if len(video_files) > 0:
1025
+
1026
+ print('Running MegaDetector on {} videos...'.format(len(video_files)))
1027
+
1028
+ # Set up video processing options
1029
+ video_options = ProcessVideoOptions()
1030
+ video_options.model_file = detector_model
1031
+ video_options.input_video_file = source_folder
1032
+ video_options.output_json_file = detector_output_file.replace('.json', '_videos.json')
1033
+ video_options.json_confidence_threshold = detection_confidence_threshold
1034
+ video_options.frame_sample = frame_sample
1035
+ video_options.time_sample = time_sample
1036
+ video_options.recursive = True
1037
+
1038
+ # Process videos
1039
+ process_videos(video_options)
1040
+
1041
+ print('Video detection results written to {}'.format(video_options.output_json_file))
1042
+ files_to_merge.append(video_options.output_json_file)
1043
+
1044
+ # ...if we had videos to process
1045
+
1046
+ # Merge results if we have both images and videos
1047
+ if len(files_to_merge) > 1:
1048
+ print('Merging image and video detection results...')
1049
+ combine_batch_output_files(files_to_merge, detector_output_file)
1050
+ print('Merged detection results written to {}'.format(detector_output_file))
1051
+ elif len(files_to_merge) == 1:
1052
+ # Just rename the single file
1053
+ if files_to_merge[0] != detector_output_file:
1054
+ if os.path.isfile(detector_output_file):
1055
+ print('Detector file {} exists, over-writing'.format(detector_output_file))
1056
+ os.remove(detector_output_file)
1057
+ os.rename(files_to_merge[0], detector_output_file)
1058
+ print('Detection results written to {}'.format(detector_output_file))
1059
+
1060
+ # ...def _run_detection_step(...)
1061
+
1062
+
1063
+ def _run_classification_step(detector_results_file: str,
1064
+ merged_results_file: str,
1065
+ source_folder: str,
1066
+ classifier_model: str = DEFAULT_CLASSIFIER_MODEL,
1067
+ classifier_batch_size: int = DEFAULT_CLASSIFIER_BATCH_SIZE,
1068
+ classifier_worker_threads: int = DEFAULT_LOADER_WORKERS,
1069
+ detection_confidence_threshold: float = \
1070
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
1071
+ enable_rollup: bool = True,
1072
+ country: str = None,
1073
+ admin1_region: str = None,
1074
+ top_n_scores: int = DEFAULT_TOP_N_SCORES,
1075
+ worker_type: str = DEFAULT_WORKER_TYPE,
1076
+ include_raw_classifications: bool = False,
1077
+ rollup_target_confidence: float = DEFAULT_ROLLUP_TARGET_CONFIDENCE):
1078
+ """
1079
+ Run SpeciesNet classification on detections from MegaDetector results.
1080
+
1081
+ Args:
1082
+ detector_results_file (str): path to MegaDetector output .json file
1083
+ merged_results_file (str): path to which we should write the merged results
1084
+ source_folder (str): source folder for resolving relative paths
1085
+ classifier_model (str, optional): classifier model identifier
1086
+ classifier_batch_size (int, optional): batch size for classification
1087
+ classifier_worker_threads (int, optional): number of worker threads
1088
+ detection_confidence_threshold (float, optional): classify detections above this threshold
1089
+ enable_rollup (bool, optional): whether to apply taxonomic rollup
1090
+ country (str, optional): country code for geofencing
1091
+ admin1_region (str, optional): admin1 region (typically a state code) for geofencing
1092
+ top_n_scores (int, optional): maximum number of scores to include for each detection
1093
+ worker_type (str, optional): type of worker parallelization ("thread" or "process")
1094
+ include_raw_classifications (bool, optional): include raw (pre-rollup/geofence)
1095
+ classification scores in output
1096
+ rollup_target_confidence (float, optional): target confidence threshold for taxonomic
1097
+ rollup. Ignored if enable_rollup is False.
1098
+ """
1099
+
1100
+ print('Starting classification step...')
1101
+
1102
+ # Load MegaDetector results
1103
+ print('Reading detection results from {}'.format(detector_results_file))
1104
+
1105
+ with open(detector_results_file, 'r') as f:
1106
+ detector_results = json.load(f)
1107
+
1108
+ print('Classification step loaded detection results for {} images'.format(
1109
+ len(detector_results['images'])))
1110
+
1111
+ images = detector_results['images']
1112
+ if len(images) == 0:
1113
+ raise ValueError('No images found in detector results')
1114
+
1115
+ print('Using SpeciesNet classifier: {}'.format(classifier_model))
1116
+
1117
+ # Set multiprocessing start method to 'spawn' for CUDA compatibility
1118
+ if worker_type == 'process':
1119
+ original_start_method = multiprocessing.get_start_method()
1120
+ if original_start_method != 'spawn':
1121
+ multiprocessing.set_start_method('spawn', force=True)
1122
+ print('Set multiprocessing start method to spawn (was {})'.format(
1123
+ original_start_method))
1124
+
1125
+ ## Set up multiprocessing queues
1126
+
1127
+ # This queue receives lists of image filenames (and associated detection results)
1128
+ # from the "main" thread (the one you're reading right now). Items are pulled off
1129
+ # of this queue by producer workers (on _crop_producer_func), where the corresponding
1130
+ # images are loaded from disk and preprocessed into crops.
1131
+ image_queue = JoinableQueue(maxsize= \
1132
+ classifier_worker_threads * MAX_IMAGE_QUEUE_SIZE_PER_WORKER)
1133
+
1134
+ # This queue receives cropped images from producers (on _crop_producer_func); those
1135
+ # crops are pulled off of this queue by the consumer (on _crop_consumer_func).
1136
+ batch_queue = Queue(maxsize=MAX_BATCH_QUEUE_SIZE)
1137
+
1138
+ # This is not really used as a queue, rather it's just used to send all the results
1139
+ # at once from the consumer process to the main process (the one you're reading right
1140
+ # now).
1141
+ results_queue = Queue()
1142
+
1143
+ WorkerClass = Thread if worker_type == 'thread' else Process # noqa
1144
+
1145
+ # When using threads, pre-load classifiers in the main thread to avoid PyTorch FX issues
1146
+ # with loading models in worker threads. When using processes, pass None and let each
1147
+ # process load its own classifier.
1148
+ if worker_type == 'thread':
1149
+ # Producer classifier (CPU only, used for preprocessing)
1150
+ producer_classifier = SpeciesNetClassifier(classifier_model, device='cpu')
1151
+ # Consumer classifier (GPU if available, used for inference)
1152
+ consumer_classifier = SpeciesNetClassifier(classifier_model)
1153
+ else:
1154
+ producer_classifier = None
1155
+ consumer_classifier = None
1156
+
1157
+ # Start producer workers
1158
+ producers = []
1159
+ for i_worker in range(classifier_worker_threads):
1160
+ p = WorkerClass(target=_crop_producer_func,
1161
+ args=(image_queue, batch_queue, classifier_model,
1162
+ detection_confidence_threshold, source_folder, i_worker,
1163
+ producer_classifier))
1164
+ p.start()
1165
+ producers.append(p)
1166
+
1167
+
1168
+ ## Start consumer worker
1169
+
1170
+ consumer = WorkerClass(target=_crop_consumer_func,
1171
+ args=(batch_queue, results_queue, classifier_model,
1172
+ classifier_batch_size, classifier_worker_threads,
1173
+ enable_rollup, country, admin1_region, consumer_classifier,
1174
+ rollup_target_confidence))
1175
+ consumer.start()
1176
+
1177
+ # This will block every time the queue reaches its maximum depth, so for
1178
+ # very small jobs, this will not be a useful progress bar.
1179
+ with tqdm(total=len(images),desc='Classification') as pbar:
1180
+ for image_data in images:
1181
+ image_queue.put(image_data)
1182
+ pbar.update()
1183
+
1184
+ # Send sentinel signals to producers
1185
+ for _ in range(classifier_worker_threads):
1186
+ image_queue.put(None)
1187
+
1188
+ # Wait for all work to complete
1189
+ image_queue.join()
1190
+
1191
+ print('Finished waiting for input queue')
1192
+
1193
+
1194
+ ## Wait for results
1195
+
1196
+ classification_results = results_queue.get()
1197
+
1198
+
1199
+ ## Clean up processes
1200
+
1201
+ for p in producers:
1202
+ p.join()
1203
+ consumer.join()
1204
+
1205
+ print('Finished waiting for workers')
1206
+
1207
+
1208
+ ## Format results and write output
1209
+
1210
+ class CategoryState:
1211
+ """
1212
+ Helper class to manage classification category IDs.
1213
+ """
1214
+
1215
+ def __init__(self):
1216
+
1217
+ self.next_category_id = 0
1218
+
1219
+ # Maps common name to string-int IDs
1220
+ self.common_name_to_id = {}
1221
+
1222
+ # Maps string-ints to common names, as per format standard
1223
+ self.classification_categories = {}
1224
+
1225
+ # Maps string-ints to latin taxonomy strings, as per format standard
1226
+ self.classification_category_descriptions = {}
1227
+
1228
+ def _get_category_id(self, class_name):
1229
+ """
1230
+ Get an integer-valued category ID for the 7-token string [class_name],
1231
+ creating a new one if necessary.
1232
+ """
1233
+
1234
+ # E.g.:
1235
+ #
1236
+ # "cb553c4e-42c9-4fe0-9bd0-da2d6ed5bfa1;mammalia;carnivora;canidae;urocyon;littoralis;island fox"
1237
+ tokens = class_name.split(';')
1238
+ assert len(tokens) == 7
1239
+ taxonomy_string = ';'.join(tokens[1:6])
1240
+ common_name = tokens[6]
1241
+ if len(common_name) == 0:
1242
+ common_name = taxonomy_string
1243
+
1244
+ if common_name not in self.common_name_to_id:
1245
+ self.common_name_to_id[common_name] = str(self.next_category_id)
1246
+ self.classification_categories[str(self.next_category_id)] = common_name
1247
+ # Store the full seven-token string, rather than the shortened five-token string, for
1248
+ # compatibility with what is expected by the classification_postprocessing module.
1249
+ # self.classification_category_descriptions[str(self.next_category_id)] = taxonomy_string
1250
+ self.classification_category_descriptions[str(self.next_category_id)] = class_name
1251
+ self.next_category_id += 1
1252
+
1253
+ category_id = self.common_name_to_id[common_name]
1254
+
1255
+ return category_id
1256
+
1257
+ # ...class CategoryState
1258
+
1259
+ category_state = CategoryState()
1260
+
1261
+ # Merge classification results back into detector results with proper category IDs
1262
+ for image_data in images:
1263
+
1264
+ image_file = image_data['file']
1265
+
1266
+ if ('detections' not in image_data) or (image_data['detections'] is None):
1267
+ continue
1268
+
1269
+ detections = image_data['detections']
1270
+
1271
+ if image_file not in classification_results:
1272
+ continue
1273
+
1274
+ image_classifications = classification_results[image_file]
1275
+
1276
+ for detection_index, detection in enumerate(detections):
1277
+
1278
+ if detection_index in image_classifications:
1279
+
1280
+ result = image_classifications[detection_index]
1281
+
1282
+ if 'failure' in result:
1283
+ # Add failure to the image, not the detection
1284
+ if 'failure' not in image_data:
1285
+ image_data['failure'] = result['failure']
1286
+ else:
1287
+ image_data['failure'] += ';' + result['failure']
1288
+ else:
1289
+
1290
+ # Convert class names to category IDs
1291
+ classification_pairs = []
1292
+ raw_classification_pairs = []
1293
+
1294
+ scores = [x[1] for x in result['classifications']]
1295
+ assert is_list_sorted(scores, reverse=True)
1296
+
1297
+ # Only report the requested number of scores per detection
1298
+ if len(result['classifications']) > top_n_scores:
1299
+ result['classifications'] = \
1300
+ result['classifications'][0:top_n_scores]
1301
+
1302
+ if len(result['raw_classifications']) > top_n_scores:
1303
+ result['raw_classifications'] = \
1304
+ result['raw_classifications'][0:top_n_scores]
1305
+
1306
+ for class_name, score in result['classifications']:
1307
+
1308
+ category_id = category_state._get_category_id(class_name)
1309
+ score = round_float(score, precision=CONF_DIGITS)
1310
+ classification_pairs.append([category_id, score])
1311
+
1312
+ for class_name, score in result['raw_classifications']:
1313
+
1314
+ category_id = category_state._get_category_id(class_name)
1315
+ score = round_float(score, precision=CONF_DIGITS)
1316
+ raw_classification_pairs.append([category_id, score])
1317
+
1318
+ # Add classifications to the detection
1319
+ detection['classifications'] = classification_pairs
1320
+ if include_raw_classifications:
1321
+ detection['raw_classifications'] = raw_classification_pairs
1322
+
1323
+ # ...if this classification contains a failure
1324
+
1325
+ # ...if this detection has classification information
1326
+
1327
+ # ...for each detection
1328
+
1329
+ # ...for each image
1330
+
1331
+ # Update metadata in the results
1332
+ if 'info' not in detector_results:
1333
+ detector_results['info'] = {}
1334
+
1335
+ detector_results['info']['classifier'] = classifier_model
1336
+ detector_results['info']['classification_completion_time'] = time.strftime(
1337
+ '%Y-%m-%d %H:%M:%S')
1338
+
1339
+ # Add classification category mapping
1340
+ detector_results['classification_categories'] = \
1341
+ category_state.classification_categories
1342
+ detector_results['classification_category_descriptions'] = \
1343
+ category_state.classification_category_descriptions
1344
+
1345
+ print('Writing output file')
1346
+
1347
+ # Write results
1348
+ write_json(merged_results_file, detector_results)
1349
+
1350
+ if verbose:
1351
+ print('Classification results written to {}'.format(merged_results_file))
1352
+
1353
+ # ...def _run_classification_step(...)
1354
+
1355
+
1356
+ #%% Main function
1357
+
1358
+ def run_md_and_speciesnet(options):
1359
+ """
1360
+ Main entry point, runs MegaDetector and SpeciesNet on a folder. See
1361
+ RunMDSpeciesNetOptions for available arguments.
1362
+
1363
+ Args:
1364
+ options (RunMDSpeciesNetOptions): options controlling MD and SN inference
1365
+ """
1366
+
1367
+ # Set global verbose flag
1368
+ global verbose
1369
+ verbose = options.verbose
1370
+
1371
+ # Also set the run_detector_batch verbose flag
1372
+ run_detector_batch.verbose = verbose
1373
+
1374
+ # Validate arguments
1375
+ if not os.path.isdir(options.source):
1376
+ raise ValueError(
1377
+ 'Source folder does not exist: {}'.format(options.source))
1378
+
1379
+ if (options.admin1_region is not None) and (options.country is None):
1380
+ raise ValueError('--admin1_region requires --country to be specified')
1381
+
1382
+ if options.skip_images and options.skip_video:
1383
+ raise ValueError('Cannot skip both images and videos')
1384
+
1385
+ if (options.frame_sample is not None) and (options.time_sample is not None):
1386
+ raise ValueError('--frame_sample and --time_sample are mutually exclusive')
1387
+ if (options.frame_sample is None) and (options.time_sample is None):
1388
+ options.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
1389
+
1390
+ if options.worker_type not in ('thread','process'):
1391
+ raise ValueError('Unknown worker type {}'.format(options.worker_type))
1392
+
1393
+ # Set up intermediate file folder
1394
+ if options.intermediate_file_folder:
1395
+ temp_folder = options.intermediate_file_folder
1396
+ os.makedirs(temp_folder, exist_ok=True)
1397
+ else:
1398
+ temp_folder = make_temp_folder(subfolder='run_md_and_speciesnet')
1399
+
1400
+ start_time = time.time()
1401
+
1402
+ print('Processing folder: {}'.format(options.source))
1403
+ print('Output file: {}'.format(options.output_file))
1404
+ print('Intermediate file folder: {}'.format(temp_folder))
1405
+
1406
+ assert options.overwrite_handling in ('overwrite','error','skip'), \
1407
+ 'Unknown overwrite_handling value {}'.format(options.overwrite_handling)
1408
+
1409
+ if os.path.isdir(options.output_file):
1410
+ raise ValueError('Output file {} exists, but is a directory'.format(
1411
+ options.output_file))
1412
+ if os.path.isfile(options.output_file):
1413
+ if options.overwrite_handling == 'overwrite':
1414
+ print('Over-writing existing output file {}'.format(options.output_file))
1415
+ elif options.overwrite_handling == 'error':
1416
+ raise ValueError('Output file {} exists, and overwrite_handling is "error"'.format(
1417
+ options.output_file))
1418
+ elif options.ovwrite_handling == 'skip':
1419
+ print('Bypassing proecssing: output file {} exists, and overwrite_handling is "skip"'.format(
1420
+ options.output_file))
1421
+ return
1422
+
1423
+ # Verify that we can create the output file
1424
+ test_file_write(options.output_file)
1425
+
1426
+ # Determine detector output file path
1427
+ if options.detections_file is not None:
1428
+ detector_output_file = options.detections_file
1429
+ if VALIDATE_DETECTION_FILE:
1430
+ print('Using existing detections file: {}'.format(detector_output_file))
1431
+ validation_options = ValidateBatchResultsOptions()
1432
+ validation_options.check_image_existence = True
1433
+ validation_options.relative_path_base = options.source
1434
+ validation_options.raise_errors = True
1435
+ validate_batch_results(detector_output_file,options=validation_options)
1436
+ print('Validated detections file')
1437
+ else:
1438
+ print('Bypassing validation of {}'.format(options.detections_file))
1439
+ else:
1440
+ detector_output_file = os.path.join(temp_folder, 'detector_output.json')
1441
+
1442
+ # Run MegaDetector
1443
+ _run_detection_step(
1444
+ source_folder=options.source,
1445
+ detector_output_file=detector_output_file,
1446
+ detector_model=options.detector_model,
1447
+ detector_batch_size=options.detector_batch_size,
1448
+ detection_confidence_threshold=options.detection_confidence_threshold_for_output,
1449
+ detector_worker_threads=options.loader_workers,
1450
+ skip_images=options.skip_images,
1451
+ skip_video=options.skip_video,
1452
+ frame_sample=options.frame_sample,
1453
+ time_sample=options.time_sample,
1454
+ worker_type=options.worker_type
1455
+ )
1456
+
1457
+ # Run SpeciesNet
1458
+ _run_classification_step(
1459
+ detector_results_file=detector_output_file,
1460
+ merged_results_file=options.output_file,
1461
+ source_folder=options.source,
1462
+ classifier_model=options.classification_model,
1463
+ classifier_batch_size=options.classifier_batch_size,
1464
+ classifier_worker_threads=options.loader_workers,
1465
+ detection_confidence_threshold=options.detection_confidence_threshold_for_classification,
1466
+ enable_rollup=(not options.norollup),
1467
+ country=options.country,
1468
+ admin1_region=options.admin1_region,
1469
+ worker_type=options.worker_type,
1470
+ include_raw_classifications=options.include_raw_classifications,
1471
+ rollup_target_confidence=options.rollup_target_confidence
1472
+ )
1473
+
1474
+ elapsed_time = time.time() - start_time
1475
+ print(
1476
+ 'Processing complete in {}'.format(humanfriendly.format_timespan(elapsed_time)))
1477
+ print('Results written to: {}'.format(options.output_file))
1478
+
1479
+ # Clean up intermediate files if requested
1480
+ if (not options.keep_intermediate_files) and \
1481
+ (not options.intermediate_file_folder) and \
1482
+ (not options.detections_file):
1483
+ try:
1484
+ os.remove(detector_output_file)
1485
+ except Exception as e:
1486
+ print('Warning: error removing temporary output file {}: {}'.format(
1487
+ detector_output_file, str(e)))
1488
+
1489
+ # ...def run_md_and_speciesnet(...)
1490
+
1491
+
1492
+ #%% Command-line driver
1493
+
1494
+ def main():
1495
+ """
1496
+ Command-line driver for run_md_and_speciesnet.py
1497
+ """
1498
+
1499
+ if 'speciesnet' not in sys.modules:
1500
+ print('It looks like the speciesnet package is not available, try "pip install speciesnet"')
1501
+ if not is_sphinx_build():
1502
+ sys.exit(-1)
1503
+
1504
+ parser = argparse.ArgumentParser(
1505
+ description='Run MegaDetector and SpeciesNet on a folder of images/videos',
1506
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
1507
+ )
1508
+
1509
+ # Required arguments
1510
+ parser.add_argument('source',
1511
+ help='Folder containing images and/or videos to process')
1512
+ parser.add_argument('output_file',
1513
+ help='Output file for results (JSON format)')
1514
+
1515
+ # Optional arguments
1516
+ parser.add_argument('--detector_model',
1517
+ default=DEFAULT_DETECTOR_MODEL,
1518
+ help='MegaDetector model identifier')
1519
+ parser.add_argument('--classification_model',
1520
+ default=DEFAULT_CLASSIFIER_MODEL,
1521
+ help='SpeciesNet classifier model identifier')
1522
+ parser.add_argument('--detector_batch_size',
1523
+ type=int,
1524
+ default=DEFAULT_DETECTOR_BATCH_SIZE,
1525
+ help='Batch size for MegaDetector inference')
1526
+ parser.add_argument('--classifier_batch_size',
1527
+ type=int,
1528
+ default=DEFAULT_CLASSIFIER_BATCH_SIZE,
1529
+ help='Batch size for SpeciesNet classification')
1530
+ parser.add_argument('--loader_workers',
1531
+ type=int,
1532
+ default=DEFAULT_LOADER_WORKERS,
1533
+ help='Number of worker threads for preprocessing')
1534
+ parser.add_argument('--detection_confidence_threshold_for_classification',
1535
+ type=float,
1536
+ default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
1537
+ help='Classify detections above this threshold')
1538
+ parser.add_argument('--detection_confidence_threshold_for_output',
1539
+ type=float,
1540
+ default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT,
1541
+ help='Include detections above this threshold in the output')
1542
+ parser.add_argument('--intermediate_file_folder',
1543
+ default=None,
1544
+ help='Folder for intermediate files (default: system temp)')
1545
+ parser.add_argument('--keep_intermediate_files',
1546
+ action='store_true',
1547
+ help='Keep intermediate files (e.g. detection-only results file)')
1548
+ parser.add_argument('--norollup',
1549
+ action='store_true',
1550
+ help='Disable taxonomic rollup')
1551
+ parser.add_argument('--rollup_target_confidence',
1552
+ type=float,
1553
+ default=DEFAULT_ROLLUP_TARGET_CONFIDENCE,
1554
+ help='Target confidence threshold for taxonomic rollup ' + \
1555
+ f'(default {DEFAULT_ROLLUP_TARGET_CONFIDENCE}), only ' + \
1556
+ 'used when geofencing is disabled')
1557
+ parser.add_argument('--country',
1558
+ default=None,
1559
+ help='Country code (ISO 3166-1 alpha-3) for geofencing')
1560
+ parser.add_argument('--admin1_region', '--state',
1561
+ default=None,
1562
+ help='Admin1 region/state code for geofencing')
1563
+ parser.add_argument('--detections_file',
1564
+ default=None,
1565
+ help='Path to existing MegaDetector output file (skips detection step)')
1566
+ parser.add_argument('--skip_video',
1567
+ action='store_true',
1568
+ help='Ignore videos, only process images')
1569
+ parser.add_argument('--skip_images',
1570
+ action='store_true',
1571
+ help='Ignore images, only process videos')
1572
+ parser.add_argument('--frame_sample',
1573
+ type=int,
1574
+ default=None,
1575
+ help='Sample every Nth frame from videos (mutually exclusive with --time_sample)')
1576
+ parser.add_argument('--time_sample',
1577
+ type=float,
1578
+ default=None,
1579
+ help='Sample frames every N seconds from videos (default {})'.\
1580
+ format(DEFAULT_SECONDS_PER_VIDEO_FRAME) + \
1581
+ ' (mutually exclusive with --frame_sample)')
1582
+ parser.add_argument('--verbose',
1583
+ action='store_true',
1584
+ help='Enable additional debug output')
1585
+ parser.add_argument('--include_raw_classifications',
1586
+ action='store_true',
1587
+ help='Include raw (pre-rollup/geofence) classification scores in output')
1588
+
1589
+ if len(sys.argv[1:]) == 0:
1590
+ parser.print_help()
1591
+ parser.exit()
1592
+
1593
+ args = parser.parse_args()
1594
+
1595
+ options = RunMDSpeciesNetOptions()
1596
+ args_to_object(args,options)
1597
+
1598
+ run_md_and_speciesnet(options)
1599
+
1600
+ # ...def main(...)
1601
+
1602
+
1603
+ if __name__ == '__main__':
1604
+ main()