megadetector 10.0.2__py3-none-any.whl → 10.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (30) hide show
  1. megadetector/data_management/animl_to_md.py +158 -0
  2. megadetector/data_management/zamba_to_md.py +188 -0
  3. megadetector/detection/process_video.py +165 -946
  4. megadetector/detection/pytorch_detector.py +575 -276
  5. megadetector/detection/run_detector_batch.py +629 -202
  6. megadetector/detection/run_md_and_speciesnet.py +1319 -0
  7. megadetector/detection/video_utils.py +243 -107
  8. megadetector/postprocessing/classification_postprocessing.py +12 -1
  9. megadetector/postprocessing/combine_batch_outputs.py +2 -0
  10. megadetector/postprocessing/compare_batch_results.py +21 -2
  11. megadetector/postprocessing/merge_detections.py +16 -12
  12. megadetector/postprocessing/separate_detections_into_folders.py +1 -1
  13. megadetector/postprocessing/subset_json_detector_output.py +1 -3
  14. megadetector/postprocessing/validate_batch_results.py +25 -2
  15. megadetector/tests/__init__.py +0 -0
  16. megadetector/tests/test_nms_synthetic.py +335 -0
  17. megadetector/utils/ct_utils.py +69 -5
  18. megadetector/utils/extract_frames_from_video.py +303 -0
  19. megadetector/utils/md_tests.py +583 -524
  20. megadetector/utils/path_utils.py +4 -15
  21. megadetector/utils/wi_utils.py +20 -4
  22. megadetector/visualization/visualization_utils.py +1 -1
  23. megadetector/visualization/visualize_db.py +8 -22
  24. megadetector/visualization/visualize_detector_output.py +7 -5
  25. megadetector/visualization/visualize_video_output.py +607 -0
  26. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/METADATA +134 -135
  27. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/RECORD +30 -23
  28. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/top_level.txt +0 -0
  30. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1319 @@
