megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1451 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
pytorch_detector.py
|
|
4
|
+
|
|
5
|
+
Module to run YOLO-based MegaDetector models.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
#%% Imports and constants
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
import math
|
|
14
|
+
import zipfile
|
|
15
|
+
import tempfile
|
|
16
|
+
import shutil
|
|
17
|
+
import uuid
|
|
18
|
+
import json
|
|
19
|
+
import inspect
|
|
20
|
+
|
|
21
|
+
import cv2
|
|
22
|
+
import torch
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from megadetector.detection.run_detector import \
|
|
26
|
+
CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, FAILURE_IMAGE_OPEN, \
|
|
27
|
+
get_detector_version_from_model_file, \
|
|
28
|
+
known_models
|
|
29
|
+
from megadetector.utils.ct_utils import parse_bool_string
|
|
30
|
+
from megadetector.utils.ct_utils import is_running_in_gha
|
|
31
|
+
from megadetector.utils import ct_utils
|
|
32
|
+
import torchvision
|
|
33
|
+
|
|
34
|
+
# We support a few ways of accessing the YOLOv5 dependencies:
|
|
35
|
+
#
|
|
36
|
+
# * The standard configuration as of 2023.09 expects that the YOLOv5 repo is checked
|
|
37
|
+
# out and on the PYTHONPATH (import utils)
|
|
38
|
+
#
|
|
39
|
+
# * Supported but non-default (used for PyPI packaging):
|
|
40
|
+
#
|
|
41
|
+
# pip install ultralytics-yolov5
|
|
42
|
+
#
|
|
43
|
+
# * Works, but not supported:
|
|
44
|
+
#
|
|
45
|
+
# pip install yolov5
|
|
46
|
+
#
|
|
47
|
+
# * Unfinished:
|
|
48
|
+
#
|
|
49
|
+
# pip install ultralytics
|
|
50
|
+
|
|
51
|
+
yolo_model_type_imported = None
|
|
52
|
+
|
|
53
|
+
def _get_model_type_for_model(model_file,
|
|
54
|
+
prefer_model_type_source='table',
|
|
55
|
+
default_model_type='yolov5',
|
|
56
|
+
verbose=False):
|
|
57
|
+
"""
|
|
58
|
+
Determine the model type (i.e., the inference library we need to use) for a .pt file.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
model_file (str): the model file to read
|
|
62
|
+
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
63
|
+
case where the metadata in the file indicates one model type, but the global model
|
|
64
|
+
type table says something else. Should be "table" (trust the table) or "file"
|
|
65
|
+
(trust the file).
|
|
66
|
+
default_model_type (str, optional): return value for the case where we can't find
|
|
67
|
+
appropriate metadata in the file or in the global table.
|
|
68
|
+
verbose (bool, optional): enable additional debug output
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
str: the model type indicated for this model
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
model_info = read_metadata_from_megadetector_model_file(model_file)
|
|
75
|
+
|
|
76
|
+
# Check whether the model file itself specified a model type
|
|
77
|
+
model_type_from_model_file_metadata = None
|
|
78
|
+
|
|
79
|
+
if model_info is not None and 'model_type' in model_info:
|
|
80
|
+
model_type_from_model_file_metadata = model_info['model_type']
|
|
81
|
+
if verbose:
|
|
82
|
+
print('Parsed model type {} from model {}'.format(
|
|
83
|
+
model_type_from_model_file_metadata,
|
|
84
|
+
model_file))
|
|
85
|
+
|
|
86
|
+
model_type_from_model_version = None
|
|
87
|
+
|
|
88
|
+
# Check whether this is a known model version with a specific model type
|
|
89
|
+
model_version_from_file = get_detector_version_from_model_file(model_file)
|
|
90
|
+
|
|
91
|
+
if model_version_from_file is not None and model_version_from_file in known_models:
|
|
92
|
+
model_info = known_models[model_version_from_file]
|
|
93
|
+
if 'model_type' in model_info:
|
|
94
|
+
model_type_from_model_version = model_info['model_type']
|
|
95
|
+
if verbose:
|
|
96
|
+
print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
|
|
97
|
+
else:
|
|
98
|
+
model_type_from_model_version = None
|
|
99
|
+
|
|
100
|
+
if model_type_from_model_file_metadata is None and \
|
|
101
|
+
model_type_from_model_version is None:
|
|
102
|
+
if verbose:
|
|
103
|
+
print('Could not determine model type for {}, assuming {}'.format(
|
|
104
|
+
model_file,default_model_type))
|
|
105
|
+
model_type = default_model_type
|
|
106
|
+
|
|
107
|
+
elif model_type_from_model_file_metadata is not None and \
|
|
108
|
+
model_type_from_model_version is not None:
|
|
109
|
+
if model_type_from_model_version == model_type_from_model_file_metadata:
|
|
110
|
+
model_type = model_type_from_model_file_metadata
|
|
111
|
+
else:
|
|
112
|
+
print('Warning: model type from model version is {}, from file metadata is {}'.format(
|
|
113
|
+
model_type_from_model_version,model_type_from_model_file_metadata))
|
|
114
|
+
if prefer_model_type_source == 'table':
|
|
115
|
+
model_type = model_type_from_model_version
|
|
116
|
+
else:
|
|
117
|
+
model_type = model_type_from_model_file_metadata
|
|
118
|
+
|
|
119
|
+
elif model_type_from_model_file_metadata is not None:
|
|
120
|
+
|
|
121
|
+
model_type = model_type_from_model_file_metadata
|
|
122
|
+
|
|
123
|
+
elif model_type_from_model_version is not None:
|
|
124
|
+
|
|
125
|
+
model_type = model_type_from_model_version
|
|
126
|
+
|
|
127
|
+
return model_type
|
|
128
|
+
|
|
129
|
+
# ...def _get_model_type_for_model(...)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _initialize_yolo_imports_for_model(model_file,
|
|
133
|
+
prefer_model_type_source='table',
|
|
134
|
+
default_model_type='yolov5',
|
|
135
|
+
detector_options=None,
|
|
136
|
+
verbose=False):
|
|
137
|
+
"""
|
|
138
|
+
Initialize the appropriate YOLO imports for a model file.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
model_file (str): The model file for which we're loading support
|
|
142
|
+
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
143
|
+
case where the metadata in the file indicates one model type, but the global model
|
|
144
|
+
type table says something else. Should be "table" (trust the table) or "file"
|
|
145
|
+
(trust the file).
|
|
146
|
+
default_model_type (str, optional): return value for the case where we can't find
|
|
147
|
+
appropriate metadata in the file or in the global table.
|
|
148
|
+
detector_options (dict, optional): dictionary of detector options that mean
|
|
149
|
+
different things to different models
|
|
150
|
+
verbose (bool, optional): enable additional debug output
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
str: the model type for which we initialized support
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
global yolo_model_type_imported
|
|
157
|
+
|
|
158
|
+
if detector_options is not None and 'model_type' in detector_options:
|
|
159
|
+
model_type = detector_options['model_type']
|
|
160
|
+
print('Model type {} provided in detector options'.format(model_type))
|
|
161
|
+
else:
|
|
162
|
+
model_type = _get_model_type_for_model(model_file,
|
|
163
|
+
prefer_model_type_source=prefer_model_type_source,
|
|
164
|
+
default_model_type=default_model_type)
|
|
165
|
+
|
|
166
|
+
if yolo_model_type_imported is not None:
|
|
167
|
+
if model_type == yolo_model_type_imported:
|
|
168
|
+
print('Bypassing imports for model type {}'.format(model_type))
|
|
169
|
+
return
|
|
170
|
+
else:
|
|
171
|
+
print('Previously set up imports for model type {}, re-importing as {}'.format(
|
|
172
|
+
yolo_model_type_imported,model_type))
|
|
173
|
+
|
|
174
|
+
_initialize_yolo_imports(model_type,verbose=verbose)
|
|
175
|
+
|
|
176
|
+
return model_type
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _clean_yolo_imports(verbose=False, aggressive_cleanup=False):
|
|
180
|
+
"""
|
|
181
|
+
Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
|
|
182
|
+
of another YOLO library version. The reason we jump through all these hoops, rather than
|
|
183
|
+
just, e.g., handling different libraries in different modules, is that we need to make sure
|
|
184
|
+
*pickle* sees the right version of modules during module loading, including modules we don't
|
|
185
|
+
load directly (i.e., every module loaded within a YOLO library), and the only way I know to
|
|
186
|
+
do that is to remove all the "wrong" versions from sys.modules and sys.path.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
verbose (bool, optional): enable additional debug output
|
|
190
|
+
aggressive_cleanup (bool, optional): err on the side of removing modules,
|
|
191
|
+
at least by ignoring whether they are/aren't in a site-packages folder.
|
|
192
|
+
By default, only modules in a folder that includes "site-packages" will
|
|
193
|
+
be considered for unloading.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
modules_to_delete = []
|
|
197
|
+
|
|
198
|
+
for module_name in sys.modules.keys():
|
|
199
|
+
|
|
200
|
+
module = sys.modules[module_name]
|
|
201
|
+
if not hasattr(module,'__file__') or (module.__file__ is None):
|
|
202
|
+
continue
|
|
203
|
+
try:
|
|
204
|
+
module_file = module.__file__.replace('\\','/')
|
|
205
|
+
if not aggressive_cleanup:
|
|
206
|
+
if 'site-packages' not in module_file:
|
|
207
|
+
continue
|
|
208
|
+
tokens = module_file.split('/')
|
|
209
|
+
|
|
210
|
+
# For local path imports, a module filename that should be unloaded might
|
|
211
|
+
# look like:
|
|
212
|
+
#
|
|
213
|
+
# c:/git/yolov9/models/common.py
|
|
214
|
+
#
|
|
215
|
+
# For pip imports, a module filename that should be unloaded might look like:
|
|
216
|
+
#
|
|
217
|
+
# c:/users/user/miniforge3/envs/megadetector/lib/site-packages/yolov9/utils/__init__.py
|
|
218
|
+
first_token_to_check = len(tokens) - 4
|
|
219
|
+
for i_token,token in enumerate(tokens):
|
|
220
|
+
if i_token < first_token_to_check:
|
|
221
|
+
continue
|
|
222
|
+
# Don't remove anything based on the environment name, which
|
|
223
|
+
# always follows "envs" in the path
|
|
224
|
+
if (i_token > 1) and (tokens[i_token-1] == 'envs'):
|
|
225
|
+
continue
|
|
226
|
+
if ('yolov5' in token) or ('yolov9' in token) or ('ultralytics' in token):
|
|
227
|
+
if verbose:
|
|
228
|
+
print('Module {} ({}) looks deletable'.format(module_name,module_file))
|
|
229
|
+
modules_to_delete.append(module_name)
|
|
230
|
+
break
|
|
231
|
+
except Exception as e:
|
|
232
|
+
if verbose:
|
|
233
|
+
print('Exception during module review: {}'.format(str(e)))
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
# ...for each module in the global namespace
|
|
237
|
+
|
|
238
|
+
for module_name in modules_to_delete:
|
|
239
|
+
|
|
240
|
+
if module_name in sys.modules.keys():
|
|
241
|
+
if verbose:
|
|
242
|
+
try:
|
|
243
|
+
module = sys.modules[module_name]
|
|
244
|
+
module_file = module.__file__.replace('\\','/')
|
|
245
|
+
print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
|
|
246
|
+
except Exception:
|
|
247
|
+
pass
|
|
248
|
+
del sys.modules[module_name]
|
|
249
|
+
|
|
250
|
+
# ...for each module we want to remove from the global namespace
|
|
251
|
+
|
|
252
|
+
paths_to_delete = []
|
|
253
|
+
|
|
254
|
+
for p in sys.path:
|
|
255
|
+
if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
|
|
256
|
+
print('clean_yolo_imports: removing {} from path'.format(p))
|
|
257
|
+
paths_to_delete.append(p)
|
|
258
|
+
|
|
259
|
+
for p in paths_to_delete:
|
|
260
|
+
sys.path.remove(p)
|
|
261
|
+
|
|
262
|
+
# ...def _clean_yolo_imports(...)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _initialize_yolo_imports(model_type='yolov5',
|
|
266
|
+
allow_fallback_import=True,
|
|
267
|
+
force_reimport=False,
|
|
268
|
+
verbose=False):
|
|
269
|
+
"""
|
|
270
|
+
Imports required functions from one or more yolo libraries (yolov5, yolov9,
|
|
271
|
+
ultralytics, targeting support for [model_type]).
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
model_type (str): The model type for which we're loading support
|
|
275
|
+
allow_fallback_import (bool, optional): If we can't import from the package for
|
|
276
|
+
which we're trying to load support, fall back to "import utils". This is
|
|
277
|
+
typically used when the right support library is on the current PYTHONPATH.
|
|
278
|
+
force_reimport (bool, optional): import the appropriate libraries even if the
|
|
279
|
+
requested model type matches the current initialization state
|
|
280
|
+
verbose (bool, optional): include additional debug output
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
str: the model type for which we initialized support
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
# When running in pytest, the megadetector 'utils' module is put in the global
|
|
287
|
+
# namespace, which creates conflicts with yolov5; remove it from the global
|
|
288
|
+
# namespsace.
|
|
289
|
+
if ('PYTEST_CURRENT_TEST' in os.environ):
|
|
290
|
+
print('*** pytest detected ***')
|
|
291
|
+
if ('utils' in sys.modules):
|
|
292
|
+
utils_module = sys.modules['utils']
|
|
293
|
+
if hasattr(utils_module, '__file__') and 'megadetector' in str(utils_module.__file__):
|
|
294
|
+
print(f"Removing conflicting utils module: {utils_module.__file__}")
|
|
295
|
+
sys.modules.pop('utils', None)
|
|
296
|
+
# Also remove any submodules
|
|
297
|
+
to_remove = [name for name in sys.modules if name.startswith('utils.')]
|
|
298
|
+
for name in to_remove:
|
|
299
|
+
sys.modules.pop(name, None)
|
|
300
|
+
|
|
301
|
+
global yolo_model_type_imported
|
|
302
|
+
|
|
303
|
+
if model_type is None:
|
|
304
|
+
model_type = 'yolov5'
|
|
305
|
+
|
|
306
|
+
# The point of this function is to make the appropriate version
|
|
307
|
+
# of the following functions available at module scope
|
|
308
|
+
global non_max_suppression
|
|
309
|
+
global xyxy2xywh
|
|
310
|
+
global letterbox
|
|
311
|
+
global scale_coords
|
|
312
|
+
|
|
313
|
+
if yolo_model_type_imported is not None:
|
|
314
|
+
if (yolo_model_type_imported == model_type) and (not force_reimport):
|
|
315
|
+
print('Bypassing imports for YOLO model type {}'.format(model_type))
|
|
316
|
+
return
|
|
317
|
+
else:
|
|
318
|
+
_clean_yolo_imports()
|
|
319
|
+
|
|
320
|
+
try_yolov5_import = (model_type == 'yolov5')
|
|
321
|
+
try_yolov9_import = (model_type == 'yolov9')
|
|
322
|
+
try_ultralytics_import = (model_type == 'ultralytics')
|
|
323
|
+
|
|
324
|
+
utils_imported = False
|
|
325
|
+
|
|
326
|
+
# First try importing from the yolov5 package; this is how the pip
|
|
327
|
+
# package finds YOLOv5 utilities.
|
|
328
|
+
if try_yolov5_import and not utils_imported:
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
# from yolov5.utils.general import non_max_suppression # type: ignore
|
|
332
|
+
from yolov5.utils.general import xyxy2xywh # noqa
|
|
333
|
+
from yolov5.utils.augmentations import letterbox # noqa
|
|
334
|
+
try:
|
|
335
|
+
from yolov5.utils.general import scale_boxes as scale_coords
|
|
336
|
+
except Exception:
|
|
337
|
+
from yolov5.utils.general import scale_coords
|
|
338
|
+
utils_imported = True
|
|
339
|
+
if verbose:
|
|
340
|
+
print('Imported utils from YOLOv5 package')
|
|
341
|
+
|
|
342
|
+
except Exception as e: # noqa
|
|
343
|
+
# print('yolov5 module import failed: {}'.format(e))
|
|
344
|
+
# print(traceback.format_exc())
|
|
345
|
+
pass
|
|
346
|
+
|
|
347
|
+
# Next try importing from the yolov9 package
|
|
348
|
+
if try_yolov9_import and not utils_imported:
|
|
349
|
+
|
|
350
|
+
try:
|
|
351
|
+
|
|
352
|
+
# from yolov9.utils.general import non_max_suppression # noqa
|
|
353
|
+
from yolov9.utils.general import xyxy2xywh # noqa
|
|
354
|
+
from yolov9.utils.augmentations import letterbox # noqa
|
|
355
|
+
from yolov9.utils.general import scale_boxes as scale_coords # noqa
|
|
356
|
+
utils_imported = True
|
|
357
|
+
if verbose:
|
|
358
|
+
print('Imported utils from YOLOv9 package')
|
|
359
|
+
|
|
360
|
+
except Exception as e: # noqa
|
|
361
|
+
|
|
362
|
+
# print('yolov9 module import failed: {}'.format(e))
|
|
363
|
+
# print(traceback.format_exc())
|
|
364
|
+
pass
|
|
365
|
+
|
|
366
|
+
# If we haven't succeeded yet, import from the ultralytics package
|
|
367
|
+
if try_ultralytics_import and not utils_imported:
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
|
|
371
|
+
import ultralytics # type: ignore # noqa
|
|
372
|
+
|
|
373
|
+
except Exception:
|
|
374
|
+
|
|
375
|
+
print('It looks like you are trying to run a model that requires the ultralytics package, '
|
|
376
|
+
'but the ultralytics package is not installed. For licensing reasons, this '
|
|
377
|
+
'is not installed by default with the MegaDetector Python package. Run '
|
|
378
|
+
'"pip install ultralytics" to install it, and try again.')
|
|
379
|
+
raise
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
|
|
383
|
+
# The non_max_suppression() function moved from the ops module to the nms module
|
|
384
|
+
# in mid-2025
|
|
385
|
+
try:
|
|
386
|
+
from ultralytics.utils.ops import non_max_suppression # type: ignore # noqa
|
|
387
|
+
except Exception:
|
|
388
|
+
from ultralytics.utils.nms import non_max_suppression # type: ignore # noqa
|
|
389
|
+
from ultralytics.utils.ops import xyxy2xywh # type: ignore # noqa
|
|
390
|
+
|
|
391
|
+
# In the ultralytics package, scale_boxes and scale_coords both exist;
|
|
392
|
+
# we want scale_boxes.
|
|
393
|
+
#
|
|
394
|
+
# from ultralytics.utils.ops import scale_coords # noqa
|
|
395
|
+
from ultralytics.utils.ops import scale_boxes as scale_coords # type: ignore # noqa
|
|
396
|
+
from ultralytics.data.augment import LetterBox # type: ignore # noqa
|
|
397
|
+
|
|
398
|
+
# letterbox() became a LetterBox class in the ultralytics package. Create a
|
|
399
|
+
# backwards-compatible letterbox function wrapper that wraps the class up.
|
|
400
|
+
def letterbox(img,new_shape,auto=False,scaleFill=False, #noqa
|
|
401
|
+
scaleup=True,center=True,stride=32):
|
|
402
|
+
|
|
403
|
+
# Ultralytics changed the "scaleFill" parameter to "scale_fill", we want to support
|
|
404
|
+
# both conventions.
|
|
405
|
+
use_old_scalefill_arg = False
|
|
406
|
+
try:
|
|
407
|
+
sig = inspect.signature(LetterBox.__init__)
|
|
408
|
+
if 'scaleFill' in sig.parameters:
|
|
409
|
+
use_old_scalefill_arg = True
|
|
410
|
+
except Exception:
|
|
411
|
+
pass
|
|
412
|
+
|
|
413
|
+
if use_old_scalefill_arg:
|
|
414
|
+
if verbose:
|
|
415
|
+
print('Using old scaleFill calling convention')
|
|
416
|
+
letterbox_transformer = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,
|
|
417
|
+
scaleup=scaleup,center=center,stride=stride)
|
|
418
|
+
else:
|
|
419
|
+
letterbox_transformer = LetterBox(new_shape,auto=auto,scale_fill=scaleFill,
|
|
420
|
+
scaleup=scaleup,center=center,stride=stride)
|
|
421
|
+
|
|
422
|
+
letterbox_result = letterbox_transformer(image=img)
|
|
423
|
+
|
|
424
|
+
if isinstance(new_shape,int):
|
|
425
|
+
new_shape = [new_shape,new_shape]
|
|
426
|
+
|
|
427
|
+
# The letterboxing is done, we just need to reverse-engineer what it did
|
|
428
|
+
shape = img.shape[:2]
|
|
429
|
+
|
|
430
|
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
431
|
+
if not scaleup:
|
|
432
|
+
r = min(r, 1.0)
|
|
433
|
+
ratio = r, r
|
|
434
|
+
|
|
435
|
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
436
|
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
|
437
|
+
if auto:
|
|
438
|
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride)
|
|
439
|
+
elif scaleFill:
|
|
440
|
+
dw, dh = 0.0, 0.0
|
|
441
|
+
new_unpad = (new_shape[1], new_shape[0])
|
|
442
|
+
ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
|
|
443
|
+
|
|
444
|
+
dw /= 2
|
|
445
|
+
dh /= 2
|
|
446
|
+
pad = (dw,dh)
|
|
447
|
+
|
|
448
|
+
return [letterbox_result,ratio,pad]
|
|
449
|
+
|
|
450
|
+
utils_imported = True
|
|
451
|
+
if verbose:
|
|
452
|
+
print('Imported utils from ultralytics package')
|
|
453
|
+
|
|
454
|
+
except Exception as e:
|
|
455
|
+
|
|
456
|
+
print('Ultralytics module import failed: {}'.format(str(e)))
|
|
457
|
+
pass
|
|
458
|
+
|
|
459
|
+
# If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
|
|
460
|
+
if (not utils_imported) and allow_fallback_import:
|
|
461
|
+
|
|
462
|
+
try:
|
|
463
|
+
|
|
464
|
+
# import pre- and post-processing functions from the YOLOv5 repo
|
|
465
|
+
# from utils.general import non_max_suppression # type: ignore
|
|
466
|
+
from utils.general import xyxy2xywh # type: ignore
|
|
467
|
+
from utils.augmentations import letterbox # type: ignore
|
|
468
|
+
|
|
469
|
+
# scale_coords() is scale_boxes() in some YOLOv5 versions
|
|
470
|
+
try:
|
|
471
|
+
from utils.general import scale_coords # type: ignore
|
|
472
|
+
except ImportError:
|
|
473
|
+
from utils.general import scale_boxes as scale_coords # type: ignore
|
|
474
|
+
utils_imported = True
|
|
475
|
+
imported_file = sys.modules[scale_coords.__module__].__file__
|
|
476
|
+
if verbose:
|
|
477
|
+
print('Imported utils from {}'.format(imported_file))
|
|
478
|
+
|
|
479
|
+
except ModuleNotFoundError as e:
|
|
480
|
+
|
|
481
|
+
raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
|
|
482
|
+
|
|
483
|
+
assert utils_imported, 'YOLO utils import error'
|
|
484
|
+
|
|
485
|
+
yolo_model_type_imported = model_type
|
|
486
|
+
if verbose:
|
|
487
|
+
print('Prepared YOLO imports for model type {}'.format(model_type))
|
|
488
|
+
|
|
489
|
+
return model_type
|
|
490
|
+
|
|
491
|
+
# ...def _initialize_yolo_imports(...)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
#%% NMS
|
|
495
|
+
|
|
496
|
+
def nms(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
|
497
|
+
"""
|
|
498
|
+
Non-maximum suppression (a wrapper around torchvision.ops.nms())
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
prediction (torch.Tensor): Model predictions with shape [batch_size, num_anchors, num_classes + 5]
|
|
502
|
+
Format: [x_center, y_center, width, height, objectness, class1_conf, class2_conf, ...]
|
|
503
|
+
Coordinates are normalized to input image size.
|
|
504
|
+
conf_thres (float): Confidence threshold for filtering detections
|
|
505
|
+
iou_thres (float): IoU threshold for NMS
|
|
506
|
+
max_det (int): Maximum number of detections per image
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
list: List of tensors, one per image in batch. Each tensor has shape [N, 6] where:
|
|
510
|
+
- N is the number of detections for that image
|
|
511
|
+
- Columns are [x1, y1, x2, y2, confidence, class_id]
|
|
512
|
+
- Coordinates are in absolute pixels relative to input image size
|
|
513
|
+
- class_id is the integer class index (0-based)
|
|
514
|
+
"""
|
|
515
|
+
|
|
516
|
+
batch_size = prediction.shape[0]
|
|
517
|
+
num_classes = prediction.shape[2] - 5 # noqa
|
|
518
|
+
output = []
|
|
519
|
+
|
|
520
|
+
# Process each image in the batch
|
|
521
|
+
for img_idx in range(batch_size):
|
|
522
|
+
|
|
523
|
+
x = prediction[img_idx] # Shape: [num_anchors, num_classes + 5]
|
|
524
|
+
|
|
525
|
+
# Filter by objectness confidence
|
|
526
|
+
obj_conf = x[:, 4]
|
|
527
|
+
valid_detections = obj_conf > conf_thres
|
|
528
|
+
x = x[valid_detections]
|
|
529
|
+
|
|
530
|
+
if x.shape[0] == 0:
|
|
531
|
+
# No detections for this image
|
|
532
|
+
output.append(torch.zeros((0, 6), device=prediction.device))
|
|
533
|
+
continue
|
|
534
|
+
|
|
535
|
+
# Convert box coordinates from [x_center, y_center, w, h] to [x1, y1, x2, y2]
|
|
536
|
+
box = x[:, :4].clone()
|
|
537
|
+
box[:, 0] = x[:, 0] - x[:, 2] / 2.0 # x1 = center_x - width/2
|
|
538
|
+
box[:, 1] = x[:, 1] - x[:, 3] / 2.0 # y1 = center_y - height/2
|
|
539
|
+
box[:, 2] = x[:, 0] + x[:, 2] / 2.0 # x2 = center_x + width/2
|
|
540
|
+
box[:, 3] = x[:, 1] + x[:, 3] / 2.0 # y2 = center_y + height/2
|
|
541
|
+
|
|
542
|
+
# Get class predictions: multiply objectness by class probabilities
|
|
543
|
+
class_conf = x[:, 5:] * x[:, 4:5] # shape: [N, num_classes]
|
|
544
|
+
|
|
545
|
+
# For each detection, take the class with highest confidence (single-label)
|
|
546
|
+
best_class_conf, best_class_idx = class_conf.max(1, keepdim=True)
|
|
547
|
+
|
|
548
|
+
# Filter by class confidence threshold
|
|
549
|
+
conf_mask = best_class_conf.view(-1) > conf_thres
|
|
550
|
+
if conf_mask.sum() == 0:
|
|
551
|
+
# No detections pass confidence threshold
|
|
552
|
+
output.append(torch.zeros((0, 6), device=prediction.device))
|
|
553
|
+
continue
|
|
554
|
+
|
|
555
|
+
box = box[conf_mask]
|
|
556
|
+
best_class_conf = best_class_conf[conf_mask]
|
|
557
|
+
best_class_idx = best_class_idx[conf_mask]
|
|
558
|
+
|
|
559
|
+
# Prepare for NMS: group detections by class
|
|
560
|
+
unique_classes = best_class_idx.unique()
|
|
561
|
+
final_detections = []
|
|
562
|
+
|
|
563
|
+
for class_id in unique_classes:
|
|
564
|
+
|
|
565
|
+
class_mask = (best_class_idx == class_id).view(-1)
|
|
566
|
+
class_boxes = box[class_mask]
|
|
567
|
+
class_scores = best_class_conf[class_mask].view(-1)
|
|
568
|
+
|
|
569
|
+
if class_boxes.shape[0] == 0:
|
|
570
|
+
continue
|
|
571
|
+
|
|
572
|
+
# Apply NMS for this class
|
|
573
|
+
keep_indices = torchvision.ops.nms(class_boxes, class_scores, iou_thres)
|
|
574
|
+
|
|
575
|
+
if len(keep_indices) > 0:
|
|
576
|
+
kept_boxes = class_boxes[keep_indices]
|
|
577
|
+
kept_scores = class_scores[keep_indices]
|
|
578
|
+
kept_classes = torch.full((len(keep_indices), 1), class_id.item(),
|
|
579
|
+
device=prediction.device, dtype=torch.float)
|
|
580
|
+
|
|
581
|
+
# Combine: [x1, y1, x2, y2, conf, class]
|
|
582
|
+
class_detections = torch.cat([kept_boxes, kept_scores.unsqueeze(1), kept_classes], 1)
|
|
583
|
+
final_detections.append(class_detections)
|
|
584
|
+
|
|
585
|
+
# ...for each category
|
|
586
|
+
|
|
587
|
+
if final_detections:
|
|
588
|
+
|
|
589
|
+
# Combine all classes and sort by confidence
|
|
590
|
+
all_detections = torch.cat(final_detections, 0)
|
|
591
|
+
conf_sort_indices = all_detections[:, 4].argsort(descending=True)
|
|
592
|
+
all_detections = all_detections[conf_sort_indices]
|
|
593
|
+
|
|
594
|
+
# Limit to max_det
|
|
595
|
+
if all_detections.shape[0] > max_det:
|
|
596
|
+
all_detections = all_detections[:max_det]
|
|
597
|
+
|
|
598
|
+
output.append(all_detections)
|
|
599
|
+
else:
|
|
600
|
+
output.append(torch.zeros((0, 6), device=prediction.device))
|
|
601
|
+
|
|
602
|
+
# ...for each image in the batch
|
|
603
|
+
|
|
604
|
+
return output
|
|
605
|
+
|
|
606
|
+
# ...def nms(...)
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
#%% Model metadata functions
|
|
610
|
+
|
|
611
|
+
def add_metadata_to_megadetector_model_file(model_file_in,
|
|
612
|
+
model_file_out,
|
|
613
|
+
metadata,
|
|
614
|
+
destination_path='megadetector_info.json'):
|
|
615
|
+
"""
|
|
616
|
+
Adds a .json file to the specified MegaDetector model file containing metadata used
|
|
617
|
+
by this module. Always over-writes the output file.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
|
|
621
|
+
model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
|
|
622
|
+
May be the same as model_file_in.
|
|
623
|
+
metadata (dict): The metadata dict to add to the output model file
|
|
624
|
+
destination_path (str, optional): The relative path within the main folder of the
|
|
625
|
+
model archive where we should write the metadata. This is not relative to the root
|
|
626
|
+
of the archive, it's relative to the one and only folder at the root of the archive
|
|
627
|
+
(this is a PyTorch convention).
|
|
628
|
+
"""
|
|
629
|
+
|
|
630
|
+
tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
|
|
631
|
+
os.makedirs(tmp_base,exist_ok=True)
|
|
632
|
+
metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
|
|
633
|
+
metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
|
|
634
|
+
|
|
635
|
+
with open(metadata_tmp_file_abs,'w') as f:
|
|
636
|
+
json.dump(metadata,f,indent=1)
|
|
637
|
+
|
|
638
|
+
# Copy the input file to the output file
|
|
639
|
+
shutil.copyfile(model_file_in,model_file_out)
|
|
640
|
+
|
|
641
|
+
# Write metadata to the output file
|
|
642
|
+
with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
|
|
643
|
+
|
|
644
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
645
|
+
# it in the one and only folder.
|
|
646
|
+
names = zipf.namelist()
|
|
647
|
+
root_folders = set()
|
|
648
|
+
for name in names:
|
|
649
|
+
root_folder = name.split('/')[0]
|
|
650
|
+
root_folders.add(root_folder)
|
|
651
|
+
assert len(root_folders) == 1,\
|
|
652
|
+
'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
|
|
653
|
+
root_folder = next(iter(root_folders))
|
|
654
|
+
|
|
655
|
+
zipf.write(metadata_tmp_file_abs,
|
|
656
|
+
root_folder + '/' + destination_path,
|
|
657
|
+
compresslevel=9,
|
|
658
|
+
compress_type=zipfile.ZIP_DEFLATED)
|
|
659
|
+
|
|
660
|
+
try:
|
|
661
|
+
os.remove(metadata_tmp_file_abs)
|
|
662
|
+
except Exception as e:
|
|
663
|
+
print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
|
|
664
|
+
|
|
665
|
+
# ...def add_metadata_to_megadetector_model_file(...)
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def read_metadata_from_megadetector_model_file(model_file,
|
|
669
|
+
relative_path='megadetector_info.json',
|
|
670
|
+
verbose=False):
|
|
671
|
+
"""
|
|
672
|
+
Reads custom MegaDetector metadata from a modified MegaDetector model file.
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
model_file (str): The model filename to read, typically .pt (.zip is also sensible)
|
|
676
|
+
relative_path (str, optional): The relative path within the main folder of the model
|
|
677
|
+
archive from which we should read the metadata. This is not relative to the root
|
|
678
|
+
of the archive, it's relative to the one and only folder at the root of the archive
|
|
679
|
+
(this is a PyTorch convention).
|
|
680
|
+
verbose (str, optional): enable additional debug output
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
object: whatever we read from the metadata file, always a dict in practice. Returns
|
|
684
|
+
None if we failed to read the specified metadata file.
|
|
685
|
+
"""
|
|
686
|
+
|
|
687
|
+
with zipfile.ZipFile(model_file,'r') as zipf:
|
|
688
|
+
|
|
689
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
690
|
+
# it in the one and only folder.
|
|
691
|
+
names = zipf.namelist()
|
|
692
|
+
root_folders = set()
|
|
693
|
+
for name in names:
|
|
694
|
+
root_folder = name.split('/')[0]
|
|
695
|
+
root_folders.add(root_folder)
|
|
696
|
+
if len(root_folders) != 1:
|
|
697
|
+
print('Warning: this archive does not have exactly one folder at the top level; ' + \
|
|
698
|
+
'are you sure it\'s a Torch model file?')
|
|
699
|
+
return None
|
|
700
|
+
root_folder = next(iter(root_folders))
|
|
701
|
+
|
|
702
|
+
metadata_file = root_folder + '/' + relative_path
|
|
703
|
+
if metadata_file not in names:
|
|
704
|
+
# This is the case for MDv5a and MDv5b
|
|
705
|
+
if verbose:
|
|
706
|
+
print('Warning: could not find metadata file {} in zip archive {}'.format(
|
|
707
|
+
metadata_file,os.path.basename(model_file)))
|
|
708
|
+
return None
|
|
709
|
+
|
|
710
|
+
try:
|
|
711
|
+
path = zipfile.Path(zipf,metadata_file)
|
|
712
|
+
contents = path.read_text()
|
|
713
|
+
d = json.loads(contents)
|
|
714
|
+
except Exception as e:
|
|
715
|
+
print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
|
|
716
|
+
return None
|
|
717
|
+
|
|
718
|
+
return d
|
|
719
|
+
|
|
720
|
+
# ...with zipfile.Zipfile(...)
|
|
721
|
+
|
|
722
|
+
# ...def read_metadata_from_megadetector_model_file(...)
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
#%% Inference classes
|
|
726
|
+
|
|
727
|
+
default_compatibility_mode = 'classic'
|
|
728
|
+
|
|
729
|
+
# This is a useful hack when I want to verify that my test driver (md_tests.py) is
|
|
730
|
+
# correctly forcing a specific compatibility mode (I use "classic-test" in that case)
|
|
731
|
+
require_non_default_compatibility_mode = False
|
|
732
|
+
|
|
733
|
+
class PTDetector:
|
|
734
|
+
"""
|
|
735
|
+
Class that runs a PyTorch-based MegaDetector model. Also used as a preprocessor
|
|
736
|
+
for images that will later be run through an instance of PTDetector.
|
|
737
|
+
"""
|
|
738
|
+
|
|
739
|
+
def __init__(self, model_path, detector_options=None, verbose=False):
|
|
740
|
+
"""
|
|
741
|
+
PTDetector constructor. If detector_options['preprocess_only'] exists and is
|
|
742
|
+
True, this instance is being used as a preprocessor, so we don't load model weights.
|
|
743
|
+
"""
|
|
744
|
+
|
|
745
|
+
if verbose:
|
|
746
|
+
print('Initializing PTDetector (verbose)')
|
|
747
|
+
|
|
748
|
+
# Set up the import environment for this model, unloading previous
|
|
749
|
+
# YOLO library versions if necessary.
|
|
750
|
+
_initialize_yolo_imports_for_model(model_path,
|
|
751
|
+
detector_options=detector_options,
|
|
752
|
+
verbose=verbose)
|
|
753
|
+
|
|
754
|
+
# Parse options specific to this detector family
|
|
755
|
+
force_cpu = False
|
|
756
|
+
use_model_native_classes = False
|
|
757
|
+
compatibility_mode = default_compatibility_mode
|
|
758
|
+
|
|
759
|
+
if detector_options is not None:
|
|
760
|
+
|
|
761
|
+
if 'force_cpu' in detector_options:
|
|
762
|
+
force_cpu = parse_bool_string(detector_options['force_cpu'])
|
|
763
|
+
if 'use_model_native_classes' in detector_options:
|
|
764
|
+
use_model_native_classes = parse_bool_string(detector_options['use_model_native_classes'])
|
|
765
|
+
if 'compatibility_mode' in detector_options:
|
|
766
|
+
if detector_options['compatibility_mode'] is None:
|
|
767
|
+
compatibility_mode = default_compatibility_mode
|
|
768
|
+
else:
|
|
769
|
+
compatibility_mode = detector_options['compatibility_mode']
|
|
770
|
+
|
|
771
|
+
# This is a global option used only during testing, to make sure I'm hitting
|
|
772
|
+
# the cases where we are not using "classic" preprocessing.
|
|
773
|
+
if require_non_default_compatibility_mode:
|
|
774
|
+
|
|
775
|
+
print('### DEBUG: requiring non-default compatibility mode ###')
|
|
776
|
+
assert compatibility_mode != 'classic'
|
|
777
|
+
assert compatibility_mode != 'default'
|
|
778
|
+
|
|
779
|
+
preprocess_only = False
|
|
780
|
+
if (detector_options is not None) and \
|
|
781
|
+
('preprocess_only' in detector_options) and \
|
|
782
|
+
(detector_options['preprocess_only']):
|
|
783
|
+
preprocess_only = True
|
|
784
|
+
|
|
785
|
+
if verbose or (not preprocess_only):
|
|
786
|
+
print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
|
|
787
|
+
|
|
788
|
+
self.model_metadata = read_metadata_from_megadetector_model_file(model_path)
|
|
789
|
+
|
|
790
|
+
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side,
|
|
791
|
+
#: preserving aspect ratio".
|
|
792
|
+
if self.model_metadata is not None and 'image_size' in self.model_metadata:
|
|
793
|
+
self.default_image_size = self.model_metadata['image_size']
|
|
794
|
+
print('Loaded image size {} from model metadata'.format(self.default_image_size))
|
|
795
|
+
else:
|
|
796
|
+
# This is not the default for most YOLO models, but most of the time, if someone
|
|
797
|
+
# is loading a model here that does not have metadata, it's MDv5[ab].0.0
|
|
798
|
+
print('No image size available in model metadata, defaulting to 1280')
|
|
799
|
+
self.default_image_size = 1280
|
|
800
|
+
|
|
801
|
+
#: Either a string ('cpu','cuda:0') or a torch.device()
|
|
802
|
+
self.device = 'cpu'
|
|
803
|
+
|
|
804
|
+
#: Have we already printed a warning about using a non-standard image size?
|
|
805
|
+
#:
|
|
806
|
+
#: :meta private:
|
|
807
|
+
self.printed_image_size_warning = False
|
|
808
|
+
|
|
809
|
+
#: If this is False, we assume the underlying model is producing class indices in the
|
|
810
|
+
#: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
|
|
811
|
+
#: MD classes (1,2,3) before generating output. If this is True, we use whatever
|
|
812
|
+
#: indices the model provides
|
|
813
|
+
self.use_model_native_classes = use_model_native_classes
|
|
814
|
+
|
|
815
|
+
#: This allows us to maintain backwards compatibility across a set of changes to the
|
|
816
|
+
#: way this class does inference. Currently should start with either "default" or
|
|
817
|
+
#: "classic".
|
|
818
|
+
self.compatibility_mode = compatibility_mode
|
|
819
|
+
|
|
820
|
+
#: Stride size passed to the YOLO letterbox() function
|
|
821
|
+
self.letterbox_stride = 32
|
|
822
|
+
|
|
823
|
+
# This is a convenient heuristic to determine the stride size without actually loading
|
|
824
|
+
# the model: the only models in the YOLO family with a stride size of 64 are the
|
|
825
|
+
# YOLOv5*6 and YOLOv5*6u models, which are 1280px models.
|
|
826
|
+
#
|
|
827
|
+
# See:
|
|
828
|
+
#
|
|
829
|
+
# github.com/ultralytics/ultralytics/issues/21544
|
|
830
|
+
#
|
|
831
|
+
# Note to self, though, if I decide later to require loading the model on preprocessing
|
|
832
|
+
# workers so I can more reliably choose a stride, this is the right way to determine the
|
|
833
|
+
# stride:
|
|
834
|
+
#
|
|
835
|
+
# self.letterbox_stride = int(self.model.stride.max())
|
|
836
|
+
if self.default_image_size == 1280:
|
|
837
|
+
self.letterbox_stride = 64
|
|
838
|
+
|
|
839
|
+
print('Using model stride: {}'.format(self.letterbox_stride))
|
|
840
|
+
|
|
841
|
+
#: Use half-precision inference... fixed by the model, generally don't mess with this
|
|
842
|
+
self.half_precision = False
|
|
843
|
+
|
|
844
|
+
if preprocess_only:
|
|
845
|
+
return
|
|
846
|
+
|
|
847
|
+
if not force_cpu:
|
|
848
|
+
if torch.cuda.is_available():
|
|
849
|
+
self.device = torch.device('cuda:0')
|
|
850
|
+
try:
|
|
851
|
+
if torch.backends.mps.is_built and torch.backends.mps.is_available():
|
|
852
|
+
# MPS inference fails on GitHub runners as of 2025.08. This is
|
|
853
|
+
# independent of model size. So, we disable MPS when running in GHA.
|
|
854
|
+
if is_running_in_gha():
|
|
855
|
+
print('GitHub actions detected, bypassing MPS backend')
|
|
856
|
+
else:
|
|
857
|
+
print('Using MPS device')
|
|
858
|
+
self.device = 'mps'
|
|
859
|
+
except AttributeError:
|
|
860
|
+
pass
|
|
861
|
+
|
|
862
|
+
# AddaxAI depends on this printout, don't remove it
|
|
863
|
+
print('PTDetector using device {}'.format(str(self.device).lower()))
|
|
864
|
+
|
|
865
|
+
try:
|
|
866
|
+
self.model = PTDetector._load_model(model_path,
|
|
867
|
+
device=self.device,
|
|
868
|
+
compatibility_mode=self.compatibility_mode)
|
|
869
|
+
|
|
870
|
+
except Exception as e:
|
|
871
|
+
# In a very esoteric scenario where an old version of YOLOv5 is used to run
|
|
872
|
+
# newer models, we run into an issue because the "Model" class became
|
|
873
|
+
# "DetectionModel". New YOLOv5 code handles this case by just setting them
|
|
874
|
+
# to be the same, so doing that externally doesn't seem *that* rude.
|
|
875
|
+
if "Can't get attribute 'DetectionModel'" in str(e):
|
|
876
|
+
print('Forward-compatibility issue detected, patching')
|
|
877
|
+
from models import yolo # type: ignore
|
|
878
|
+
yolo.DetectionModel = yolo.Model
|
|
879
|
+
self.model = PTDetector._load_model(model_path,
|
|
880
|
+
device=self.device,
|
|
881
|
+
compatibility_mode=self.compatibility_mode,
|
|
882
|
+
verbose=verbose)
|
|
883
|
+
else:
|
|
884
|
+
raise
|
|
885
|
+
if (self.device != 'cpu'):
|
|
886
|
+
if verbose:
|
|
887
|
+
print('Sending model to GPU')
|
|
888
|
+
self.model.to(self.device)
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
@staticmethod
|
|
892
|
+
def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
|
|
893
|
+
|
|
894
|
+
if verbose:
|
|
895
|
+
print(f'Using PyTorch version {torch.__version__}')
|
|
896
|
+
|
|
897
|
+
# I get quirky errors when loading YOLOv5 models on MPS hardware using
|
|
898
|
+
# map_location, but this is the recommended method, so I'm using it everywhere
|
|
899
|
+
# other than MPS devices.
|
|
900
|
+
use_map_location = (device != 'mps')
|
|
901
|
+
|
|
902
|
+
if use_map_location:
|
|
903
|
+
try:
|
|
904
|
+
checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
|
|
905
|
+
# For a transitional period, we want to support torch 1.1x, where the weights_only
|
|
906
|
+
# parameter doesn't exist
|
|
907
|
+
except Exception as e:
|
|
908
|
+
if "'weights_only' is an invalid keyword" in str(e):
|
|
909
|
+
checkpoint = torch.load(model_pt_path, map_location=device)
|
|
910
|
+
else:
|
|
911
|
+
raise
|
|
912
|
+
else:
|
|
913
|
+
try:
|
|
914
|
+
checkpoint = torch.load(model_pt_path, weights_only=False)
|
|
915
|
+
# For a transitional period, we want to support torch 1.1x, where the weights_only
|
|
916
|
+
# parameter doesn't exist
|
|
917
|
+
except Exception as e:
|
|
918
|
+
if "'weights_only' is an invalid keyword" in str(e):
|
|
919
|
+
checkpoint = torch.load(model_pt_path)
|
|
920
|
+
else:
|
|
921
|
+
raise
|
|
922
|
+
|
|
923
|
+
# Compatibility fix that allows us to load older YOLOv5 models with
|
|
924
|
+
# newer versions of YOLOv5/PT
|
|
925
|
+
for m in checkpoint['model'].modules():
|
|
926
|
+
t = type(m)
|
|
927
|
+
if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
|
928
|
+
m.recompute_scale_factor = None
|
|
929
|
+
|
|
930
|
+
# Calling .to(device) should no longer be necessary now that we're using map_location=device
|
|
931
|
+
# model = checkpoint['model'].float().fuse().eval().to(device)
|
|
932
|
+
model = checkpoint['model'].float().fuse().eval()
|
|
933
|
+
|
|
934
|
+
return model
|
|
935
|
+
|
|
936
|
+
# ...def _load_model(...)
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
def preprocess_image(self,
|
|
940
|
+
img_original,
|
|
941
|
+
image_id='unknown',
|
|
942
|
+
image_size=None,
|
|
943
|
+
verbose=False):
|
|
944
|
+
"""
|
|
945
|
+
Prepare an image for detection, including scaling and letterboxing.
|
|
946
|
+
|
|
947
|
+
Args:
|
|
948
|
+
img_original (Image or np.array): the image on which we should run the detector, with
|
|
949
|
+
EXIF rotation already handled
|
|
950
|
+
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
951
|
+
of the output object
|
|
952
|
+
detection_threshold (float, optional): only detections above this confidence threshold
|
|
953
|
+
will be included in the return value
|
|
954
|
+
image_size (int, optional): image size (long side) to use for inference, or None to
|
|
955
|
+
use the default size specified at the time the model was loaded
|
|
956
|
+
verbose (bool, optional): enable additional debug output
|
|
957
|
+
|
|
958
|
+
Returns:
|
|
959
|
+
dict: dict with fields:
|
|
960
|
+
- file (filename)
|
|
961
|
+
- img (the preprocessed np.array)
|
|
962
|
+
- img_original (the input image before preprocessing, as an np.array)
|
|
963
|
+
- img_original_pil (the input image before preprocessing, as a PIL Image)
|
|
964
|
+
- target_shape (the 2D shape to which the image was resized during preprocessing)
|
|
965
|
+
- scaling_shape (the 2D original size, for normalizing coordinates later)
|
|
966
|
+
- letterbox_ratio (letterbox parameter used for normalizing coordinates later)
|
|
967
|
+
- letterbox_pad (letterbox parameter used for normalizing coordinates later)
|
|
968
|
+
"""
|
|
969
|
+
|
|
970
|
+
# Prepare return dict
|
|
971
|
+
result = {'file': image_id }
|
|
972
|
+
|
|
973
|
+
# Store the PIL version of the original image, the caller may want to use
|
|
974
|
+
# it for metadata extraction later.
|
|
975
|
+
img_original_pil = None
|
|
976
|
+
|
|
977
|
+
# If we were given a PIL image, rather than a numpy array
|
|
978
|
+
if not isinstance(img_original,np.ndarray):
|
|
979
|
+
img_original_pil = img_original
|
|
980
|
+
img_original = np.asarray(img_original)
|
|
981
|
+
|
|
982
|
+
# PIL images are RGB already
|
|
983
|
+
# img_original = img_original[:, :, ::-1]
|
|
984
|
+
|
|
985
|
+
# Save the original shape for scaling boxes later
|
|
986
|
+
scaling_shape = img_original.shape
|
|
987
|
+
|
|
988
|
+
# If the caller is requesting a specific target size...
|
|
989
|
+
if image_size is not None:
|
|
990
|
+
|
|
991
|
+
assert isinstance(image_size,int)
|
|
992
|
+
|
|
993
|
+
if not self.printed_image_size_warning:
|
|
994
|
+
print('Using user-supplied image size {}'.format(image_size))
|
|
995
|
+
self.printed_image_size_warning = True
|
|
996
|
+
|
|
997
|
+
# Otherwise resize to self.default_image_size
|
|
998
|
+
else:
|
|
999
|
+
|
|
1000
|
+
image_size = self.default_image_size
|
|
1001
|
+
self.printed_image_size_warning = False
|
|
1002
|
+
|
|
1003
|
+
# ...if the caller has specified an image size
|
|
1004
|
+
|
|
1005
|
+
# In "classic mode", we only do the letterboxing resize, we don't do an
|
|
1006
|
+
# additional initial resizing operation
|
|
1007
|
+
if 'classic' in self.compatibility_mode:
|
|
1008
|
+
|
|
1009
|
+
resize_ratio = 1.0
|
|
1010
|
+
|
|
1011
|
+
# Resize the image so the long side matches the target image size. This is not
|
|
1012
|
+
# letterboxing (i.e., padding) yet, just resizing.
|
|
1013
|
+
else:
|
|
1014
|
+
|
|
1015
|
+
use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
|
|
1016
|
+
|
|
1017
|
+
h,w = img_original.shape[:2]
|
|
1018
|
+
resize_ratio = image_size / max(h,w)
|
|
1019
|
+
|
|
1020
|
+
# Only resize if we have to
|
|
1021
|
+
if resize_ratio != 1:
|
|
1022
|
+
|
|
1023
|
+
# Match what yolov5 does: use linear interpolation for upsizing;
|
|
1024
|
+
# area interpolation for downsizing
|
|
1025
|
+
if resize_ratio > 1:
|
|
1026
|
+
interpolation_method = cv2.INTER_LINEAR
|
|
1027
|
+
else:
|
|
1028
|
+
interpolation_method = cv2.INTER_AREA
|
|
1029
|
+
|
|
1030
|
+
if use_ceil_for_resize:
|
|
1031
|
+
target_w = math.ceil(w * resize_ratio)
|
|
1032
|
+
target_h = math.ceil(h * resize_ratio)
|
|
1033
|
+
else:
|
|
1034
|
+
target_w = int(w * resize_ratio)
|
|
1035
|
+
target_h = int(h * resize_ratio)
|
|
1036
|
+
|
|
1037
|
+
img_original = cv2.resize(
|
|
1038
|
+
img_original, (target_w, target_h),
|
|
1039
|
+
interpolation=interpolation_method)
|
|
1040
|
+
|
|
1041
|
+
if 'classic' in self.compatibility_mode:
|
|
1042
|
+
|
|
1043
|
+
letterbox_auto = True
|
|
1044
|
+
letterbox_scaleup = True
|
|
1045
|
+
target_shape = image_size
|
|
1046
|
+
|
|
1047
|
+
else:
|
|
1048
|
+
|
|
1049
|
+
letterbox_auto = False
|
|
1050
|
+
letterbox_scaleup = False
|
|
1051
|
+
|
|
1052
|
+
# The padding to apply as a fraction of the stride size
|
|
1053
|
+
pad = 0.5
|
|
1054
|
+
|
|
1055
|
+
# Resize to a multiple of the model stride
|
|
1056
|
+
#
|
|
1057
|
+
# This is how we would determine the stride if we knew the model had been loaded:
|
|
1058
|
+
#
|
|
1059
|
+
# model_stride = int(self.model.stride.max())
|
|
1060
|
+
#
|
|
1061
|
+
# ...but because we do this on preprocessing workers now, we try to avoid loading the model
|
|
1062
|
+
# just for preprocessing, and we assume the stride was determined at the time the PTDetector
|
|
1063
|
+
# object was created.
|
|
1064
|
+
try:
|
|
1065
|
+
model_stride = int(self.model.stride.max())
|
|
1066
|
+
if model_stride != self.letterbox_stride:
|
|
1067
|
+
print('*** Warning: model stride is {}, stride at construction time was {} ***'.format(
|
|
1068
|
+
model_stride,self.letterbox_stride
|
|
1069
|
+
))
|
|
1070
|
+
except Exception:
|
|
1071
|
+
pass
|
|
1072
|
+
|
|
1073
|
+
model_stride = self.letterbox_stride
|
|
1074
|
+
max_dimension = max(img_original.shape)
|
|
1075
|
+
normalized_shape = [img_original.shape[0] / max_dimension,
|
|
1076
|
+
img_original.shape[1] / max_dimension]
|
|
1077
|
+
target_shape = np.ceil(((np.array(normalized_shape) * image_size) / model_stride) + \
|
|
1078
|
+
pad).astype(int) * model_stride
|
|
1079
|
+
|
|
1080
|
+
# Now we letterbox, which is just padding, since we've already resized
|
|
1081
|
+
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
|
|
1082
|
+
new_shape=target_shape,
|
|
1083
|
+
stride=self.letterbox_stride,
|
|
1084
|
+
auto=letterbox_auto,
|
|
1085
|
+
scaleFill=False,
|
|
1086
|
+
scaleup=letterbox_scaleup)
|
|
1087
|
+
|
|
1088
|
+
result['img_processed'] = img
|
|
1089
|
+
result['img_original'] = img_original
|
|
1090
|
+
result['img_original_pil'] = img_original_pil
|
|
1091
|
+
result['target_shape'] = target_shape
|
|
1092
|
+
result['scaling_shape'] = scaling_shape
|
|
1093
|
+
result['letterbox_ratio'] = letterbox_ratio
|
|
1094
|
+
result['letterbox_pad'] = letterbox_pad
|
|
1095
|
+
return result
|
|
1096
|
+
|
|
1097
|
+
# ...def preprocess_image(...)
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
def generate_detections_one_batch(self,
|
|
1101
|
+
img_original,
|
|
1102
|
+
image_id=None,
|
|
1103
|
+
detection_threshold=0.00001,
|
|
1104
|
+
image_size=None,
|
|
1105
|
+
augment=False,
|
|
1106
|
+
verbose=False):
|
|
1107
|
+
"""
|
|
1108
|
+
Run a detector on a batch of images.
|
|
1109
|
+
|
|
1110
|
+
Args:
|
|
1111
|
+
img_original (list): list of images (Image, np.array, or dict) on which we should run the detector, with
|
|
1112
|
+
EXIF rotation already handled, or dicts representing preprocessed images with associated
|
|
1113
|
+
letterbox parameters
|
|
1114
|
+
image_id (list or None): list of paths to identify the images; will be in the "file" field
|
|
1115
|
+
of the output objects. Will be ignored when img_original contains preprocessed dicts.
|
|
1116
|
+
detection_threshold (float, optional): only detections above this confidence threshold
|
|
1117
|
+
will be included in the return value
|
|
1118
|
+
image_size (int, optional): image size (long side) to use for inference, or None to
|
|
1119
|
+
use the default size specified at the time the model was loaded
|
|
1120
|
+
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
1121
|
+
verbose (bool, optional): enable additional debug output
|
|
1122
|
+
|
|
1123
|
+
Returns:
|
|
1124
|
+
list: a list of dictionaries, each with the following fields:
|
|
1125
|
+
- 'file' (filename, always present)
|
|
1126
|
+
- 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
|
|
1127
|
+
- 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
|
|
1128
|
+
- 'failure' (a failure string, or None if everything went fine)
|
|
1129
|
+
"""
|
|
1130
|
+
|
|
1131
|
+
# Validate inputs
|
|
1132
|
+
if not isinstance(img_original, list):
|
|
1133
|
+
raise ValueError('img_original must be a list for batch processing')
|
|
1134
|
+
|
|
1135
|
+
if len(img_original) == 0:
|
|
1136
|
+
return []
|
|
1137
|
+
|
|
1138
|
+
# Check input consistency
|
|
1139
|
+
if isinstance(img_original[0], dict):
|
|
1140
|
+
# All items in img_original should be preprocessed dicts
|
|
1141
|
+
for i, img in enumerate(img_original):
|
|
1142
|
+
if not isinstance(img, dict):
|
|
1143
|
+
raise ValueError(f'Mixed input types in batch: item {i} is not a dict, but item 0 is a dict')
|
|
1144
|
+
else:
|
|
1145
|
+
# All items in img_original should be PIL/numpy images, and image_id should be a list of strings
|
|
1146
|
+
if image_id is None:
|
|
1147
|
+
raise ValueError('image_id must be a list when img_original contains PIL/numpy images')
|
|
1148
|
+
if not isinstance(image_id, list):
|
|
1149
|
+
raise ValueError('image_id must be a list for batch processing')
|
|
1150
|
+
if len(image_id) != len(img_original):
|
|
1151
|
+
raise ValueError(
|
|
1152
|
+
'Length mismatch: img_original has {} items, image_id has {} items'.format(
|
|
1153
|
+
len(img_original),len(image_id)))
|
|
1154
|
+
for i_img, img in enumerate(img_original):
|
|
1155
|
+
if isinstance(img, dict):
|
|
1156
|
+
raise ValueError(
|
|
1157
|
+
'Mixed input types in batch: item {} is a dict, but item 0 is not a dict'.format(
|
|
1158
|
+
i_img))
|
|
1159
|
+
|
|
1160
|
+
if detection_threshold is None:
|
|
1161
|
+
detection_threshold = 0.0
|
|
1162
|
+
|
|
1163
|
+
batch_size = len(img_original)
|
|
1164
|
+
results = [None] * batch_size
|
|
1165
|
+
|
|
1166
|
+
# Preprocess all images, handling failures
|
|
1167
|
+
preprocessed_images = []
|
|
1168
|
+
preprocessing_failed_indices = set()
|
|
1169
|
+
|
|
1170
|
+
for i_img, img in enumerate(img_original):
|
|
1171
|
+
|
|
1172
|
+
try:
|
|
1173
|
+
if isinstance(img, dict):
|
|
1174
|
+
# Already preprocessed
|
|
1175
|
+
image_info = img
|
|
1176
|
+
current_image_id = image_info['file']
|
|
1177
|
+
else:
|
|
1178
|
+
# Need to preprocess
|
|
1179
|
+
current_image_id = image_id[i_img]
|
|
1180
|
+
image_info = self.preprocess_image(
|
|
1181
|
+
img_original=img,
|
|
1182
|
+
image_id=current_image_id,
|
|
1183
|
+
image_size=image_size,
|
|
1184
|
+
verbose=verbose)
|
|
1185
|
+
|
|
1186
|
+
preprocessed_images.append((i_img, image_info, current_image_id))
|
|
1187
|
+
|
|
1188
|
+
except Exception as e:
|
|
1189
|
+
print('Warning: preprocessing failed for image {}: {}'.format(
|
|
1190
|
+
image_id[i_img] if image_id else f'index_{i_img}', str(e)))
|
|
1191
|
+
|
|
1192
|
+
preprocessing_failed_indices.add(i_img)
|
|
1193
|
+
current_image_id = image_id[i_img] if image_id else f'index_{i_img}'
|
|
1194
|
+
results[i_img] = {
|
|
1195
|
+
'file': current_image_id,
|
|
1196
|
+
'detections': None,
|
|
1197
|
+
'failure': FAILURE_IMAGE_OPEN
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
# ...for each image in this batch
|
|
1201
|
+
|
|
1202
|
+
# Group preprocessed images by actual processed image shape for batching
|
|
1203
|
+
shape_groups = {}
|
|
1204
|
+
for original_idx, image_info, current_image_id in preprocessed_images:
|
|
1205
|
+
# Use the actual processed image shape for grouping, not target_shape
|
|
1206
|
+
actual_shape = tuple(image_info['img_processed'].shape)
|
|
1207
|
+
if actual_shape not in shape_groups:
|
|
1208
|
+
shape_groups[actual_shape] = []
|
|
1209
|
+
shape_groups[actual_shape].append((original_idx, image_info, current_image_id))
|
|
1210
|
+
|
|
1211
|
+
# Process each shape group as a batch
|
|
1212
|
+
for target_shape, group_items in shape_groups.items():
|
|
1213
|
+
|
|
1214
|
+
try:
|
|
1215
|
+
self._process_batch_group(group_items, results, detection_threshold, augment, verbose)
|
|
1216
|
+
except Exception as e:
|
|
1217
|
+
# If inference fails for the entire batch, mark all images in this batch as failed
|
|
1218
|
+
print('Warning: batch inference failed for shape {}: {}'.format(target_shape, str(e)))
|
|
1219
|
+
|
|
1220
|
+
for original_idx, image_info, current_image_id in group_items:
|
|
1221
|
+
results[original_idx] = {
|
|
1222
|
+
'file': current_image_id,
|
|
1223
|
+
'detections': None,
|
|
1224
|
+
'failure': FAILURE_INFER
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
# ...for each shape group
|
|
1228
|
+
return results
|
|
1229
|
+
|
|
1230
|
+
# ...def generate_detections_one_batch(...)
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
def _process_batch_group(self, group_items, results, detection_threshold, augment, verbose):
|
|
1234
|
+
"""
|
|
1235
|
+
Process a group of images with the same target shape as a single batch.
|
|
1236
|
+
|
|
1237
|
+
Args:
|
|
1238
|
+
group_items (list): List of (original_idx, image_info, current_image_id) tuples
|
|
1239
|
+
results (list): Results list to populate (modified in place)
|
|
1240
|
+
detection_threshold (float): Detection confidence threshold
|
|
1241
|
+
augment (bool): Enable augmentation
|
|
1242
|
+
verbose (bool): Enable verbose output
|
|
1243
|
+
|
|
1244
|
+
Returns:
|
|
1245
|
+
list of dict: list of dictionaries the same length as group_items, with fields 'file',
|
|
1246
|
+
'detections', 'max_detection_conf'.
|
|
1247
|
+
"""
|
|
1248
|
+
|
|
1249
|
+
if len(group_items) == 0:
|
|
1250
|
+
return
|
|
1251
|
+
|
|
1252
|
+
# Extract batch data
|
|
1253
|
+
batch_images = []
|
|
1254
|
+
batch_metadata = []
|
|
1255
|
+
|
|
1256
|
+
# For each image in this batch...
|
|
1257
|
+
for original_idx, image_info, current_image_id in group_items:
|
|
1258
|
+
|
|
1259
|
+
img = image_info['img_processed']
|
|
1260
|
+
|
|
1261
|
+
# Convert HWC to CHW and prepare tensor
|
|
1262
|
+
img_tensor = img.transpose((2, 0, 1))
|
|
1263
|
+
img_tensor = np.ascontiguousarray(img_tensor)
|
|
1264
|
+
img_tensor = torch.from_numpy(img_tensor)
|
|
1265
|
+
batch_images.append(img_tensor)
|
|
1266
|
+
|
|
1267
|
+
metadata = {
|
|
1268
|
+
'original_idx': original_idx,
|
|
1269
|
+
'current_image_id': current_image_id,
|
|
1270
|
+
'scaling_shape': image_info['scaling_shape'],
|
|
1271
|
+
'letterbox_pad': image_info['letterbox_pad'],
|
|
1272
|
+
'img_original': image_info['img_original']
|
|
1273
|
+
}
|
|
1274
|
+
batch_metadata.append(metadata)
|
|
1275
|
+
|
|
1276
|
+
# ...for each image in this batch
|
|
1277
|
+
|
|
1278
|
+
# Stack images into a batch tensor
|
|
1279
|
+
batch_tensor = torch.stack(batch_images)
|
|
1280
|
+
|
|
1281
|
+
batch_tensor = batch_tensor.float()
|
|
1282
|
+
batch_tensor /= 255.0
|
|
1283
|
+
|
|
1284
|
+
batch_tensor = batch_tensor.to(self.device)
|
|
1285
|
+
if self.half_precision:
|
|
1286
|
+
batch_tensor = batch_tensor.half()
|
|
1287
|
+
|
|
1288
|
+
# Run the model on the batch
|
|
1289
|
+
pred = self.model(batch_tensor, augment=augment)[0]
|
|
1290
|
+
|
|
1291
|
+
# Configure NMS parameters
|
|
1292
|
+
if 'classic' in self.compatibility_mode:
|
|
1293
|
+
nms_iou_thres = 0.45
|
|
1294
|
+
else:
|
|
1295
|
+
nms_iou_thres = 0.6
|
|
1296
|
+
|
|
1297
|
+
use_library_nms = False
|
|
1298
|
+
|
|
1299
|
+
# Model output format changed in recent ultralytics packages, and the nms implementation
|
|
1300
|
+
# in this module hasn't been updated to handle that format yet.
|
|
1301
|
+
if (yolo_model_type_imported is not None) and (yolo_model_type_imported == 'ultralytics'):
|
|
1302
|
+
use_library_nms = True
|
|
1303
|
+
|
|
1304
|
+
if use_library_nms:
|
|
1305
|
+
pred = non_max_suppression(prediction=pred,
|
|
1306
|
+
conf_thres=detection_threshold,
|
|
1307
|
+
iou_thres=nms_iou_thres,
|
|
1308
|
+
agnostic=False,
|
|
1309
|
+
multi_label=False)
|
|
1310
|
+
else:
|
|
1311
|
+
pred = nms(prediction=pred,
|
|
1312
|
+
conf_thres=detection_threshold,
|
|
1313
|
+
iou_thres=nms_iou_thres)
|
|
1314
|
+
|
|
1315
|
+
assert isinstance(pred, list)
|
|
1316
|
+
assert len(pred) == len(batch_metadata), \
|
|
1317
|
+
'Mismatch between prediction length {} and batch size {}'.format(
|
|
1318
|
+
len(pred),len(batch_metadata))
|
|
1319
|
+
|
|
1320
|
+
# Process each image's detections
|
|
1321
|
+
for i_image, det in enumerate(pred):
|
|
1322
|
+
|
|
1323
|
+
metadata = batch_metadata[i_image]
|
|
1324
|
+
original_idx = metadata['original_idx']
|
|
1325
|
+
current_image_id = metadata['current_image_id']
|
|
1326
|
+
scaling_shape = metadata['scaling_shape']
|
|
1327
|
+
letterbox_pad = metadata['letterbox_pad']
|
|
1328
|
+
img_original = metadata['img_original']
|
|
1329
|
+
|
|
1330
|
+
detections = []
|
|
1331
|
+
max_conf = 0.0
|
|
1332
|
+
|
|
1333
|
+
if len(det) > 0:
|
|
1334
|
+
|
|
1335
|
+
# Prepare scaling parameters
|
|
1336
|
+
gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
|
|
1337
|
+
|
|
1338
|
+
if 'classic' in self.compatibility_mode:
|
|
1339
|
+
ratio = None
|
|
1340
|
+
ratio_pad = None
|
|
1341
|
+
else:
|
|
1342
|
+
ratio = (img_original.shape[0]/scaling_shape[0],
|
|
1343
|
+
img_original.shape[1]/scaling_shape[1])
|
|
1344
|
+
ratio_pad = (ratio, letterbox_pad)
|
|
1345
|
+
|
|
1346
|
+
# Rescale boxes
|
|
1347
|
+
if 'classic' in self.compatibility_mode:
|
|
1348
|
+
det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], img_original.shape).round()
|
|
1349
|
+
else:
|
|
1350
|
+
det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
|
|
1351
|
+
|
|
1352
|
+
# Process each detection
|
|
1353
|
+
for *xyxy, conf, cls in reversed(det):
|
|
1354
|
+
if conf < detection_threshold:
|
|
1355
|
+
continue
|
|
1356
|
+
|
|
1357
|
+
# Convert to YOLO format then to MD format
|
|
1358
|
+
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
|
|
1359
|
+
api_box = ct_utils.convert_yolo_to_xywh(xywh)
|
|
1360
|
+
|
|
1361
|
+
if 'classic' in self.compatibility_mode:
|
|
1362
|
+
api_box = ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
|
|
1363
|
+
conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
|
|
1364
|
+
else:
|
|
1365
|
+
api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
|
|
1366
|
+
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
|
|
1367
|
+
|
|
1368
|
+
if not self.use_model_native_classes:
|
|
1369
|
+
cls = int(cls.tolist()) + 1
|
|
1370
|
+
if cls not in (1, 2, 3):
|
|
1371
|
+
raise KeyError(f'{cls} is not a valid class.')
|
|
1372
|
+
else:
|
|
1373
|
+
cls = int(cls.tolist())
|
|
1374
|
+
|
|
1375
|
+
detections.append({
|
|
1376
|
+
'category': str(cls),
|
|
1377
|
+
'conf': conf,
|
|
1378
|
+
'bbox': api_box
|
|
1379
|
+
})
|
|
1380
|
+
max_conf = max(max_conf, conf)
|
|
1381
|
+
|
|
1382
|
+
# ...for each detection
|
|
1383
|
+
|
|
1384
|
+
# ...if there are > 0 detections
|
|
1385
|
+
|
|
1386
|
+
# Store result for this image
|
|
1387
|
+
results[original_idx] = {
|
|
1388
|
+
'file': current_image_id,
|
|
1389
|
+
'detections': detections,
|
|
1390
|
+
'max_detection_conf': max_conf
|
|
1391
|
+
}
|
|
1392
|
+
|
|
1393
|
+
# ...for each image
|
|
1394
|
+
|
|
1395
|
+
# ...def _process_batch_group(...)
|
|
1396
|
+
|
|
1397
|
+
def generate_detections_one_image(self,
|
|
1398
|
+
img_original,
|
|
1399
|
+
image_id='unknown',
|
|
1400
|
+
detection_threshold=0.00001,
|
|
1401
|
+
image_size=None,
|
|
1402
|
+
augment=False,
|
|
1403
|
+
verbose=False):
|
|
1404
|
+
"""
|
|
1405
|
+
Run a detector on an image (wrapper around batch function).
|
|
1406
|
+
|
|
1407
|
+
Args:
|
|
1408
|
+
img_original (Image, np.array, or dict): the image on which we should run the detector, with
|
|
1409
|
+
EXIF rotation already handled, or a dict representing a preprocessed image with associated
|
|
1410
|
+
letterbox parameters
|
|
1411
|
+
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
1412
|
+
of the output object
|
|
1413
|
+
detection_threshold (float, optional): only detections above this confidence threshold
|
|
1414
|
+
will be included in the return value
|
|
1415
|
+
image_size (int, optional): image size (long side) to use for inference, or None to
|
|
1416
|
+
use the default size specified at the time the model was loaded
|
|
1417
|
+
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
1418
|
+
verbose (bool, optional): enable additional debug output
|
|
1419
|
+
|
|
1420
|
+
Returns:
|
|
1421
|
+
dict: a dictionary with the following fields:
|
|
1422
|
+
- 'file' (filename, always present)
|
|
1423
|
+
- 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
|
|
1424
|
+
- 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
|
|
1425
|
+
- 'failure' (a failure string, or None if everything went fine)
|
|
1426
|
+
"""
|
|
1427
|
+
|
|
1428
|
+
# Prepare batch inputs
|
|
1429
|
+
if isinstance(img_original, dict):
|
|
1430
|
+
batch_results = self.generate_detections_one_batch(
|
|
1431
|
+
img_original=[img_original],
|
|
1432
|
+
image_id=None,
|
|
1433
|
+
detection_threshold=detection_threshold,
|
|
1434
|
+
image_size=image_size,
|
|
1435
|
+
augment=augment,
|
|
1436
|
+
verbose=verbose)
|
|
1437
|
+
else:
|
|
1438
|
+
batch_results = self.generate_detections_one_batch(
|
|
1439
|
+
img_original=[img_original],
|
|
1440
|
+
image_id=[image_id],
|
|
1441
|
+
detection_threshold=detection_threshold,
|
|
1442
|
+
image_size=image_size,
|
|
1443
|
+
augment=augment,
|
|
1444
|
+
verbose=verbose)
|
|
1445
|
+
|
|
1446
|
+
# Return the single result
|
|
1447
|
+
return batch_results[0]
|
|
1448
|
+
|
|
1449
|
+
# ...def generate_detections_one_image(...)
|
|
1450
|
+
|
|
1451
|
+
# ...class PTDetector
|