megadetector 10.0.2__py3-none-any.whl → 10.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

@@ -14,7 +14,6 @@ import math
14
14
  import zipfile
15
15
  import tempfile
16
16
  import shutil
17
- import traceback
18
17
  import uuid
19
18
  import json
20
19
  import inspect
@@ -24,11 +23,12 @@ import torch
24
23
  import numpy as np
25
24
 
26
25
  from megadetector.detection.run_detector import \
27
- CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, \
26
+ CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, FAILURE_IMAGE_OPEN, \
28
27
  get_detector_version_from_model_file, \
29
28
  known_models
30
29
  from megadetector.utils.ct_utils import parse_bool_string
31
30
  from megadetector.utils import ct_utils
31
+ import torchvision
32
32
 
33
33
  # We support a few ways of accessing the YOLOv5 dependencies:
34
34
  #
@@ -175,7 +175,7 @@ def _initialize_yolo_imports_for_model(model_file,
175
175
  return model_type
176
176
 
177
177
 
178
- def _clean_yolo_imports(verbose=False,aggressive_cleanup=False):
178
+ def _clean_yolo_imports(verbose=False, aggressive_cleanup=False):
179
179
  """
180
180
  Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
181
181
  of another YOLO library version. The reason we jump through all these hoops, rather than
@@ -454,14 +454,14 @@ def _initialize_yolo_imports(model_type='yolov5',
454
454
  try:
455
455
 
456
456
  # import pre- and post-processing functions from the YOLOv5 repo
457
- from utils.general import non_max_suppression, xyxy2xywh # noqa
458
- from utils.augmentations import letterbox # noqa
457
+ from utils.general import non_max_suppression, xyxy2xywh # type: ignore
458
+ from utils.augmentations import letterbox # type: ignore
459
459
 
460
460
  # scale_coords() is scale_boxes() in some YOLOv5 versions
461
461
  try:
462
- from utils.general import scale_coords # noqa
462
+ from utils.general import scale_coords # type: ignore
463
463
  except ImportError:
464
- from utils.general import scale_boxes as scale_coords
464
+ from utils.general import scale_boxes as scale_coords # type: ignore
465
465
  utils_imported = True
466
466
  imported_file = sys.modules[scale_coords.__module__].__file__
467
467
  if verbose:
@@ -482,6 +482,121 @@ def _initialize_yolo_imports(model_type='yolov5',
482
482
  # ...def _initialize_yolo_imports(...)
483
483
 
484
484
 
485
+ #%% NMS
486
+
487
+ def nms(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
488
+ """
489
+ Non-maximum suppression (a wrapper around torchvision.ops.nms())
490
+
491
+ Args:
492
+ prediction (torch.Tensor): Model predictions with shape [batch_size, num_anchors, num_classes + 5]
493
+ Format: [x_center, y_center, width, height, objectness, class1_conf, class2_conf, ...]
494
+ Coordinates are normalized to input image size.
495
+ conf_thres (float): Confidence threshold for filtering detections
496
+ iou_thres (float): IoU threshold for NMS
497
+ max_det (int): Maximum number of detections per image
498
+
499
+ Returns:
500
+ list: List of tensors, one per image in batch. Each tensor has shape [N, 6] where:
501
+ - N is the number of detections for that image
502
+ - Columns are [x1, y1, x2, y2, confidence, class_id]
503
+ - Coordinates are in absolute pixels relative to input image size
504
+ - class_id is the integer class index (0-based)
505
+ """
506
+
507
+ batch_size = prediction.shape[0]
508
+ num_classes = prediction.shape[2] - 5 # noqa
509
+ output = []
510
+
511
+ # Process each image in the batch
512
+ for img_idx in range(batch_size):
513
+
514
+ x = prediction[img_idx] # Shape: [num_anchors, num_classes + 5]
515
+
516
+ # Filter by objectness confidence
517
+ obj_conf = x[:, 4]
518
+ valid_detections = obj_conf > conf_thres
519
+ x = x[valid_detections]
520
+
521
+ if x.shape[0] == 0:
522
+ # No detections for this image
523
+ output.append(torch.zeros((0, 6), device=prediction.device))
524
+ continue
525
+
526
+ # Convert box coordinates from [x_center, y_center, w, h] to [x1, y1, x2, y2]
527
+ box = x[:, :4].clone()
528
+ box[:, 0] = x[:, 0] - x[:, 2] / 2.0 # x1 = center_x - width/2
529
+ box[:, 1] = x[:, 1] - x[:, 3] / 2.0 # y1 = center_y - height/2
530
+ box[:, 2] = x[:, 0] + x[:, 2] / 2.0 # x2 = center_x + width/2
531
+ box[:, 3] = x[:, 1] + x[:, 3] / 2.0 # y2 = center_y + height/2
532
+
533
+ # Get class predictions: multiply objectness by class probabilities
534
+ class_conf = x[:, 5:] * x[:, 4:5] # shape: [N, num_classes]
535
+
536
+ # For each detection, take the class with highest confidence (single-label)
537
+ best_class_conf, best_class_idx = class_conf.max(1, keepdim=True)
538
+
539
+ # Filter by class confidence threshold
540
+ conf_mask = best_class_conf.view(-1) > conf_thres
541
+ if conf_mask.sum() == 0:
542
+ # No detections pass confidence threshold
543
+ output.append(torch.zeros((0, 6), device=prediction.device))
544
+ continue
545
+
546
+ box = box[conf_mask]
547
+ best_class_conf = best_class_conf[conf_mask]
548
+ best_class_idx = best_class_idx[conf_mask]
549
+
550
+ # Prepare for NMS: group detections by class
551
+ unique_classes = best_class_idx.unique()
552
+ final_detections = []
553
+
554
+ for class_id in unique_classes:
555
+
556
+ class_mask = (best_class_idx == class_id).view(-1)
557
+ class_boxes = box[class_mask]
558
+ class_scores = best_class_conf[class_mask].view(-1)
559
+
560
+ if class_boxes.shape[0] == 0:
561
+ continue
562
+
563
+ # Apply NMS for this class
564
+ keep_indices = torchvision.ops.nms(class_boxes, class_scores, iou_thres)
565
+
566
+ if len(keep_indices) > 0:
567
+ kept_boxes = class_boxes[keep_indices]
568
+ kept_scores = class_scores[keep_indices]
569
+ kept_classes = torch.full((len(keep_indices), 1), class_id.item(),
570
+ device=prediction.device, dtype=torch.float)
571
+
572
+ # Combine: [x1, y1, x2, y2, conf, class]
573
+ class_detections = torch.cat([kept_boxes, kept_scores.unsqueeze(1), kept_classes], 1)
574
+ final_detections.append(class_detections)
575
+
576
+ # ...for each category
577
+
578
+ if final_detections:
579
+
580
+ # Combine all classes and sort by confidence
581
+ all_detections = torch.cat(final_detections, 0)
582
+ conf_sort_indices = all_detections[:, 4].argsort(descending=True)
583
+ all_detections = all_detections[conf_sort_indices]
584
+
585
+ # Limit to max_det
586
+ if all_detections.shape[0] > max_det:
587
+ all_detections = all_detections[:max_det]
588
+
589
+ output.append(all_detections)
590
+ else:
591
+ output.append(torch.zeros((0, 6), device=prediction.device))
592
+
593
+ # ...for each image in the batch
594
+
595
+ return output
596
+
597
+ # ...def nms(...)
598
+
599
+
485
600
  #%% Model metadata functions
486
601
 
487
602
  def add_metadata_to_megadetector_model_file(model_file_in,
@@ -593,6 +708,8 @@ def read_metadata_from_megadetector_model_file(model_file,
593
708
 
594
709
  return d
595
710
 
711
+ # ...with zipfile.Zipfile(...)
712
+
596
713
  # ...def read_metadata_from_megadetector_model_file(...)
597
714
 
598
715
 
@@ -606,10 +723,15 @@ require_non_default_compatibility_mode = False
606
723
 
607
724
  class PTDetector:
608
725
  """
609
- Class that runs a PyTorch-based MegaDetector model.
726
+ Class that runs a PyTorch-based MegaDetector model. Also used as a preprocessor
727
+ for images that will later be run through an instance of PTDetector.
610
728
  """
611
729
 
612
730
  def __init__(self, model_path, detector_options=None, verbose=False):
731
+ """
732
+ PTDetector constructor. If detector_options['preprocess_only'] exists and is
733
+ True, this instance is being used as a preprocessor, so we don't load model weights.
734
+ """
613
735
 
614
736
  if verbose:
615
737
  print('Initializing PTDetector (verbose)')
@@ -637,6 +759,8 @@ class PTDetector:
637
759
  else:
638
760
  compatibility_mode = detector_options['compatibility_mode']
639
761
 
762
+ # This is a global option used only during testing, to make sure I'm hitting
763
+ # the cases where we are not using "classic" preprocessing.
640
764
  if require_non_default_compatibility_mode:
641
765
 
642
766
  print('### DEBUG: requiring non-default compatibility mode ###')
@@ -652,14 +776,16 @@ class PTDetector:
652
776
  if verbose or (not preprocess_only):
653
777
  print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
654
778
 
655
- model_metadata = read_metadata_from_megadetector_model_file(model_path)
779
+ self.model_metadata = read_metadata_from_megadetector_model_file(model_path)
656
780
 
657
- #: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
658
- #: aspect ratio".
659
- if model_metadata is not None and 'image_size' in model_metadata:
660
- self.default_image_size = model_metadata['image_size']
781
+ #: Image size passed to the letterbox() function; 1280 means "1280 on the long side,
782
+ #: preserving aspect ratio".
783
+ if self.model_metadata is not None and 'image_size' in self.model_metadata:
784
+ self.default_image_size = self.model_metadata['image_size']
661
785
  print('Loaded image size {} from model metadata'.format(self.default_image_size))
662
786
  else:
787
+ # This is not the default for most YOLO models, but most of the time, if someone
788
+ # is loading a model here that does not have metadata, it's MDv5[ab].0.0
663
789
  print('No image size available in model metadata, defaulting to 1280')
664
790
  self.default_image_size = 1280
665
791
 
@@ -682,12 +808,27 @@ class PTDetector:
682
808
  #: "classic".
683
809
  self.compatibility_mode = compatibility_mode
684
810
 
685
- #: Stride size passed to YOLOv5's letterbox() function
811
+ #: Stride size passed to the YOLO letterbox() function
686
812
  self.letterbox_stride = 32
687
813
 
688
- if 'classic' in self.compatibility_mode:
814
+ # This is a convenient heuristic to determine the stride size without actually loading
815
+ # the model: the only models in the YOLO family with a stride size of 64 are the
816
+ # YOLOv5*6 and YOLOv5*6u models, which are 1280px models.
817
+ #
818
+ # See:
819
+ #
820
+ # github.com/ultralytics/ultralytics/issues/21544
821
+ #
822
+ # Note to self, though, if I decide later to require loading the model on preprocessing
823
+ # workers so I can more reliably choose a stride, this is the right way to determine the
824
+ # stride:
825
+ #
826
+ # self.letterbox_stride = int(self.model.stride.max())
827
+ if self.default_image_size == 1280:
689
828
  self.letterbox_stride = 64
690
829
 
830
+ print('Using model stride: {}'.format(self.letterbox_stride))
831
+
691
832
  #: Use half-precision inference... fixed by the model, generally don't mess with this
692
833
  self.half_precision = False
693
834
 
@@ -699,6 +840,7 @@ class PTDetector:
699
840
  self.device = torch.device('cuda:0')
700
841
  try:
701
842
  if torch.backends.mps.is_built and torch.backends.mps.is_available():
843
+ print('Using MPS device')
702
844
  self.device = 'mps'
703
845
  except AttributeError:
704
846
  pass
@@ -714,7 +856,7 @@ class PTDetector:
714
856
  # to be the same, so doing that externally doesn't seem *that* rude.
715
857
  if "Can't get attribute 'DetectionModel'" in str(e):
716
858
  print('Forward-compatibility issue detected, patching')
717
- from models import yolo
859
+ from models import yolo # type: ignore
718
860
  yolo.DetectionModel = yolo.Model
719
861
  self.model = PTDetector._load_model(model_path,
720
862
  device=self.device,
@@ -737,9 +879,11 @@ class PTDetector:
737
879
  # There are two very slightly different ways to load the model, (1) using the
738
880
  # map_location=device parameter to torch.load and (2) calling .to(device) after
739
881
  # loading the model. The former is what we did for a zillion years, but is not
740
- # supported on Apple silicon at of 2029.09. Switching to the latter causes
882
+ # supported on Apple silicon at of 2024.09. Switching to the latter causes
741
883
  # very slight changes to the output, which always make me nervous, so I'm not
742
- # doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
884
+ # doing a wholesale swap just yet. Instead, when running in "classic" compatibility
885
+ # mode, we'll only use map_location on M1 hardware, where at least at some point
886
+ # there was not a choice.
743
887
  if 'classic' in compatibility_mode:
744
888
  use_map_location = (device != 'mps')
745
889
  else:
@@ -783,287 +927,428 @@ class PTDetector:
783
927
  # ...def _load_model(...)
784
928
 
785
929
 
786
- def generate_detections_one_image(self,
787
- img_original,
788
- image_id='unknown',
789
- detection_threshold=0.00001,
790
- image_size=None,
791
- skip_image_resizing=False,
792
- augment=False,
793
- preprocess_only=False,
794
- verbose=False):
930
+ def preprocess_image(self,
931
+ img_original,
932
+ image_id='unknown',
933
+ image_size=None,
934
+ verbose=False):
795
935
  """
796
- Applies the detector to an image.
936
+ Prepare an image for detection, including scaling and letterboxing.
797
937
 
798
938
  Args:
799
- img_original (Image): the PIL Image object (or numpy array) on which we should run the
800
- detector, with EXIF rotation already handled
939
+ img_original (Image or np.array): the image on which we should run the detector, with
940
+ EXIF rotation already handled
801
941
  image_id (str, optional): a path to identify the image; will be in the "file" field
802
942
  of the output object
803
943
  detection_threshold (float, optional): only detections above this confidence threshold
804
944
  will be included in the return value
805
- image_size (int, optional): image size to use for inference, only mess with this if
806
- (a) you're using a model other than MegaDetector or (b) you know what you're getting into
807
- skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
808
- external resizing), only mess with this if (a) you're using a model other than MegaDetector
809
- or (b) you know what you're getting into
810
- augment (bool, optional): enable (implementation-specific) image augmentation
811
- preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
945
+ image_size (int, optional): image size (long side) to use for inference, or None to
946
+ use the default size specified at the time the model was loaded
812
947
  verbose (bool, optional): enable additional debug output
