megadetector 10.0.2__py3-none-any.whl → 10.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/data_management/animl_to_md.py +158 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/process_video.py +165 -946
- megadetector/detection/pytorch_detector.py +575 -276
- megadetector/detection/run_detector_batch.py +629 -202
- megadetector/detection/run_md_and_speciesnet.py +1319 -0
- megadetector/detection/video_utils.py +243 -107
- megadetector/postprocessing/classification_postprocessing.py +12 -1
- megadetector/postprocessing/combine_batch_outputs.py +2 -0
- megadetector/postprocessing/compare_batch_results.py +21 -2
- megadetector/postprocessing/merge_detections.py +16 -12
- megadetector/postprocessing/separate_detections_into_folders.py +1 -1
- megadetector/postprocessing/subset_json_detector_output.py +1 -3
- megadetector/postprocessing/validate_batch_results.py +25 -2
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/ct_utils.py +69 -5
- megadetector/utils/extract_frames_from_video.py +303 -0
- megadetector/utils/md_tests.py +583 -524
- megadetector/utils/path_utils.py +4 -15
- megadetector/utils/wi_utils.py +20 -4
- megadetector/visualization/visualization_utils.py +1 -1
- megadetector/visualization/visualize_db.py +8 -22
- megadetector/visualization/visualize_detector_output.py +7 -5
- megadetector/visualization/visualize_video_output.py +607 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/METADATA +134 -135
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/RECORD +30 -23
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/licenses/LICENSE +0 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/top_level.txt +0 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/WHEEL +0 -0
|
@@ -14,7 +14,6 @@ import math
|
|
|
14
14
|
import zipfile
|
|
15
15
|
import tempfile
|
|
16
16
|
import shutil
|
|
17
|
-
import traceback
|
|
18
17
|
import uuid
|
|
19
18
|
import json
|
|
20
19
|
import inspect
|
|
@@ -24,11 +23,13 @@ import torch
|
|
|
24
23
|
import numpy as np
|
|
25
24
|
|
|
26
25
|
from megadetector.detection.run_detector import \
|
|
27
|
-
CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, \
|
|
26
|
+
CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, FAILURE_IMAGE_OPEN, \
|
|
28
27
|
get_detector_version_from_model_file, \
|
|
29
28
|
known_models
|
|
30
29
|
from megadetector.utils.ct_utils import parse_bool_string
|
|
30
|
+
from megadetector.utils.ct_utils import is_running_in_gha
|
|
31
31
|
from megadetector.utils import ct_utils
|
|
32
|
+
import torchvision
|
|
32
33
|
|
|
33
34
|
# We support a few ways of accessing the YOLOv5 dependencies:
|
|
34
35
|
#
|
|
@@ -175,7 +176,7 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
175
176
|
return model_type
|
|
176
177
|
|
|
177
178
|
|
|
178
|
-
def _clean_yolo_imports(verbose=False,aggressive_cleanup=False):
|
|
179
|
+
def _clean_yolo_imports(verbose=False, aggressive_cleanup=False):
|
|
179
180
|
"""
|
|
180
181
|
Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
|
|
181
182
|
of another YOLO library version. The reason we jump through all these hoops, rather than
|
|
@@ -454,14 +455,14 @@ def _initialize_yolo_imports(model_type='yolov5',
|
|
|
454
455
|
try:
|
|
455
456
|
|
|
456
457
|
# import pre- and post-processing functions from the YOLOv5 repo
|
|
457
|
-
from utils.general import non_max_suppression, xyxy2xywh #
|
|
458
|
-
from utils.augmentations import letterbox #
|
|
458
|
+
from utils.general import non_max_suppression, xyxy2xywh # type: ignore
|
|
459
|
+
from utils.augmentations import letterbox # type: ignore
|
|
459
460
|
|
|
460
461
|
# scale_coords() is scale_boxes() in some YOLOv5 versions
|
|
461
462
|
try:
|
|
462
|
-
from utils.general import scale_coords #
|
|
463
|
+
from utils.general import scale_coords # type: ignore
|
|
463
464
|
except ImportError:
|
|
464
|
-
from utils.general import scale_boxes as scale_coords
|
|
465
|
+
from utils.general import scale_boxes as scale_coords # type: ignore
|
|
465
466
|
utils_imported = True
|
|
466
467
|
imported_file = sys.modules[scale_coords.__module__].__file__
|
|
467
468
|
if verbose:
|
|
@@ -482,6 +483,121 @@ def _initialize_yolo_imports(model_type='yolov5',
|
|
|
482
483
|
# ...def _initialize_yolo_imports(...)
|
|
483
484
|
|
|
484
485
|
|
|
486
|
+
#%% NMS
|
|
487
|
+
|
|
488
|
+
def nms(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
|
489
|
+
"""
|
|
490
|
+
Non-maximum suppression (a wrapper around torchvision.ops.nms())
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
prediction (torch.Tensor): Model predictions with shape [batch_size, num_anchors, num_classes + 5]
|
|
494
|
+
Format: [x_center, y_center, width, height, objectness, class1_conf, class2_conf, ...]
|
|
495
|
+
Coordinates are normalized to input image size.
|
|
496
|
+
conf_thres (float): Confidence threshold for filtering detections
|
|
497
|
+
iou_thres (float): IoU threshold for NMS
|
|
498
|
+
max_det (int): Maximum number of detections per image
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
list: List of tensors, one per image in batch. Each tensor has shape [N, 6] where:
|
|
502
|
+
- N is the number of detections for that image
|
|
503
|
+
- Columns are [x1, y1, x2, y2, confidence, class_id]
|
|
504
|
+
- Coordinates are in absolute pixels relative to input image size
|
|
505
|
+
- class_id is the integer class index (0-based)
|
|
506
|
+
"""
|
|
507
|
+
|
|
508
|
+
batch_size = prediction.shape[0]
|
|
509
|
+
num_classes = prediction.shape[2] - 5 # noqa
|
|
510
|
+
output = []
|
|
511
|
+
|
|
512
|
+
# Process each image in the batch
|
|
513
|
+
for img_idx in range(batch_size):
|
|
514
|
+
|
|
515
|
+
x = prediction[img_idx] # Shape: [num_anchors, num_classes + 5]
|
|
516
|
+
|
|
517
|
+
# Filter by objectness confidence
|
|
518
|
+
obj_conf = x[:, 4]
|
|
519
|
+
valid_detections = obj_conf > conf_thres
|
|
520
|
+
x = x[valid_detections]
|
|
521
|
+
|
|
522
|
+
if x.shape[0] == 0:
|
|
523
|
+
# No detections for this image
|
|
524
|
+
output.append(torch.zeros((0, 6), device=prediction.device))
|
|
525
|
+
continue
|
|
526
|
+
|
|
527
|
+
# Convert box coordinates from [x_center, y_center, w, h] to [x1, y1, x2, y2]
|
|
528
|
+
box = x[:, :4].clone()
|
|
529
|
+
box[:, 0] = x[:, 0] - x[:, 2] / 2.0 # x1 = center_x - width/2
|
|
530
|
+
box[:, 1] = x[:, 1] - x[:, 3] / 2.0 # y1 = center_y - height/2
|
|
531
|
+
box[:, 2] = x[:, 0] + x[:, 2] / 2.0 # x2 = center_x + width/2
|
|
532
|
+
box[:, 3] = x[:, 1] + x[:, 3] / 2.0 # y2 = center_y + height/2
|
|
533
|
+
|
|
534
|
+
# Get class predictions: multiply objectness by class probabilities
|
|
535
|
+
class_conf = x[:, 5:] * x[:, 4:5] # shape: [N, num_classes]
|
|
536
|
+
|
|
537
|
+
# For each detection, take the class with highest confidence (single-label)
|
|
538
|
+
best_class_conf, best_class_idx = class_conf.max(1, keepdim=True)
|
|
539
|
+
|
|
540
|
+
# Filter by class confidence threshold
|
|
541
|
+
conf_mask = best_class_conf.view(-1) > conf_thres
|
|
542
|
+
if conf_mask.sum() == 0:
|
|
543
|
+
# No detections pass confidence threshold
|
|
544
|
+
output.append(torch.zeros((0, 6), device=prediction.device))
|
|
545
|
+
continue
|
|
546
|
+
|
|
547
|
+
box = box[conf_mask]
|
|
548
|
+
best_class_conf = best_class_conf[conf_mask]
|
|
549
|
+
best_class_idx = best_class_idx[conf_mask]
|
|
550
|
+
|
|
551
|
+
# Prepare for NMS: group detections by class
|
|
552
|
+
unique_classes = best_class_idx.unique()
|
|
553
|
+
final_detections = []
|
|
554
|
+
|
|
555
|
+
for class_id in unique_classes:
|
|
556
|
+
|
|
557
|
+
class_mask = (best_class_idx == class_id).view(-1)
|
|
558
|
+
class_boxes = box[class_mask]
|
|
559
|
+
class_scores = best_class_conf[class_mask].view(-1)
|
|
560
|
+
|
|
561
|
+
if class_boxes.shape[0] == 0:
|
|
562
|
+
continue
|
|
563
|
+
|
|
564
|
+
# Apply NMS for this class
|
|
565
|
+
keep_indices = torchvision.ops.nms(class_boxes, class_scores, iou_thres)
|
|
566
|
+
|
|
567
|
+
if len(keep_indices) > 0:
|
|
568
|
+
kept_boxes = class_boxes[keep_indices]
|
|
569
|
+
kept_scores = class_scores[keep_indices]
|
|
570
|
+
kept_classes = torch.full((len(keep_indices), 1), class_id.item(),
|
|
571
|
+
device=prediction.device, dtype=torch.float)
|
|
572
|
+
|
|
573
|
+
# Combine: [x1, y1, x2, y2, conf, class]
|
|
574
|
+
class_detections = torch.cat([kept_boxes, kept_scores.unsqueeze(1), kept_classes], 1)
|
|
575
|
+
final_detections.append(class_detections)
|
|
576
|
+
|
|
577
|
+
# ...for each category
|
|
578
|
+
|
|
579
|
+
if final_detections:
|
|
580
|
+
|
|
581
|
+
# Combine all classes and sort by confidence
|
|
582
|
+
all_detections = torch.cat(final_detections, 0)
|
|
583
|
+
conf_sort_indices = all_detections[:, 4].argsort(descending=True)
|
|
584
|
+
all_detections = all_detections[conf_sort_indices]
|
|
585
|
+
|
|
586
|
+
# Limit to max_det
|
|
587
|
+
if all_detections.shape[0] > max_det:
|
|
588
|
+
all_detections = all_detections[:max_det]
|
|
589
|
+
|
|
590
|
+
output.append(all_detections)
|
|
591
|
+
else:
|
|
592
|
+
output.append(torch.zeros((0, 6), device=prediction.device))
|
|
593
|
+
|
|
594
|
+
# ...for each image in the batch
|
|
595
|
+
|
|
596
|
+
return output
|
|
597
|
+
|
|
598
|
+
# ...def nms(...)
|
|
599
|
+
|
|
600
|
+
|
|
485
601
|
#%% Model metadata functions
|
|
486
602
|
|
|
487
603
|
def add_metadata_to_megadetector_model_file(model_file_in,
|
|
@@ -593,6 +709,8 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
593
709
|
|
|
594
710
|
return d
|
|
595
711
|
|
|
712
|
+
# ...with zipfile.Zipfile(...)
|
|
713
|
+
|
|
596
714
|
# ...def read_metadata_from_megadetector_model_file(...)
|
|
597
715
|
|
|
598
716
|
|
|
@@ -606,10 +724,15 @@ require_non_default_compatibility_mode = False
|
|
|
606
724
|
|
|
607
725
|
class PTDetector:
|
|
608
726
|
"""
|
|
609
|
-
Class that runs a PyTorch-based MegaDetector model.
|
|
727
|
+
Class that runs a PyTorch-based MegaDetector model. Also used as a preprocessor
|
|
728
|
+
for images that will later be run through an instance of PTDetector.
|
|
610
729
|
"""
|
|
611
730
|
|
|
612
731
|
def __init__(self, model_path, detector_options=None, verbose=False):
|
|
732
|
+
"""
|
|
733
|
+
PTDetector constructor. If detector_options['preprocess_only'] exists and is
|
|
734
|
+
True, this instance is being used as a preprocessor, so we don't load model weights.
|
|
735
|
+
"""
|
|
613
736
|
|
|
614
737
|
if verbose:
|
|
615
738
|
print('Initializing PTDetector (verbose)')
|
|
@@ -637,6 +760,8 @@ class PTDetector:
|
|
|
637
760
|
else:
|
|
638
761
|
compatibility_mode = detector_options['compatibility_mode']
|
|
639
762
|
|
|
763
|
+
# This is a global option used only during testing, to make sure I'm hitting
|
|
764
|
+
# the cases where we are not using "classic" preprocessing.
|
|
640
765
|
if require_non_default_compatibility_mode:
|
|
641
766
|
|
|
642
767
|
print('### DEBUG: requiring non-default compatibility mode ###')
|
|
@@ -652,14 +777,16 @@ class PTDetector:
|
|
|
652
777
|
if verbose or (not preprocess_only):
|
|
653
778
|
print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
|
|
654
779
|
|
|
655
|
-
model_metadata = read_metadata_from_megadetector_model_file(model_path)
|
|
780
|
+
self.model_metadata = read_metadata_from_megadetector_model_file(model_path)
|
|
656
781
|
|
|
657
|
-
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side,
|
|
658
|
-
#: aspect ratio".
|
|
659
|
-
if model_metadata is not None and 'image_size' in model_metadata:
|
|
660
|
-
self.default_image_size = model_metadata['image_size']
|
|
782
|
+
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side,
|
|
783
|
+
#: preserving aspect ratio".
|
|
784
|
+
if self.model_metadata is not None and 'image_size' in self.model_metadata:
|
|
785
|
+
self.default_image_size = self.model_metadata['image_size']
|
|
661
786
|
print('Loaded image size {} from model metadata'.format(self.default_image_size))
|
|
662
787
|
else:
|
|
788
|
+
# This is not the default for most YOLO models, but most of the time, if someone
|
|
789
|
+
# is loading a model here that does not have metadata, it's MDv5[ab].0.0
|
|
663
790
|
print('No image size available in model metadata, defaulting to 1280')
|
|
664
791
|
self.default_image_size = 1280
|
|
665
792
|
|
|
@@ -682,12 +809,27 @@ class PTDetector:
|
|
|
682
809
|
#: "classic".
|
|
683
810
|
self.compatibility_mode = compatibility_mode
|
|
684
811
|
|
|
685
|
-
#: Stride size passed to
|
|
812
|
+
#: Stride size passed to the YOLO letterbox() function
|
|
686
813
|
self.letterbox_stride = 32
|
|
687
814
|
|
|
688
|
-
|
|
815
|
+
# This is a convenient heuristic to determine the stride size without actually loading
|
|
816
|
+
# the model: the only models in the YOLO family with a stride size of 64 are the
|
|
817
|
+
# YOLOv5*6 and YOLOv5*6u models, which are 1280px models.
|
|
818
|
+
#
|
|
819
|
+
# See:
|
|
820
|
+
#
|
|
821
|
+
# github.com/ultralytics/ultralytics/issues/21544
|
|
822
|
+
#
|
|
823
|
+
# Note to self, though, if I decide later to require loading the model on preprocessing
|
|
824
|
+
# workers so I can more reliably choose a stride, this is the right way to determine the
|
|
825
|
+
# stride:
|
|
826
|
+
#
|
|
827
|
+
# self.letterbox_stride = int(self.model.stride.max())
|
|
828
|
+
if self.default_image_size == 1280:
|
|
689
829
|
self.letterbox_stride = 64
|
|
690
830
|
|
|
831
|
+
print('Using model stride: {}'.format(self.letterbox_stride))
|
|
832
|
+
|
|
691
833
|
#: Use half-precision inference... fixed by the model, generally don't mess with this
|
|
692
834
|
self.half_precision = False
|
|
693
835
|
|
|
@@ -699,9 +841,16 @@ class PTDetector:
|
|
|
699
841
|
self.device = torch.device('cuda:0')
|
|
700
842
|
try:
|
|
701
843
|
if torch.backends.mps.is_built and torch.backends.mps.is_available():
|
|
702
|
-
|
|
844
|
+
# MPS inference fails on GitHub runners as of 2025.08. This is
|
|
845
|
+
# independent of model size. So, we disable MPS when running in GHA.
|
|
846
|
+
if is_running_in_gha():
|
|
847
|
+
print('GitHub actions detected, bypassing MPS backend')
|
|
848
|
+
else:
|
|
849
|
+
print('Using MPS device')
|
|
850
|
+
self.device = 'mps'
|
|
703
851
|
except AttributeError:
|
|
704
852
|
pass
|
|
853
|
+
|
|
705
854
|
try:
|
|
706
855
|
self.model = PTDetector._load_model(model_path,
|
|
707
856
|
device=self.device,
|
|
@@ -714,7 +863,7 @@ class PTDetector:
|
|
|
714
863
|
# to be the same, so doing that externally doesn't seem *that* rude.
|
|
715
864
|
if "Can't get attribute 'DetectionModel'" in str(e):
|
|
716
865
|
print('Forward-compatibility issue detected, patching')
|
|
717
|
-
from models import yolo
|
|
866
|
+
from models import yolo # type: ignore
|
|
718
867
|
yolo.DetectionModel = yolo.Model
|
|
719
868
|
self.model = PTDetector._load_model(model_path,
|
|
720
869
|
device=self.device,
|
|
@@ -734,16 +883,10 @@ class PTDetector:
|
|
|
734
883
|
if verbose:
|
|
735
884
|
print(f'Using PyTorch version {torch.__version__}')
|
|
736
885
|
|
|
737
|
-
#
|
|
738
|
-
# map_location
|
|
739
|
-
#
|
|
740
|
-
|
|
741
|
-
# very slight changes to the output, which always make me nervous, so I'm not
|
|
742
|
-
# doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
|
|
743
|
-
if 'classic' in compatibility_mode:
|
|
744
|
-
use_map_location = (device != 'mps')
|
|
745
|
-
else:
|
|
746
|
-
use_map_location = False
|
|
886
|
+
# I get quirky errors when loading YOLOv5 models on MPS hardware using
|
|
887
|
+
# map_location, but this is the recommended method, so I'm using it everywhere
|
|
888
|
+
# other than MPS devices.
|
|
889
|
+
use_map_location = (device != 'mps')
|
|
747
890
|
|
|
748
891
|
if use_map_location:
|
|
749
892
|
try:
|
|
@@ -773,297 +916,429 @@ class PTDetector:
|
|
|
773
916
|
if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
|
774
917
|
m.recompute_scale_factor = None
|
|
775
918
|
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
model = checkpoint['model'].float().fuse().eval().to(device)
|
|
919
|
+
# Calling .to(device) should no longer be necessary now that we're using map_location=device
|
|
920
|
+
# model = checkpoint['model'].float().fuse().eval().to(device)
|
|
921
|
+
model = checkpoint['model'].float().fuse().eval()
|
|
780
922
|
|
|
781
923
|
return model
|
|
782
924
|
|
|
783
925
|
# ...def _load_model(...)
|
|
784
926
|
|
|
785
927
|
|
|
786
|
-
def
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
skip_image_resizing=False,
|
|
792
|
-
augment=False,
|
|
793
|
-
preprocess_only=False,
|
|
794
|
-
verbose=False):
|
|
928
|
+
def preprocess_image(self,
|
|
929
|
+
img_original,
|
|
930
|
+
image_id='unknown',
|
|
931
|
+
image_size=None,
|
|
932
|
+
verbose=False):
|
|
795
933
|
"""
|
|
796
|
-
|
|
934
|
+
Prepare an image for detection, including scaling and letterboxing.
|
|
797
935
|
|
|
798
936
|
Args:
|
|
799
|
-
img_original (Image): the
|
|
800
|
-
|
|
937
|
+
img_original (Image or np.array): the image on which we should run the detector, with
|
|
938
|
+
EXIF rotation already handled
|
|
801
939
|
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
802
940
|
of the output object
|
|
803
941
|
detection_threshold (float, optional): only detections above this confidence threshold
|
|
804
942
|
will be included in the return value
|
|
805
|
-
image_size (int, optional): image size to use for inference,
|
|
806
|
-
|
|
807
|
-
skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
|
|
808
|
-
external resizing), only mess with this if (a) you're using a model other than MegaDetector
|
|
809
|
-
or (b) you know what you're getting into
|
|
810
|
-
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
811
|
-
preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
|
|
943
|
+
image_size (int, optional): image size (long side) to use for inference, or None to
|
|
944
|
+
use the default size specified at the time the model was loaded
|
|
812
945
|
verbose (bool, optional): enable additional debug output
|
|
813
946
|
|
|
814
947
|
Returns:
|
|
815
|
-
dict:
|
|
816
|
-
-
|
|
817
|
-
-
|
|
818
|
-
-
|
|
819
|
-
-
|
|
948
|
+
dict: dict with fields:
|
|
949
|
+
- file (filename)
|
|
950
|
+
- img (the preprocessed np.array)
|
|
951
|
+
- img_original (the input image before preprocessing, as an np.array)
|
|
952
|
+
- img_original_pil (the input image before preprocessing, as a PIL Image)
|
|
953
|
+
- target_shape (the 2D shape to which the image was resized during preprocessing)
|
|
954
|
+
- scaling_shape (the 2D original size, for normalizing coordinates later)
|
|
955
|
+
- letterbox_ratio (letterbox parameter used for normalizing coordinates later)
|
|
956
|
+
- letterbox_pad (letterbox parameter used for normalizing coordinates later)
|
|
820
957
|
"""
|
|
821
958
|
|
|
959
|
+
# Prepare return dict
|
|
822
960
|
result = {'file': image_id }
|
|
823
|
-
detections = []
|
|
824
|
-
max_conf = 0.0
|
|
825
961
|
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
assert not skip_image_resizing, \
|
|
830
|
-
'skip_image_resizing and preprocess_only are exclusive'
|
|
962
|
+
# Store the PIL version of the original image, the caller may want to use
|
|
963
|
+
# it for metadata extraction later.
|
|
964
|
+
img_original_pil = None
|
|
831
965
|
|
|
832
|
-
|
|
966
|
+
# If we were given a PIL image, rather than a numpy array
|
|
967
|
+
if not isinstance(img_original,np.ndarray):
|
|
968
|
+
img_original_pil = img_original
|
|
969
|
+
img_original = np.asarray(img_original)
|
|
833
970
|
|
|
834
|
-
|
|
971
|
+
# PIL images are RGB already
|
|
972
|
+
# img_original = img_original[:, :, ::-1]
|
|
835
973
|
|
|
836
|
-
|
|
974
|
+
# Save the original shape for scaling boxes later
|
|
975
|
+
scaling_shape = img_original.shape
|
|
837
976
|
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
if isinstance(img_original,dict):
|
|
842
|
-
image_info = img_original
|
|
843
|
-
img = image_info['img_processed']
|
|
844
|
-
scaling_shape = image_info['scaling_shape']
|
|
845
|
-
letterbox_pad = image_info['letterbox_pad']
|
|
846
|
-
letterbox_ratio = image_info['letterbox_ratio']
|
|
847
|
-
img_original = image_info['img_original']
|
|
848
|
-
img_original_pil = image_info['img_original_pil']
|
|
849
|
-
else:
|
|
850
|
-
img = img_original
|
|
977
|
+
# If the caller is requesting a specific target size...
|
|
978
|
+
if image_size is not None:
|
|
851
979
|
|
|
852
|
-
|
|
980
|
+
assert isinstance(image_size,int)
|
|
853
981
|
|
|
854
|
-
|
|
855
|
-
|
|
982
|
+
if not self.printed_image_size_warning:
|
|
983
|
+
print('Using user-supplied image size {}'.format(image_size))
|
|
984
|
+
self.printed_image_size_warning = True
|
|
856
985
|
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
986
|
+
# Otherwise resize to self.default_image_size
|
|
987
|
+
else:
|
|
988
|
+
|
|
989
|
+
image_size = self.default_image_size
|
|
990
|
+
self.printed_image_size_warning = False
|
|
991
|
+
|
|
992
|
+
# ...if the caller has specified an image size
|
|
860
993
|
|
|
861
|
-
|
|
862
|
-
|
|
994
|
+
# In "classic mode", we only do the letterboxing resize, we don't do an
|
|
995
|
+
# additional initial resizing operation
|
|
996
|
+
if 'classic' in self.compatibility_mode:
|
|
997
|
+
|
|
998
|
+
resize_ratio = 1.0
|
|
863
999
|
|
|
864
|
-
|
|
865
|
-
|
|
1000
|
+
# Resize the image so the long side matches the target image size. This is not
|
|
1001
|
+
# letterboxing (i.e., padding) yet, just resizing.
|
|
1002
|
+
else:
|
|
866
1003
|
|
|
867
|
-
|
|
868
|
-
if image_size is not None:
|
|
1004
|
+
use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
|
|
869
1005
|
|
|
870
|
-
|
|
1006
|
+
h,w = img_original.shape[:2]
|
|
1007
|
+
resize_ratio = image_size / max(h,w)
|
|
871
1008
|
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
self.printed_image_size_warning = True
|
|
1009
|
+
# Only resize if we have to
|
|
1010
|
+
if resize_ratio != 1:
|
|
875
1011
|
|
|
876
|
-
#
|
|
1012
|
+
# Match what yolov5 does: use linear interpolation for upsizing;
|
|
1013
|
+
# area interpolation for downsizing
|
|
1014
|
+
if resize_ratio > 1:
|
|
1015
|
+
interpolation_method = cv2.INTER_LINEAR
|
|
877
1016
|
else:
|
|
1017
|
+
interpolation_method = cv2.INTER_AREA
|
|
878
1018
|
|
|
879
|
-
|
|
880
|
-
|
|
1019
|
+
if use_ceil_for_resize:
|
|
1020
|
+
target_w = math.ceil(w * resize_ratio)
|
|
1021
|
+
target_h = math.ceil(h * resize_ratio)
|
|
1022
|
+
else:
|
|
1023
|
+
target_w = int(w * resize_ratio)
|
|
1024
|
+
target_h = int(h * resize_ratio)
|
|
881
1025
|
|
|
882
|
-
|
|
1026
|
+
img_original = cv2.resize(
|
|
1027
|
+
img_original, (target_w, target_h),
|
|
1028
|
+
interpolation=interpolation_method)
|
|
883
1029
|
|
|
884
|
-
|
|
885
|
-
# additional initial resizing operation
|
|
886
|
-
if 'classic' in self.compatibility_mode:
|
|
1030
|
+
if 'classic' in self.compatibility_mode:
|
|
887
1031
|
|
|
888
|
-
|
|
1032
|
+
letterbox_auto = True
|
|
1033
|
+
letterbox_scaleup = True
|
|
1034
|
+
target_shape = image_size
|
|
889
1035
|
|
|
890
|
-
|
|
891
|
-
# letterboxing (i.e., padding) yet, just resizing.
|
|
892
|
-
else:
|
|
1036
|
+
else:
|
|
893
1037
|
|
|
894
|
-
|
|
1038
|
+
letterbox_auto = False
|
|
1039
|
+
letterbox_scaleup = False
|
|
895
1040
|
|
|
896
|
-
|
|
897
|
-
|
|
1041
|
+
# The padding to apply as a fraction of the stride size
|
|
1042
|
+
pad = 0.5
|
|
898
1043
|
|
|
899
|
-
|
|
900
|
-
|
|
1044
|
+
# Resize to a multiple of the model stride
|
|
1045
|
+
#
|
|
1046
|
+
# This is how we would determine the stride if we knew the model had been loaded:
|
|
1047
|
+
#
|
|
1048
|
+
# model_stride = int(self.model.stride.max())
|
|
1049
|
+
#
|
|
1050
|
+
# ...but because we do this on preprocessing workers now, we try to avoid loading the model
|
|
1051
|
+
# just for preprocessing, and we assume the stride was determined at the time the PTDetector
|
|
1052
|
+
# object was created.
|
|
1053
|
+
try:
|
|
1054
|
+
model_stride = int(self.model.stride.max())
|
|
1055
|
+
if model_stride != self.letterbox_stride:
|
|
1056
|
+
print('*** Warning: model stride is {}, stride at construction time was {} ***'.format(
|
|
1057
|
+
model_stride,self.letterbox_stride
|
|
1058
|
+
))
|
|
1059
|
+
except Exception:
|
|
1060
|
+
pass
|
|
901
1061
|
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
1062
|
+
model_stride = self.letterbox_stride
|
|
1063
|
+
max_dimension = max(img_original.shape)
|
|
1064
|
+
normalized_shape = [img_original.shape[0] / max_dimension,
|
|
1065
|
+
img_original.shape[1] / max_dimension]
|
|
1066
|
+
target_shape = np.ceil(((np.array(normalized_shape) * image_size) / model_stride) + \
|
|
1067
|
+
pad).astype(int) * model_stride
|
|
1068
|
+
|
|
1069
|
+
# Now we letterbox, which is just padding, since we've already resized
|
|
1070
|
+
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
|
|
1071
|
+
new_shape=target_shape,
|
|
1072
|
+
stride=self.letterbox_stride,
|
|
1073
|
+
auto=letterbox_auto,
|
|
1074
|
+
scaleFill=False,
|
|
1075
|
+
scaleup=letterbox_scaleup)
|
|
1076
|
+
|
|
1077
|
+
result['img_processed'] = img
|
|
1078
|
+
result['img_original'] = img_original
|
|
1079
|
+
result['img_original_pil'] = img_original_pil
|
|
1080
|
+
result['target_shape'] = target_shape
|
|
1081
|
+
result['scaling_shape'] = scaling_shape
|
|
1082
|
+
result['letterbox_ratio'] = letterbox_ratio
|
|
1083
|
+
result['letterbox_pad'] = letterbox_pad
|
|
1084
|
+
return result
|
|
908
1085
|
|
|
909
|
-
|
|
910
|
-
target_w = math.ceil(w * resize_ratio)
|
|
911
|
-
target_h = math.ceil(h * resize_ratio)
|
|
912
|
-
else:
|
|
913
|
-
target_w = int(w * resize_ratio)
|
|
914
|
-
target_h = int(h * resize_ratio)
|
|
1086
|
+
# ...def preprocess_image(...)
|
|
915
1087
|
|
|
916
|
-
img_original = cv2.resize(
|
|
917
|
-
img_original, (target_w, target_h),
|
|
918
|
-
interpolation=interpolation_method)
|
|
919
1088
|
|
|
920
|
-
|
|
1089
|
+
def generate_detections_one_batch(self,
|
|
1090
|
+
img_original,
|
|
1091
|
+
image_id=None,
|
|
1092
|
+
detection_threshold=0.00001,
|
|
1093
|
+
image_size=None,
|
|
1094
|
+
augment=False,
|
|
1095
|
+
verbose=False):
|
|
1096
|
+
"""
|
|
1097
|
+
Run a detector on a batch of images.
|
|
1098
|
+
|
|
1099
|
+
Args:
|
|
1100
|
+
img_original (list): list of images (Image, np.array, or dict) on which we should run the detector, with
|
|
1101
|
+
EXIF rotation already handled, or dicts representing preprocessed images with associated
|
|
1102
|
+
letterbox parameters
|
|
1103
|
+
image_id (list or None): list of paths to identify the images; will be in the "file" field
|
|
1104
|
+
of the output objects. Will be ignored when img_original contains preprocessed dicts.
|
|
1105
|
+
detection_threshold (float, optional): only detections above this confidence threshold
|
|
1106
|
+
will be included in the return value
|
|
1107
|
+
image_size (int, optional): image size (long side) to use for inference, or None to
|
|
1108
|
+
use the default size specified at the time the model was loaded
|
|
1109
|
+
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
1110
|
+
verbose (bool, optional): enable additional debug output
|
|
1111
|
+
|
|
1112
|
+
Returns:
|
|
1113
|
+
list: a list of dictionaries, each with the following fields:
|
|
1114
|
+
- 'file' (filename, always present)
|
|
1115
|
+
- 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
|
|
1116
|
+
- 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
|
|
1117
|
+
- 'failure' (a failure string, or None if everything went fine)
|
|
1118
|
+
"""
|
|
1119
|
+
|
|
1120
|
+
# Validate inputs
|
|
1121
|
+
if not isinstance(img_original, list):
|
|
1122
|
+
raise ValueError('img_original must be a list for batch processing')
|
|
1123
|
+
|
|
1124
|
+
if len(img_original) == 0:
|
|
1125
|
+
return []
|
|
1126
|
+
|
|
1127
|
+
# Check input consistency
|
|
1128
|
+
if isinstance(img_original[0], dict):
|
|
1129
|
+
# All items in img_original should be preprocessed dicts
|
|
1130
|
+
for i, img in enumerate(img_original):
|
|
1131
|
+
if not isinstance(img, dict):
|
|
1132
|
+
raise ValueError(f'Mixed input types in batch: item {i} is not a dict, but item 0 is a dict')
|
|
1133
|
+
else:
|
|
1134
|
+
# All items in img_original should be PIL/numpy images, and image_id should be a list of strings
|
|
1135
|
+
if image_id is None:
|
|
1136
|
+
raise ValueError('image_id must be a list when img_original contains PIL/numpy images')
|
|
1137
|
+
if not isinstance(image_id, list):
|
|
1138
|
+
raise ValueError('image_id must be a list for batch processing')
|
|
1139
|
+
if len(image_id) != len(img_original):
|
|
1140
|
+
raise ValueError(
|
|
1141
|
+
'Length mismatch: img_original has {} items, image_id has {} items'.format(
|
|
1142
|
+
len(img_original),len(image_id)))
|
|
1143
|
+
for i_img, img in enumerate(img_original):
|
|
1144
|
+
if isinstance(img, dict):
|
|
1145
|
+
raise ValueError(
|
|
1146
|
+
'Mixed input types in batch: item {} is a dict, but item 0 is not a dict'.format(
|
|
1147
|
+
i_img))
|
|
1148
|
+
|
|
1149
|
+
if detection_threshold is None:
|
|
1150
|
+
detection_threshold = 0.0
|
|
921
1151
|
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
target_shape = image_size
|
|
1152
|
+
batch_size = len(img_original)
|
|
1153
|
+
results = [None] * batch_size
|
|
925
1154
|
|
|
1155
|
+
# Preprocess all images, handling failures
|
|
1156
|
+
preprocessed_images = []
|
|
1157
|
+
preprocessing_failed_indices = set()
|
|
1158
|
+
|
|
1159
|
+
for i_img, img in enumerate(img_original):
|
|
1160
|
+
|
|
1161
|
+
try:
|
|
1162
|
+
if isinstance(img, dict):
|
|
1163
|
+
# Already preprocessed
|
|
1164
|
+
image_info = img
|
|
1165
|
+
current_image_id = image_info['file']
|
|
926
1166
|
else:
|
|
1167
|
+
# Need to preprocess
|
|
1168
|
+
current_image_id = image_id[i_img]
|
|
1169
|
+
image_info = self.preprocess_image(
|
|
1170
|
+
img_original=img,
|
|
1171
|
+
image_id=current_image_id,
|
|
1172
|
+
image_size=image_size,
|
|
1173
|
+
verbose=verbose)
|
|
927
1174
|
|
|
928
|
-
|
|
929
|
-
letterbox_scaleup = False
|
|
1175
|
+
preprocessed_images.append((i_img, image_info, current_image_id))
|
|
930
1176
|
|
|
931
|
-
|
|
932
|
-
|
|
1177
|
+
except Exception as e:
|
|
1178
|
+
print('Warning: preprocessing failed for image {}: {}'.format(
|
|
1179
|
+
image_id[i_img] if image_id else f'index_{i_img}', str(e)))
|
|
1180
|
+
|
|
1181
|
+
preprocessing_failed_indices.add(i_img)
|
|
1182
|
+
current_image_id = image_id[i_img] if image_id else f'index_{i_img}'
|
|
1183
|
+
results[i_img] = {
|
|
1184
|
+
'file': current_image_id,
|
|
1185
|
+
'detections': None,
|
|
1186
|
+
'failure': FAILURE_IMAGE_OPEN
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
# ...for each image in this batch
|
|
1190
|
+
|
|
1191
|
+
# Group preprocessed images by actual processed image shape for batching
|
|
1192
|
+
shape_groups = {}
|
|
1193
|
+
for original_idx, image_info, current_image_id in preprocessed_images:
|
|
1194
|
+
# Use the actual processed image shape for grouping, not target_shape
|
|
1195
|
+
actual_shape = tuple(image_info['img_processed'].shape)
|
|
1196
|
+
if actual_shape not in shape_groups:
|
|
1197
|
+
shape_groups[actual_shape] = []
|
|
1198
|
+
shape_groups[actual_shape].append((original_idx, image_info, current_image_id))
|
|
1199
|
+
|
|
1200
|
+
# Process each shape group as a batch
|
|
1201
|
+
for target_shape, group_items in shape_groups.items():
|
|
933
1202
|
|
|
934
|
-
|
|
1203
|
+
try:
|
|
1204
|
+
self._process_batch_group(group_items, results, detection_threshold, augment, verbose)
|
|
1205
|
+
except Exception as e:
|
|
1206
|
+
# If inference fails for the entire batch, mark all images in this batch as failed
|
|
1207
|
+
print('Warning: batch inference failed for shape {}: {}'.format(target_shape, str(e)))
|
|
935
1208
|
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
1209
|
+
for original_idx, image_info, current_image_id in group_items:
|
|
1210
|
+
results[original_idx] = {
|
|
1211
|
+
'file': current_image_id,
|
|
1212
|
+
'detections': None,
|
|
1213
|
+
'failure': FAILURE_INFER
|
|
1214
|
+
}
|
|
941
1215
|
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
new_shape=target_shape,
|
|
945
|
-
stride=self.letterbox_stride,
|
|
946
|
-
auto=letterbox_auto,
|
|
947
|
-
scaleFill=False,
|
|
948
|
-
scaleup=letterbox_scaleup)
|
|
1216
|
+
# ...for each shape group
|
|
1217
|
+
return results
|
|
949
1218
|
|
|
950
|
-
|
|
1219
|
+
# ...def generate_detections_one_batch(...)
|
|
951
1220
|
|
|
952
|
-
assert 'file' in result
|
|
953
|
-
result['img_processed'] = img
|
|
954
|
-
result['img_original'] = img_original
|
|
955
|
-
result['img_original_pil'] = img_original_pil
|
|
956
|
-
result['target_shape'] = target_shape
|
|
957
|
-
result['scaling_shape'] = scaling_shape
|
|
958
|
-
result['letterbox_ratio'] = letterbox_ratio
|
|
959
|
-
result['letterbox_pad'] = letterbox_pad
|
|
960
|
-
return result
|
|
961
1221
|
|
|
962
|
-
|
|
1222
|
+
def _process_batch_group(self, group_items, results, detection_threshold, augment, verbose):
|
|
1223
|
+
"""
|
|
1224
|
+
Process a group of images with the same target shape as a single batch.
|
|
963
1225
|
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
img = np.ascontiguousarray(img)
|
|
971
|
-
img = torch.from_numpy(img)
|
|
972
|
-
img = img.to(self.device)
|
|
973
|
-
img = img.half() if self.half_precision else img.float()
|
|
974
|
-
img /= 255
|
|
975
|
-
|
|
976
|
-
# In practice this is always true
|
|
977
|
-
if len(img.shape) == 3:
|
|
978
|
-
img = torch.unsqueeze(img, 0)
|
|
979
|
-
|
|
980
|
-
# Run the model
|
|
981
|
-
pred = self.model(img,augment=augment)[0]
|
|
982
|
-
|
|
983
|
-
if 'classic' in self.compatibility_mode:
|
|
984
|
-
nms_conf_thres = detection_threshold
|
|
985
|
-
nms_iou_thres = 0.45
|
|
986
|
-
nms_agnostic = False
|
|
987
|
-
nms_multi_label = False
|
|
988
|
-
else:
|
|
989
|
-
nms_conf_thres = detection_threshold # 0.01
|
|
990
|
-
nms_iou_thres = 0.6
|
|
991
|
-
nms_agnostic = False
|
|
992
|
-
nms_multi_label = True
|
|
1226
|
+
Args:
|
|
1227
|
+
group_items (list): List of (original_idx, image_info, current_image_id) tuples
|
|
1228
|
+
results (list): Results list to populate (modified in place)
|
|
1229
|
+
detection_threshold (float): Detection confidence threshold
|
|
1230
|
+
augment (bool): Enable augmentation
|
|
1231
|
+
verbose (bool): Enable verbose output
|
|
993
1232
|
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
pred_nms = pred.cpu()
|
|
999
|
-
else:
|
|
1000
|
-
pred_nms = pred
|
|
1233
|
+
Returns:
|
|
1234
|
+
list of dict: list of dictionaries the same length as group_items, with fields 'file',
|
|
1235
|
+
'detections', 'max_detection_conf'.
|
|
1236
|
+
"""
|
|
1001
1237
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
conf_thres=nms_conf_thres,
|
|
1005
|
-
iou_thres=nms_iou_thres,
|
|
1006
|
-
agnostic=nms_agnostic,
|
|
1007
|
-
multi_label=nms_multi_label)
|
|
1238
|
+
if len(group_items) == 0:
|
|
1239
|
+
return
|
|
1008
1240
|
|
|
1009
|
-
|
|
1010
|
-
|
|
1241
|
+
# Extract batch data
|
|
1242
|
+
batch_images = []
|
|
1243
|
+
batch_metadata = []
|
|
1011
1244
|
|
|
1012
|
-
|
|
1245
|
+
# For each image in this batch...
|
|
1246
|
+
for original_idx, image_info, current_image_id in group_items:
|
|
1013
1247
|
|
|
1014
|
-
|
|
1015
|
-
ratio_pad = None
|
|
1248
|
+
img = image_info['img_processed']
|
|
1016
1249
|
|
|
1017
|
-
|
|
1250
|
+
# Convert HWC to CHW and prepare tensor
|
|
1251
|
+
img_tensor = img.transpose((2, 0, 1))
|
|
1252
|
+
img_tensor = np.ascontiguousarray(img_tensor)
|
|
1253
|
+
img_tensor = torch.from_numpy(img_tensor)
|
|
1254
|
+
batch_images.append(img_tensor)
|
|
1018
1255
|
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1256
|
+
metadata = {
|
|
1257
|
+
'original_idx': original_idx,
|
|
1258
|
+
'current_image_id': current_image_id,
|
|
1259
|
+
'scaling_shape': image_info['scaling_shape'],
|
|
1260
|
+
'letterbox_pad': image_info['letterbox_pad'],
|
|
1261
|
+
'img_original': image_info['img_original']
|
|
1262
|
+
}
|
|
1263
|
+
batch_metadata.append(metadata)
|
|
1026
1264
|
|
|
1027
|
-
|
|
1028
|
-
# since we're not doing batch inference.
|
|
1029
|
-
#
|
|
1030
|
-
# det = pred[0]
|
|
1031
|
-
#
|
|
1032
|
-
# det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
|
|
1033
|
-
# in descending order by confidence.
|
|
1034
|
-
#
|
|
1035
|
-
# Columns are:
|
|
1036
|
-
#
|
|
1037
|
-
# x0,y0,x1,y1,confidence,class
|
|
1038
|
-
#
|
|
1039
|
-
# At this point, these are *non*-normalized values, referring to the size at which we
|
|
1040
|
-
# ran inference (img.shape).
|
|
1041
|
-
for det in pred:
|
|
1265
|
+
# ...for each image in this batch
|
|
1042
1266
|
|
|
1043
|
-
|
|
1044
|
-
|
|
1267
|
+
# Stack images into a batch tensor
|
|
1268
|
+
batch_tensor = torch.stack(batch_images)
|
|
1045
1269
|
|
|
1046
|
-
|
|
1047
|
-
|
|
1270
|
+
batch_tensor = batch_tensor.float()
|
|
1271
|
+
batch_tensor /= 255.0
|
|
1272
|
+
|
|
1273
|
+
batch_tensor = batch_tensor.to(self.device)
|
|
1274
|
+
if self.half_precision:
|
|
1275
|
+
batch_tensor = batch_tensor.half()
|
|
1276
|
+
|
|
1277
|
+
# Run the model on the batch
|
|
1278
|
+
pred = self.model(batch_tensor, augment=augment)[0]
|
|
1279
|
+
|
|
1280
|
+
# Configure NMS parameters
|
|
1281
|
+
if 'classic' in self.compatibility_mode:
|
|
1282
|
+
nms_iou_thres = 0.45
|
|
1283
|
+
else:
|
|
1284
|
+
nms_iou_thres = 0.6
|
|
1285
|
+
|
|
1286
|
+
pred = nms(prediction=pred,
|
|
1287
|
+
conf_thres=detection_threshold,
|
|
1288
|
+
iou_thres=nms_iou_thres)
|
|
1289
|
+
|
|
1290
|
+
# For posterity, the ultralytics implementation
|
|
1291
|
+
if False:
|
|
1292
|
+
pred = non_max_suppression(prediction=pred,
|
|
1293
|
+
conf_thres=detection_threshold,
|
|
1294
|
+
iou_thres=nms_iou_thres,
|
|
1295
|
+
agnostic=False,
|
|
1296
|
+
multi_label=False)
|
|
1048
1297
|
|
|
1049
|
-
|
|
1298
|
+
assert isinstance(pred, list)
|
|
1299
|
+
assert len(pred) == len(batch_metadata), \
|
|
1300
|
+
print('Mismatch between prediction length {} and batch size {}'.format(
|
|
1301
|
+
len(pred),len(batch_metadata)))
|
|
1050
1302
|
|
|
1303
|
+
# Process each image's detections
|
|
1304
|
+
for i_image, det in enumerate(pred):
|
|
1305
|
+
|
|
1306
|
+
metadata = batch_metadata[i_image]
|
|
1307
|
+
original_idx = metadata['original_idx']
|
|
1308
|
+
current_image_id = metadata['current_image_id']
|
|
1309
|
+
scaling_shape = metadata['scaling_shape']
|
|
1310
|
+
letterbox_pad = metadata['letterbox_pad']
|
|
1311
|
+
img_original = metadata['img_original']
|
|
1312
|
+
|
|
1313
|
+
detections = []
|
|
1314
|
+
max_conf = 0.0
|
|
1315
|
+
|
|
1316
|
+
if len(det) > 0:
|
|
1317
|
+
|
|
1318
|
+
# Prepare scaling parameters
|
|
1319
|
+
gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
|
|
1320
|
+
|
|
1321
|
+
if 'classic' in self.compatibility_mode:
|
|
1322
|
+
ratio = None
|
|
1323
|
+
ratio_pad = None
|
|
1051
1324
|
else:
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1325
|
+
ratio = (img_original.shape[0]/scaling_shape[0],
|
|
1326
|
+
img_original.shape[1]/scaling_shape[1])
|
|
1327
|
+
ratio_pad = (ratio, letterbox_pad)
|
|
1055
1328
|
|
|
1056
|
-
#
|
|
1057
|
-
|
|
1329
|
+
# Rescale boxes
|
|
1330
|
+
if 'classic' in self.compatibility_mode:
|
|
1331
|
+
det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], img_original.shape).round()
|
|
1332
|
+
else:
|
|
1333
|
+
det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
|
|
1058
1334
|
|
|
1335
|
+
# Process each detection
|
|
1336
|
+
for *xyxy, conf, cls in reversed(det):
|
|
1059
1337
|
if conf < detection_threshold:
|
|
1060
1338
|
continue
|
|
1061
1339
|
|
|
1062
|
-
# Convert
|
|
1340
|
+
# Convert to YOLO format then to MD format
|
|
1063
1341
|
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
|
|
1064
|
-
|
|
1065
|
-
# Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
|
|
1066
|
-
# left/top/w/h (i.e., MD format)
|
|
1067
1342
|
api_box = ct_utils.convert_yolo_to_xywh(xywh)
|
|
1068
1343
|
|
|
1069
1344
|
if 'classic' in self.compatibility_mode:
|
|
@@ -1074,8 +1349,6 @@ class PTDetector:
|
|
|
1074
1349
|
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
|
|
1075
1350
|
|
|
1076
1351
|
if not self.use_model_native_classes:
|
|
1077
|
-
# The MegaDetector output format's categories start at 1, but all YOLO-based
|
|
1078
|
-
# MD models have category numbers starting at 0.
|
|
1079
1352
|
cls = int(cls.tolist()) + 1
|
|
1080
1353
|
if cls not in (1, 2, 3):
|
|
1081
1354
|
raise KeyError(f'{cls} is not a valid class.')
|
|
@@ -1089,47 +1362,73 @@ class PTDetector:
|
|
|
1089
1362
|
})
|
|
1090
1363
|
max_conf = max(max_conf, conf)
|
|
1091
1364
|
|
|
1092
|
-
# ...for each detection
|
|
1093
|
-
|
|
1094
|
-
# ...for each detection batch (always one iteration)
|
|
1365
|
+
# ...for each detection
|
|
1095
1366
|
|
|
1096
|
-
|
|
1367
|
+
# ...if there are > 0 detections
|
|
1097
1368
|
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
result['max_detection_conf'] = max_conf
|
|
1106
|
-
result['detections'] = detections
|
|
1107
|
-
|
|
1108
|
-
return result
|
|
1109
|
-
|
|
1110
|
-
# ...def generate_detections_one_image(...)
|
|
1111
|
-
|
|
1112
|
-
# ...class PTDetector
|
|
1369
|
+
# Store result for this image
|
|
1370
|
+
results[original_idx] = {
|
|
1371
|
+
'file': current_image_id,
|
|
1372
|
+
'detections': detections,
|
|
1373
|
+
'max_detection_conf': max_conf
|
|
1374
|
+
}
|
|
1113
1375
|
|
|
1376
|
+
# ...for each image
|
|
1114
1377
|
|
|
1115
|
-
|
|
1378
|
+
# ...def _process_batch_group(...)
|
|
1116
1379
|
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1380
|
+
def generate_detections_one_image(self,
|
|
1381
|
+
img_original,
|
|
1382
|
+
image_id='unknown',
|
|
1383
|
+
detection_threshold=0.00001,
|
|
1384
|
+
image_size=None,
|
|
1385
|
+
augment=False,
|
|
1386
|
+
verbose=False):
|
|
1387
|
+
"""
|
|
1388
|
+
Run a detector on an image (wrapper around batch function).
|
|
1120
1389
|
|
|
1121
|
-
|
|
1390
|
+
Args:
|
|
1391
|
+
img_original (Image, np.array, or dict): the image on which we should run the detector, with
|
|
1392
|
+
EXIF rotation already handled, or a dict representing a preprocessed image with associated
|
|
1393
|
+
letterbox parameters
|
|
1394
|
+
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
1395
|
+
of the output object
|
|
1396
|
+
detection_threshold (float, optional): only detections above this confidence threshold
|
|
1397
|
+
will be included in the return value
|
|
1398
|
+
image_size (int, optional): image size (long side) to use for inference, or None to
|
|
1399
|
+
use the default size specified at the time the model was loaded
|
|
1400
|
+
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
1401
|
+
verbose (bool, optional): enable additional debug output
|
|
1122
1402
|
|
|
1123
|
-
|
|
1403
|
+
Returns:
|
|
1404
|
+
dict: a dictionary with the following fields:
|
|
1405
|
+
- 'file' (filename, always present)
|
|
1406
|
+
- 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
|
|
1407
|
+
- 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
|
|
1408
|
+
- 'failure' (a failure string, or None if everything went fine)
|
|
1409
|
+
"""
|
|
1124
1410
|
|
|
1125
|
-
|
|
1126
|
-
|
|
1411
|
+
# Prepare batch inputs
|
|
1412
|
+
if isinstance(img_original, dict):
|
|
1413
|
+
batch_results = self.generate_detections_one_batch(
|
|
1414
|
+
img_original=[img_original],
|
|
1415
|
+
image_id=None,
|
|
1416
|
+
detection_threshold=detection_threshold,
|
|
1417
|
+
image_size=image_size,
|
|
1418
|
+
augment=augment,
|
|
1419
|
+
verbose=verbose)
|
|
1420
|
+
else:
|
|
1421
|
+
batch_results = self.generate_detections_one_batch(
|
|
1422
|
+
img_original=[img_original],
|
|
1423
|
+
image_id=[image_id],
|
|
1424
|
+
detection_threshold=detection_threshold,
|
|
1425
|
+
image_size=image_size,
|
|
1426
|
+
augment=augment,
|
|
1427
|
+
verbose=verbose)
|
|
1127
1428
|
|
|
1128
|
-
|
|
1129
|
-
|
|
1429
|
+
# Return the single result
|
|
1430
|
+
return batch_results[0]
|
|
1130
1431
|
|
|
1131
|
-
|
|
1132
|
-
image = vis_utils.load_image(im_file)
|
|
1432
|
+
# ...def generate_detections_one_image(...)
|
|
1133
1433
|
|
|
1134
|
-
|
|
1135
|
-
print(res)
|
|
1434
|
+
# ...class PTDetector
|