megadetector 10.0.2__py3-none-any.whl → 10.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (30) hide show
  1. megadetector/data_management/animl_to_md.py +158 -0
  2. megadetector/data_management/zamba_to_md.py +188 -0
  3. megadetector/detection/process_video.py +165 -946
  4. megadetector/detection/pytorch_detector.py +575 -276
  5. megadetector/detection/run_detector_batch.py +629 -202
  6. megadetector/detection/run_md_and_speciesnet.py +1319 -0
  7. megadetector/detection/video_utils.py +243 -107
  8. megadetector/postprocessing/classification_postprocessing.py +12 -1
  9. megadetector/postprocessing/combine_batch_outputs.py +2 -0
  10. megadetector/postprocessing/compare_batch_results.py +21 -2
  11. megadetector/postprocessing/merge_detections.py +16 -12
  12. megadetector/postprocessing/separate_detections_into_folders.py +1 -1
  13. megadetector/postprocessing/subset_json_detector_output.py +1 -3
  14. megadetector/postprocessing/validate_batch_results.py +25 -2
  15. megadetector/tests/__init__.py +0 -0
  16. megadetector/tests/test_nms_synthetic.py +335 -0
  17. megadetector/utils/ct_utils.py +69 -5
  18. megadetector/utils/extract_frames_from_video.py +303 -0
  19. megadetector/utils/md_tests.py +583 -524
  20. megadetector/utils/path_utils.py +4 -15
  21. megadetector/utils/wi_utils.py +20 -4
  22. megadetector/visualization/visualization_utils.py +1 -1
  23. megadetector/visualization/visualize_db.py +8 -22
  24. megadetector/visualization/visualize_detector_output.py +7 -5
  25. megadetector/visualization/visualize_video_output.py +607 -0
  26. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/METADATA +134 -135
  27. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/RECORD +30 -23
  28. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/top_level.txt +0 -0
  30. {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/WHEEL +0 -0
@@ -14,7 +14,6 @@ import math
14
14
  import zipfile
15
15
  import tempfile
16
16
  import shutil
17
- import traceback
18
17
  import uuid
19
18
  import json
20
19
  import inspect
@@ -24,11 +23,13 @@ import torch
24
23
  import numpy as np
25
24
 
26
25
  from megadetector.detection.run_detector import \
27
- CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, \
26
+ CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, FAILURE_IMAGE_OPEN, \
28
27
  get_detector_version_from_model_file, \
29
28
  known_models
30
29
  from megadetector.utils.ct_utils import parse_bool_string
30
+ from megadetector.utils.ct_utils import is_running_in_gha
31
31
  from megadetector.utils import ct_utils
32
+ import torchvision
32
33
 
33
34
  # We support a few ways of accessing the YOLOv5 dependencies:
34
35
  #
@@ -175,7 +176,7 @@ def _initialize_yolo_imports_for_model(model_file,
175
176
  return model_type
176
177
 
177
178
 
178
- def _clean_yolo_imports(verbose=False,aggressive_cleanup=False):
179
+ def _clean_yolo_imports(verbose=False, aggressive_cleanup=False):
179
180
  """
180
181
  Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
181
182
  of another YOLO library version. The reason we jump through all these hoops, rather than
@@ -454,14 +455,14 @@ def _initialize_yolo_imports(model_type='yolov5',
454
455
  try:
455
456
 
456
457
  # import pre- and post-processing functions from the YOLOv5 repo
457
- from utils.general import non_max_suppression, xyxy2xywh # noqa
458
- from utils.augmentations import letterbox # noqa
458
+ from utils.general import non_max_suppression, xyxy2xywh # type: ignore
459
+ from utils.augmentations import letterbox # type: ignore
459
460
 
460
461
  # scale_coords() is scale_boxes() in some YOLOv5 versions
461
462
  try:
462
- from utils.general import scale_coords # noqa
463
+ from utils.general import scale_coords # type: ignore
463
464
  except ImportError:
464
- from utils.general import scale_boxes as scale_coords
465
+ from utils.general import scale_boxes as scale_coords # type: ignore
465
466
  utils_imported = True
466
467
  imported_file = sys.modules[scale_coords.__module__].__file__
467
468
  if verbose:
@@ -482,6 +483,121 @@ def _initialize_yolo_imports(model_type='yolov5',
482
483
  # ...def _initialize_yolo_imports(...)
483
484
 
484
485
 
486
+ #%% NMS
487
+
488
+ def nms(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
489
+ """
490
+ Non-maximum suppression (a wrapper around torchvision.ops.nms())
491
+
492
+ Args:
493
+ prediction (torch.Tensor): Model predictions with shape [batch_size, num_anchors, num_classes + 5]
494
+ Format: [x_center, y_center, width, height, objectness, class1_conf, class2_conf, ...]
495
+ Coordinates are normalized to input image size.
496
+ conf_thres (float): Confidence threshold for filtering detections
497
+ iou_thres (float): IoU threshold for NMS
498
+ max_det (int): Maximum number of detections per image
499
+
500
+ Returns:
501
+ list: List of tensors, one per image in batch. Each tensor has shape [N, 6] where:
502
+ - N is the number of detections for that image
503
+ - Columns are [x1, y1, x2, y2, confidence, class_id]
504
+ - Coordinates are in absolute pixels relative to input image size
505
+ - class_id is the integer class index (0-based)
506
+ """
507
+
508
+ batch_size = prediction.shape[0]
509
+ num_classes = prediction.shape[2] - 5 # noqa
510
+ output = []
511
+
512
+ # Process each image in the batch
513
+ for img_idx in range(batch_size):
514
+
515
+ x = prediction[img_idx] # Shape: [num_anchors, num_classes + 5]
516
+
517
+ # Filter by objectness confidence
518
+ obj_conf = x[:, 4]
519
+ valid_detections = obj_conf > conf_thres
520
+ x = x[valid_detections]
521
+
522
+ if x.shape[0] == 0:
523
+ # No detections for this image
524
+ output.append(torch.zeros((0, 6), device=prediction.device))
525
+ continue
526
+
527
+ # Convert box coordinates from [x_center, y_center, w, h] to [x1, y1, x2, y2]
528
+ box = x[:, :4].clone()
529
+ box[:, 0] = x[:, 0] - x[:, 2] / 2.0 # x1 = center_x - width/2
530
+ box[:, 1] = x[:, 1] - x[:, 3] / 2.0 # y1 = center_y - height/2
531
+ box[:, 2] = x[:, 0] + x[:, 2] / 2.0 # x2 = center_x + width/2
532
+ box[:, 3] = x[:, 1] + x[:, 3] / 2.0 # y2 = center_y + height/2
533
+
534
+ # Get class predictions: multiply objectness by class probabilities
535
+ class_conf = x[:, 5:] * x[:, 4:5] # shape: [N, num_classes]
536
+
537
+ # For each detection, take the class with highest confidence (single-label)
538
+ best_class_conf, best_class_idx = class_conf.max(1, keepdim=True)
539
+
540
+ # Filter by class confidence threshold
541
+ conf_mask = best_class_conf.view(-1) > conf_thres
542
+ if conf_mask.sum() == 0:
543
+ # No detections pass confidence threshold
544
+ output.append(torch.zeros((0, 6), device=prediction.device))
545
+ continue
546
+
547
+ box = box[conf_mask]
548
+ best_class_conf = best_class_conf[conf_mask]
549
+ best_class_idx = best_class_idx[conf_mask]
550
+
551
+ # Prepare for NMS: group detections by class
552
+ unique_classes = best_class_idx.unique()
553
+ final_detections = []
554
+
555
+ for class_id in unique_classes:
556
+
557
+ class_mask = (best_class_idx == class_id).view(-1)
558
+ class_boxes = box[class_mask]
559
+ class_scores = best_class_conf[class_mask].view(-1)
560
+
561
+ if class_boxes.shape[0] == 0:
562
+ continue
563
+
564
+ # Apply NMS for this class
565
+ keep_indices = torchvision.ops.nms(class_boxes, class_scores, iou_thres)
566
+
567
+ if len(keep_indices) > 0:
568
+ kept_boxes = class_boxes[keep_indices]
569
+ kept_scores = class_scores[keep_indices]
570
+ kept_classes = torch.full((len(keep_indices), 1), class_id.item(),
571
+ device=prediction.device, dtype=torch.float)
572
+
573
+ # Combine: [x1, y1, x2, y2, conf, class]
574
+ class_detections = torch.cat([kept_boxes, kept_scores.unsqueeze(1), kept_classes], 1)
575
+ final_detections.append(class_detections)
576
+
577
+ # ...for each category
578
+
579
+ if final_detections:
580
+
581
+ # Combine all classes and sort by confidence
582
+ all_detections = torch.cat(final_detections, 0)
583
+ conf_sort_indices = all_detections[:, 4].argsort(descending=True)
584
+ all_detections = all_detections[conf_sort_indices]
585
+
586
+ # Limit to max_det
587
+ if all_detections.shape[0] > max_det:
588
+ all_detections = all_detections[:max_det]
589
+
590
+ output.append(all_detections)
591
+ else:
592
+ output.append(torch.zeros((0, 6), device=prediction.device))
593
+
594
+ # ...for each image in the batch
595
+
596
+ return output
597
+
598
+ # ...def nms(...)
599
+
600
+
485
601
  #%% Model metadata functions
486
602
 
487
603
  def add_metadata_to_megadetector_model_file(model_file_in,
@@ -593,6 +709,8 @@ def read_metadata_from_megadetector_model_file(model_file,
593
709
 
594
710
  return d
595
711
 
712
+ # ...with zipfile.Zipfile(...)
713
+
596
714
  # ...def read_metadata_from_megadetector_model_file(...)
597
715
 
598
716
 
@@ -606,10 +724,15 @@ require_non_default_compatibility_mode = False
606
724
 
607
725
  class PTDetector:
608
726
  """
609
- Class that runs a PyTorch-based MegaDetector model.
727
+ Class that runs a PyTorch-based MegaDetector model. Also used as a preprocessor
728
+ for images that will later be run through an instance of PTDetector.
610
729
  """
611
730
 
612
731
  def __init__(self, model_path, detector_options=None, verbose=False):
732
+ """
733
+ PTDetector constructor. If detector_options['preprocess_only'] exists and is
734
+ True, this instance is being used as a preprocessor, so we don't load model weights.
735
+ """
613
736
 
614
737
  if verbose:
615
738
  print('Initializing PTDetector (verbose)')
@@ -637,6 +760,8 @@ class PTDetector:
637
760
  else:
638
761
  compatibility_mode = detector_options['compatibility_mode']
639
762
 
763
+ # This is a global option used only during testing, to make sure I'm hitting
764
+ # the cases where we are not using "classic" preprocessing.
640
765
  if require_non_default_compatibility_mode:
641
766
 
642
767
  print('### DEBUG: requiring non-default compatibility mode ###')
@@ -652,14 +777,16 @@ class PTDetector:
652
777
  if verbose or (not preprocess_only):
653
778
  print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
654
779
 
655
- model_metadata = read_metadata_from_megadetector_model_file(model_path)
780
+ self.model_metadata = read_metadata_from_megadetector_model_file(model_path)
656
781
 
657
- #: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
658
- #: aspect ratio".
659
- if model_metadata is not None and 'image_size' in model_metadata:
660
- self.default_image_size = model_metadata['image_size']
782
+ #: Image size passed to the letterbox() function; 1280 means "1280 on the long side,
783
+ #: preserving aspect ratio".
784
+ if self.model_metadata is not None and 'image_size' in self.model_metadata:
785
+ self.default_image_size = self.model_metadata['image_size']
661
786
  print('Loaded image size {} from model metadata'.format(self.default_image_size))
662
787
  else:
788
+ # This is not the default for most YOLO models, but most of the time, if someone
789
+ # is loading a model here that does not have metadata, it's MDv5[ab].0.0
663
790
  print('No image size available in model metadata, defaulting to 1280')
664
791
  self.default_image_size = 1280
665
792
 
@@ -682,12 +809,27 @@ class PTDetector:
682
809
  #: "classic".
683
810
  self.compatibility_mode = compatibility_mode
684
811
 
685
- #: Stride size passed to YOLOv5's letterbox() function
812
+ #: Stride size passed to the YOLO letterbox() function
686
813
  self.letterbox_stride = 32
687
814
 
688
- if 'classic' in self.compatibility_mode:
815
+ # This is a convenient heuristic to determine the stride size without actually loading
816
+ # the model: the only models in the YOLO family with a stride size of 64 are the
817
+ # YOLOv5*6 and YOLOv5*6u models, which are 1280px models.
818
+ #
819
+ # See:
820
+ #
821
+ # github.com/ultralytics/ultralytics/issues/21544
822
+ #
823
+ # Note to self, though, if I decide later to require loading the model on preprocessing
824
+ # workers so I can more reliably choose a stride, this is the right way to determine the
825
+ # stride:
826
+ #
827
+ # self.letterbox_stride = int(self.model.stride.max())
828
+ if self.default_image_size == 1280:
689
829
  self.letterbox_stride = 64
690
830
 
831
+ print('Using model stride: {}'.format(self.letterbox_stride))
832
+
691
833
  #: Use half-precision inference... fixed by the model, generally don't mess with this
692
834
  self.half_precision = False
693
835
 
@@ -699,9 +841,16 @@ class PTDetector:
699
841
  self.device = torch.device('cuda:0')
700
842
  try:
701
843
  if torch.backends.mps.is_built and torch.backends.mps.is_available():
702
- self.device = 'mps'
844
+ # MPS inference fails on GitHub runners as of 2025.08. This is
845
+ # independent of model size. So, we disable MPS when running in GHA.
846
+ if is_running_in_gha():
847
+ print('GitHub actions detected, bypassing MPS backend')
848
+ else:
849
+ print('Using MPS device')
850
+ self.device = 'mps'
703
851
  except AttributeError:
704
852
  pass
853
+
705
854
  try:
706
855
  self.model = PTDetector._load_model(model_path,
707
856
  device=self.device,
@@ -714,7 +863,7 @@ class PTDetector:
714
863
  # to be the same, so doing that externally doesn't seem *that* rude.
715
864
  if "Can't get attribute 'DetectionModel'" in str(e):
716
865
  print('Forward-compatibility issue detected, patching')
717
- from models import yolo
866
+ from models import yolo # type: ignore
718
867
  yolo.DetectionModel = yolo.Model
719
868
  self.model = PTDetector._load_model(model_path,
720
869
  device=self.device,
@@ -734,16 +883,10 @@ class PTDetector:
734
883
  if verbose:
735
884
  print(f'Using PyTorch version {torch.__version__}')
736
885
 
737
- # There are two very slightly different ways to load the model, (1) using the
738
- # map_location=device parameter to torch.load and (2) calling .to(device) after
739
- # loading the model. The former is what we did for a zillion years, but is not
740
- # supported on Apple silicon at of 2029.09. Switching to the latter causes
741
- # very slight changes to the output, which always make me nervous, so I'm not
742
- # doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
743
- if 'classic' in compatibility_mode:
744
- use_map_location = (device != 'mps')
745
- else:
746
- use_map_location = False
886
+ # I get quirky errors when loading YOLOv5 models on MPS hardware using
887
+ # map_location, but this is the recommended method, so I'm using it everywhere
888
+ # other than MPS devices.
889
+ use_map_location = (device != 'mps')
747
890
 
748
891
  if use_map_location:
749
892
  try:
@@ -773,297 +916,429 @@ class PTDetector:
773
916
  if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
774
917
  m.recompute_scale_factor = None
775
918
 
776
- if use_map_location:
777
- model = checkpoint['model'].float().fuse().eval()
778
- else:
779
- model = checkpoint['model'].float().fuse().eval().to(device)
919
+ # Calling .to(device) should no longer be necessary now that we're using map_location=device
920
+ # model = checkpoint['model'].float().fuse().eval().to(device)
921
+ model = checkpoint['model'].float().fuse().eval()
780
922
 
781
923
  return model
782
924
 
783
925
  # ...def _load_model(...)
784
926
 
785
927
 
786
- def generate_detections_one_image(self,
787
- img_original,
788
- image_id='unknown',
789
- detection_threshold=0.00001,
790
- image_size=None,
791
- skip_image_resizing=False,
792
- augment=False,
793
- preprocess_only=False,
794
- verbose=False):
928
+ def preprocess_image(self,
929
+ img_original,
930
+ image_id='unknown',
931
+ image_size=None,
932
+ verbose=False):
795
933
  """
796
- Applies the detector to an image.
934
+ Prepare an image for detection, including scaling and letterboxing.
797
935
 
798
936
  Args:
799
- img_original (Image): the PIL Image object (or numpy array) on which we should run the
800
- detector, with EXIF rotation already handled
937
+ img_original (Image or np.array): the image on which we should run the detector, with
938
+ EXIF rotation already handled
801
939
  image_id (str, optional): a path to identify the image; will be in the "file" field
802
940
  of the output object
803
941
  detection_threshold (float, optional): only detections above this confidence threshold
804
942
  will be included in the return value
805
- image_size (int, optional): image size to use for inference, only mess with this if
806
- (a) you're using a model other than MegaDetector or (b) you know what you're getting into
807
- skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
808
- external resizing), only mess with this if (a) you're using a model other than MegaDetector
809
- or (b) you know what you're getting into
810
- augment (bool, optional): enable (implementation-specific) image augmentation
811
- preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
943
+ image_size (int, optional): image size (long side) to use for inference, or None to
944
+ use the default size specified at the time the model was loaded
812
945
  verbose (bool, optional): enable additional debug output
813
946
 
814
947
  Returns:
815
- dict: a dictionary with the following fields:
816
- - 'file' (filename, always present)
817
- - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
818
- - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
819
- - 'failure' (a failure string, or None if everything went fine)
948
+ dict: dict with fields:
949
+ - file (filename)
950
+ - img (the preprocessed np.array)
951
+ - img_original (the input image before preprocessing, as an np.array)
952
+ - img_original_pil (the input image before preprocessing, as a PIL Image)
953
+ - target_shape (the 2D shape to which the image was resized during preprocessing)
954
+ - scaling_shape (the 2D original size, for normalizing coordinates later)
955
+ - letterbox_ratio (letterbox parameter used for normalizing coordinates later)
956
+ - letterbox_pad (letterbox parameter used for normalizing coordinates later)
820
957
  """
821
958
 
959
+ # Prepare return dict
822
960
  result = {'file': image_id }
823
- detections = []
824
- max_conf = 0.0
825
961
 
826
- if preprocess_only:
827
- assert 'classic' in self.compatibility_mode, \
828
- 'Standalone preprocessing only supported in "classic" mode'
829
- assert not skip_image_resizing, \
830
- 'skip_image_resizing and preprocess_only are exclusive'
962
+ # Store the PIL version of the original image, the caller may want to use
963
+ # it for metadata extraction later.
964
+ img_original_pil = None
831
965
 
832
- if detection_threshold is None:
966
+ # If we were given a PIL image, rather than a numpy array
967
+ if not isinstance(img_original,np.ndarray):
968
+ img_original_pil = img_original
969
+ img_original = np.asarray(img_original)
833
970
 
834
- detection_threshold = 0
971
+ # PIL images are RGB already
972
+ # img_original = img_original[:, :, ::-1]
835
973
 
836
- try:
974
+ # Save the original shape for scaling boxes later
975
+ scaling_shape = img_original.shape
837
976
 
838
- # If the caller wants us to skip all the resizing operations...
839
- if skip_image_resizing:
840
-
841
- if isinstance(img_original,dict):
842
- image_info = img_original
843
- img = image_info['img_processed']
844
- scaling_shape = image_info['scaling_shape']
845
- letterbox_pad = image_info['letterbox_pad']
846
- letterbox_ratio = image_info['letterbox_ratio']
847
- img_original = image_info['img_original']
848
- img_original_pil = image_info['img_original_pil']
849
- else:
850
- img = img_original
977
+ # If the caller is requesting a specific target size...
978
+ if image_size is not None:
851
979
 
852
- else:
980
+ assert isinstance(image_size,int)
853
981
 
854
- img_original_pil = None
855
- # If we were given a PIL image
982
+ if not self.printed_image_size_warning:
983
+ print('Using user-supplied image size {}'.format(image_size))
984
+ self.printed_image_size_warning = True
856
985
 
857
- if not isinstance(img_original,np.ndarray):
858
- img_original_pil = img_original
859
- img_original = np.asarray(img_original)
986
+ # Otherwise resize to self.default_image_size
987
+ else:
988
+
989
+ image_size = self.default_image_size
990
+ self.printed_image_size_warning = False
991
+
992
+ # ...if the caller has specified an image size
860
993
 
861
- # PIL images are RGB already
862
- # img_original = img_original[:, :, ::-1]
994
+ # In "classic mode", we only do the letterboxing resize, we don't do an
995
+ # additional initial resizing operation
996
+ if 'classic' in self.compatibility_mode:
997
+
998
+ resize_ratio = 1.0
863
999
 
864
- # Save the original shape for scaling boxes later
865
- scaling_shape = img_original.shape
1000
+ # Resize the image so the long side matches the target image size. This is not
1001
+ # letterboxing (i.e., padding) yet, just resizing.
1002
+ else:
866
1003
 
867
- # If the caller is requesting a specific target size...
868
- if image_size is not None:
1004
+ use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
869
1005
 
870
- assert isinstance(image_size,int)
1006
+ h,w = img_original.shape[:2]
1007
+ resize_ratio = image_size / max(h,w)
871
1008
 
872
- if not self.printed_image_size_warning:
873
- print('Using user-supplied image size {}'.format(image_size))
874
- self.printed_image_size_warning = True
1009
+ # Only resize if we have to
1010
+ if resize_ratio != 1:
875
1011
 
876
- # Otherwise resize to self.default_image_size
1012
+ # Match what yolov5 does: use linear interpolation for upsizing;
1013
+ # area interpolation for downsizing
1014
+ if resize_ratio > 1:
1015
+ interpolation_method = cv2.INTER_LINEAR
877
1016
  else:
1017
+ interpolation_method = cv2.INTER_AREA
878
1018
 
879
- image_size = self.default_image_size
880
- self.printed_image_size_warning = False
1019
+ if use_ceil_for_resize:
1020
+ target_w = math.ceil(w * resize_ratio)
1021
+ target_h = math.ceil(h * resize_ratio)
1022
+ else:
1023
+ target_w = int(w * resize_ratio)
1024
+ target_h = int(h * resize_ratio)
881
1025
 
882
- # ...if the caller has specified an image size
1026
+ img_original = cv2.resize(
1027
+ img_original, (target_w, target_h),
1028
+ interpolation=interpolation_method)
883
1029
 
884
- # In "classic mode", we only do the letterboxing resize, we don't do an
885
- # additional initial resizing operation
886
- if 'classic' in self.compatibility_mode:
1030
+ if 'classic' in self.compatibility_mode:
887
1031
 
888
- resize_ratio = 1.0
1032
+ letterbox_auto = True
1033
+ letterbox_scaleup = True
1034
+ target_shape = image_size
889
1035
 
890
- # Resize the image so the long side matches the target image size. This is not
891
- # letterboxing (i.e., padding) yet, just resizing.
892
- else:
1036
+ else:
893
1037
 
894
- use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
1038
+ letterbox_auto = False
1039
+ letterbox_scaleup = False
895
1040
 
896
- h,w = img_original.shape[:2]
897
- resize_ratio = image_size / max(h,w)
1041
+ # The padding to apply as a fraction of the stride size
1042
+ pad = 0.5
898
1043
 
899
- # Only resize if we have to
900
- if resize_ratio != 1:
1044
+ # Resize to a multiple of the model stride
1045
+ #
1046
+ # This is how we would determine the stride if we knew the model had been loaded:
1047
+ #
1048
+ # model_stride = int(self.model.stride.max())
1049
+ #
1050
+ # ...but because we do this on preprocessing workers now, we try to avoid loading the model
1051
+ # just for preprocessing, and we assume the stride was determined at the time the PTDetector
1052
+ # object was created.
1053
+ try:
1054
+ model_stride = int(self.model.stride.max())
1055
+ if model_stride != self.letterbox_stride:
1056
+ print('*** Warning: model stride is {}, stride at construction time was {} ***'.format(
1057
+ model_stride,self.letterbox_stride
1058
+ ))
1059
+ except Exception:
1060
+ pass
901
1061
 
902
- # Match what yolov5 does: use linear interpolation for upsizing;
903
- # area interpolation for downsizing
904
- if resize_ratio > 1:
905
- interpolation_method = cv2.INTER_LINEAR
906
- else:
907
- interpolation_method = cv2.INTER_AREA
1062
+ model_stride = self.letterbox_stride
1063
+ max_dimension = max(img_original.shape)
1064
+ normalized_shape = [img_original.shape[0] / max_dimension,
1065
+ img_original.shape[1] / max_dimension]
1066
+ target_shape = np.ceil(((np.array(normalized_shape) * image_size) / model_stride) + \
1067
+ pad).astype(int) * model_stride
1068
+
1069
+ # Now we letterbox, which is just padding, since we've already resized
1070
+ img,letterbox_ratio,letterbox_pad = letterbox(img_original,
1071
+ new_shape=target_shape,
1072
+ stride=self.letterbox_stride,
1073
+ auto=letterbox_auto,
1074
+ scaleFill=False,
1075
+ scaleup=letterbox_scaleup)
1076
+
1077
+ result['img_processed'] = img
1078
+ result['img_original'] = img_original
1079
+ result['img_original_pil'] = img_original_pil
1080
+ result['target_shape'] = target_shape
1081
+ result['scaling_shape'] = scaling_shape
1082
+ result['letterbox_ratio'] = letterbox_ratio
1083
+ result['letterbox_pad'] = letterbox_pad
1084
+ return result
908
1085
 
909
- if use_ceil_for_resize:
910
- target_w = math.ceil(w * resize_ratio)
911
- target_h = math.ceil(h * resize_ratio)
912
- else:
913
- target_w = int(w * resize_ratio)
914
- target_h = int(h * resize_ratio)
1086
+ # ...def preprocess_image(...)
915
1087
 
916
- img_original = cv2.resize(
917
- img_original, (target_w, target_h),
918
- interpolation=interpolation_method)
919
1088
 
920
- if 'classic' in self.compatibility_mode:
1089
+ def generate_detections_one_batch(self,
1090
+ img_original,
1091
+ image_id=None,
1092
+ detection_threshold=0.00001,
1093
+ image_size=None,
1094
+ augment=False,
1095
+ verbose=False):
1096
+ """
1097
+ Run a detector on a batch of images.
1098
+
1099
+ Args:
1100
+ img_original (list): list of images (Image, np.array, or dict) on which we should run the detector, with
1101
+ EXIF rotation already handled, or dicts representing preprocessed images with associated
1102
+ letterbox parameters
1103
+ image_id (list or None): list of paths to identify the images; will be in the "file" field
1104
+ of the output objects. Will be ignored when img_original contains preprocessed dicts.
1105
+ detection_threshold (float, optional): only detections above this confidence threshold
1106
+ will be included in the return value
1107
+ image_size (int, optional): image size (long side) to use for inference, or None to
1108
+ use the default size specified at the time the model was loaded
1109
+ augment (bool, optional): enable (implementation-specific) image augmentation
1110
+ verbose (bool, optional): enable additional debug output
1111
+
1112
+ Returns:
1113
+ list: a list of dictionaries, each with the following fields:
1114
+ - 'file' (filename, always present)
1115
+ - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
1116
+ - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
1117
+ - 'failure' (a failure string, or None if everything went fine)
1118
+ """
1119
+
1120
+ # Validate inputs
1121
+ if not isinstance(img_original, list):
1122
+ raise ValueError('img_original must be a list for batch processing')
1123
+
1124
+ if len(img_original) == 0:
1125
+ return []
1126
+
1127
+ # Check input consistency
1128
+ if isinstance(img_original[0], dict):
1129
+ # All items in img_original should be preprocessed dicts
1130
+ for i, img in enumerate(img_original):
1131
+ if not isinstance(img, dict):
1132
+ raise ValueError(f'Mixed input types in batch: item {i} is not a dict, but item 0 is a dict')
1133
+ else:
1134
+ # All items in img_original should be PIL/numpy images, and image_id should be a list of strings
1135
+ if image_id is None:
1136
+ raise ValueError('image_id must be a list when img_original contains PIL/numpy images')
1137
+ if not isinstance(image_id, list):
1138
+ raise ValueError('image_id must be a list for batch processing')
1139
+ if len(image_id) != len(img_original):
1140
+ raise ValueError(
1141
+ 'Length mismatch: img_original has {} items, image_id has {} items'.format(
1142
+ len(img_original),len(image_id)))
1143
+ for i_img, img in enumerate(img_original):
1144
+ if isinstance(img, dict):
1145
+ raise ValueError(
1146
+ 'Mixed input types in batch: item {} is a dict, but item 0 is not a dict'.format(
1147
+ i_img))
1148
+
1149
+ if detection_threshold is None:
1150
+ detection_threshold = 0.0
921
1151
 
922
- letterbox_auto = True
923
- letterbox_scaleup = True
924
- target_shape = image_size
1152
+ batch_size = len(img_original)
1153
+ results = [None] * batch_size
925
1154
 
1155
+ # Preprocess all images, handling failures
1156
+ preprocessed_images = []
1157
+ preprocessing_failed_indices = set()
1158
+
1159
+ for i_img, img in enumerate(img_original):
1160
+
1161
+ try:
1162
+ if isinstance(img, dict):
1163
+ # Already preprocessed
1164
+ image_info = img
1165
+ current_image_id = image_info['file']
926
1166
  else:
1167
+ # Need to preprocess
1168
+ current_image_id = image_id[i_img]
1169
+ image_info = self.preprocess_image(
1170
+ img_original=img,
1171
+ image_id=current_image_id,
1172
+ image_size=image_size,
1173
+ verbose=verbose)
927
1174
 
928
- letterbox_auto = False
929
- letterbox_scaleup = False
1175
+ preprocessed_images.append((i_img, image_info, current_image_id))
930
1176
 
931
- # The padding to apply as a fraction of the stride size
932
- pad = 0.5
1177
+ except Exception as e:
1178
+ print('Warning: preprocessing failed for image {}: {}'.format(
1179
+ image_id[i_img] if image_id else f'index_{i_img}', str(e)))
1180
+
1181
+ preprocessing_failed_indices.add(i_img)
1182
+ current_image_id = image_id[i_img] if image_id else f'index_{i_img}'
1183
+ results[i_img] = {
1184
+ 'file': current_image_id,
1185
+ 'detections': None,
1186
+ 'failure': FAILURE_IMAGE_OPEN
1187
+ }
1188
+
1189
+ # ...for each image in this batch
1190
+
1191
+ # Group preprocessed images by actual processed image shape for batching
1192
+ shape_groups = {}
1193
+ for original_idx, image_info, current_image_id in preprocessed_images:
1194
+ # Use the actual processed image shape for grouping, not target_shape
1195
+ actual_shape = tuple(image_info['img_processed'].shape)
1196
+ if actual_shape not in shape_groups:
1197
+ shape_groups[actual_shape] = []
1198
+ shape_groups[actual_shape].append((original_idx, image_info, current_image_id))
1199
+
1200
+ # Process each shape group as a batch
1201
+ for target_shape, group_items in shape_groups.items():
933
1202
 
934
- model_stride = int(self.model.stride.max())
1203
+ try:
1204
+ self._process_batch_group(group_items, results, detection_threshold, augment, verbose)
1205
+ except Exception as e:
1206
+ # If inference fails for the entire batch, mark all images in this batch as failed
1207
+ print('Warning: batch inference failed for shape {}: {}'.format(target_shape, str(e)))
935
1208
 
936
- max_dimension = max(img_original.shape)
937
- normalized_shape = [img_original.shape[0] / max_dimension,
938
- img_original.shape[1] / max_dimension]
939
- target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
940
- pad).astype(int) * model_stride
1209
+ for original_idx, image_info, current_image_id in group_items:
1210
+ results[original_idx] = {
1211
+ 'file': current_image_id,
1212
+ 'detections': None,
1213
+ 'failure': FAILURE_INFER
1214
+ }
941
1215
 
942
- # Now we letterbox, which is just padding, since we've already resized.
943
- img,letterbox_ratio,letterbox_pad = letterbox(img_original,
944
- new_shape=target_shape,
945
- stride=self.letterbox_stride,
946
- auto=letterbox_auto,
947
- scaleFill=False,
948
- scaleup=letterbox_scaleup)
1216
+ # ...for each shape group
1217
+ return results
949
1218
 
950
- if preprocess_only:
1219
+ # ...def generate_detections_one_batch(...)
951
1220
 
952
- assert 'file' in result
953
- result['img_processed'] = img
954
- result['img_original'] = img_original
955
- result['img_original_pil'] = img_original_pil
956
- result['target_shape'] = target_shape
957
- result['scaling_shape'] = scaling_shape
958
- result['letterbox_ratio'] = letterbox_ratio
959
- result['letterbox_pad'] = letterbox_pad
960
- return result
961
1221
 
962
- # ...are we doing resizing here, or were images already resized?
1222
+ def _process_batch_group(self, group_items, results, detection_threshold, augment, verbose):
1223
+ """
1224
+ Process a group of images with the same target shape as a single batch.
963
1225
 
964
- # Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
965
- # so we don't need to mess with the color channels.
966
- #
967
- # TODO, this could be moved into the preprocessing loop
968
-
969
- img = img.transpose((2, 0, 1)) # [::-1]
970
- img = np.ascontiguousarray(img)
971
- img = torch.from_numpy(img)
972
- img = img.to(self.device)
973
- img = img.half() if self.half_precision else img.float()
974
- img /= 255
975
-
976
- # In practice this is always true
977
- if len(img.shape) == 3:
978
- img = torch.unsqueeze(img, 0)
979
-
980
- # Run the model
981
- pred = self.model(img,augment=augment)[0]
982
-
983
- if 'classic' in self.compatibility_mode:
984
- nms_conf_thres = detection_threshold
985
- nms_iou_thres = 0.45
986
- nms_agnostic = False
987
- nms_multi_label = False
988
- else:
989
- nms_conf_thres = detection_threshold # 0.01
990
- nms_iou_thres = 0.6
991
- nms_agnostic = False
992
- nms_multi_label = True
1226
+ Args:
1227
+ group_items (list): List of (original_idx, image_info, current_image_id) tuples
1228
+ results (list): Results list to populate (modified in place)
1229
+ detection_threshold (float): Detection confidence threshold
1230
+ augment (bool): Enable augmentation
1231
+ verbose (bool): Enable verbose output
993
1232
 
994
- # As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
995
- #
996
- # Send predictions back to the CPU for NMS.
997
- if self.device == 'mps':
998
- pred_nms = pred.cpu()
999
- else:
1000
- pred_nms = pred
1233
+ Returns:
1234
+ list of dict: list of dictionaries the same length as group_items, with fields 'file',
1235
+ 'detections', 'max_detection_conf'.
1236
+ """
1001
1237
 
1002
- # NMS
1003
- pred = non_max_suppression(prediction=pred_nms,
1004
- conf_thres=nms_conf_thres,
1005
- iou_thres=nms_iou_thres,
1006
- agnostic=nms_agnostic,
1007
- multi_label=nms_multi_label)
1238
+ if len(group_items) == 0:
1239
+ return
1008
1240
 
1009
- # In practice this is [w,h,w,h] of the original image
1010
- gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
1241
+ # Extract batch data
1242
+ batch_images = []
1243
+ batch_metadata = []
1011
1244
 
1012
- if 'classic' in self.compatibility_mode:
1245
+ # For each image in this batch...
1246
+ for original_idx, image_info, current_image_id in group_items:
1013
1247
 
1014
- ratio = None
1015
- ratio_pad = None
1248
+ img = image_info['img_processed']
1016
1249
 
1017
- else:
1250
+ # Convert HWC to CHW and prepare tensor
1251
+ img_tensor = img.transpose((2, 0, 1))
1252
+ img_tensor = np.ascontiguousarray(img_tensor)
1253
+ img_tensor = torch.from_numpy(img_tensor)
1254
+ batch_images.append(img_tensor)
1018
1255
 
1019
- # letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
1020
- #
1021
- # ratio is a 2-tuple specifying the scaling that was applied to each dimension.
1022
- #
1023
- # The scale_boxes function expects a 2-tuple with these things combined.
1024
- ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
1025
- ratio_pad = (ratio, letterbox_pad)
1256
+ metadata = {
1257
+ 'original_idx': original_idx,
1258
+ 'current_image_id': current_image_id,
1259
+ 'scaling_shape': image_info['scaling_shape'],
1260
+ 'letterbox_pad': image_info['letterbox_pad'],
1261
+ 'img_original': image_info['img_original']
1262
+ }
1263
+ batch_metadata.append(metadata)
1026
1264
 
1027
- # This is a loop over detection batches, which will always be length 1 in our case,
1028
- # since we're not doing batch inference.
1029
- #
1030
- # det = pred[0]
1031
- #
1032
- # det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
1033
- # in descending order by confidence.
1034
- #
1035
- # Columns are:
1036
- #
1037
- # x0,y0,x1,y1,confidence,class
1038
- #
1039
- # At this point, these are *non*-normalized values, referring to the size at which we
1040
- # ran inference (img.shape).
1041
- for det in pred:
1265
+ # ...for each image in this batch
1042
1266
 
1043
- if len(det) == 0:
1044
- continue
1267
+ # Stack images into a batch tensor
1268
+ batch_tensor = torch.stack(batch_images)
1045
1269
 
1046
- # Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
1047
- if 'classic' in self.compatibility_mode:
1270
+ batch_tensor = batch_tensor.float()
1271
+ batch_tensor /= 255.0
1272
+
1273
+ batch_tensor = batch_tensor.to(self.device)
1274
+ if self.half_precision:
1275
+ batch_tensor = batch_tensor.half()
1276
+
1277
+ # Run the model on the batch
1278
+ pred = self.model(batch_tensor, augment=augment)[0]
1279
+
1280
+ # Configure NMS parameters
1281
+ if 'classic' in self.compatibility_mode:
1282
+ nms_iou_thres = 0.45
1283
+ else:
1284
+ nms_iou_thres = 0.6
1285
+
1286
+ pred = nms(prediction=pred,
1287
+ conf_thres=detection_threshold,
1288
+ iou_thres=nms_iou_thres)
1289
+
1290
+ # For posterity, the ultralytics implementation
1291
+ if False:
1292
+ pred = non_max_suppression(prediction=pred,
1293
+ conf_thres=detection_threshold,
1294
+ iou_thres=nms_iou_thres,
1295
+ agnostic=False,
1296
+ multi_label=False)
1048
1297
 
1049
- det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
1298
+ assert isinstance(pred, list)
1299
+ assert len(pred) == len(batch_metadata), \
1300
+ print('Mismatch between prediction length {} and batch size {}'.format(
1301
+ len(pred),len(batch_metadata)))
1050
1302
 
1303
+ # Process each image's detections
1304
+ for i_image, det in enumerate(pred):
1305
+
1306
+ metadata = batch_metadata[i_image]
1307
+ original_idx = metadata['original_idx']
1308
+ current_image_id = metadata['current_image_id']
1309
+ scaling_shape = metadata['scaling_shape']
1310
+ letterbox_pad = metadata['letterbox_pad']
1311
+ img_original = metadata['img_original']
1312
+
1313
+ detections = []
1314
+ max_conf = 0.0
1315
+
1316
+ if len(det) > 0:
1317
+
1318
+ # Prepare scaling parameters
1319
+ gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
1320
+
1321
+ if 'classic' in self.compatibility_mode:
1322
+ ratio = None
1323
+ ratio_pad = None
1051
1324
  else:
1052
- # After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
1053
- # original pixel dimension of the image, followed by the class and confidence
1054
- det[:, :4] = scale_coords(img.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
1325
+ ratio = (img_original.shape[0]/scaling_shape[0],
1326
+ img_original.shape[1]/scaling_shape[1])
1327
+ ratio_pad = (ratio, letterbox_pad)
1055
1328
 
1056
- # Loop over detections
1057
- for *xyxy, conf, cls in reversed(det):
1329
+ # Rescale boxes
1330
+ if 'classic' in self.compatibility_mode:
1331
+ det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], img_original.shape).round()
1332
+ else:
1333
+ det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
1058
1334
 
1335
+ # Process each detection
1336
+ for *xyxy, conf, cls in reversed(det):
1059
1337
  if conf < detection_threshold:
1060
1338
  continue
1061
1339
 
1062
- # Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
1340
+ # Convert to YOLO format then to MD format
1063
1341
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
1064
-
1065
- # Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
1066
- # left/top/w/h (i.e., MD format)
1067
1342
  api_box = ct_utils.convert_yolo_to_xywh(xywh)
1068
1343
 
1069
1344
  if 'classic' in self.compatibility_mode:
@@ -1074,8 +1349,6 @@ class PTDetector:
1074
1349
  conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
1075
1350
 
1076
1351
  if not self.use_model_native_classes:
1077
- # The MegaDetector output format's categories start at 1, but all YOLO-based
1078
- # MD models have category numbers starting at 0.
1079
1352
  cls = int(cls.tolist()) + 1
1080
1353
  if cls not in (1, 2, 3):
1081
1354
  raise KeyError(f'{cls} is not a valid class.')
@@ -1089,47 +1362,73 @@ class PTDetector:
1089
1362
  })
1090
1363
  max_conf = max(max_conf, conf)
1091
1364
 
1092
- # ...for each detection in this batch
1093
-
1094
- # ...for each detection batch (always one iteration)
1365
+ # ...for each detection
1095
1366
 
1096
- # ...try
1367
+ # ...if there are > 0 detections
1097
1368
 
1098
- except Exception as e:
1099
-
1100
- result['failure'] = FAILURE_INFER
1101
- print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
1102
- # traceback.print_exc(e)
1103
- print(traceback.format_exc())
1104
-
1105
- result['max_detection_conf'] = max_conf
1106
- result['detections'] = detections
1107
-
1108
- return result
1109
-
1110
- # ...def generate_detections_one_image(...)
1111
-
1112
- # ...class PTDetector
1369
+ # Store result for this image
1370
+ results[original_idx] = {
1371
+ 'file': current_image_id,
1372
+ 'detections': detections,
1373
+ 'max_detection_conf': max_conf
1374
+ }
1113
1375
 
1376
+ # ...for each image
1114
1377
 
1115
- #%% Command-line driver
1378
+ # ...def _process_batch_group(...)
1116
1379
 
1117
- # For testing only... you don't really want to run this module directly.
1118
-
1119
- if __name__ == '__main__':
1380
+ def generate_detections_one_image(self,
1381
+ img_original,
1382
+ image_id='unknown',
1383
+ detection_threshold=0.00001,
1384
+ image_size=None,
1385
+ augment=False,
1386
+ verbose=False):
1387
+ """
1388
+ Run a detector on an image (wrapper around batch function).
1120
1389
 
1121
- pass
1390
+ Args:
1391
+ img_original (Image, np.array, or dict): the image on which we should run the detector, with
1392
+ EXIF rotation already handled, or a dict representing a preprocessed image with associated
1393
+ letterbox parameters
1394
+ image_id (str, optional): a path to identify the image; will be in the "file" field
1395
+ of the output object
1396
+ detection_threshold (float, optional): only detections above this confidence threshold
1397
+ will be included in the return value
1398
+ image_size (int, optional): image size (long side) to use for inference, or None to
1399
+ use the default size specified at the time the model was loaded
1400
+ augment (bool, optional): enable (implementation-specific) image augmentation
1401
+ verbose (bool, optional): enable additional debug output
1122
1402
 
1123
- #%%
1403
+ Returns:
1404
+ dict: a dictionary with the following fields:
1405
+ - 'file' (filename, always present)
1406
+ - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
1407
+ - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
1408
+ - 'failure' (a failure string, or None if everything went fine)
1409
+ """
1124
1410
 
1125
- import os #noqa
1126
- from megadetector.visualization import visualization_utils as vis_utils
1411
+ # Prepare batch inputs
1412
+ if isinstance(img_original, dict):
1413
+ batch_results = self.generate_detections_one_batch(
1414
+ img_original=[img_original],
1415
+ image_id=None,
1416
+ detection_threshold=detection_threshold,
1417
+ image_size=image_size,
1418
+ augment=augment,
1419
+ verbose=verbose)
1420
+ else:
1421
+ batch_results = self.generate_detections_one_batch(
1422
+ img_original=[img_original],
1423
+ image_id=[image_id],
1424
+ detection_threshold=detection_threshold,
1425
+ image_size=image_size,
1426
+ augment=augment,
1427
+ verbose=verbose)
1127
1428
 
1128
- model_file = os.environ['MDV5A']
1129
- im_file = os.path.expanduser('~/git/MegaDetector/images/nacti.jpg')
1429
+ # Return the single result
1430
+ return batch_results[0]
1130
1431
 
1131
- detector = PTDetector(model_file)
1132
- image = vis_utils.load_image(im_file)
1432
+ # ...def generate_detections_one_image(...)
1133
1433
 
1134
- res = detector.generate_detections_one_image(image, im_file, detection_threshold=0.00001)
1135
- print(res)
1434
+ # ...class PTDetector