813
948
 
814
949
  Returns:
815
- dict: a dictionary with the following fields:
816
- - 'file' (filename, always present)
817
- - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
818
- - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
819
- - 'failure' (a failure string, or None if everything went fine)
950
+ dict: dict with fields:
951
+ - file (filename)
952
+ - img (the preprocessed np.array)
953
+ - img_original (the input image before preprocessing, as an np.array)
954
+ - img_original_pil (the input image before preprocessing, as a PIL Image)
955
+ - target_shape (the 2D shape to which the image was resized during preprocessing)
956
+ - scaling_shape (the 2D original size, for normalizing coordinates later)
957
+ - letterbox_ratio (letterbox parameter used for normalizing coordinates later)
958
+ - letterbox_pad (letterbox parameter used for normalizing coordinates later)
820
959
  """
821
960
 
961
+ # Prepare return dict
822
962
  result = {'file': image_id }
823
- detections = []
824
- max_conf = 0.0
825
963
 
826
- if preprocess_only:
827
- assert 'classic' in self.compatibility_mode, \
828
- 'Standalone preprocessing only supported in "classic" mode'
829
- assert not skip_image_resizing, \
830
- 'skip_image_resizing and preprocess_only are exclusive'
964
+ # Store the PIL version of the original image, the caller may want to use
965
+ # it for metadata extraction later.
966
+ img_original_pil = None
831
967
 
832
- if detection_threshold is None:
968
+ # If we were given a PIL image, rather than a numpy array
969
+ if not isinstance(img_original,np.ndarray):
970
+ img_original_pil = img_original
971
+ img_original = np.asarray(img_original)
833
972
 
834
- detection_threshold = 0
973
+ # PIL images are RGB already
974
+ # img_original = img_original[:, :, ::-1]
835
975
 
836
- try:
976
+ # Save the original shape for scaling boxes later
977
+ scaling_shape = img_original.shape
837
978
 
838
- # If the caller wants us to skip all the resizing operations...
839
- if skip_image_resizing:
840
-
841
- if isinstance(img_original,dict):
842
- image_info = img_original
843
- img = image_info['img_processed']
844
- scaling_shape = image_info['scaling_shape']
845
- letterbox_pad = image_info['letterbox_pad']
846
- letterbox_ratio = image_info['letterbox_ratio']
847
- img_original = image_info['img_original']
848
- img_original_pil = image_info['img_original_pil']
849
- else:
850
- img = img_original
979
+ # If the caller is requesting a specific target size...
980
+ if image_size is not None:
851
981
 
852
- else:
982
+ assert isinstance(image_size,int)
983
+
984
+ if not self.printed_image_size_warning:
985
+ print('Using user-supplied image size {}'.format(image_size))
986
+ self.printed_image_size_warning = True
853
987
 
854
- img_original_pil = None
855
- # If we were given a PIL image
988
+ # Otherwise resize to self.default_image_size
989
+ else:
856
990
 
857
- if not isinstance(img_original,np.ndarray):
858
- img_original_pil = img_original
859
- img_original = np.asarray(img_original)
991
+ image_size = self.default_image_size
992
+ self.printed_image_size_warning = False
860
993
 
861
- # PIL images are RGB already
862
- # img_original = img_original[:, :, ::-1]
994
+ # ...if the caller has specified an image size
863
995
 
864
- # Save the original shape for scaling boxes later
865
- scaling_shape = img_original.shape
996
+ # In "classic mode", we only do the letterboxing resize, we don't do an
997
+ # additional initial resizing operation
998
+ if 'classic' in self.compatibility_mode:
866
999
 
867
- # If the caller is requesting a specific target size...
868
- if image_size is not None:
1000
+ resize_ratio = 1.0
869
1001
 
870
- assert isinstance(image_size,int)
1002
+ # Resize the image so the long side matches the target image size. This is not
1003
+ # letterboxing (i.e., padding) yet, just resizing.
1004
+ else:
871
1005
 
872
- if not self.printed_image_size_warning:
873
- print('Using user-supplied image size {}'.format(image_size))
874
- self.printed_image_size_warning = True
1006
+ use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
875
1007
 
876
- # Otherwise resize to self.default_image_size
1008
+ h,w = img_original.shape[:2]
1009
+ resize_ratio = image_size / max(h,w)
1010
+
1011
+ # Only resize if we have to
1012
+ if resize_ratio != 1:
1013
+
1014
+ # Match what yolov5 does: use linear interpolation for upsizing;
1015
+ # area interpolation for downsizing
1016
+ if resize_ratio > 1:
1017
+ interpolation_method = cv2.INTER_LINEAR
877
1018
  else:
1019
+ interpolation_method = cv2.INTER_AREA
878
1020
 
879
- image_size = self.default_image_size
880
- self.printed_image_size_warning = False
1021
+ if use_ceil_for_resize:
1022
+ target_w = math.ceil(w * resize_ratio)
1023
+ target_h = math.ceil(h * resize_ratio)
1024
+ else:
1025
+ target_w = int(w * resize_ratio)
1026
+ target_h = int(h * resize_ratio)
881
1027
 
882
- # ...if the caller has specified an image size
1028
+ img_original = cv2.resize(
1029
+ img_original, (target_w, target_h),
1030
+ interpolation=interpolation_method)
883
1031
 
884
- # In "classic mode", we only do the letterboxing resize, we don't do an
885
- # additional initial resizing operation
886
- if 'classic' in self.compatibility_mode:
1032
+ if 'classic' in self.compatibility_mode:
887
1033
 
888
- resize_ratio = 1.0
1034
+ letterbox_auto = True
1035
+ letterbox_scaleup = True
1036
+ target_shape = image_size
889
1037
 
890
- # Resize the image so the long side matches the target image size. This is not
891
- # letterboxing (i.e., padding) yet, just resizing.
892
- else:
1038
+ else:
893
1039
 
894
- use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
1040
+ letterbox_auto = False
1041
+ letterbox_scaleup = False
895
1042
 
896
- h,w = img_original.shape[:2]
897
- resize_ratio = image_size / max(h,w)
1043
+ # The padding to apply as a fraction of the stride size
1044
+ pad = 0.5
898
1045
 
899
- # Only resize if we have to
900
- if resize_ratio != 1:
1046
+ # Resize to a multiple of the model stride
1047
+ #
1048
+ # This is how we would determine the stride if we knew the model had been loaded:
1049
+ #
1050
+ # model_stride = int(self.model.stride.max())
1051
+ #
1052
+ # ...but because we do this on preprocessing workers now, we try to avoid loading the model
1053
+ # just for preprocessing, and we assume the stride was determined at the time the PTDetector
1054
+ # object was created.
1055
+ try:
1056
+ model_stride = int(self.model.stride.max())
1057
+ if model_stride != self.letterbox_stride:
1058
+ print('*** Warning: model stride is {}, stride at construction time was {} ***'.format(
1059
+ model_stride,self.letterbox_stride
1060
+ ))
1061
+ except Exception:
1062
+ pass
901
1063
 
902
- # Match what yolov5 does: use linear interpolation for upsizing;
903
- # area interpolation for downsizing
904
- if resize_ratio > 1:
905
- interpolation_method = cv2.INTER_LINEAR
906
- else:
907
- interpolation_method = cv2.INTER_AREA
1064
+ model_stride = self.letterbox_stride
1065
+ max_dimension = max(img_original.shape)
1066
+ normalized_shape = [img_original.shape[0] / max_dimension,
1067
+ img_original.shape[1] / max_dimension]
1068
+ target_shape = np.ceil(((np.array(normalized_shape) * image_size) / model_stride) + \
1069
+ pad).astype(int) * model_stride
1070
+
1071
+ # Now we letterbox, which is just padding, since we've already resized
1072
+ img,letterbox_ratio,letterbox_pad = letterbox(img_original,
1073
+ new_shape=target_shape,
1074
+ stride=self.letterbox_stride,
1075
+ auto=letterbox_auto,
1076
+ scaleFill=False,
1077
+ scaleup=letterbox_scaleup)
1078
+
1079
+ result['img_processed'] = img
1080
+ result['img_original'] = img_original
1081
+ result['img_original_pil'] = img_original_pil
1082
+ result['target_shape'] = target_shape
1083
+ result['scaling_shape'] = scaling_shape
1084
+ result['letterbox_ratio'] = letterbox_ratio
1085
+ result['letterbox_pad'] = letterbox_pad
1086
+ return result
908
1087
 
909
- if use_ceil_for_resize:
910
- target_w = math.ceil(w * resize_ratio)
911
- target_h = math.ceil(h * resize_ratio)
912
- else:
913
- target_w = int(w * resize_ratio)
914
- target_h = int(h * resize_ratio)
1088
+ # ...def preprocess_image(...)
915
1089
 
916
- img_original = cv2.resize(
917
- img_original, (target_w, target_h),
918
- interpolation=interpolation_method)
919
1090
 
920
- if 'classic' in self.compatibility_mode:
1091
+ def generate_detections_one_batch(self,
1092
+ img_original,
1093
+ image_id=None,
1094
+ detection_threshold=0.00001,
1095
+ image_size=None,
1096
+ augment=False,
1097
+ verbose=False):
1098
+ """
1099
+ Run a detector on a batch of images.
1100
+
1101
+ Args:
1102
+ img_original (list): list of images (Image, np.array, or dict) on which we should run the detector, with
1103
+ EXIF rotation already handled, or dicts representing preprocessed images with associated
1104
+ letterbox parameters
1105
+ image_id (list or None): list of paths to identify the images; will be in the "file" field
1106
+ of the output objects. Will be ignored when img_original contains preprocessed dicts.
1107
+ detection_threshold (float, optional): only detections above this confidence threshold
1108
+ will be included in the return value
1109
+ image_size (int, optional): image size (long side) to use for inference, or None to
1110
+ use the default size specified at the time the model was loaded
1111
+ augment (bool, optional): enable (implementation-specific) image augmentation
1112
+ verbose (bool, optional): enable additional debug output
1113
+
1114
+ Returns:
1115
+ list: a list of dictionaries, each with the following fields:
1116
+ - 'file' (filename, always present)
1117
+ - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
1118
+ - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
1119
+ - 'failure' (a failure string, or None if everything went fine)
1120
+ """
921
1121
 
922
- letterbox_auto = True
923
- letterbox_scaleup = True
924
- target_shape = image_size
1122
+ # Validate inputs
1123
+ if not isinstance(img_original, list):
1124
+ raise ValueError('img_original must be a list for batch processing')
925
1125
 
1126
+ if verbose:
1127
+ print('generate_detections_one_batch: processing a batch of size {}'.format(len(img_original)))
1128
+
1129
+ if len(img_original) == 0:
1130
+ return []
1131
+
1132
+ # Check input consistency
1133
+ if isinstance(img_original[0], dict):
1134
+ # All items in img_original should be preprocessed dicts
1135
+ for i, img in enumerate(img_original):
1136
+ if not isinstance(img, dict):
1137
+ raise ValueError(f'Mixed input types in batch: item {i} is not a dict, but item 0 is a dict')
1138
+ else:
1139
+ # All items in img_original should be PIL/numpy images, and image_id should be a list of strings
1140
+ if image_id is None:
1141
+ raise ValueError('image_id must be a list when img_original contains PIL/numpy images')
1142
+ if not isinstance(image_id, list):
1143
+ raise ValueError('image_id must be a list for batch processing')
1144
+ if len(image_id) != len(img_original):
1145
+ raise ValueError(
1146
+ 'Length mismatch: img_original has {} items, image_id has {} items'.format(
1147
+ len(img_original),len(image_id)))
1148
+ for i_img, img in enumerate(img_original):
1149
+ if isinstance(img, dict):
1150
+ raise ValueError(
1151
+ 'Mixed input types in batch: item {} is a dict, but item 0 is not a dict'.format(
1152
+ i_img))
1153
+
1154
+ if detection_threshold is None:
1155
+ detection_threshold = 0.0
1156
+
1157
+ batch_size = len(img_original)
1158
+ results = [None] * batch_size
1159
+
1160
+ # Preprocess all images, handling failures
1161
+ preprocessed_images = []
1162
+ preprocessing_failed_indices = set()
1163
+
1164
+ for i_img, img in enumerate(img_original):
1165
+
1166
+ try:
1167
+ if isinstance(img, dict):
1168
+ # Already preprocessed
1169
+ image_info = img
1170
+ current_image_id = image_info['file']
926
1171
  else:
1172
+ # Need to preprocess
1173
+ current_image_id = image_id[i_img]
1174
+ image_info = self.preprocess_image(
1175
+ img_original=img,
1176
+ image_id=current_image_id,
1177
+ image_size=image_size,
1178
+ verbose=verbose)
1179
+
1180
+ preprocessed_images.append((i_img, image_info, current_image_id))
1181
+
1182
+ except Exception as e:
1183
+ if verbose:
1184
+ print('Preprocessing failed for image {}: {}'.format(
1185
+ image_id[i_img] if image_id else f'index_{i_img}', str(e)))
1186
+
1187
+ preprocessing_failed_indices.add(i_img)
1188
+ current_image_id = image_id[i_img] if image_id else f'index_{i_img}'
1189
+ results[i_img] = {
1190
+ 'file': current_image_id,
1191
+ 'detections': None,
1192
+ 'failure': FAILURE_IMAGE_OPEN
1193
+ }
1194
+
1195
+ # ...for each image in this batch
1196
+
1197
+ # Group preprocessed images by actual processed image shape for batching
1198
+ shape_groups = {}
1199
+ for original_idx, image_info, current_image_id in preprocessed_images:
1200
+ # Use the actual processed image shape for grouping, not target_shape
1201
+ actual_shape = tuple(image_info['img_processed'].shape)
1202
+ if actual_shape not in shape_groups:
1203
+ shape_groups[actual_shape] = []
1204
+ shape_groups[actual_shape].append((original_idx, image_info, current_image_id))
1205
+
1206
+ if verbose and len(shape_groups) > 1:
1207
+ print('generate_detections_one_batch: batch of size {} split into {} shape-group batches'.\
1208
+ format(len(preprocessed_images), len(shape_groups)))
1209
+
1210
+ # Process each shape group as a batch
1211
+ for target_shape, group_items in shape_groups.items():
1212
+ try:
1213
+ self._process_batch_group(group_items, results, detection_threshold, augment, verbose)
1214
+ except Exception as e:
1215
+ # If inference fails for the entire batch, mark all images in this batch as failed
1216
+ if verbose:
1217
+ print('Batch inference failed for shape {}: {}'.format(target_shape, str(e)))
1218
+
1219
+ for original_idx, image_info, current_image_id in group_items:
1220
+ results[original_idx] = {
1221
+ 'file': current_image_id,
1222
+ 'detections': None,
1223
+ 'failure': FAILURE_INFER
1224
+ }
927
1225
 
928
- letterbox_auto = False
929
- letterbox_scaleup = False
1226
+ return results
930
1227
 
931
- # The padding to apply as a fraction of the stride size
932
- pad = 0.5
1228
+ def _process_batch_group(self, group_items, results, detection_threshold, augment, verbose):
1229
+ """
1230
+ Process a group of images with the same target shape as a single batch.
933
1231
 
