megadetector 5.0.22__py3-none-any.whl → 5.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +2 -3
- megadetector/classification/merge_classification_detection_output.py +2 -2
- megadetector/data_management/coco_to_labelme.py +2 -1
- megadetector/data_management/databases/integrity_check_json_db.py +15 -14
- megadetector/data_management/databases/subset_json_db.py +49 -21
- megadetector/data_management/mewc_to_md.py +340 -0
- megadetector/data_management/wi_to_md.py +41 -0
- megadetector/data_management/yolo_output_to_md_output.py +15 -8
- megadetector/detection/process_video.py +24 -7
- megadetector/detection/pytorch_detector.py +841 -160
- megadetector/detection/run_detector.py +340 -146
- megadetector/detection/run_detector_batch.py +306 -70
- megadetector/detection/run_inference_with_yolov5_val.py +61 -4
- megadetector/detection/tf_detector.py +6 -1
- megadetector/postprocessing/{combine_api_outputs.py → combine_batch_outputs.py} +10 -13
- megadetector/postprocessing/compare_batch_results.py +68 -6
- megadetector/postprocessing/md_to_labelme.py +7 -7
- megadetector/postprocessing/md_to_wi.py +40 -0
- megadetector/postprocessing/merge_detections.py +1 -1
- megadetector/postprocessing/postprocess_batch_results.py +10 -3
- megadetector/postprocessing/separate_detections_into_folders.py +32 -4
- megadetector/postprocessing/validate_batch_results.py +9 -4
- megadetector/utils/ct_utils.py +172 -57
- megadetector/utils/gpu_test.py +107 -0
- megadetector/utils/md_tests.py +363 -108
- megadetector/utils/path_utils.py +9 -2
- megadetector/utils/wi_utils.py +1794 -0
- megadetector/visualization/visualization_utils.py +82 -16
- megadetector/visualization/visualize_db.py +25 -7
- megadetector/visualization/visualize_detector_output.py +60 -13
- {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/LICENSE +0 -0
- {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/METADATA +129 -143
- {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/RECORD +35 -33
- {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/top_level.txt +0 -0
- megadetector/detection/detector_training/__init__.py +0 -0
- megadetector/detection/detector_training/model_main_tf2.py +0 -114
- megadetector/utils/torch_test.py +0 -32
- {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/WHEEL +0 -0
|
@@ -2,17 +2,31 @@
|
|
|
2
2
|
|
|
3
3
|
pytorch_detector.py
|
|
4
4
|
|
|
5
|
-
Module to run MegaDetector
|
|
5
|
+
Module to run YOLO-based MegaDetector models.
|
|
6
6
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
#%% Imports and constants
|
|
10
10
|
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
import math
|
|
14
|
+
import zipfile
|
|
15
|
+
import tempfile
|
|
16
|
+
import shutil
|
|
17
|
+
import traceback
|
|
18
|
+
import uuid
|
|
19
|
+
import json
|
|
20
|
+
|
|
21
|
+
import cv2
|
|
11
22
|
import torch
|
|
12
23
|
import numpy as np
|
|
13
|
-
import traceback
|
|
14
24
|
|
|
15
|
-
from megadetector.detection.run_detector import
|
|
25
|
+
from megadetector.detection.run_detector import \
|
|
26
|
+
CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, \
|
|
27
|
+
get_detector_version_from_model_file, \
|
|
28
|
+
known_models
|
|
29
|
+
from megadetector.utils.ct_utils import parse_bool_string
|
|
16
30
|
from megadetector.utils import ct_utils
|
|
17
31
|
|
|
18
32
|
# We support a few ways of accessing the YOLOv5 dependencies:
|
|
@@ -31,95 +45,575 @@ from megadetector.utils import ct_utils
|
|
|
31
45
|
# * Unfinished:
|
|
32
46
|
#
|
|
33
47
|
# pip install ultralytics
|
|
34
|
-
#
|
|
35
|
-
# If try_ultralytics_import is True, we'll try to import all YOLOv5 dependencies from
|
|
36
|
-
# ultralytics.utils and ultralytics.data. But as of 2023.11, this results in a "No
|
|
37
|
-
# module named 'models'" error when running MDv5, and there's no upside to this approach
|
|
38
|
-
# compared to using either of the YOLOv5 PyPI packages, so... punting on this for now.
|
|
39
48
|
|
|
40
|
-
|
|
41
|
-
|
|
49
|
+
yolo_model_type_imported = None
|
|
50
|
+
|
|
51
|
+
def _get_model_type_for_model(model_file,
|
|
52
|
+
prefer_model_type_source='table',
|
|
53
|
+
default_model_type='yolov5',
|
|
54
|
+
verbose=False):
|
|
55
|
+
"""
|
|
56
|
+
Determine the model type (i.e., the inference library we need to use) for a .pt file.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
model_file (str): the model file to read
|
|
60
|
+
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
61
|
+
case where the metadata in the file indicates one model type, but the global model
|
|
62
|
+
type table says something else. Should be "table" (trust the table) or "file"
|
|
63
|
+
(trust the file).
|
|
64
|
+
default_model_type (str, optional): return value for the case where we can't find
|
|
65
|
+
appropriate metadata in the file or in the global table.
|
|
66
|
+
verbose (bool, optional): enable additional debug output
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
str: the model type indicated for this model
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
model_info = read_metadata_from_megadetector_model_file(model_file)
|
|
73
|
+
|
|
74
|
+
# Check whether the model file itself specified a model type
|
|
75
|
+
model_type_from_model_file_metadata = None
|
|
76
|
+
|
|
77
|
+
if model_info is not None and 'model_type' in model_info:
|
|
78
|
+
model_type_from_model_file_metadata = model_info['model_type']
|
|
79
|
+
if verbose:
|
|
80
|
+
print('Parsed model type {} from model {}'.format(
|
|
81
|
+
model_type_from_model_file_metadata,
|
|
82
|
+
model_file))
|
|
83
|
+
|
|
84
|
+
model_type_from_model_version = None
|
|
85
|
+
|
|
86
|
+
# Check whether this is a known model version with a specific model type
|
|
87
|
+
model_version_from_file = get_detector_version_from_model_file(model_file)
|
|
88
|
+
|
|
89
|
+
if model_version_from_file is not None and model_version_from_file in known_models:
|
|
90
|
+
model_info = known_models[model_version_from_file]
|
|
91
|
+
if 'model_type' in model_info:
|
|
92
|
+
model_type_from_model_version = model_info['model_type']
|
|
93
|
+
if verbose:
|
|
94
|
+
print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
|
|
95
|
+
else:
|
|
96
|
+
model_type_from_model_version = None
|
|
97
|
+
|
|
98
|
+
if model_type_from_model_file_metadata is None and \
|
|
99
|
+
model_type_from_model_version is None:
|
|
100
|
+
if verbose:
|
|
101
|
+
print('Could not determine model type for {}, assuming {}'.format(
|
|
102
|
+
model_file,default_model_type))
|
|
103
|
+
model_type = default_model_type
|
|
104
|
+
|
|
105
|
+
elif model_type_from_model_file_metadata is not None and \
|
|
106
|
+
model_type_from_model_version is not None:
|
|
107
|
+
if model_type_from_model_version == model_type_from_model_file_metadata:
|
|
108
|
+
model_type = model_type_from_model_file_metadata
|
|
109
|
+
else:
|
|
110
|
+
print('Waring: model type from model version is {}, from file metadata is {}'.format(
|
|
111
|
+
model_type_from_model_version,model_type_from_model_file_metadata))
|
|
112
|
+
if prefer_model_type_source == 'table':
|
|
113
|
+
model_type = model_type_from_model_file_metadata
|
|
114
|
+
else:
|
|
115
|
+
model_type = model_type_from_model_version
|
|
116
|
+
|
|
117
|
+
elif model_type_from_model_file_metadata is not None:
|
|
118
|
+
|
|
119
|
+
model_type = model_type_from_model_file_metadata
|
|
120
|
+
|
|
121
|
+
elif model_type_from_model_version is not None:
|
|
122
|
+
|
|
123
|
+
model_type = model_type_from_model_version
|
|
124
|
+
|
|
125
|
+
return model_type
|
|
126
|
+
|
|
127
|
+
# ...def _get_model_type_for_model(...)
|
|
42
128
|
|
|
43
|
-
# See above; this should remain as "False" unless we update the MegaDetector .pt file
|
|
44
|
-
# to use more recent YOLOv5 namespace conventions.
|
|
45
|
-
try_ultralytics_import = False
|
|
46
129
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
130
|
+
def _initialize_yolo_imports_for_model(model_file,
|
|
131
|
+
prefer_model_type_source='table',
|
|
132
|
+
default_model_type='yolov5',
|
|
133
|
+
detector_options=None,
|
|
134
|
+
verbose=False):
|
|
135
|
+
"""
|
|
136
|
+
Initialize the appropriate YOLO imports for a model file.
|
|
50
137
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
138
|
+
Args:
|
|
139
|
+
model_file (str): The model file for which we're loading support
|
|
140
|
+
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
141
|
+
case where the metadata in the file indicates one model type, but the global model
|
|
142
|
+
type table says something else. Should be "table" (trust the table) or "file"
|
|
143
|
+
(trust the file).
|
|
144
|
+
default_model_type (str, optional): return value for the case where we can't find
|
|
145
|
+
appropriate metadata in the file or in the global table.
|
|
146
|
+
detector_options (dict, optional): dictionary of detector options that mean
|
|
147
|
+
different things to different models
|
|
148
|
+
verbose (bool, optional): enable additonal debug output
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
str: the model type for which we initialized support
|
|
152
|
+
"""
|
|
63
153
|
|
|
64
|
-
try:
|
|
65
|
-
from ultralytics.utils.ops import non_max_suppression # noqa
|
|
66
|
-
from ultralytics.utils.ops import xyxy2xywh # noqa
|
|
67
|
-
from ultralytics.utils.ops import scale_coords # noqa
|
|
68
|
-
from ultralytics.data.augment import LetterBox
|
|
69
|
-
|
|
70
|
-
# letterbox() became a LetterBox class in the ultralytics package
|
|
71
|
-
def letterbox(img,new_shape,stride,auto=True): # noqa
|
|
72
|
-
L = LetterBox(new_shape,stride=stride,auto=auto)
|
|
73
|
-
letterbox_result = L(image=img)
|
|
74
|
-
return [letterbox_result]
|
|
75
|
-
utils_imported = True
|
|
76
|
-
print('Imported YOLOv5 from ultralytics package')
|
|
77
|
-
except Exception:
|
|
78
|
-
# print('Ultralytics module import failed, falling back to yolov5 import')
|
|
79
|
-
pass
|
|
80
|
-
|
|
81
|
-
# If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
|
|
82
|
-
if not utils_imported:
|
|
83
154
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
155
|
+
global yolo_model_type_imported
|
|
156
|
+
|
|
157
|
+
if detector_options is not None and 'model_type' in detector_options:
|
|
158
|
+
model_type = detector_options['model_type']
|
|
159
|
+
print('Model type {} provided in detector options'.format(model_type))
|
|
160
|
+
else:
|
|
161
|
+
model_type = _get_model_type_for_model(model_file,
|
|
162
|
+
prefer_model_type_source=prefer_model_type_source,
|
|
163
|
+
default_model_type=default_model_type)
|
|
164
|
+
|
|
165
|
+
if yolo_model_type_imported is not None:
|
|
166
|
+
if model_type == yolo_model_type_imported:
|
|
167
|
+
print('Bypassing imports for model type {}'.format(model_type))
|
|
168
|
+
return
|
|
169
|
+
else:
|
|
170
|
+
print('Previously set up imports for model type {}, re-importing as {}'.format(
|
|
171
|
+
yolo_model_type_imported,model_type))
|
|
172
|
+
|
|
173
|
+
_initialize_yolo_imports(model_type,verbose=verbose)
|
|
174
|
+
|
|
175
|
+
return model_type
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _clean_yolo_imports(verbose=False):
|
|
179
|
+
"""
|
|
180
|
+
Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
|
|
181
|
+
of another YOLO library version. The reason we jump through all these hoops, rather than
|
|
182
|
+
just, e.g., handling different libraries in different modules, is that we need to make sure
|
|
183
|
+
*pickle* sees the right version of modules during module loading, including modules we don't
|
|
184
|
+
load directly (i.e., every module loaded within a YOLO library), and the only way I know to
|
|
185
|
+
do that is to remove all the "wrong" versions from sys.modules and sys.path.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
verbose (bool, optional): enable additional debug output
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
modules_to_delete = []
|
|
192
|
+
for module_name in sys.modules.keys():
|
|
193
|
+
module = sys.modules[module_name]
|
|
194
|
+
try:
|
|
195
|
+
module_file = module.__file__.replace('\\','/')
|
|
196
|
+
if 'site-packages' not in module_file:
|
|
197
|
+
continue
|
|
198
|
+
tokens = module_file.split('/')[-4:]
|
|
199
|
+
for token in tokens:
|
|
200
|
+
if 'yolov5' in token or 'yolov9' in token or 'ultralytics' in token:
|
|
201
|
+
modules_to_delete.append(module_name)
|
|
202
|
+
break
|
|
203
|
+
except Exception:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
for module_name in modules_to_delete:
|
|
207
|
+
if module_name in sys.modules.keys():
|
|
208
|
+
module_file = module.__file__.replace('\\','/')
|
|
209
|
+
if verbose:
|
|
210
|
+
print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
|
|
211
|
+
del sys.modules[module_name]
|
|
212
|
+
|
|
213
|
+
paths_to_delete = []
|
|
214
|
+
|
|
215
|
+
for p in sys.path:
|
|
216
|
+
if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
|
|
217
|
+
print('clean_yolo_imports: removing {} from path'.format(p))
|
|
218
|
+
paths_to_delete.append(p)
|
|
219
|
+
|
|
220
|
+
for p in paths_to_delete:
|
|
221
|
+
sys.path.remove(p)
|
|
222
|
+
|
|
223
|
+
# ...def _clean_yolo_imports(...)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _initialize_yolo_imports(model_type='yolov5',
|
|
227
|
+
allow_fallback_import=True,
|
|
228
|
+
force_reimport=False,
|
|
229
|
+
verbose=False):
|
|
230
|
+
"""
|
|
231
|
+
Imports required functions from one or more yolo libraries (yolov5, yolov9,
|
|
232
|
+
ultralytics, targeting support for [model_type]).
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
model_type (str): The model type for which we're loading support
|
|
236
|
+
allow_fallback_import (bool, optional): If we can't import from the package for
|
|
237
|
+
which we're trying to load support, fall back to "import utils". This is
|
|
238
|
+
typically used when the right support library is on the current PYTHONPATH.
|
|
239
|
+
force_reimport (bool, optional): import the appropriate libraries even if the
|
|
240
|
+
requested model type matches the current initialization state
|
|
241
|
+
verbose (bool, optional): include additonal debug output
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
str: the model type for which we initialized support
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
global yolo_model_type_imported
|
|
248
|
+
|
|
249
|
+
if model_type is None:
|
|
250
|
+
model_type = 'yolov5'
|
|
251
|
+
|
|
252
|
+
# The point of this function is to make the appropriate version
|
|
253
|
+
# of the following functions available at module scope
|
|
254
|
+
global non_max_suppression
|
|
255
|
+
global xyxy2xywh
|
|
256
|
+
global letterbox
|
|
257
|
+
global scale_coords
|
|
258
|
+
|
|
259
|
+
if yolo_model_type_imported is not None:
|
|
260
|
+
if yolo_model_type_imported == model_type:
|
|
261
|
+
print('Bypassing imports for YOLO model type {}'.format(model_type))
|
|
262
|
+
return
|
|
263
|
+
else:
|
|
264
|
+
_clean_yolo_imports()
|
|
265
|
+
|
|
266
|
+
try_yolov5_import = (model_type == 'yolov5')
|
|
267
|
+
try_yolov9_import = (model_type == 'yolov9')
|
|
268
|
+
try_ultralytics_import = (model_type == 'ultralytics')
|
|
269
|
+
|
|
270
|
+
utils_imported = False
|
|
271
|
+
|
|
272
|
+
# First try importing from the yolov5 package; this is how the pip
|
|
273
|
+
# package finds YOLOv5 utilities.
|
|
274
|
+
if try_yolov5_import and not utils_imported:
|
|
275
|
+
|
|
276
|
+
try:
|
|
277
|
+
from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
278
|
+
from yolov5.utils.augmentations import letterbox # noqa
|
|
279
|
+
try:
|
|
280
|
+
from yolov5.utils.general import scale_boxes as scale_coords
|
|
281
|
+
except Exception:
|
|
282
|
+
from yolov5.utils.general import scale_coords
|
|
283
|
+
utils_imported = True
|
|
284
|
+
if verbose:
|
|
285
|
+
print('Imported utils from YOLOv5 package')
|
|
286
|
+
|
|
287
|
+
except Exception as e: # noqa
|
|
288
|
+
|
|
289
|
+
# print('yolov5 module import failed: {}'.format(e))
|
|
290
|
+
# print(traceback.format_exc())
|
|
291
|
+
pass
|
|
292
|
+
|
|
293
|
+
# Next try importing from the yolov9 package
|
|
294
|
+
if try_yolov9_import and not utils_imported:
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
|
|
298
|
+
from yolov9.utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
299
|
+
from yolov9.utils.augmentations import letterbox # noqa
|
|
300
|
+
from yolov9.utils.general import scale_boxes as scale_coords # noqa
|
|
301
|
+
utils_imported = True
|
|
302
|
+
if verbose:
|
|
303
|
+
print('Imported utils from YOLOv9 package')
|
|
304
|
+
|
|
305
|
+
except Exception as e: # noqa
|
|
306
|
+
|
|
307
|
+
# print('yolov9 module import failed: {}'.format(e))
|
|
308
|
+
# print(traceback.format_exc())
|
|
309
|
+
pass
|
|
310
|
+
|
|
311
|
+
# If we haven't succeeded yet, import from the ultralytics package
|
|
312
|
+
if try_ultralytics_import and not utils_imported:
|
|
88
313
|
|
|
89
|
-
# scale_coords() became scale_boxes() in later YOLOv5 versions
|
|
90
314
|
try:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
315
|
+
|
|
316
|
+
import ultralytics # noqa
|
|
317
|
+
|
|
318
|
+
except Exception:
|
|
319
|
+
|
|
320
|
+
print('It looks like you are trying to run a model that requires the ultralytics package, '
|
|
321
|
+
'but the ultralytics package is not installed, but . For licensing reasons, this '
|
|
322
|
+
'is not installed by default with the MegaDetector Python package. Run '
|
|
323
|
+
'"pip install ultralytics" to install it, and try again.')
|
|
324
|
+
raise
|
|
325
|
+
|
|
326
|
+
try:
|
|
327
|
+
|
|
328
|
+
from ultralytics.utils.ops import non_max_suppression # noqa
|
|
329
|
+
from ultralytics.utils.ops import xyxy2xywh # noqa
|
|
330
|
+
|
|
331
|
+
# In the ultralytics package, scale_boxes and scale_coords both exist;
|
|
332
|
+
# we want scale_boxes.
|
|
333
|
+
#
|
|
334
|
+
# from ultralytics.utils.ops import scale_coords # noqa
|
|
335
|
+
from ultralytics.utils.ops import scale_boxes as scale_coords # noqa
|
|
336
|
+
from ultralytics.data.augment import LetterBox
|
|
337
|
+
|
|
338
|
+
# letterbox() became a LetterBox class in the ultralytics package. Create a
|
|
339
|
+
# backwards-compatible letterbox function wrapper that wraps the class up.
|
|
340
|
+
def letterbox(img,new_shape,auto=False,scaleFill=False,scaleup=True,center=True,stride=32): # noqa
|
|
341
|
+
|
|
342
|
+
L = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,scaleup=scaleup,center=center,stride=stride)
|
|
343
|
+
letterbox_result = L(image=img)
|
|
344
|
+
|
|
345
|
+
if isinstance(new_shape,int):
|
|
346
|
+
new_shape = [new_shape,new_shape]
|
|
347
|
+
|
|
348
|
+
# The letterboxing is done, we just need to reverse-engineer what it did
|
|
349
|
+
shape = img.shape[:2]
|
|
96
350
|
|
|
97
|
-
|
|
98
|
-
|
|
351
|
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
352
|
+
if not scaleup:
|
|
353
|
+
r = min(r, 1.0)
|
|
354
|
+
ratio = r, r
|
|
355
|
+
|
|
356
|
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
357
|
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
|
358
|
+
if auto:
|
|
359
|
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride)
|
|
360
|
+
elif scaleFill:
|
|
361
|
+
dw, dh = 0.0, 0.0
|
|
362
|
+
new_unpad = (new_shape[1], new_shape[0])
|
|
363
|
+
ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
|
|
364
|
+
|
|
365
|
+
dw /= 2
|
|
366
|
+
dh /= 2
|
|
367
|
+
pad = (dw,dh)
|
|
368
|
+
|
|
369
|
+
return [letterbox_result,ratio,pad]
|
|
370
|
+
|
|
371
|
+
utils_imported = True
|
|
372
|
+
if verbose:
|
|
373
|
+
print('Imported utils from ultralytics package')
|
|
374
|
+
|
|
375
|
+
except Exception:
|
|
376
|
+
|
|
377
|
+
# print('Ultralytics module import failed')
|
|
378
|
+
pass
|
|
379
|
+
|
|
380
|
+
# If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
|
|
381
|
+
if (not utils_imported) and allow_fallback_import:
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
|
|
385
|
+
# import pre- and post-processing functions from the YOLOv5 repo
|
|
386
|
+
from utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
387
|
+
from utils.augmentations import letterbox # noqa
|
|
388
|
+
|
|
389
|
+
# scale_coords() is scale_boxes() in some YOLOv5 versions
|
|
390
|
+
try:
|
|
391
|
+
from utils.general import scale_coords # noqa
|
|
392
|
+
except ImportError:
|
|
393
|
+
from utils.general import scale_boxes as scale_coords
|
|
394
|
+
utils_imported = True
|
|
395
|
+
imported_file = sys.modules[scale_coords.__module__].__file__
|
|
396
|
+
if verbose:
|
|
397
|
+
print('Imported utils from {}'.format(imported_file))
|
|
398
|
+
|
|
399
|
+
except ModuleNotFoundError as e:
|
|
400
|
+
|
|
401
|
+
raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
|
|
402
|
+
|
|
403
|
+
assert utils_imported, 'YOLO utils import error'
|
|
404
|
+
|
|
405
|
+
yolo_model_type_imported = model_type
|
|
406
|
+
if verbose:
|
|
407
|
+
print('Prepared YOLO imports for model type {}'.format(model_type))
|
|
408
|
+
|
|
409
|
+
return model_type
|
|
99
410
|
|
|
100
|
-
|
|
411
|
+
# ...def _initialize_yolo_imports(...)
|
|
412
|
+
|
|
101
413
|
|
|
102
|
-
|
|
414
|
+
#%% Model metadata functions
|
|
103
415
|
|
|
416
|
+
def add_metadata_to_megadetector_model_file(model_file_in,
|
|
417
|
+
model_file_out,
|
|
418
|
+
metadata,
|
|
419
|
+
destination_path='megadetector_info.json'):
|
|
420
|
+
"""
|
|
421
|
+
Adds a .json file to the specified MegaDetector model file containing metadata used
|
|
422
|
+
by this module. Always over-writes the output file.
|
|
104
423
|
|
|
105
|
-
|
|
424
|
+
Args:
|
|
425
|
+
model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
|
|
426
|
+
model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
|
|
427
|
+
May be the same as model_file_in.
|
|
428
|
+
metadata (dict): The metadata dict to add to the output model file
|
|
429
|
+
destination_path (str, optional): The relative path within the main folder of the
|
|
430
|
+
model archive where we should write the metadata. This is not relative to the root
|
|
431
|
+
of the archive, it's relative to the one and only folder at the root of the archive
|
|
432
|
+
(this is a PyTorch convention).
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
|
|
436
|
+
os.makedirs(tmp_base,exist_ok=True)
|
|
437
|
+
metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
|
|
438
|
+
metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
|
|
106
439
|
|
|
107
|
-
|
|
440
|
+
with open(metadata_tmp_file_abs,'w') as f:
|
|
441
|
+
json.dump(metadata,f,indent=1)
|
|
442
|
+
|
|
443
|
+
# Copy the input file to the output file
|
|
444
|
+
shutil.copyfile(model_file_in,model_file_out)
|
|
108
445
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
446
|
+
# Write metadata to the output file
|
|
447
|
+
with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
|
|
448
|
+
|
|
449
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
450
|
+
# it in the one and only folder.
|
|
451
|
+
names = zipf.namelist()
|
|
452
|
+
root_folders = set()
|
|
453
|
+
for name in names:
|
|
454
|
+
root_folder = name.split('/')[0]
|
|
455
|
+
root_folders.add(root_folder)
|
|
456
|
+
assert len(root_folders) == 1,\
|
|
457
|
+
'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
|
|
458
|
+
root_folder = next(iter(root_folders))
|
|
459
|
+
|
|
460
|
+
zipf.write(metadata_tmp_file_abs,
|
|
461
|
+
root_folder + '/' + destination_path,
|
|
462
|
+
compresslevel=9,
|
|
463
|
+
compress_type=zipfile.ZIP_DEFLATED)
|
|
464
|
+
|
|
465
|
+
try:
|
|
466
|
+
os.remove(metadata_tmp_file_abs)
|
|
467
|
+
except Exception as e:
|
|
468
|
+
print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
|
|
469
|
+
|
|
470
|
+
# ...def add_metadata_to_megadetector_model_file(...)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def read_metadata_from_megadetector_model_file(model_file,
|
|
474
|
+
relative_path='megadetector_info.json',
|
|
475
|
+
verbose=False):
|
|
476
|
+
"""
|
|
477
|
+
Reads custom MegaDetector metadata from a modified MegaDetector model file.
|
|
114
478
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
479
|
+
Args:
|
|
480
|
+
model_file (str): The model filename to read, typically .pt (.zip is also sensible)
|
|
481
|
+
relative_path (str, optional): The relative path within the main folder of the model
|
|
482
|
+
archive from which we should read the metadata. This is not relative to the root
|
|
483
|
+
of the archive, it's relative to the one and only folder at the root of the archive
|
|
484
|
+
(this is a PyTorch convention).
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
object: Whatever we read from the metadata file, always a dict in practice. Returns
|
|
488
|
+
None if we failed to read the specified metadata file.
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
with zipfile.ZipFile(model_file,'r') as zipf:
|
|
492
|
+
|
|
493
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
494
|
+
# it in the one and only folder.
|
|
495
|
+
names = zipf.namelist()
|
|
496
|
+
root_folders = set()
|
|
497
|
+
for name in names:
|
|
498
|
+
root_folder = name.split('/')[0]
|
|
499
|
+
root_folders.add(root_folder)
|
|
500
|
+
if len(root_folders) != 1:
|
|
501
|
+
print('Warning: this archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?')
|
|
502
|
+
return None
|
|
503
|
+
root_folder = next(iter(root_folders))
|
|
504
|
+
|
|
505
|
+
metadata_file = root_folder + '/' + relative_path
|
|
506
|
+
if metadata_file not in names:
|
|
507
|
+
# This is the case for MDv5a and MDv5b
|
|
508
|
+
if verbose:
|
|
509
|
+
print('Warning: could not find metadata file {} in zip archive'.format(metadata_file))
|
|
510
|
+
return None
|
|
511
|
+
|
|
512
|
+
try:
|
|
513
|
+
path = zipfile.Path(zipf,metadata_file)
|
|
514
|
+
contents = path.read_text()
|
|
515
|
+
d = json.loads(contents)
|
|
516
|
+
except Exception as e:
|
|
517
|
+
print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
|
|
518
|
+
return None
|
|
519
|
+
|
|
520
|
+
return d
|
|
521
|
+
|
|
522
|
+
# ...def read_metadata_from_megadetector_model_file(...)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
#%% Inference classes
|
|
526
|
+
|
|
527
|
+
default_compatibility_mode = 'classic'
|
|
528
|
+
|
|
529
|
+
# This is a useful hack when I want to verify that my test driver (md_tests.py) is
|
|
530
|
+
# correctly forcing a specific compabitility mode (I use "classic-test" in that case)
|
|
531
|
+
require_non_default_compatibility_mode = False
|
|
119
532
|
|
|
120
|
-
|
|
533
|
+
class PTDetector:
|
|
534
|
+
|
|
535
|
+
def __init__(self, model_path, detector_options=None, verbose=False):
|
|
536
|
+
|
|
537
|
+
# Set up the import environment for this model, unloading previous
|
|
538
|
+
# YOLO library versions if necessary.
|
|
539
|
+
_initialize_yolo_imports_for_model(model_path,
|
|
540
|
+
detector_options=detector_options,
|
|
541
|
+
verbose=verbose)
|
|
542
|
+
|
|
543
|
+
# Parse options specific to this detector family
|
|
544
|
+
force_cpu = False
|
|
545
|
+
use_model_native_classes = False
|
|
546
|
+
compatibility_mode = default_compatibility_mode
|
|
547
|
+
|
|
548
|
+
if detector_options is not None:
|
|
549
|
+
|
|
550
|
+
if 'force_cpu' in detector_options:
|
|
551
|
+
force_cpu = parse_bool_string(detector_options['force_cpu'])
|
|
552
|
+
if 'use_model_native_classes' in detector_options:
|
|
553
|
+
use_model_native_classes = parse_bool_string(detector_options['use_model_native_classes'])
|
|
554
|
+
if 'compatibility_mode' in detector_options:
|
|
555
|
+
if detector_options['compatibility_mode'] is None:
|
|
556
|
+
compatibility_mode = default_compatibility_mode
|
|
557
|
+
else:
|
|
558
|
+
compatibility_mode = detector_options['compatibility_mode']
|
|
559
|
+
|
|
560
|
+
if require_non_default_compatibility_mode:
|
|
561
|
+
|
|
562
|
+
print('### DEBUG: requiring non-default compatibility mode ###')
|
|
563
|
+
assert compatibility_mode != 'classic'
|
|
564
|
+
assert compatibility_mode != 'default'
|
|
565
|
+
|
|
566
|
+
preprocess_only = False
|
|
567
|
+
if (detector_options is not None) and \
|
|
568
|
+
('preprocess_only' in detector_options) and \
|
|
569
|
+
(detector_options['preprocess_only']):
|
|
570
|
+
preprocess_only = True
|
|
571
|
+
|
|
572
|
+
if verbose or (not preprocess_only):
|
|
573
|
+
print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
|
|
574
|
+
|
|
575
|
+
model_metadata = read_metadata_from_megadetector_model_file(model_path)
|
|
121
576
|
|
|
577
|
+
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
|
|
578
|
+
#: aspect ratio".
|
|
579
|
+
if model_metadata is not None and 'image_size' in model_metadata:
|
|
580
|
+
self.default_image_size = model_metadata['image_size']
|
|
581
|
+
if verbose:
|
|
582
|
+
print('Loaded image size {} from model metadata'.format(self.default_image_size))
|
|
583
|
+
else:
|
|
584
|
+
self.default_image_size = 1280
|
|
585
|
+
|
|
586
|
+
#: Either a string ('cpu','cuda:0') or a torch.device()
|
|
122
587
|
self.device = 'cpu'
|
|
588
|
+
|
|
589
|
+
#: Have we already printed a warning about using a non-standard image size?
|
|
590
|
+
#:
|
|
591
|
+
#: :meta private:
|
|
592
|
+
self.printed_image_size_warning = False
|
|
593
|
+
|
|
594
|
+
#: If this is False, we assume the underlying model is producing class indices in the
|
|
595
|
+
#: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
|
|
596
|
+
#: MD classes (1,2,3) before generating output. If this is True, we use whatever
|
|
597
|
+
#: indices the model provides
|
|
598
|
+
self.use_model_native_classes = use_model_native_classes
|
|
599
|
+
|
|
600
|
+
#: This allows us to maintain backwards compatibility across a set of changes to the
|
|
601
|
+
#: way this class does inference. Currently should start with either "default" or
|
|
602
|
+
#: "classic".
|
|
603
|
+
self.compatibility_mode = compatibility_mode
|
|
604
|
+
|
|
605
|
+
#: Stride size passed to YOLOv5's letterbox() function
|
|
606
|
+
self.letterbox_stride = 32
|
|
607
|
+
|
|
608
|
+
if 'classic' in self.compatibility_mode:
|
|
609
|
+
self.letterbox_stride = 64
|
|
610
|
+
|
|
611
|
+
#: Use half-precision inference... fixed by the model, generally don't mess with this
|
|
612
|
+
self.half_precision = False
|
|
613
|
+
|
|
614
|
+
if preprocess_only:
|
|
615
|
+
return
|
|
616
|
+
|
|
123
617
|
if not force_cpu:
|
|
124
618
|
if torch.cuda.is_available():
|
|
125
619
|
self.device = torch.device('cuda:0')
|
|
@@ -129,41 +623,53 @@ class PTDetector:
|
|
|
129
623
|
except AttributeError:
|
|
130
624
|
pass
|
|
131
625
|
try:
|
|
132
|
-
self.model = PTDetector._load_model(model_path,
|
|
626
|
+
self.model = PTDetector._load_model(model_path,
|
|
627
|
+
device=self.device,
|
|
628
|
+
compatibility_mode=self.compatibility_mode)
|
|
629
|
+
|
|
133
630
|
except Exception as e:
|
|
134
631
|
# In a very esoteric scenario where an old version of YOLOv5 is used to run
|
|
135
632
|
# newer models, we run into an issue because the "Model" class became
|
|
136
633
|
# "DetectionModel". New YOLOv5 code handles this case by just setting them
|
|
137
|
-
# to be the same, so doing that
|
|
634
|
+
# to be the same, so doing that externally doesn't seem *that* rude.
|
|
138
635
|
if "Can't get attribute 'DetectionModel'" in str(e):
|
|
139
636
|
print('Forward-compatibility issue detected, patching')
|
|
140
637
|
from models import yolo
|
|
141
|
-
yolo.DetectionModel = yolo.Model
|
|
142
|
-
self.model = PTDetector._load_model(model_path,
|
|
638
|
+
yolo.DetectionModel = yolo.Model
|
|
639
|
+
self.model = PTDetector._load_model(model_path,
|
|
640
|
+
device=self.device,
|
|
641
|
+
compatibility_mode=self.compatibility_mode,
|
|
642
|
+
verbose=verbose)
|
|
143
643
|
else:
|
|
144
644
|
raise
|
|
145
645
|
if (self.device != 'cpu'):
|
|
146
|
-
|
|
646
|
+
if verbose:
|
|
647
|
+
print('Sending model to GPU')
|
|
147
648
|
self.model.to(self.device)
|
|
148
|
-
|
|
149
|
-
self.printed_image_size_warning = False
|
|
150
|
-
self.use_model_native_classes = use_model_native_classes
|
|
151
|
-
|
|
649
|
+
|
|
152
650
|
|
|
153
651
|
@staticmethod
|
|
154
|
-
def _load_model(model_pt_path, device):
|
|
652
|
+
def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
|
|
155
653
|
|
|
654
|
+
if verbose:
|
|
655
|
+
print(f'Using PyTorch version {torch.__version__}')
|
|
656
|
+
|
|
156
657
|
# There are two very slightly different ways to load the model, (1) using the
|
|
157
658
|
# map_location=device parameter to torch.load and (2) calling .to(device) after
|
|
158
659
|
# loading the model. The former is what we did for a zillion years, but is not
|
|
159
660
|
# supported on Apple silicon at of 2029.09. Switching to the latter causes
|
|
160
661
|
# very slight changes to the output, which always make me nervous, so I'm not
|
|
161
662
|
# doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
|
|
162
|
-
|
|
663
|
+
if 'classic' in compatibility_mode:
|
|
664
|
+
use_map_location = (device != 'mps')
|
|
665
|
+
else:
|
|
666
|
+
use_map_location = False
|
|
163
667
|
|
|
164
668
|
if use_map_location:
|
|
165
669
|
try:
|
|
166
670
|
checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
|
|
671
|
+
# For a transitional period, we want to support torch 1.1x, where the weights_only
|
|
672
|
+
# parameter doesn't exist
|
|
167
673
|
except Exception as e:
|
|
168
674
|
if "'weights_only' is an invalid keyword" in str(e):
|
|
169
675
|
checkpoint = torch.load(model_pt_path, map_location=device)
|
|
@@ -172,6 +678,8 @@ class PTDetector:
|
|
|
172
678
|
else:
|
|
173
679
|
try:
|
|
174
680
|
checkpoint = torch.load(model_pt_path, weights_only=False)
|
|
681
|
+
# For a transitional period, we want to support torch 1.1x, where the weights_only
|
|
682
|
+
# parameter doesn't exist
|
|
175
683
|
except Exception as e:
|
|
176
684
|
if "'weights_only' is an invalid keyword" in str(e):
|
|
177
685
|
checkpoint = torch.load(model_pt_path)
|
|
@@ -185,26 +693,31 @@ class PTDetector:
|
|
|
185
693
|
if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
|
186
694
|
m.recompute_scale_factor = None
|
|
187
695
|
|
|
188
|
-
if use_map_location:
|
|
189
|
-
model = checkpoint['model'].float().fuse().eval()
|
|
696
|
+
if use_map_location:
|
|
697
|
+
model = checkpoint['model'].float().fuse().eval()
|
|
190
698
|
else:
|
|
191
|
-
model = checkpoint['model'].float().fuse().eval().to(device)
|
|
699
|
+
model = checkpoint['model'].float().fuse().eval().to(device)
|
|
192
700
|
|
|
193
701
|
return model
|
|
194
702
|
|
|
703
|
+
# ...def _load_model(...)
|
|
704
|
+
|
|
705
|
+
|
|
195
706
|
def generate_detections_one_image(self,
|
|
196
707
|
img_original,
|
|
197
708
|
image_id='unknown',
|
|
198
709
|
detection_threshold=0.00001,
|
|
199
710
|
image_size=None,
|
|
200
711
|
skip_image_resizing=False,
|
|
201
|
-
augment=False
|
|
712
|
+
augment=False,
|
|
713
|
+
preprocess_only=False,
|
|
714
|
+
verbose=False):
|
|
202
715
|
"""
|
|
203
716
|
Applies the detector to an image.
|
|
204
717
|
|
|
205
718
|
Args:
|
|
206
719
|
img_original (Image): the PIL Image object (or numpy array) on which we should run the
|
|
207
|
-
detector, with EXIF rotation already handled
|
|
720
|
+
detector, with EXIF rotation already handled
|
|
208
721
|
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
209
722
|
of the output object
|
|
210
723
|
detection_threshold (float, optional): only detections above this confidence threshold
|
|
@@ -212,8 +725,11 @@ class PTDetector:
|
|
|
212
725
|
image_size (tuple, optional): image size to use for inference, only mess with this if
|
|
213
726
|
(a) you're using a model other than MegaDetector or (b) you know what you're getting into
|
|
214
727
|
skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
|
|
215
|
-
external resizing)
|
|
728
|
+
external resizing), only mess with this if (a) you're using a model other than MegaDetector
|
|
729
|
+
or (b) you know what you're getting into
|
|
216
730
|
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
731
|
+
preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
|
|
732
|
+
verbose (bool, optional): enable additional debug output
|
|
217
733
|
|
|
218
734
|
Returns:
|
|
219
735
|
dict: a dictionary with the following fields:
|
|
@@ -227,111 +743,275 @@ class PTDetector:
|
|
|
227
743
|
detections = []
|
|
228
744
|
max_conf = 0.0
|
|
229
745
|
|
|
746
|
+
if preprocess_only:
|
|
747
|
+
assert 'classic' in self.compatibility_mode, \
|
|
748
|
+
'Standalone preprocessing only supported in "classic" mode'
|
|
749
|
+
assert not skip_image_resizing, \
|
|
750
|
+
'skip_image_resizing and preprocess_only are exclusive'
|
|
751
|
+
|
|
230
752
|
if detection_threshold is None:
|
|
231
753
|
|
|
232
754
|
detection_threshold = 0
|
|
233
755
|
|
|
234
756
|
try:
|
|
235
757
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
# Padded resize
|
|
240
|
-
target_size = PTDetector.IMAGE_SIZE
|
|
241
|
-
|
|
242
|
-
# Image size can be an int (which translates to a square target size) or (h,w)
|
|
243
|
-
if image_size is not None:
|
|
244
|
-
|
|
245
|
-
assert isinstance(image_size,int) or (len(image_size)==2)
|
|
758
|
+
# If the caller wants us to skip all the resizing operations...
|
|
759
|
+
if skip_image_resizing:
|
|
246
760
|
|
|
247
|
-
if
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
761
|
+
if isinstance(img_original,dict):
|
|
762
|
+
image_info = img_original
|
|
763
|
+
img = image_info['img_processed']
|
|
764
|
+
scaling_shape = image_info['scaling_shape']
|
|
765
|
+
letterbox_pad = image_info['letterbox_pad']
|
|
766
|
+
letterbox_ratio = image_info['letterbox_ratio']
|
|
767
|
+
img_original = image_info['img_original']
|
|
768
|
+
img_original_pil = image_info['img_original_pil']
|
|
769
|
+
else:
|
|
770
|
+
img = img_original
|
|
252
771
|
|
|
253
772
|
else:
|
|
254
773
|
|
|
255
|
-
|
|
774
|
+
img_original_pil = None
|
|
775
|
+
# If we were given a PIL image
|
|
776
|
+
|
|
777
|
+
if not isinstance(img_original,np.ndarray):
|
|
778
|
+
img_original_pil = img_original
|
|
779
|
+
img_original = np.asarray(img_original)
|
|
780
|
+
|
|
781
|
+
# PIL images are RGB already
|
|
782
|
+
# img_original = img_original[:, :, ::-1]
|
|
256
783
|
|
|
257
|
-
|
|
784
|
+
# Save the original shape for scaling boxes later
|
|
785
|
+
scaling_shape = img_original.shape
|
|
786
|
+
|
|
787
|
+
# If the caller is requesting a specific target size...
|
|
788
|
+
if image_size is not None:
|
|
789
|
+
|
|
790
|
+
assert isinstance(image_size,int)
|
|
791
|
+
|
|
792
|
+
if not self.printed_image_size_warning:
|
|
793
|
+
print('Using user-supplied image size {}'.format(image_size))
|
|
794
|
+
self.printed_image_size_warning = True
|
|
795
|
+
|
|
796
|
+
# Otherwise resize to self.default_image_size
|
|
797
|
+
else:
|
|
798
|
+
|
|
799
|
+
image_size = self.default_image_size
|
|
800
|
+
self.printed_image_size_warning = False
|
|
801
|
+
|
|
802
|
+
# ...if the caller has specified an image size
|
|
803
|
+
|
|
804
|
+
# In "classic mode", we only do the letterboxing resize, we don't do an
|
|
805
|
+
# additional initial resizing operation
|
|
806
|
+
if 'classic' in self.compatibility_mode:
|
|
807
|
+
|
|
808
|
+
resize_ratio = 1.0
|
|
809
|
+
|
|
810
|
+
# Resize the image so the long side matches the target image size. This is not
|
|
811
|
+
# letterboxing (i.e., padding) yet, just resizing.
|
|
812
|
+
else:
|
|
813
|
+
|
|
814
|
+
use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
|
|
815
|
+
|
|
816
|
+
h,w = img_original.shape[:2]
|
|
817
|
+
resize_ratio = image_size / max(h,w)
|
|
818
|
+
|
|
819
|
+
# Only resize if we have to
|
|
820
|
+
if resize_ratio != 1:
|
|
821
|
+
|
|
822
|
+
# Match what yolov5 does: use linear interpolation for upsizing;
|
|
823
|
+
# area interpolation for downsizing
|
|
824
|
+
if resize_ratio > 1:
|
|
825
|
+
interpolation_method = cv2.INTER_LINEAR
|
|
826
|
+
else:
|
|
827
|
+
interpolation_method = cv2.INTER_AREA
|
|
828
|
+
|
|
829
|
+
if use_ceil_for_resize:
|
|
830
|
+
target_w = math.ceil(w * resize_ratio)
|
|
831
|
+
target_h = math.ceil(h * resize_ratio)
|
|
832
|
+
else:
|
|
833
|
+
target_w = int(w * resize_ratio)
|
|
834
|
+
target_h = int(h * resize_ratio)
|
|
835
|
+
|
|
836
|
+
img_original = cv2.resize(
|
|
837
|
+
img_original, (target_w, target_h),
|
|
838
|
+
interpolation=interpolation_method)
|
|
839
|
+
|
|
840
|
+
if 'classic' in self.compatibility_mode:
|
|
841
|
+
|
|
842
|
+
letterbox_auto = True
|
|
843
|
+
letterbox_scaleup = True
|
|
844
|
+
target_shape = image_size
|
|
845
|
+
|
|
846
|
+
else:
|
|
847
|
+
|
|
848
|
+
letterbox_auto = False
|
|
849
|
+
letterbox_scaleup = False
|
|
850
|
+
|
|
851
|
+
# The padding to apply as a fraction of the stride size
|
|
852
|
+
pad = 0.5
|
|
853
|
+
|
|
854
|
+
model_stride = int(self.model.stride.max())
|
|
855
|
+
|
|
856
|
+
max_dimension = max(img_original.shape)
|
|
857
|
+
normalized_shape = [img_original.shape[0] / max_dimension,
|
|
858
|
+
img_original.shape[1] / max_dimension]
|
|
859
|
+
target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
|
|
860
|
+
pad).astype(int) * model_stride
|
|
861
|
+
|
|
862
|
+
# Now we letterbox, which is just padding, since we've already resized.
|
|
863
|
+
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
|
|
864
|
+
new_shape=target_shape,
|
|
865
|
+
stride=self.letterbox_stride,
|
|
866
|
+
auto=letterbox_auto,
|
|
867
|
+
scaleFill=False,
|
|
868
|
+
scaleup=letterbox_scaleup)
|
|
258
869
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
870
|
+
if preprocess_only:
|
|
871
|
+
|
|
872
|
+
assert 'file' in result
|
|
873
|
+
result['img_processed'] = img
|
|
874
|
+
result['img_original'] = img_original
|
|
875
|
+
result['img_original_pil'] = img_original_pil
|
|
876
|
+
result['target_shape'] = target_shape
|
|
877
|
+
result['scaling_shape'] = scaling_shape
|
|
878
|
+
result['letterbox_ratio'] = letterbox_ratio
|
|
879
|
+
result['letterbox_pad'] = letterbox_pad
|
|
880
|
+
return result
|
|
881
|
+
|
|
882
|
+
# ...are we doing resizing here, or were images already resized?
|
|
883
|
+
|
|
884
|
+
# Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
|
|
885
|
+
# so we don't need to mess with the color channels.
|
|
886
|
+
#
|
|
887
|
+
# TODO, this could be moved into the preprocessing loop
|
|
888
|
+
|
|
889
|
+
img = img.transpose((2, 0, 1)) # [::-1]
|
|
270
890
|
img = np.ascontiguousarray(img)
|
|
271
891
|
img = torch.from_numpy(img)
|
|
272
892
|
img = img.to(self.device)
|
|
273
|
-
img = img.float()
|
|
893
|
+
img = img.half() if self.half_precision else img.float()
|
|
274
894
|
img /= 255
|
|
275
895
|
|
|
276
896
|
# In practice this is always true
|
|
277
897
|
if len(img.shape) == 3:
|
|
278
898
|
img = torch.unsqueeze(img, 0)
|
|
279
899
|
|
|
900
|
+
# Run the model
|
|
280
901
|
pred = self.model(img,augment=augment)[0]
|
|
281
902
|
|
|
282
|
-
|
|
903
|
+
if 'classic' in self.compatibility_mode:
|
|
904
|
+
nms_conf_thres = detection_threshold
|
|
905
|
+
nms_iou_thres = 0.45
|
|
906
|
+
nms_agnostic = False
|
|
907
|
+
nms_multi_label = False
|
|
908
|
+
else:
|
|
909
|
+
nms_conf_thres = detection_threshold # 0.01
|
|
910
|
+
nms_iou_thres = 0.6
|
|
911
|
+
nms_agnostic = False
|
|
912
|
+
nms_multi_label = True
|
|
913
|
+
|
|
914
|
+
# As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
|
|
915
|
+
#
|
|
916
|
+
# Send predictions back to the CPU for NMS.
|
|
283
917
|
if self.device == 'mps':
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
918
|
+
pred_nms = pred.cpu()
|
|
919
|
+
else:
|
|
920
|
+
pred_nms = pred
|
|
921
|
+
|
|
922
|
+
# NMS
|
|
923
|
+
pred = non_max_suppression(prediction=pred_nms,
|
|
924
|
+
conf_thres=nms_conf_thres,
|
|
925
|
+
iou_thres=nms_iou_thres,
|
|
926
|
+
agnostic=nms_agnostic,
|
|
927
|
+
multi_label=nms_multi_label)
|
|
290
928
|
|
|
291
|
-
#
|
|
292
|
-
|
|
293
|
-
# normalization gain whwh
|
|
294
|
-
gn = torch.tensor(img_original.shape)[[1, 0, 1, 0]]
|
|
929
|
+
# In practice this is [w,h,w,h] of the original image
|
|
930
|
+
gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
|
|
295
931
|
|
|
932
|
+
if 'classic' in self.compatibility_mode:
|
|
933
|
+
|
|
934
|
+
ratio = None
|
|
935
|
+
ratio_pad = None
|
|
936
|
+
|
|
937
|
+
else:
|
|
938
|
+
|
|
939
|
+
# letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
|
|
940
|
+
#
|
|
941
|
+
# ratio is a 2-tuple specifying the scaling that was applied to each dimension.
|
|
942
|
+
#
|
|
943
|
+
# The scale_boxes function expects a 2-tuple with these things combined.
|
|
944
|
+
ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
|
|
945
|
+
ratio_pad = (ratio, letterbox_pad)
|
|
946
|
+
|
|
296
947
|
# This is a loop over detection batches, which will always be length 1 in our case,
|
|
297
948
|
# since we're not doing batch inference.
|
|
949
|
+
#
|
|
950
|
+
# det = pred[0]
|
|
951
|
+
#
|
|
952
|
+
# det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
|
|
953
|
+
# in descending order by confidence.
|
|
954
|
+
#
|
|
955
|
+
# Columns are:
|
|
956
|
+
#
|
|
957
|
+
# x0,y0,x1,y1,confidence,class
|
|
958
|
+
#
|
|
959
|
+
# At this point, these are *non*-normalized values, referring to the size at which we
|
|
960
|
+
# ran inference (img.shape).
|
|
298
961
|
for det in pred:
|
|
299
962
|
|
|
300
|
-
if len(det):
|
|
963
|
+
if len(det) == 0:
|
|
964
|
+
continue
|
|
965
|
+
|
|
966
|
+
# Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
|
|
967
|
+
if 'classic' in self.compatibility_mode:
|
|
301
968
|
|
|
302
|
-
# Rescale boxes from img_size to im0 size
|
|
303
969
|
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
|
|
970
|
+
|
|
971
|
+
else:
|
|
972
|
+
# After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
|
|
973
|
+
# original pixel dimension of the image, followed by the class and confidence
|
|
974
|
+
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
|
|
304
975
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
976
|
+
# Loop over detections
|
|
977
|
+
for *xyxy, conf, cls in reversed(det):
|
|
978
|
+
|
|
979
|
+
if conf < detection_threshold:
|
|
980
|
+
continue
|
|
981
|
+
|
|
982
|
+
# Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
|
|
983
|
+
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
|
|
309
984
|
|
|
310
|
-
|
|
985
|
+
# Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
|
|
986
|
+
# left/top/w/h (i.e., MD format)
|
|
987
|
+
api_box = ct_utils.convert_yolo_to_xywh(xywh)
|
|
311
988
|
|
|
989
|
+
if 'classic' in self.compatibility_mode:
|
|
990
|
+
api_box = ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
|
|
312
991
|
conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
|
|
992
|
+
else:
|
|
993
|
+
api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
|
|
994
|
+
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
|
|
995
|
+
|
|
996
|
+
if not self.use_model_native_classes:
|
|
997
|
+
# The MegaDetector output format's categories start at 1, but all YOLO-based
|
|
998
|
+
# MD models have category numbers starting at 0.
|
|
999
|
+
cls = int(cls.tolist()) + 1
|
|
1000
|
+
if cls not in (1, 2, 3):
|
|
1001
|
+
raise KeyError(f'{cls} is not a valid class.')
|
|
1002
|
+
else:
|
|
1003
|
+
cls = int(cls.tolist())
|
|
313
1004
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
'category': str(cls),
|
|
325
|
-
'conf': conf,
|
|
326
|
-
'bbox': ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
|
|
327
|
-
})
|
|
328
|
-
max_conf = max(max_conf, conf)
|
|
329
|
-
|
|
330
|
-
# ...for each detection in this batch
|
|
331
|
-
|
|
332
|
-
# ...if this is a non-empty batch
|
|
333
|
-
|
|
334
|
-
# ...for each detection batch
|
|
1005
|
+
detections.append({
|
|
1006
|
+
'category': str(cls),
|
|
1007
|
+
'conf': conf,
|
|
1008
|
+
'bbox': api_box
|
|
1009
|
+
})
|
|
1010
|
+
max_conf = max(max_conf, conf)
|
|
1011
|
+
|
|
1012
|
+
# ...for each detection in this batch
|
|
1013
|
+
|
|
1014
|
+
# ...for each detection batch (always one iteration)
|
|
335
1015
|
|
|
336
1016
|
# ...try
|
|
337
1017
|
|
|
@@ -339,7 +1019,8 @@ class PTDetector:
|
|
|
339
1019
|
|
|
340
1020
|
result['failure'] = FAILURE_INFER
|
|
341
1021
|
print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
|
|
342
|
-
traceback.print_exc(e)
|
|
1022
|
+
# traceback.print_exc(e)
|
|
1023
|
+
print(traceback.format_exc())
|
|
343
1024
|
|
|
344
1025
|
result['max_detection_conf'] = max_conf
|
|
345
1026
|
result['detections'] = detections
|
|
@@ -361,7 +1042,7 @@ if __name__ == '__main__':
|
|
|
361
1042
|
|
|
362
1043
|
#%%
|
|
363
1044
|
|
|
364
|
-
import os
|
|
1045
|
+
import os #noqa
|
|
365
1046
|
from megadetector.visualization import visualization_utils as vis_utils
|
|
366
1047
|
|
|
367
1048
|
model_file = os.environ['MDV5A']
|