1
+ """
2
+
3
+ run_md_and_speciesnet.py
4
+
5
+ Script to run MegaDetector and SpeciesNet on a folder of images and/or videos.
6
+ Runs MD first, then runs SpeciesNet on every above-threshold crop.
7
+
8
+ """
9
+
10
+ #%% Constants, imports, environment
11
+
12
+ import argparse
13
+ import json
14
+ import multiprocessing
15
+ import os
16
+ import sys
17
+ import time
18
+
19
+ from tqdm import tqdm
20
+ from multiprocessing import JoinableQueue, Process, Queue
21
+
22
+ import humanfriendly
23
+
24
+ from megadetector.detection import run_detector_batch
25
+ from megadetector.detection.video_utils import find_videos, run_callback_on_frames, is_video_file
26
+ from megadetector.detection.run_detector_batch import load_and_run_detector_batch
27
+ from megadetector.detection.run_detector import DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
28
+ from megadetector.detection.run_detector import CONF_DIGITS
29
+ from megadetector.detection.run_detector_batch import write_results_to_file
30
+ from megadetector.utils.ct_utils import round_float
31
+ from megadetector.utils.ct_utils import write_json
32
+ from megadetector.utils.ct_utils import make_temp_folder
33
+ from megadetector.utils.ct_utils import is_list_sorted
34
+ from megadetector.utils.ct_utils import is_sphinx_build
35
+ from megadetector.utils import path_utils
36
+ from megadetector.visualization import visualization_utils as vis_utils
37
+ from megadetector.postprocessing.validate_batch_results import \
38
+ validate_batch_results, ValidateBatchResultsOptions
39
+ from megadetector.detection.process_video import \
40
+ process_videos, ProcessVideoOptions
41
+ from megadetector.postprocessing.combine_batch_outputs import combine_batch_output_files
42
+
43
+ # We aren't taking an explicit dependency on the speciesnet package yet,
44
+ # so we wrap this in a try/except so sphinx can still document this module.
45
+ try:
46
+ from speciesnet import SpeciesNetClassifier
47
+ from speciesnet.utils import BBox
48
+ from speciesnet.ensemble import SpeciesNetEnsemble
49
+ from speciesnet.geofence_utils import roll_up_labels_to_first_matching_level
50
+ from speciesnet.geofence_utils import geofence_animal_classification
51
+ except Exception:
52
+ pass
53
+
54
+
55
+ #%% Constants
56
+
57
+ DEFAULT_DETECTOR_MODEL = 'MDV5A'
58
+ DEFAULT_CLASSIFIER_MODEL = 'kaggle:google/speciesnet/pyTorch/v4.0.1a'
59
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION = 0.1
60
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
61
+ DEFAULT_DETECTOR_BATCH_SIZE = 1
62
+ DEFAULT_CLASSIFIER_BATCH_SIZE = 8
63
+ DEFAULT_LOADER_WORKERS = 4
64
+ MAX_QUEUE_SIZE_IMAGES_PER_WORKER = 10
65
+ DEAFULT_SECONDS_PER_VIDEO_FRAME = 1.0
66
+
67
+ # Max number of classification scores to include per detection
68
+ DEFAULT_TOP_N_SCORES = 2
69
+
70
+ # Unless --norollup is specified, roll up taxonomic levels until the
71
+ # cumulative confidence is above this value
72
+ ROLLUP_TARGET_CONFIDENCE = 0.5
73
+
74
+ verbose = False
75
+
76
+
77
+ #%% Support classes
78
+
79
+ class CropMetadata:
80
+ """
81
+ Metadata for a crop extracted from an image detection.
82
+ """
83
+
84
+ def __init__(self,
85
+ image_file: str,
86
+ detection_index: int,
87
+ bbox: list[float],
88
+ original_width: int,
89
+ original_height: int):
90
+ """
91
+ Args:
92
+ image_file (str): path to the original image file
93
+ detection_index (int): index of this detection in the image
94
+ bbox (List[float]): normalized bounding box [x_min, y_min, width, height]
95
+ original_width (int): width of the original image
96
+ original_height (int): height of the original image
97
+ """
98
+
99
+ self.image_file = image_file
100
+ self.detection_index = detection_index
101
+ self.bbox = bbox
102
+ self.original_width = original_width
103
+ self.original_height = original_height
104
+
105
+
106
+ class CropBatch:
107
+ """
108
+ A batch of crops with their metadata for classification.
109
+ """
110
+
111
+ def __init__(self):
112
+ # List of preprocessed images
113
+ self.crops = []
114
+
115
+ # List of CropMetadata objects
116
+ self.metadata = []
117
+
118
+ def add_crop(self, crop_data, metadata):
119
+ """
120
+ Args:
121
+ crop_data (PreprocessedImage): preprocessed image data from
122
+ SpeciesNetClassifier.preprocess()
123
+ metadata (CropMetadata): metadata for this crop
124
+ """
125
+
126
+ self.crops.append(crop_data)
127
+ self.metadata.append(metadata)
128
+
129
+ def __len__(self):
130
+ return len(self.crops)
131
+
132
+
133
+ #%% Support functions for classification
134
+
135
+ def _process_image_detections(file_path: str,
136
+ absolute_file_path: str,
137
+ detection_results: dict,
138
+ classifier: 'SpeciesNetClassifier',
139
+ detection_confidence_threshold: float,
140
+ batch_queue: Queue):
141
+ """
142
+ Process detections from a single image.
143
+
144
+ Args:
145
+ file_path (str): relative path to the image file
146
+ absolute_file_path (str): absolute path to the image file
147
+ detection_results (dict): detection results for this image
148
+ classifier (SpeciesNetClassifier): classifier instance for preprocessing
149
+ detection_confidence_threshold (float): classify detections above this threshold
150
+ batch_queue (Queue): queue to send crops to
151
+ """
152
+
153
+ detections = detection_results['detections']
154
+
155
+ # Load the image
156
+ try:
157
+ image = vis_utils.load_image(absolute_file_path)
158
+ original_width, original_height = image.size
159
+ except Exception as e:
160
+ print('Warning: failed to load image {}: {}'.format(file_path, str(e)))
161
+
162
+ # Send failure information to consumer
163
+ failure_metadata = CropMetadata(
164
+ image_file=file_path,
165
+ detection_index=-1, # -1 indicates whole-image failure
166
+ bbox=[],
167
+ original_width=0,
168
+ original_height=0
169
+ )
170
+ batch_queue.put(('failure',
171
+ 'Failed to load image: {}'.format(str(e)),
172
+ failure_metadata))
173
+ return
174
+
175
+ # Process each detection above threshold
176
+ for detection_index, detection in enumerate(detections):
177
+
178
+ conf = detection['conf']
179
+ if conf < detection_confidence_threshold:
180
+ continue
181
+
182
+ bbox = detection['bbox']
183
+ assert len(bbox) == 4
184
+
185
+ # Convert normalized bbox to BBox object for SpeciesNet
186
+ speciesnet_bbox = BBox(
187
+ xmin=bbox[0],
188
+ ymin=bbox[1],
189
+ width=bbox[2],
190
+ height=bbox[3]
191
+ )
192
+
193
+ # Preprocess the crop
194
+ try:
195
+ preprocessed_crop = classifier.preprocess(
196
+ image,
197
+ bboxes=[speciesnet_bbox],
198
+ resize=True
199
+ )
200
+
201
+ if preprocessed_crop is not None:
202
+ metadata = CropMetadata(
203
+ image_file=file_path,
204
+ detection_index=detection_index,
205
+ bbox=bbox,
206
+ original_width=original_width,
207
+ original_height=original_height
208
+ )
209
+
210
+ # Send individual crop immediately to consumer
211
+ batch_queue.put(('crop', preprocessed_crop, metadata))
212
+
213
+ except Exception as e:
214
+ print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
215
+ file_path, detection_index, str(e)))
216
+
217
+ # Send failure information to consumer
218
+ failure_metadata = CropMetadata(
219
+ image_file=file_path,
220
+ detection_index=detection_index,
221
+ bbox=bbox,
222
+ original_width=original_width,
223
+ original_height=original_height
224
+ )
225
+ batch_queue.put(('failure',
226
+ 'Failed to preprocess crop: {}'.format(str(e)),
227
+ failure_metadata))
228
+
229
+ # ...for each detection in this image
230
+
231
+ # ...def _process_image_detections(...)
232
+
233
+
234
+ def _process_video_detections(file_path: str,
235
+ absolute_file_path: str,
236
+ detection_results: dict,
237
+ classifier: 'SpeciesNetClassifier',
238
+ detection_confidence_threshold: float,
239
+ batch_queue: Queue):
240
+ """
241
+ Process detections from a single video.
242
+
243
+ Args:
244
+ file_path (str): relative path to the video file
245
+ absolute_file_path (str): absolute path to the video file
246
+ detection_results (dict): detection results for this video
247
+ classifier (SpeciesNetClassifier): classifier instance for preprocessing
248
+ detection_confidence_threshold (float): classify detections above this threshold
249
+ batch_queue (Queue): queue to send crops to
250
+ """
251
+
252
+ detections = detection_results['detections']
253
+
254
+ # Find frames with above-threshold detections
255
+ frames_with_detections = set()
256
+ frame_to_detections = {}
257
+
258
+ for detection_index, detection in enumerate(detections):
259
+ conf = detection['conf']
260
+ if conf < detection_confidence_threshold:
261
+ continue
262
+
263
+ frame_number = detection['frame_number']
264
+ frames_with_detections.add(frame_number)
265
+
266
+ if frame_number not in frame_to_detections:
267
+ frame_to_detections[frame_number] = []
268
+ frame_to_detections[frame_number].append((detection_index, detection))
269
+
270
+ if len(frames_with_detections) == 0:
271
+ return
272
+
273
+ frames_to_process = sorted(list(frames_with_detections))
274
+
275
+ # Define callback for processing each frame
276
+ def frame_callback(frame_array, frame_id):
277
+ """
278
+ Callback to process a single frame.
279
+
280
+ Args:
281
+ frame_array (numpy.ndarray): frame data in PIL format
282
+ frame_id (str): frame identifier like "frame0006.jpg"
283
+ """
284
+
285
+ # Extract frame number from frame_id (e.g., "frame0006.jpg" -> 6)
286
+ import re
287
+ match = re.match(r'frame(\d+)\.jpg', frame_id)
288
+ if not match:
289
+ print('Warning: could not parse frame number from {}'.format(frame_id))
290
+ return
291
+ frame_number = int(match.group(1))
292
+
293
+ if frame_number not in frame_to_detections:
294
+ return
295
+
296
+ # Convert numpy array to PIL Image
297
+ from PIL import Image
298
+ if frame_array.dtype != 'uint8':
299
+ frame_array = (frame_array * 255).astype('uint8')
300
+ frame_image = Image.fromarray(frame_array)
301
+ original_width, original_height = frame_image.size
302
+
303
+ # Process each detection in this frame
304
+ for detection_index, detection in frame_to_detections[frame_number]:
305
+
306
+ bbox = detection['bbox']
307
+ assert len(bbox) == 4
308
+
309
+ # Convert normalized bbox to BBox object for SpeciesNet
310
+ speciesnet_bbox = BBox(
311
+ xmin=bbox[0],
312
+ ymin=bbox[1],
313
+ width=bbox[2],
314
+ height=bbox[3]
315
+ )
316
+
317
+ # Preprocess the crop
318
+ try:
319
+
320
+ preprocessed_crop = classifier.preprocess(
321
+ frame_image,
322
+ bboxes=[speciesnet_bbox],
323
+ resize=True
324
+ )
325
+
326
+ if preprocessed_crop is not None:
327
+ metadata = CropMetadata(
328
+ image_file=file_path,
329
+ detection_index=detection_index,
330
+ bbox=bbox,
331
+ original_width=original_width,
332
+ original_height=original_height
333
+ )
334
+
335
+ # Send individual crop immediately to consumer
336
+ batch_queue.put(('crop', preprocessed_crop, metadata))
337
+
338
+ except Exception as e:
339
+
340
+ print('Warning: failed to preprocess crop from {}, detection {}: {}'.format(
341
+ file_path, detection_index, str(e)))
342
+
343
+ # Send failure information to consumer
344
+ failure_metadata = CropMetadata(
345
+ image_file=file_path,
346
+ detection_index=detection_index,
347
+ bbox=bbox,
348
+ original_width=original_width,
349
+ original_height=original_height
350
+ )
351
+ batch_queue.put(('failure',
352
+ 'Failed to preprocess crop: {}'.format(str(e)),
353
+ failure_metadata))
354
+
355
+ # ...try/except
356
+
357
+ # ...for each detection
358
+
359
+ # ...def frame_callback(...)
360
+
361
+ # Process the video frames
362
+ try:
363
+ run_callback_on_frames(
364
+ input_video_file=absolute_file_path,
365
+ frame_callback=frame_callback,
366
+ frames_to_process=frames_to_process,
367
+ verbose=verbose
368
+ )
369
+ except Exception as e:
370
+ print('Warning: failed to process video {}: {}'.format(file_path, str(e)))
371
+
372
+ # Send failure information to consumer for the whole video
373
+ failure_metadata = CropMetadata(
374
+ image_file=file_path,
375
+ detection_index=-1, # -1 indicates whole-file failure
376
+ bbox=[],
377
+ original_width=0,
378
+ original_height=0
379
+ )
380
+ batch_queue.put(('failure',
381
+ 'Failed to process video: {}'.format(str(e)),
382
+ failure_metadata))
383
+ # ...try/except
384
+
385
+ # ...def _process_video_detections(...)
386
+
387
+
388
+ def _crop_producer_func(image_queue: JoinableQueue,
389
+ batch_queue: Queue,
390
+ classifier_model: str,
391
+ detection_confidence_threshold: float,
392
+ source_folder: str,
393
+ producer_id: int = -1):
394
+ """
395
+ Producer function for classification workers.
396
+
397
+ Reads images and videos from [image_queue], crops detections above a threshold,
398
+ preprocesses them, and sends individual crops to [batch_queue].
399
+ See the documentation of _crop_consumer_func to for the format of the
400
+ tuples placed on batch_queue.
401
+
402
+ Args:
403
+ image_queue (JoinableQueue): queue containing detection_results dicts (for both images and videos)
404
+ batch_queue (Queue): queue to put individual crops into
405
+ classifier_model (str): classifier model identifier to load in this process
406
+ detection_confidence_threshold (float): classify detections above this threshold
407
+ source_folder (str): source folder to resolve relative paths
408
+ producer_id (int, optional): identifier for this producer worker
409
+ """
410
+
411
+ if verbose:
412
+ print('Classification producer starting: ID {}'.format(producer_id))
413
+
414
+ # Load classifier; this is just being used as a preprocessor, so we force device=cpu.
415
+ #
416
+ # There are a number of reasons loading the model might fail; note to self: *don't*
417
+ # catch Exceptions here. This should be a catastrophic failure that stops the whole
418
+ # process.
419
+ classifier = SpeciesNetClassifier(classifier_model, device='cpu')
420
+ if verbose:
421
+ print('Classification producer {}: loaded classifier'.format(producer_id))
422
+
423
+ while True:
424
+
425
+ # Pull an image of detection results from the queue
426
+ detection_results = image_queue.get()
427
+
428
+ # Pulling None from the queue indicates that this producer is done
429
+ if detection_results is None:
430
+ image_queue.task_done()
431
+ break
432
+
433
+ file_path = detection_results['file']
434
+
435
+ # Skip files that failed at the detection stage
436
+ if 'failure' in detection_results:
437
+ image_queue.task_done()
438
+ continue
439
+
440
+ # Skip files with no detections
441
+ detections = detection_results['detections']
442
+ if len(detections) == 0:
443
+ image_queue.task_done()
444
+ continue
445
+
446
+ # Determine if this is an image or video
447
+ absolute_file_path = os.path.join(source_folder, file_path)
448
+ is_video = is_video_file(file_path)
449
+
450
+ if is_video:
451
+ # Process video
452
+ _process_video_detections(
453
+ file_path=file_path,
454
+ absolute_file_path=absolute_file_path,
455
+ detection_results=detection_results,
456
+ classifier=classifier,
457
+ detection_confidence_threshold=detection_confidence_threshold,
458
+ batch_queue=batch_queue
459
+ )
460
+ else:
461
+ # Process image
462
+ _process_image_detections(
463
+ file_path=file_path,
464
+ absolute_file_path=absolute_file_path,
465
+ detection_results=detection_results,
466
+ classifier=classifier,
467
+ detection_confidence_threshold=detection_confidence_threshold,
468
+ batch_queue=batch_queue
469
+ )
470
+
471
+ image_queue.task_done()
472
+
473
+ # ...while(we still have items to process)
474
+
475
+ # Send sentinel to indicate this producer is done
476
+ batch_queue.put(None)
477
+
478
+ if verbose:
479
+ print('Classification producer {} finished'.format(producer_id))
480
+
481
+ # ...def _crop_producer_func(...)
482
+
483
+
484
+ def _crop_consumer_func(batch_queue: Queue,
485
+ results_queue: Queue,
486
+ classifier_model: str,
487
+ batch_size: int,
488
+ num_producers: int,
489
+ enable_rollup: bool,
490
+ country: str = None,
491
+ admin1_region: str = None):
492
+ """
493
+ Consumer function for classification inference.
494
+
495
+ Pulls individual crops from batch_queue, assembles them into batches,
496
+ runs inference, and puts results into results_queue.
497
+
498
+ Args:
499
+ batch_queue (Queue): queue containing individual crop tuples or failures.
500
+ Items on this queue are either None (to indicate that a producer finished)
501
+ or tuples formatted as (type,image,metadata). [type] is a string (either
502
+ "crop" or "failure"), [image] is a PreprocessedImage, and [metadata] is
503
+ a CropMetadata object.
504
+ results_queue (Queue): queue to put classification results into
505
+ classifier_model (str): classifier model identifier to load
506
+ batch_size (int): batch size for inference
507
+ num_producers (int): number of producer workers
508
+ enable_rollup (bool): whether to apply taxonomic rollup
509
+ country (str, optional): country code for geofencing
510
+ admin1_region (str, optional): admin1 region for geofencing
511
+ """
512
+
513
+ if verbose:
514
+ print('Classification consumer starting')
515
+
516
+ # Load classifier
517
+ try:
518
+ classifier = SpeciesNetClassifier(classifier_model)
519
+ if verbose:
520
+ print('Classification consumer: loaded classifier')
521
+ except Exception as e:
522
+ print('Classification consumer: failed to load classifier: {}'.format(str(e)))
523
+ results_queue.put({})
524
+ return
525
+
526
+ all_results = {} # image_file -> {detection_index -> classification_result}
527
+ current_batch = CropBatch()
528
+ producers_finished = 0
529
+
530
+ # Load ensemble metadata if rollup/geofencing is enabled
531
+ taxonomy_map = {}
532
+ geofence_map = {}
533
+
534
+ if (enable_rollup is not None) or (country is not None):
535
+
536
+ # Note to self: there are a number of reasons loading the ensemble
537
+ # could fail here; don't catch this exception, this should be a
538
+ # catatstrophic failure.
539
+ ensemble = SpeciesNetEnsemble(
540
+ classifier_model, geofence=(country is not None))
541
+ taxonomy_map = ensemble.taxonomy_map
542
+ geofence_map = ensemble.geofence_map
543
+
544
+ # ...if we need to load ensemble components
545
+
546
+ while True:
547
+
548
+ # Pull an item from the queue
549
+ item = batch_queue.get()
550
+
551
+ # This indicates that a producer worker finished
552
+ if item is None:
553
+
554
+ producers_finished += 1
555
+ if producers_finished == num_producers:
556
+ # Process any remaining images
557
+ if len(current_batch) > 0:
558
+ _process_classification_batch(
559
+ current_batch, classifier, all_results,
560
+ enable_rollup, taxonomy_map, geofence_map,
561
+ country, admin1_region
562
+ )
563
+ break
564
+ continue
565
+
566
+ # ...if a producer finished
567
+
568
+ # If we got here, we know we have a crop to process, or
569
+ # a failure to ignore.
570
+ assert isinstance(item, tuple) and len(item) == 3
571
+ item_type, data, metadata = item
572
+
573
+ if metadata.image_file not in all_results:
574
+ all_results[metadata.image_file] = {}
575
+
576
+ # We should never be processing the same detetion twice
577
+ assert metadata.detection_index not in all_results[metadata.image_file]
578
+
579
+ if item_type == 'failure':
580
+
581
+ all_results[metadata.image_file][metadata.detection_index] = {
582
+ 'failure': 'Failure classification: {}'.format(data)
583
+ }
584
+
585
+ else:
586
+
587
+ assert item_type == 'crop'
588
+ current_batch.add_crop(data, metadata)
589
+ assert len(current_batch) <= batch_size
590
+
591
+ # Process batch if necessary
592
+ if len(current_batch) == batch_size:
593
+ _process_classification_batch(
594
+ current_batch, classifier, all_results,
595
+ enable_rollup, taxonomy_map, geofence_map,
596
+ country, admin1_region
597
+ )
598
+ current_batch = CropBatch()
599
+
600
+ # ...was this item a failure or a crop?
601
+
602
+ # ...while (we have items to process)
603
+
604
+ results_queue.put(all_results)
605
+
606
+ if verbose:
607
+ print('Classification consumer finished')
608
+
609
+ # ...def _crop_consumer_func(...)
610
+
611
+
612
+ def _process_classification_batch(batch: CropBatch,
613
+ classifier: 'SpeciesNetClassifier',
614
+ all_results: dict,
615
+ enable_rollup: bool,
616
+ taxonomy_map: dict,
617
+ geofence_map: dict,
618
+ country: str = None,
619
+ admin1_region: str = None):
620
+ """
621
+ Run a batch of crops through the classifier.
622
+
623
+ Args:
624
+ batch (CropBatch): batch of crops to process
625
+ classifier (SpeciesNetClassifier): classifier instance
626
+ all_results (dict): dictionary to store results in, modified in-place with format:
627
+ {image_file: {detection_index: {'predictions': [[class_name, score], ...]}
628
+ or {image_file: {detection_index: {'failure': error_message}}}.
629
+ enable_rollup (bool): whether to apply rollup
630
+ taxonomy_map (dict): taxonomy mapping for rollup
631
+ geofence_map (dict): geofence mapping
632
+ country (str, optional): country code for geofencing
633
+ admin1_region (str, optional): admin1 region for geofencing
634
+ """
635
+
636
+ if len(batch) == 0:
637
+ print('Warning: _process_classification_batch received empty batch')
638
+ return
639
+
640
+ # Prepare batch for inference
641
+ filepaths = [f"{metadata.image_file}_{metadata.detection_index}"
642
+ for metadata in batch.metadata]
643
+
644
+ # Run batch inference
645
+ try:
646
+ batch_results = classifier.batch_predict(filepaths, batch.crops)
647
+ except Exception as e:
648
+ print('*** Batch classification failed: {} ***'.format(str(e)))
649
+ # Mark all crops in this batch as failed
650
+ for metadata in batch.metadata:
651
+ if metadata.image_file not in all_results:
652
+ all_results[metadata.image_file] = {}
653
+ all_results[metadata.image_file][metadata.detection_index] = {
654
+ 'failure': 'Failure classification: {}'.format(str(e))
655
+ }
656
+ return
657
+
658
+ # Process results
659
+ assert len(batch_results) == len(batch.metadata)
660
+ assert len(batch_results) == len(filepaths)
661
+
662
+ for i_result in range(0, len(batch_results)):
663
+
664
+ result = batch_results[i_result]
665
+ metadata = batch.metadata[i_result]
666
+
667
+ assert metadata.image_file in all_results, \
668
+ 'File {} not in results dict'.format(metadata.image_file)
669
+
670
+ detection_index = metadata.detection_index
671
+
672
+ # Handle classification failure
673
+ if 'failures' in result:
674
+ print('*** Classification failure for image: {} ***'.format(
675
+ filepaths[i_result]))
676
+ all_results[metadata.image_file][detection_index] = {
677
+ 'failure': 'Failure classification: SpeciesNet classifier failed'
678
+ }
679
+ continue
680
+
681
+ # Extract classification results; this is a dict with keys "classes"
682
+ # and "scores", each of which points to a list.
683
+ classifications = result['classifications']
684
+ classes = classifications['classes']
685
+ scores = classifications['scores']
686
+
687
+ classification_was_geofenced = False
688
+
689
+ predicted_class = classes[0]
690
+ predicted_score = scores[0]
691
+
692
+ # Possibly apply geofencing
693
+ if country:
694
+
695
+ geofence_result = geofence_animal_classification(
696
+ labels=classes,
697
+ scores=scores,
698
+ country=country,
699
+ admin1_region=admin1_region,
700
+ taxonomy_map=taxonomy_map,
701
+ geofence_map=geofence_map,
702
+ enable_geofence=True
703
+ )
704
+
705
+ geofenced_class, geofenced_score, prediction_source = geofence_result
706
+
707
+ if prediction_source != 'classifier':
708
+ classification_was_geofenced = True
709
+ predicted_class = geofenced_class
710
+ predicted_score = geofenced_score
711
+
712
+ # ...if we might need to apply geofencing
713
+
714
+ # Possibly apply rollup; this was already done if geofencing was applied
715
+ if enable_rollup and (not classification_was_geofenced):
716
+
717
+ rollup_result = roll_up_labels_to_first_matching_level(
718
+ labels=classes,
719
+ scores=scores,
720
+ country=country,
721
+ admin1_region=admin1_region,
722
+ target_taxonomy_levels=['species','genus','family', 'order','class', 'kingdom'],
723
+ non_blank_threshold=ROLLUP_TARGET_CONFIDENCE,
724
+ taxonomy_map=taxonomy_map,
725
+ geofence_map=geofence_map,
726
+ enable_geofence=(country is not None)
727
+ )
728
+
729
+ if rollup_result is not None:
730
+ rolled_up_class, rolled_up_score, prediction_source = rollup_result
731
+ if rolled_up_class != predicted_class:
732
+ predicted_class = rolled_up_class
733
+ predicted_score = rolled_up_score
734
+
735
+ # ...if we might need to apply taxonomic rollup
736
+
737
+ # For now, we'll store category names as strings; these will be assigned to integer
738
+ # IDs before writing results to file later.
739
+ classification = [predicted_class,predicted_score]
740
+
741
+ # Also report raw model classifications
742
+ raw_classifications = []
743
+ for i_class in range(0,len(classes)):
744
+ raw_classifications.append([classes[i_class],scores[i_class]])
745
+
746
+ all_results[metadata.image_file][detection_index] = {
747
+ 'classifications': [classification],
748
+ 'raw_classifications': raw_classifications
749
+ }
750
+
751
+ # ...for each result in this batch
752
+
753
+ # ...def _process_classification_batch(...)
754
+
755
+
756
+ #%% Inference functions
757
+
758
+ def _run_detection_step(source_folder: str,
759
+ detector_output_file: str,
760
+ detector_model: str = DEFAULT_DETECTOR_MODEL,
761
+ detector_batch_size: int = DEFAULT_DETECTOR_BATCH_SIZE,
762
+ detection_confidence_threshold: float = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
763
+ detector_worker_threads: int = DEFAULT_LOADER_WORKERS,
764
+ skip_images: bool = False,
765
+ skip_video: bool = False,
766
+ frame_sample: int = None,
767
+ time_sample: float = None) -> str:
768
+ """
769
+ Run MegaDetector on all images/videos in [source_folder].
770
+
771
+ Args:
772
+ source_folder (str): folder containing images/videos
773
+ detector_output_file (str): output .json file
774
+ detector_model (str, optional): detector model identifier
775
+ detector_batch_size (int, optional): batch size for detection
776
+ detection_confidence_threshold (float, optional): confidence threshold for detections
777
+ (to include in the output file)
778
+ detector_worker_threads (int, optional): number of workers to use for preprocessing
779
+ skip_images (bool, optional): ignore images, only process videos
780
+ skip_video (bool, optional): ignore videos, only process images
781
+ frame_sample (int, optional): sample every Nth frame from videos
782
+ time_sample (float, optional): sample frames every N seconds from videos
783
+ """
784
+
785
+ print('Starting detection step...')
786
+
787
+ # Validate arguments
788
+ assert not (frame_sample is None and time_sample is None), \
789
+ 'Must specify either frame_sample or time_sample'
790
+
791
+ # Find image and video files
792
+ if not skip_images:
793
+ image_files = path_utils.find_images(source_folder, recursive=True,
794
+ return_relative_paths=False)
795
+ else:
796
+ image_files = []
797
+
798
+ if not skip_video:
799
+ video_files = find_videos(source_folder, recursive=True,
800
+ return_relative_paths=False)
801
+ else:
802
+ video_files = []
803
+
804
+ if len(image_files) == 0 and len(video_files) == 0:
805
+ raise ValueError(
806
+ 'No images or videos found in {}'.format(source_folder))
807
+
808
+ print('Found {} images and {} videos'.format(len(image_files), len(video_files)))
809
+
810
+ files_to_merge = []
811
+
812
+ # Process images if necessary
813
+ if len(image_files) > 0:
814
+
815
+ print('Running MegaDetector on {} images...'.format(len(image_files)))
816
+
817
+ image_results = load_and_run_detector_batch(
818
+ model_file=detector_model,
819
+ image_file_names=image_files,
820
+ checkpoint_path=None,
821
+ confidence_threshold=detection_confidence_threshold,
822
+ checkpoint_frequency=-1,
823
+ results=None,
824
+ n_cores=0,
825
+ use_image_queue=True,
826
+ quiet=True,
827
+ image_size=None,
828
+ batch_size=detector_batch_size,
829
+ include_image_size=False,
830
+ include_image_timestamp=False,
831
+ include_exif_data=False,
832
+ loader_workers=detector_worker_threads,
833
+ preprocess_on_image_queue=True
834
+ )
835
+
836
+ # Write image results to temporary file
837
+ image_output_file = detector_output_file.replace('.json', '_images.json')
838
+ write_results_to_file(image_results,
839
+ image_output_file,
840
+ relative_path_base=source_folder,
841
+ detector_file=detector_model)
842
+
843
+ print('Image detection results written to {}'.format(image_output_file))
844
+ files_to_merge.append(image_output_file)
845
+
846
+ # ...if we had images to process
847
+
848
+ # Process videos if necessary
849
+ if len(video_files) > 0:
850
+
851
+ print('Running MegaDetector on {} videos...'.format(len(video_files)))
852
+
853
+ # Set up video processing options
854
+ video_options = ProcessVideoOptions()
855
+ video_options.model_file = detector_model
856
+ video_options.input_video_file = source_folder
857
+ video_options.output_json_file = detector_output_file.replace('.json', '_videos.json')
858
+ video_options.json_confidence_threshold = detection_confidence_threshold
859
+ video_options.frame_sample = frame_sample
860
+ video_options.time_sample = time_sample
861
+ video_options.recursive = True
862
+
863
+ # Process videos
864
+ process_videos(video_options)
865
+
866
+ print('Video detection results written to {}'.format(video_options.output_json_file))
867
+ files_to_merge.append(video_options.output_json_file)
868
+
869
+ # ...if we had videos to process
870
+
871
+ # Merge results if we have both images and videos
872
+ if len(files_to_merge) > 1:
873
+ print('Merging image and video detection results...')
874
+ combine_batch_output_files(files_to_merge, detector_output_file)
875
+ print('Merged detection results written to {}'.format(detector_output_file))
876
+ elif len(files_to_merge) == 1:
877
+ # Just rename the single file
878
+ if files_to_merge[0] != detector_output_file:
879
+ if os.path.isfile(detector_output_file):
880
+ print('Detector file {} exists, over-writing'.format(detector_output_file))
881
+ os.remove(detector_output_file)
882
+ os.rename(files_to_merge[0], detector_output_file)
883
+ print('Detection results written to {}'.format(detector_output_file))
884
+
885
+ # ...def _run_detection_step(...)
886
+
887
+
888
+ def _run_classification_step(detector_results_file: str,
889
+ merged_results_file: str,
890
+ source_folder: str,
891
+ classifier_model: str = DEFAULT_CLASSIFIER_MODEL,
892
+ classifier_batch_size: int = DEFAULT_CLASSIFIER_BATCH_SIZE,
893
+ classifier_worker_threads: int = DEFAULT_LOADER_WORKERS,
894
+ detection_confidence_threshold: float = \
895
+ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
896
+ enable_rollup: bool = True,
897
+ country: str = None,
898
+ admin1_region: str = None,
899
+ top_n_scores: int = DEFAULT_TOP_N_SCORES):
900
+ """
901
+ Run SpeciesNet classification on detections from MegaDetector results.
902
+
903
+ Args:
904
+ detector_results_file (str): path to MegaDetector output .json file
905
+ merged_results_file (str): path to which we should write the merged results
906
+ source_folder (str): source folder for resolving relative paths
907
+ classifier_model (str, optional): classifier model identifier
908
+ classifier_batch_size (int, optional): batch size for classification
909
+ classifier_worker_threads (int, optional): number of worker threads
910
+ detection_confidence_threshold (float, optional): classify detections above this threshold
911
+ enable_rollup (bool, optional): whether to apply taxonomic rollup
912
+ country (str, optional): country code for geofencing
913
+ admin1_region (str, optional): admin1 region (typically a state code) for geofencing
914
+ top_n_scores (int, optional): maximum number of scores to include for each detection
915
+ """
916
+
917
+ print('Starting SpeciesNet classification step...')
918
+
919
+ # Load MegaDetector results
920
+ with open(detector_results_file, 'r') as f:
921
+ detector_results = json.load(f)
922
+
923
+ print('Classification step loaded detection results for {} images'.format(
924
+ len(detector_results['images'])))
925
+
926
+ images = detector_results['images']
927
+ if len(images) == 0:
928
+ raise ValueError('No images found in detector results')
929
+
930
+ print('Using SpeciesNet classifier: {}'.format(classifier_model))
931
+
932
+ # Set multiprocessing start method to 'spawn' for CUDA compatibility
933
+ original_start_method = multiprocessing.get_start_method()
934
+ if original_start_method != 'spawn':
935
+ multiprocessing.set_start_method('spawn', force=True)
936
+ print('Set multiprocessing start method to spawn (was {})'.format(
937
+ original_start_method))
938
+
939
+ # Set up multiprocessing queues
940
+ max_queue_size = classifier_worker_threads * MAX_QUEUE_SIZE_IMAGES_PER_WORKER
941
+ image_queue = JoinableQueue(max_queue_size)
942
+ batch_queue = Queue()
943
+ results_queue = Queue()
944
+
945
+ # Start producer workers
946
+ producers = []
947
+ for i_worker in range(classifier_worker_threads):
948
+ p = Process(target=_crop_producer_func,
949
+ args=(image_queue, batch_queue, classifier_model,
950
+ detection_confidence_threshold, source_folder, i_worker))
951
+ p.start()
952
+ producers.append(p)
953
+
954
+ # Start consumer worker
955
+ consumer = Process(target=_crop_consumer_func,
956
+ args=(batch_queue, results_queue, classifier_model,
957
+ classifier_batch_size, classifier_worker_threads,
958
+ enable_rollup, country, admin1_region))
959
+ consumer.start()
960
+
961
+ # This will block every time the queue reaches its maximum depth, so for
962
+ # very small jobs, this will not be a useful progress bar.
963
+ with tqdm(total=len(images),desc='Classification') as pbar:
964
+ for image_data in images:
965
+ image_queue.put(image_data)
966
+ pbar.update()
967
+
968
+ # Send sentinel signals to producers
969
+ for _ in range(classifier_worker_threads):
970
+ image_queue.put(None)
971
+
972
+ # Wait for all work to complete
973
+ image_queue.join()
974
+
975
+ print('Finished waiting for input queue')
976
+
977
+ # Wait for results
978
+ classification_results = results_queue.get()
979
+
980
+ # Clean up processes
981
+ for p in producers:
982
+ p.join()
983
+ consumer.join()
984
+
985
+ print('Finished waiting for workers')
986
+
987
+ class CategoryState:
988
+ """
989
+ Helper class to manage classification category IDs.
990
+ """
991
+
992
+ def __init__(self):
993
+
994
+ self.next_category_id = 0
995
+
996
+ # Maps common name to string-int IDs
997
+ self.common_name_to_id = {}
998
+
999
+ # Maps string-ints to common names, as per format standard
1000
+ self.classification_categories = {}
1001
+
1002
+ # Maps string-ints to latin taxonomy strings, as per format standard
1003
+ self.classification_category_descriptions = {}
1004
+
1005
+ def _get_category_id(self, class_name):
1006
+ """
1007
+ Get an integer-valued category ID for the 7-token string [class_name],
1008
+ creating a new one if necessary.
1009
+ """
1010
+
1011
+ # E.g.:
1012
+ #
1013
+ # "cb553c4e-42c9-4fe0-9bd0-da2d6ed5bfa1;mammalia;carnivora;canidae;urocyon;littoralis;island fox"
1014
+ tokens = class_name.split(';')
1015
+ assert len(tokens) == 7
1016
+ taxonomy_string = ';'.join(tokens[1:6])
1017
+ common_name = tokens[6]
1018
+ if len(common_name) == 0:
1019
+ common_name = taxonomy_string
1020
+
1021
+ if common_name not in self.common_name_to_id:
1022
+ self.common_name_to_id[common_name] = str(self.next_category_id)
1023
+ self.classification_categories[str(self.next_category_id)] = common_name
1024
+ self.classification_category_descriptions[str(self.next_category_id)] = taxonomy_string
1025
+ self.next_category_id += 1
1026
+
1027
+ category_id = self.common_name_to_id[common_name]
1028
+
1029
+ return category_id
1030
+
1031
+ # ...class CategoryState
1032
+
1033
+ category_state = CategoryState()
1034
+
1035
+ # Merge classification results back into detector results with proper category IDs
1036
+ for image_data in images:
1037
+
1038
+ image_file = image_data['file']
1039
+
1040
+ if ('detections' not in image_data) or (image_data['detections'] is None):
1041
+ continue
1042
+
1043
+ detections = image_data['detections']
1044
+
1045
+ if image_file not in classification_results:
1046
+ continue
1047
+
1048
+ image_classifications = classification_results[image_file]
1049
+
1050
+ for detection_index, detection in enumerate(detections):
1051
+
1052
+ if detection_index in image_classifications:
1053
+
1054
+ result = image_classifications[detection_index]
1055
+
1056
+ if 'failure' in result:
1057
+ # Add failure to the image, not the detection
1058
+ if 'failure' not in image_data:
1059
+ image_data['failure'] = result['failure']
1060
+ else:
1061
+ image_data['failure'] += ';' + result['failure']
1062
+ else:
1063
+
1064
+ # Convert class names to category IDs
1065
+ classification_pairs = []
1066
+ raw_classification_pairs = []
1067
+
1068
+ scores = [x[1] for x in result['classifications']]
1069
+ assert is_list_sorted(scores, reverse=True)
1070
+
1071
+ # Only report the requested number of scores per detection
1072
+ if len(result['classifications']) > top_n_scores:
1073
+ result['classifications'] = \
1074
+ result['classifications'][0:top_n_scores]
1075
+
1076
+ if len(result['raw_classifications']) > top_n_scores:
1077
+ result['raw_classifications'] = \
1078
+ result['raw_classifications'][0:top_n_scores]
1079
+
1080
+ for class_name, score in result['classifications']:
1081
+
1082
+ category_id = category_state._get_category_id(class_name)
1083
+ score = round_float(score, precision=CONF_DIGITS)
1084
+ classification_pairs.append([category_id, score])
1085
+
1086
+ for class_name, score in result['raw_classifications']:
1087
+
1088
+ category_id = category_state._get_category_id(class_name)
1089
+ score = round_float(score, precision=CONF_DIGITS)
1090
+ raw_classification_pairs.append([category_id, score])
1091
+
1092
+ # Add classifications to the detection
1093
+ detection['classifications'] = classification_pairs
1094
+ detection['raw_classifications'] = raw_classification_pairs
1095
+
1096
+ # ...if this classification contains a failure
1097
+
1098
+ # ...if this detection has classification information
1099
+
1100
+ # ...for each detection
1101
+
1102
+ # ...for each image
1103
+
1104
+ # Update metadata in the results
1105
+ if 'info' not in detector_results:
1106
+ detector_results['info'] = {}
1107
+
1108
+ detector_results['info']['classifier'] = classifier_model
1109
+ detector_results['info']['classification_completion_time'] = time.strftime(
1110
+ '%Y-%m-%d %H:%M:%S')
1111
+
1112
+ # Add classification category mapping
1113
+ detector_results['classification_categories'] = \
1114
+ category_state.classification_categories
1115
+ detector_results['classification_category_descriptions'] = \
1116
+ category_state.classification_category_descriptions
1117
+
1118
+ print('Writing output file')
1119
+
1120
+ # Write results
1121
+ write_json(merged_results_file, detector_results)
1122
+
1123
+ if verbose:
1124
+ print('Classification results written to {}'.format(merged_results_file))
1125
+
1126
+ # ...def _run_classification_step(...)
1127
+
1128
+
1129
+ #%% Command-line driver
1130
+
1131
+ def main():
1132
+ """
1133
+ Command-line driver for run_md_and_speciesnet.py
1134
+ """
1135
+
1136
+ if 'speciesnet' not in sys.modules:
1137
+ print('It looks like the speciesnet package is not available, try "pip install speciesnet"')
1138
+ if not is_sphinx_build():
1139
+ sys.exit(-1)
1140
+
1141
+ parser = argparse.ArgumentParser(
1142
+ description='Run MegaDetector and SpeciesNet on a folder of images/videos',
1143
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
1144
+ )
1145
+
1146
+ # Required arguments
1147
+ parser.add_argument('source',
1148
+ help='Folder containing images and/or videos to process')
1149
+ parser.add_argument('output_file',
1150
+ help='Output file for results (JSON format)')
1151
+
1152
+ # Optional arguments
1153
+ parser.add_argument('--detector_model',
1154
+ default=DEFAULT_DETECTOR_MODEL,
1155
+ help='MegaDetector model identifier')
1156
+ parser.add_argument('--classification_model',
1157
+ default=DEFAULT_CLASSIFIER_MODEL,
1158
+ help='SpeciesNet classifier model identifier')
1159
+ parser.add_argument('--detector_batch_size',
1160
+ type=int,
1161
+ default=DEFAULT_DETECTOR_BATCH_SIZE,
1162
+ help='Batch size for MegaDetector inference')
1163
+ parser.add_argument('--classifier_batch_size',
1164
+ type=int,
1165
+ default=DEFAULT_CLASSIFIER_BATCH_SIZE,
1166
+ help='Batch size for SpeciesNet classification')
1167
+ parser.add_argument('--loader_workers',
1168
+ type=int,
1169
+ default=DEFAULT_LOADER_WORKERS,
1170
+ help='Number of worker threads for preprocessing')
1171
+ parser.add_argument('--detection_confidence_threshold_for_classification',
1172
+ type=float,
1173
+ default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION,
1174
+ help='Classify detections above this threshold')
1175
+ parser.add_argument('--detection_confidence_threshold_for_output',
1176
+ type=float,
1177
+ default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT,
1178
+ help='Include detections above this threshold in the output')
1179
+ parser.add_argument('--intermediate_file_folder',
1180
+ default=None,
1181
+ help='Folder for intermediate files (default: system temp)')
1182
+ parser.add_argument('--keep_intermediate_files',
1183
+ action='store_true',
1184
+ help='Keep intermediate files for debugging')
1185
+ parser.add_argument('--norollup',
1186
+ action='store_true',
1187
+ help='Disable taxonomic rollup')
1188
+ parser.add_argument('--country',
1189
+ default=None,
1190
+ help='Country code (ISO 3166-1 alpha-3) for geofencing')
1191
+ parser.add_argument('--admin1_region', '--state',
1192
+ default=None,
1193
+ help='Admin1 region/state code for geofencing')
1194
+ parser.add_argument('--detections_file',
1195
+ default=None,
1196
+ help='Path to existing MegaDetector output file (skips detection step)')
1197
+ parser.add_argument('--skip_video',
1198
+ action='store_true',
1199
+ help='Ignore videos, only process images')
1200
+ parser.add_argument('--skip_images',
1201
+ action='store_true',
1202
+ help='Ignore images, only process videos')
1203
+ parser.add_argument('--frame_sample',
1204
+ type=int,
1205
+ default=None,
1206
+ help='Sample every Nth frame from videos (mutually exclusive with --time_sample)')
1207
+ parser.add_argument('--time_sample',
1208
+ type=float,
1209
+ default=None,
1210
+ help='Sample frames every N seconds from videos (default {})'.\
1211
+ format(DEAFULT_SECONDS_PER_VIDEO_FRAME) + \
1212
+ ' (mutually exclusive with --frame_sample)')
1213
+ parser.add_argument('--verbose',
1214
+ action='store_true',
1215
+ help='Enable additional debug output')
1216
+
1217
+ if len(sys.argv[1:]) == 0:
1218
+ parser.print_help()
1219
+ parser.exit()
1220
+
1221
+ args = parser.parse_args()
1222
+
1223
+ # Set global verbose flag
1224
+ global verbose
1225
+ verbose = args.verbose
1226
+
1227
+ # Also set the run_detector_batch verbose flag
1228
+ run_detector_batch.verbose = verbose
1229
+
1230
+ # Validate arguments
1231
+ if not os.path.isdir(args.source):
1232
+ raise ValueError(
1233
+ 'Source folder does not exist: {}'.format(args.source))
1234
+
1235
+ if args.admin1_region and not args.country:
1236
+ raise ValueError('--admin1_region requires --country to be specified')
1237
+
1238
+ if args.skip_images and args.skip_video:
1239
+ raise ValueError('Cannot skip both images and videos')
1240
+
1241
+ if (args.frame_sample is not None) and (args.time_sample is not None):
1242
+ raise ValueError('--frame_sample and --time_sample are mutually exclusive')
1243
+ if (args.frame_sample is None) and (args.time_sample is None):
1244
+ args.time_sample = DEAFULT_SECONDS_PER_VIDEO_FRAME
1245
+
1246
+ # Set up intermediate file folder
1247
+ if args.intermediate_file_folder:
1248
+ temp_folder = args.intermediate_file_folder
1249
+ os.makedirs(temp_folder, exist_ok=True)
1250
+ else:
1251
+ temp_folder = make_temp_folder(subfolder='run_md_and_speciesnet')
1252
+
1253
+ start_time = time.time()
1254
+
1255
+ print('Processing folder: {}'.format(args.source))
1256
+ print('Output file: {}'.format(args.output_file))
1257
+ print('Intermediate files: {}'.format(temp_folder))
1258
+
1259
+ # Determine detector output file path
1260
+ if args.detections_file:
1261
+ detector_output_file = args.detections_file
1262
+ print('Using existing detections file: {}'.format(detector_output_file))
1263
+ validation_options = ValidateBatchResultsOptions()
1264
+ validation_options.check_image_existence = True
1265
+ validation_options.relative_path_base = args.source
1266
+ validation_options.raise_errors = True
1267
+ validate_batch_results(detector_output_file,options=validation_options)
1268
+ print('Validated detections file')
1269
+ else:
1270
+ detector_output_file = os.path.join(temp_folder, 'detector_output.json')
1271
+
1272
+ # Run MegaDetector
1273
+ _run_detection_step(
1274
+ source_folder=args.source,
1275
+ detector_output_file=detector_output_file,
1276
+ detector_model=args.detector_model,
1277
+ detector_batch_size=args.detector_batch_size,
1278
+ detection_confidence_threshold=args.detection_confidence_threshold_for_output,
1279
+ detector_worker_threads=args.loader_workers,
1280
+ skip_images=args.skip_images,
1281
+ skip_video=args.skip_video,
1282
+ frame_sample=args.frame_sample,
1283
+ time_sample=args.time_sample
1284
+ )
1285
+
1286
+ # Run SpeciesNet
1287
+ _run_classification_step(
1288
+ detector_results_file=detector_output_file,
1289
+ merged_results_file=args.output_file,
1290
+ source_folder=args.source,
1291
+ classifier_model=args.classification_model,
1292
+ classifier_batch_size=args.classifier_batch_size,
1293
+ classifier_worker_threads=args.loader_workers,
1294
+ detection_confidence_threshold=args.detection_confidence_threshold_for_classification,
1295
+ enable_rollup=(not args.norollup),
1296
+ country=args.country,
1297
+ admin1_region=args.admin1_region,
1298
+ )
1299
+
1300
+ elapsed_time = time.time() - start_time
1301
+ print(
1302
+ 'Processing complete in {}'.format(humanfriendly.format_timespan(elapsed_time)))
1303
+ print('Results written to: {}'.format(args.output_file))
1304
+
1305
+ # Clean up intermediate files if requested
1306
+ if (not args.keep_intermediate_files) and \
1307
+ (not args.intermediate_file_folder) and \
1308
+ (not args.detections_file):
1309
+ try:
1310
+ os.remove(detector_output_file)
1311
+ except Exception as e:
1312
+ print('Warning: error removing temporary output file {}: {}'.format(
1313
+ detector_output_file, str(e)))
1314
+
1315
+ # ...def main(...)
1316
+
1317
+
1318
+ if __name__ == '__main__':
1319
+ main()