934
- model_stride = int(self.model.stride.max())
1232
+ Args:
1233
+ group_items (list): List of (original_idx, image_info, current_image_id) tuples
1234
+ results (list): Results list to populate (modified in place)
1235
+ detection_threshold (float): Detection confidence threshold
1236
+ augment (bool): Enable augmentation
1237
+ verbose (bool): Enable verbose output
935
1238
 
936
- max_dimension = max(img_original.shape)
937
- normalized_shape = [img_original.shape[0] / max_dimension,
938
- img_original.shape[1] / max_dimension]
939
- target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
940
- pad).astype(int) * model_stride
1239
+ Returns:
1240
+ list of dict: list of dictionaries the same length as group_items, with fields 'file',
1241
+ 'detections', 'max_detection_conf'.
1242
+ """
941
1243
 
942
- # Now we letterbox, which is just padding, since we've already resized.
943
- img,letterbox_ratio,letterbox_pad = letterbox(img_original,
944
- new_shape=target_shape,
945
- stride=self.letterbox_stride,
946
- auto=letterbox_auto,
947
- scaleFill=False,
948
- scaleup=letterbox_scaleup)
1244
+ if len(group_items) == 0:
1245
+ return
949
1246
 
950
- if preprocess_only:
1247
+ # Extract batch data
1248
+ batch_images = []
1249
+ batch_metadata = []
951
1250
 
952
- assert 'file' in result
953
- result['img_processed'] = img
954
- result['img_original'] = img_original
955
- result['img_original_pil'] = img_original_pil
956
- result['target_shape'] = target_shape
957
- result['scaling_shape'] = scaling_shape
958
- result['letterbox_ratio'] = letterbox_ratio
959
- result['letterbox_pad'] = letterbox_pad
960
- return result
1251
+ # For each image in this batch...
1252
+ for original_idx, image_info, current_image_id in group_items:
961
1253
 
962
- # ...are we doing resizing here, or were images already resized?
1254
+ img = image_info['img_processed']
963
1255
 
964
- # Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
965
- # so we don't need to mess with the color channels.
966
- #
967
- # TODO, this could be moved into the preprocessing loop
968
-
969
- img = img.transpose((2, 0, 1)) # [::-1]
970
- img = np.ascontiguousarray(img)
971
- img = torch.from_numpy(img)
972
- img = img.to(self.device)
973
- img = img.half() if self.half_precision else img.float()
974
- img /= 255
975
-
976
- # In practice this is always true
977
- if len(img.shape) == 3:
978
- img = torch.unsqueeze(img, 0)
979
-
980
- # Run the model
981
- pred = self.model(img,augment=augment)[0]
982
-
983
- if 'classic' in self.compatibility_mode:
984
- nms_conf_thres = detection_threshold
985
- nms_iou_thres = 0.45
986
- nms_agnostic = False
987
- nms_multi_label = False
988
- else:
989
- nms_conf_thres = detection_threshold # 0.01
990
- nms_iou_thres = 0.6
991
- nms_agnostic = False
992
- nms_multi_label = True
1256
+ # Convert HWC to CHW and prepare tensor
1257
+ img_tensor = img.transpose((2, 0, 1))
1258
+ img_tensor = np.ascontiguousarray(img_tensor)
1259
+ img_tensor = torch.from_numpy(img_tensor)
1260
+ batch_images.append(img_tensor)
993
1261
 
994
- # As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
995
- #
996
- # Send predictions back to the CPU for NMS.
997
- if self.device == 'mps':
998
- pred_nms = pred.cpu()
999
- else:
1000
- pred_nms = pred
1262
+ metadata = {
1263
+ 'original_idx': original_idx,
1264
+ 'current_image_id': current_image_id,
1265
+ 'scaling_shape': image_info['scaling_shape'],
1266
+ 'letterbox_pad': image_info['letterbox_pad'],
1267
+ 'img_original': image_info['img_original']
1268
+ }
1269
+ batch_metadata.append(metadata)
1001
1270
 
1002
- # NMS
1003
- pred = non_max_suppression(prediction=pred_nms,
1004
- conf_thres=nms_conf_thres,
1005
- iou_thres=nms_iou_thres,
1006
- agnostic=nms_agnostic,
1007
- multi_label=nms_multi_label)
1271
+ # ...for each image in this batch
1008
1272
 
1009
- # In practice this is [w,h,w,h] of the original image
1010
- gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
1273
+ # Stack images into a batch tensor
1274
+ batch_tensor = torch.stack(batch_images)
1011
1275
 
1012
- if 'classic' in self.compatibility_mode:
1276
+ if verbose:
1277
+ if batch_tensor.shape[0] > 1:
1278
+ print('_process_batch_group: processing a batch of size {}'.format(batch_tensor.shape[0]))
1013
1279
 
1014
- ratio = None
1015
- ratio_pad = None
1280
+ batch_tensor = batch_tensor.float()
1281
+ batch_tensor /= 255.0
1016
1282
 
1017
- else:
1283
+ batch_tensor = batch_tensor.to(self.device)
1284
+ if self.half_precision:
1285
+ batch_tensor = batch_tensor.half()
1018
1286
 
1019
- # letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
1020
- #
1021
- # ratio is a 2-tuple specifying the scaling that was applied to each dimension.
1022
- #
1023
- # The scale_boxes function expects a 2-tuple with these things combined.
1024
- ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
1025
- ratio_pad = (ratio, letterbox_pad)
1287
+ # Run the model on the batch
1288
+ pred = self.model(batch_tensor, augment=augment)[0]
1026
1289
 
1027
- # This is a loop over detection batches, which will always be length 1 in our case,
1028
- # since we're not doing batch inference.
1029
- #
1030
- # det = pred[0]
1031
- #
1032
- # det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
1033
- # in descending order by confidence.
1034
- #
1035
- # Columns are:
1036
- #
1037
- # x0,y0,x1,y1,confidence,class
1038
- #
1039
- # At this point, these are *non*-normalized values, referring to the size at which we
1040
- # ran inference (img.shape).
1041
- for det in pred:
1290
+ # Configure NMS parameters
1291
+ if 'classic' in self.compatibility_mode:
1292
+ nms_iou_thres = 0.45
1293
+ else:
1294
+ nms_iou_thres = 0.6
1042
1295
 
1043
- if len(det) == 0:
1044
- continue
1296
+ pred = nms(prediction=pred,
1297
+ conf_thres=detection_threshold,
1298
+ iou_thres=nms_iou_thres)
1045
1299
 
1046
- # Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
1047
- if 'classic' in self.compatibility_mode:
1300
+ # For posterity, the ultralytics implementation
1301
+ if False:
1302
+ pred = non_max_suppression(prediction=pred,
1303
+ conf_thres=detection_threshold,
1304
+ iou_thres=nms_iou_thres,
1305
+ agnostic=False,
1306
+ multi_label=False)
1307
+
1308
+ assert isinstance(pred, list)
1309
+ assert len(pred) == len(batch_metadata), \
1310
+ print('Mismatch between prediction length {} and batch size {}'.format(
1311
+ len(pred),len(batch_metadata)))
1048
1312
 
1049
- det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
1313
+ # Process each image's detections
1314
+ for i_image, det in enumerate(pred):
1050
1315
 
1316
+ metadata = batch_metadata[i_image]
1317
+ original_idx = metadata['original_idx']
1318
+ current_image_id = metadata['current_image_id']
1319
+ scaling_shape = metadata['scaling_shape']
1320
+ letterbox_pad = metadata['letterbox_pad']
1321
+ img_original = metadata['img_original']
1322
+
1323
+ detections = []
1324
+ max_conf = 0.0
1325
+
1326
+ if len(det) > 0:
1327
+
1328
+ # Prepare scaling parameters
1329
+ gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
1330
+
1331
+ if 'classic' in self.compatibility_mode:
1332
+ ratio = None
1333
+ ratio_pad = None
1051
1334
  else:
1052
- # After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
1053
- # original pixel dimension of the image, followed by the class and confidence
1054
- det[:, :4] = scale_coords(img.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
1335
+ ratio = (img_original.shape[0]/scaling_shape[0],
1336
+ img_original.shape[1]/scaling_shape[1])
1337
+ ratio_pad = (ratio, letterbox_pad)
1055
1338
 
1056
- # Loop over detections
1057
- for *xyxy, conf, cls in reversed(det):
1339
+ # Rescale boxes
1340
+ if 'classic' in self.compatibility_mode:
1341
+ det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], img_original.shape).round()
1342
+ else:
1343
+ det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
1058
1344
 
1345
+ # Process each detection
1346
+ for *xyxy, conf, cls in reversed(det):
1059
1347
  if conf < detection_threshold:
1060
1348
  continue
1061
1349
 
1062
- # Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
1350
+ # Convert to YOLO format then to MD format
1063
1351
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
1064
-
1065
- # Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
1066
- # left/top/w/h (i.e., MD format)
1067
1352
  api_box = ct_utils.convert_yolo_to_xywh(xywh)
1068
1353
 
1069
1354
  if 'classic' in self.compatibility_mode:
@@ -1074,8 +1359,6 @@ class PTDetector:
1074
1359
  conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
1075
1360
 
1076
1361
  if not self.use_model_native_classes:
1077
- # The MegaDetector output format's categories start at 1, but all YOLO-based
1078
- # MD models have category numbers starting at 0.
1079
1362
  cls = int(cls.tolist()) + 1
1080
1363
  if cls not in (1, 2, 3):
1081
1364
  raise KeyError(f'{cls} is not a valid class.')
@@ -1089,47 +1372,73 @@ class PTDetector:
1089
1372
  })
1090
1373
  max_conf = max(max_conf, conf)
1091
1374
 
1092
- # ...for each detection in this batch
1093
-
1094
- # ...for each detection batch (always one iteration)
1375
+ # ...for each detection
1095
1376
 
1096
- # ...try
1377
+ # ...if there are > 0 detections
1097
1378
 
1098
- except Exception as e:
1379
+ # Store result for this image
1380
+ results[original_idx] = {
1381
+ 'file': current_image_id,
1382
+ 'detections': detections,
1383
+ 'max_detection_conf': max_conf
1384
+ }
1099
1385
 
1100
- result['failure'] = FAILURE_INFER
1101
- print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
1102
- # traceback.print_exc(e)
1103
- print(traceback.format_exc())
1386
+ # ...for each image
1104
1387
 
1105
- result['max_detection_conf'] = max_conf
1106
- result['detections'] = detections
1388
+ # ...def _process_batch_group(...)
1107
1389
 
1108
- return result
1109
-
1110
- # ...def generate_detections_one_image(...)
1111
-
1112
- # ...class PTDetector
1113
-
1114
-
1115
- #%% Command-line driver
1116
-
1117
- # For testing only... you don't really want to run this module directly.
1118
-
1119
- if __name__ == '__main__':
1390
+ def generate_detections_one_image(self,
1391
+ img_original,
1392
+ image_id='unknown',
1393
+ detection_threshold=0.00001,
1394
+ image_size=None,
1395
+ augment=False,
1396
+ verbose=False):
1397
+ """
1398
+ Run a detector on an image (wrapper around batch function).
1120
1399
 
1121
- pass
1400
+ Args:
1401
+ img_original (Image, np.array, or dict): the image on which we should run the detector, with
1402
+ EXIF rotation already handled, or a dict representing a preprocessed image with associated
1403
+ letterbox parameters
1404
+ image_id (str, optional): a path to identify the image; will be in the "file" field
1405
+ of the output object
1406
+ detection_threshold (float, optional): only detections above this confidence threshold
1407
+ will be included in the return value
1408
+ image_size (int, optional): image size (long side) to use for inference, or None to
1409
+ use the default size specified at the time the model was loaded
1410
+ augment (bool, optional): enable (implementation-specific) image augmentation
1411
+ verbose (bool, optional): enable additional debug output
1122
1412
 
1123
- #%%
1413
+ Returns:
1414
+ dict: a dictionary with the following fields:
1415
+ - 'file' (filename, always present)
1416
+ - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
1417
+ - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
1418
+ - 'failure' (a failure string, or None if everything went fine)
1419
+ """
1124
1420
 
1125
- import os #noqa
1126
- from megadetector.visualization import visualization_utils as vis_utils
1421
+ # Prepare batch inputs
1422
+ if isinstance(img_original, dict):
1423
+ batch_results = self.generate_detections_one_batch(
1424
+ img_original=[img_original],
1425
+ image_id=None,
1426
+ detection_threshold=detection_threshold,
1427
+ image_size=image_size,
1428
+ augment=augment,
1429
+ verbose=verbose)
1430
+ else:
1431
+ batch_results = self.generate_detections_one_batch(
1432
+ img_original=[img_original],
1433
+ image_id=[image_id],
1434
+ detection_threshold=detection_threshold,
1435
+ image_size=image_size,
1436
+ augment=augment,
1437
+ verbose=verbose)
1127
1438
 
1128
- model_file = os.environ['MDV5A']
1129
- im_file = os.path.expanduser('~/git/MegaDetector/images/nacti.jpg')
1439
+ # Return the single result
1440
+ return batch_results[0]
1130
1441
 
1131
- detector = PTDetector(model_file)
1132
- image = vis_utils.load_image(im_file)
1442
+ # ...def generate_detections_one_image(...)
1133
1443
 
1134
- res = detector.generate_detections_one_image(image, im_file, detection_threshold=0.00001)
1135
- print(res)
1444
+ # ...class PTDetector