megadetector 5.0.22__py3-none-any.whl → 5.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (38) hide show
  1. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +2 -3
  2. megadetector/classification/merge_classification_detection_output.py +2 -2
  3. megadetector/data_management/coco_to_labelme.py +2 -1
  4. megadetector/data_management/databases/integrity_check_json_db.py +15 -14
  5. megadetector/data_management/databases/subset_json_db.py +49 -21
  6. megadetector/data_management/mewc_to_md.py +340 -0
  7. megadetector/data_management/wi_to_md.py +41 -0
  8. megadetector/data_management/yolo_output_to_md_output.py +15 -8
  9. megadetector/detection/process_video.py +24 -7
  10. megadetector/detection/pytorch_detector.py +841 -160
  11. megadetector/detection/run_detector.py +340 -146
  12. megadetector/detection/run_detector_batch.py +306 -70
  13. megadetector/detection/run_inference_with_yolov5_val.py +61 -4
  14. megadetector/detection/tf_detector.py +6 -1
  15. megadetector/postprocessing/{combine_api_outputs.py → combine_batch_outputs.py} +10 -13
  16. megadetector/postprocessing/compare_batch_results.py +68 -6
  17. megadetector/postprocessing/md_to_labelme.py +7 -7
  18. megadetector/postprocessing/md_to_wi.py +40 -0
  19. megadetector/postprocessing/merge_detections.py +1 -1
  20. megadetector/postprocessing/postprocess_batch_results.py +10 -3
  21. megadetector/postprocessing/separate_detections_into_folders.py +32 -4
  22. megadetector/postprocessing/validate_batch_results.py +9 -4
  23. megadetector/utils/ct_utils.py +172 -57
  24. megadetector/utils/gpu_test.py +107 -0
  25. megadetector/utils/md_tests.py +363 -108
  26. megadetector/utils/path_utils.py +9 -2
  27. megadetector/utils/wi_utils.py +1794 -0
  28. megadetector/visualization/visualization_utils.py +82 -16
  29. megadetector/visualization/visualize_db.py +25 -7
  30. megadetector/visualization/visualize_detector_output.py +60 -13
  31. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/LICENSE +0 -0
  32. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/METADATA +129 -143
  33. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/RECORD +35 -33
  34. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/top_level.txt +0 -0
  35. megadetector/detection/detector_training/__init__.py +0 -0
  36. megadetector/detection/detector_training/model_main_tf2.py +0 -114
  37. megadetector/utils/torch_test.py +0 -32
  38. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/WHEEL +0 -0
@@ -2,17 +2,31 @@
2
2
 
3
3
  pytorch_detector.py
4
4
 
5
- Module to run MegaDetector v5.
5
+ Module to run YOLO-based MegaDetector models.
6
6
 
7
7
  """
8
8
 
9
9
  #%% Imports and constants
10
10
 
11
+ import os
12
+ import sys
13
+ import math
14
+ import zipfile
15
+ import tempfile
16
+ import shutil
17
+ import traceback
18
+ import uuid
19
+ import json
20
+
21
+ import cv2
11
22
  import torch
12
23
  import numpy as np
13
- import traceback
14
24
 
15
- from megadetector.detection.run_detector import CONF_DIGITS, COORD_DIGITS, FAILURE_INFER
25
+ from megadetector.detection.run_detector import \
26
+ CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, \
27
+ get_detector_version_from_model_file, \
28
+ known_models
29
+ from megadetector.utils.ct_utils import parse_bool_string
16
30
  from megadetector.utils import ct_utils
17
31
 
18
32
  # We support a few ways of accessing the YOLOv5 dependencies:
@@ -31,95 +45,575 @@ from megadetector.utils import ct_utils
31
45
  # * Unfinished:
32
46
  #
33
47
  # pip install ultralytics
34
- #
35
- # If try_ultralytics_import is True, we'll try to import all YOLOv5 dependencies from
36
- # ultralytics.utils and ultralytics.data. But as of 2023.11, this results in a "No
37
- # module named 'models'" error when running MDv5, and there's no upside to this approach
38
- # compared to using either of the YOLOv5 PyPI packages, so... punting on this for now.
39
48
 
40
- utils_imported = False
41
- try_yolov5_import = True
49
+ yolo_model_type_imported = None
50
+
51
+ def _get_model_type_for_model(model_file,
52
+ prefer_model_type_source='table',
53
+ default_model_type='yolov5',
54
+ verbose=False):
55
+ """
56
+ Determine the model type (i.e., the inference library we need to use) for a .pt file.
57
+
58
+ Args:
59
+ model_file (str): the model file to read
60
+ prefer_model_type_source (str, optional): how should we handle the (very unlikely)
61
+ case where the metadata in the file indicates one model type, but the global model
62
+ type table says something else. Should be "table" (trust the table) or "file"
63
+ (trust the file).
64
+ default_model_type (str, optional): return value for the case where we can't find
65
+ appropriate metadata in the file or in the global table.
66
+ verbose (bool, optional): enable additional debug output
67
+
68
+ Returns:
69
+ str: the model type indicated for this model
70
+ """
71
+
72
+ model_info = read_metadata_from_megadetector_model_file(model_file)
73
+
74
+ # Check whether the model file itself specified a model type
75
+ model_type_from_model_file_metadata = None
76
+
77
+ if model_info is not None and 'model_type' in model_info:
78
+ model_type_from_model_file_metadata = model_info['model_type']
79
+ if verbose:
80
+ print('Parsed model type {} from model {}'.format(
81
+ model_type_from_model_file_metadata,
82
+ model_file))
83
+
84
+ model_type_from_model_version = None
85
+
86
+ # Check whether this is a known model version with a specific model type
87
+ model_version_from_file = get_detector_version_from_model_file(model_file)
88
+
89
+ if model_version_from_file is not None and model_version_from_file in known_models:
90
+ model_info = known_models[model_version_from_file]
91
+ if 'model_type' in model_info:
92
+ model_type_from_model_version = model_info['model_type']
93
+ if verbose:
94
+ print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
95
+ else:
96
+ model_type_from_model_version = None
97
+
98
+ if model_type_from_model_file_metadata is None and \
99
+ model_type_from_model_version is None:
100
+ if verbose:
101
+ print('Could not determine model type for {}, assuming {}'.format(
102
+ model_file,default_model_type))
103
+ model_type = default_model_type
104
+
105
+ elif model_type_from_model_file_metadata is not None and \
106
+ model_type_from_model_version is not None:
107
+ if model_type_from_model_version == model_type_from_model_file_metadata:
108
+ model_type = model_type_from_model_file_metadata
109
+ else:
110
+ print('Waring: model type from model version is {}, from file metadata is {}'.format(
111
+ model_type_from_model_version,model_type_from_model_file_metadata))
112
+ if prefer_model_type_source == 'table':
113
+ model_type = model_type_from_model_file_metadata
114
+ else:
115
+ model_type = model_type_from_model_version
116
+
117
+ elif model_type_from_model_file_metadata is not None:
118
+
119
+ model_type = model_type_from_model_file_metadata
120
+
121
+ elif model_type_from_model_version is not None:
122
+
123
+ model_type = model_type_from_model_version
124
+
125
+ return model_type
126
+
127
+ # ...def _get_model_type_for_model(...)
42
128
 
43
- # See above; this should remain as "False" unless we update the MegaDetector .pt file
44
- # to use more recent YOLOv5 namespace conventions.
45
- try_ultralytics_import = False
46
129
 
47
- # First try importing from the yolov5 package; this is how the pip
48
- # package finds YOLOv5 utilities.
49
- if try_yolov5_import and not utils_imported:
130
+ def _initialize_yolo_imports_for_model(model_file,
131
+ prefer_model_type_source='table',
132
+ default_model_type='yolov5',
133
+ detector_options=None,
134
+ verbose=False):
135
+ """
136
+ Initialize the appropriate YOLO imports for a model file.
50
137
 
51
- try:
52
- from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
53
- from yolov5.utils.augmentations import letterbox # noqa
54
- from yolov5.utils.general import scale_boxes as scale_coords # noqa
55
- utils_imported = True
56
- print('Imported YOLOv5 from YOLOv5 package')
57
- except Exception:
58
- # print('YOLOv5 module import failed, falling back to path-based import')
59
- pass
60
-
61
- # If we haven't succeeded yet, import from the ultralytics package
62
- if try_ultralytics_import and not utils_imported:
138
+ Args:
139
+ model_file (str): The model file for which we're loading support
140
+ prefer_model_type_source (str, optional): how should we handle the (very unlikely)
141
+ case where the metadata in the file indicates one model type, but the global model
142
+ type table says something else. Should be "table" (trust the table) or "file"
143
+ (trust the file).
144
+ default_model_type (str, optional): return value for the case where we can't find
145
+ appropriate metadata in the file or in the global table.
146
+ detector_options (dict, optional): dictionary of detector options that mean
147
+ different things to different models
148
+ verbose (bool, optional): enable additonal debug output
149
+
150
+ Returns:
151
+ str: the model type for which we initialized support
152
+ """
63
153
 
64
- try:
65
- from ultralytics.utils.ops import non_max_suppression # noqa
66
- from ultralytics.utils.ops import xyxy2xywh # noqa
67
- from ultralytics.utils.ops import scale_coords # noqa
68
- from ultralytics.data.augment import LetterBox
69
-
70
- # letterbox() became a LetterBox class in the ultralytics package
71
- def letterbox(img,new_shape,stride,auto=True): # noqa
72
- L = LetterBox(new_shape,stride=stride,auto=auto)
73
- letterbox_result = L(image=img)
74
- return [letterbox_result]
75
- utils_imported = True
76
- print('Imported YOLOv5 from ultralytics package')
77
- except Exception:
78
- # print('Ultralytics module import failed, falling back to yolov5 import')
79
- pass
80
-
81
- # If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
82
- if not utils_imported:
83
154
 
84
- try:
85
- # import pre- and post-processing functions from the YOLOv5 repo
86
- from utils.general import non_max_suppression, xyxy2xywh # noqa
87
- from utils.augmentations import letterbox # noqa
155
+ global yolo_model_type_imported
156
+
157
+ if detector_options is not None and 'model_type' in detector_options:
158
+ model_type = detector_options['model_type']
159
+ print('Model type {} provided in detector options'.format(model_type))
160
+ else:
161
+ model_type = _get_model_type_for_model(model_file,
162
+ prefer_model_type_source=prefer_model_type_source,
163
+ default_model_type=default_model_type)
164
+
165
+ if yolo_model_type_imported is not None:
166
+ if model_type == yolo_model_type_imported:
167
+ print('Bypassing imports for model type {}'.format(model_type))
168
+ return
169
+ else:
170
+ print('Previously set up imports for model type {}, re-importing as {}'.format(
171
+ yolo_model_type_imported,model_type))
172
+
173
+ _initialize_yolo_imports(model_type,verbose=verbose)
174
+
175
+ return model_type
176
+
177
+
178
+ def _clean_yolo_imports(verbose=False):
179
+ """
180
+ Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
181
+ of another YOLO library version. The reason we jump through all these hoops, rather than
182
+ just, e.g., handling different libraries in different modules, is that we need to make sure
183
+ *pickle* sees the right version of modules during module loading, including modules we don't
184
+ load directly (i.e., every module loaded within a YOLO library), and the only way I know to
185
+ do that is to remove all the "wrong" versions from sys.modules and sys.path.
186
+
187
+ Args:
188
+ verbose (bool, optional): enable additional debug output
189
+ """
190
+
191
+ modules_to_delete = []
192
+ for module_name in sys.modules.keys():
193
+ module = sys.modules[module_name]
194
+ try:
195
+ module_file = module.__file__.replace('\\','/')
196
+ if 'site-packages' not in module_file:
197
+ continue
198
+ tokens = module_file.split('/')[-4:]
199
+ for token in tokens:
200
+ if 'yolov5' in token or 'yolov9' in token or 'ultralytics' in token:
201
+ modules_to_delete.append(module_name)
202
+ break
203
+ except Exception:
204
+ pass
205
+
206
+ for module_name in modules_to_delete:
207
+ if module_name in sys.modules.keys():
208
+ module_file = module.__file__.replace('\\','/')
209
+ if verbose:
210
+ print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
211
+ del sys.modules[module_name]
212
+
213
+ paths_to_delete = []
214
+
215
+ for p in sys.path:
216
+ if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
217
+ print('clean_yolo_imports: removing {} from path'.format(p))
218
+ paths_to_delete.append(p)
219
+
220
+ for p in paths_to_delete:
221
+ sys.path.remove(p)
222
+
223
+ # ...def _clean_yolo_imports(...)
224
+
225
+
226
+ def _initialize_yolo_imports(model_type='yolov5',
227
+ allow_fallback_import=True,
228
+ force_reimport=False,
229
+ verbose=False):
230
+ """
231
+ Imports required functions from one or more yolo libraries (yolov5, yolov9,
232
+ ultralytics, targeting support for [model_type]).
233
+
234
+ Args:
235
+ model_type (str): The model type for which we're loading support
236
+ allow_fallback_import (bool, optional): If we can't import from the package for
237
+ which we're trying to load support, fall back to "import utils". This is
238
+ typically used when the right support library is on the current PYTHONPATH.
239
+ force_reimport (bool, optional): import the appropriate libraries even if the
240
+ requested model type matches the current initialization state
241
+ verbose (bool, optional): include additonal debug output
242
+
243
+ Returns:
244
+ str: the model type for which we initialized support
245
+ """
246
+
247
+ global yolo_model_type_imported
248
+
249
+ if model_type is None:
250
+ model_type = 'yolov5'
251
+
252
+ # The point of this function is to make the appropriate version
253
+ # of the following functions available at module scope
254
+ global non_max_suppression
255
+ global xyxy2xywh
256
+ global letterbox
257
+ global scale_coords
258
+
259
+ if yolo_model_type_imported is not None:
260
+ if yolo_model_type_imported == model_type:
261
+ print('Bypassing imports for YOLO model type {}'.format(model_type))
262
+ return
263
+ else:
264
+ _clean_yolo_imports()
265
+
266
+ try_yolov5_import = (model_type == 'yolov5')
267
+ try_yolov9_import = (model_type == 'yolov9')
268
+ try_ultralytics_import = (model_type == 'ultralytics')
269
+
270
+ utils_imported = False
271
+
272
+ # First try importing from the yolov5 package; this is how the pip
273
+ # package finds YOLOv5 utilities.
274
+ if try_yolov5_import and not utils_imported:
275
+
276
+ try:
277
+ from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
278
+ from yolov5.utils.augmentations import letterbox # noqa
279
+ try:
280
+ from yolov5.utils.general import scale_boxes as scale_coords
281
+ except Exception:
282
+ from yolov5.utils.general import scale_coords
283
+ utils_imported = True
284
+ if verbose:
285
+ print('Imported utils from YOLOv5 package')
286
+
287
+ except Exception as e: # noqa
288
+
289
+ # print('yolov5 module import failed: {}'.format(e))
290
+ # print(traceback.format_exc())
291
+ pass
292
+
293
+ # Next try importing from the yolov9 package
294
+ if try_yolov9_import and not utils_imported:
295
+
296
+ try:
297
+
298
+ from yolov9.utils.general import non_max_suppression, xyxy2xywh # noqa
299
+ from yolov9.utils.augmentations import letterbox # noqa
300
+ from yolov9.utils.general import scale_boxes as scale_coords # noqa
301
+ utils_imported = True
302
+ if verbose:
303
+ print('Imported utils from YOLOv9 package')
304
+
305
+ except Exception as e: # noqa
306
+
307
+ # print('yolov9 module import failed: {}'.format(e))
308
+ # print(traceback.format_exc())
309
+ pass
310
+
311
+ # If we haven't succeeded yet, import from the ultralytics package
312
+ if try_ultralytics_import and not utils_imported:
88
313
 
89
- # scale_coords() became scale_boxes() in later YOLOv5 versions
90
314
  try:
91
- from utils.general import scale_coords # noqa
92
- except ImportError:
93
- from utils.general import scale_boxes as scale_coords
94
- utils_imported = True
95
- print('Imported YOLOv5 as utils.*')
315
+
316
+ import ultralytics # noqa
317
+
318
+ except Exception:
319
+
320
+ print('It looks like you are trying to run a model that requires the ultralytics package, '
321
+ 'but the ultralytics package is not installed, but . For licensing reasons, this '
322
+ 'is not installed by default with the MegaDetector Python package. Run '
323
+ '"pip install ultralytics" to install it, and try again.')
324
+ raise
325
+
326
+ try:
327
+
328
+ from ultralytics.utils.ops import non_max_suppression # noqa
329
+ from ultralytics.utils.ops import xyxy2xywh # noqa
330
+
331
+ # In the ultralytics package, scale_boxes and scale_coords both exist;
332
+ # we want scale_boxes.
333
+ #
334
+ # from ultralytics.utils.ops import scale_coords # noqa
335
+ from ultralytics.utils.ops import scale_boxes as scale_coords # noqa
336
+ from ultralytics.data.augment import LetterBox
337
+
338
+ # letterbox() became a LetterBox class in the ultralytics package. Create a
339
+ # backwards-compatible letterbox function wrapper that wraps the class up.
340
+ def letterbox(img,new_shape,auto=False,scaleFill=False,scaleup=True,center=True,stride=32): # noqa
341
+
342
+ L = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,scaleup=scaleup,center=center,stride=stride)
343
+ letterbox_result = L(image=img)
344
+
345
+ if isinstance(new_shape,int):
346
+ new_shape = [new_shape,new_shape]
347
+
348
+ # The letterboxing is done, we just need to reverse-engineer what it did
349
+ shape = img.shape[:2]
96
350
 
97
- except ModuleNotFoundError as e:
98
- raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
351
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
352
+ if not scaleup:
353
+ r = min(r, 1.0)
354
+ ratio = r, r
355
+
356
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
357
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
358
+ if auto:
359
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride)
360
+ elif scaleFill:
361
+ dw, dh = 0.0, 0.0
362
+ new_unpad = (new_shape[1], new_shape[0])
363
+ ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
364
+
365
+ dw /= 2
366
+ dh /= 2
367
+ pad = (dw,dh)
368
+
369
+ return [letterbox_result,ratio,pad]
370
+
371
+ utils_imported = True
372
+ if verbose:
373
+ print('Imported utils from ultralytics package')
374
+
375
+ except Exception:
376
+
377
+ # print('Ultralytics module import failed')
378
+ pass
379
+
380
+ # If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
381
+ if (not utils_imported) and allow_fallback_import:
382
+
383
+ try:
384
+
385
+ # import pre- and post-processing functions from the YOLOv5 repo
386
+ from utils.general import non_max_suppression, xyxy2xywh # noqa
387
+ from utils.augmentations import letterbox # noqa
388
+
389
+ # scale_coords() is scale_boxes() in some YOLOv5 versions
390
+ try:
391
+ from utils.general import scale_coords # noqa
392
+ except ImportError:
393
+ from utils.general import scale_boxes as scale_coords
394
+ utils_imported = True
395
+ imported_file = sys.modules[scale_coords.__module__].__file__
396
+ if verbose:
397
+ print('Imported utils from {}'.format(imported_file))
398
+
399
+ except ModuleNotFoundError as e:
400
+
401
+ raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
402
+
403
+ assert utils_imported, 'YOLO utils import error'
404
+
405
+ yolo_model_type_imported = model_type
406
+ if verbose:
407
+ print('Prepared YOLO imports for model type {}'.format(model_type))
408
+
409
+ return model_type
99
410
 
100
- assert utils_imported, 'YOLOv5 import error'
411
+ # ...def _initialize_yolo_imports(...)
412
+
101
413
 
102
- print(f'Using PyTorch version {torch.__version__}')
414
+ #%% Model metadata functions
103
415
 
416
+ def add_metadata_to_megadetector_model_file(model_file_in,
417
+ model_file_out,
418
+ metadata,
419
+ destination_path='megadetector_info.json'):
420
+ """
421
+ Adds a .json file to the specified MegaDetector model file containing metadata used
422
+ by this module. Always over-writes the output file.
104
423
 
105
- #%% Classes
424
+ Args:
425
+ model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
426
+ model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
427
+ May be the same as model_file_in.
428
+ metadata (dict): The metadata dict to add to the output model file
429
+ destination_path (str, optional): The relative path within the main folder of the
430
+ model archive where we should write the metadata. This is not relative to the root
431
+ of the archive, it's relative to the one and only folder at the root of the archive
432
+ (this is a PyTorch convention).
433
+ """
434
+
435
+ tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
436
+ os.makedirs(tmp_base,exist_ok=True)
437
+ metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
438
+ metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
106
439
 
107
- class PTDetector:
440
+ with open(metadata_tmp_file_abs,'w') as f:
441
+ json.dump(metadata,f,indent=1)
442
+
443
+ # Copy the input file to the output file
444
+ shutil.copyfile(model_file_in,model_file_out)
108
445
 
109
- #: Image size passed to YOLOv5's letterbox() function; 1280 means "1280 on the long side, preserving
110
- #: aspect ratio"
111
- #:
112
- #: :meta private:
113
- IMAGE_SIZE = 1280
446
+ # Write metadata to the output file
447
+ with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
448
+
449
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
450
+ # it in the one and only folder.
451
+ names = zipf.namelist()
452
+ root_folders = set()
453
+ for name in names:
454
+ root_folder = name.split('/')[0]
455
+ root_folders.add(root_folder)
456
+ assert len(root_folders) == 1,\
457
+ 'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
458
+ root_folder = next(iter(root_folders))
459
+
460
+ zipf.write(metadata_tmp_file_abs,
461
+ root_folder + '/' + destination_path,
462
+ compresslevel=9,
463
+ compress_type=zipfile.ZIP_DEFLATED)
464
+
465
+ try:
466
+ os.remove(metadata_tmp_file_abs)
467
+ except Exception as e:
468
+ print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
469
+
470
+ # ...def add_metadata_to_megadetector_model_file(...)
471
+
472
+
473
+ def read_metadata_from_megadetector_model_file(model_file,
474
+ relative_path='megadetector_info.json',
475
+ verbose=False):
476
+ """
477
+ Reads custom MegaDetector metadata from a modified MegaDetector model file.
114
478
 
115
- #: Stride size passed to YOLOv5's letterbox() function
116
- #:
117
- #: :meta private:
118
- STRIDE = 64
479
+ Args:
480
+ model_file (str): The model filename to read, typically .pt (.zip is also sensible)
481
+ relative_path (str, optional): The relative path within the main folder of the model
482
+ archive from which we should read the metadata. This is not relative to the root
483
+ of the archive, it's relative to the one and only folder at the root of the archive
484
+ (this is a PyTorch convention).
485
+
486
+ Returns:
487
+ object: Whatever we read from the metadata file, always a dict in practice. Returns
488
+ None if we failed to read the specified metadata file.
489
+ """
490
+
491
+ with zipfile.ZipFile(model_file,'r') as zipf:
492
+
493
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
494
+ # it in the one and only folder.
495
+ names = zipf.namelist()
496
+ root_folders = set()
497
+ for name in names:
498
+ root_folder = name.split('/')[0]
499
+ root_folders.add(root_folder)
500
+ if len(root_folders) != 1:
501
+ print('Warning: this archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?')
502
+ return None
503
+ root_folder = next(iter(root_folders))
504
+
505
+ metadata_file = root_folder + '/' + relative_path
506
+ if metadata_file not in names:
507
+ # This is the case for MDv5a and MDv5b
508
+ if verbose:
509
+ print('Warning: could not find metadata file {} in zip archive'.format(metadata_file))
510
+ return None
511
+
512
+ try:
513
+ path = zipfile.Path(zipf,metadata_file)
514
+ contents = path.read_text()
515
+ d = json.loads(contents)
516
+ except Exception as e:
517
+ print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
518
+ return None
519
+
520
+ return d
521
+
522
+ # ...def read_metadata_from_megadetector_model_file(...)
523
+
524
+
525
+ #%% Inference classes
526
+
527
+ default_compatibility_mode = 'classic'
528
+
529
+ # This is a useful hack when I want to verify that my test driver (md_tests.py) is
530
+ # correctly forcing a specific compabitility mode (I use "classic-test" in that case)
531
+ require_non_default_compatibility_mode = False
119
532
 
120
- def __init__(self, model_path, force_cpu=False, use_model_native_classes= False):
533
+ class PTDetector:
534
+
535
+ def __init__(self, model_path, detector_options=None, verbose=False):
536
+
537
+ # Set up the import environment for this model, unloading previous
538
+ # YOLO library versions if necessary.
539
+ _initialize_yolo_imports_for_model(model_path,
540
+ detector_options=detector_options,
541
+ verbose=verbose)
542
+
543
+ # Parse options specific to this detector family
544
+ force_cpu = False
545
+ use_model_native_classes = False
546
+ compatibility_mode = default_compatibility_mode
547
+
548
+ if detector_options is not None:
549
+
550
+ if 'force_cpu' in detector_options:
551
+ force_cpu = parse_bool_string(detector_options['force_cpu'])
552
+ if 'use_model_native_classes' in detector_options:
553
+ use_model_native_classes = parse_bool_string(detector_options['use_model_native_classes'])
554
+ if 'compatibility_mode' in detector_options:
555
+ if detector_options['compatibility_mode'] is None:
556
+ compatibility_mode = default_compatibility_mode
557
+ else:
558
+ compatibility_mode = detector_options['compatibility_mode']
559
+
560
+ if require_non_default_compatibility_mode:
561
+
562
+ print('### DEBUG: requiring non-default compatibility mode ###')
563
+ assert compatibility_mode != 'classic'
564
+ assert compatibility_mode != 'default'
565
+
566
+ preprocess_only = False
567
+ if (detector_options is not None) and \
568
+ ('preprocess_only' in detector_options) and \
569
+ (detector_options['preprocess_only']):
570
+ preprocess_only = True
571
+
572
+ if verbose or (not preprocess_only):
573
+ print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
574
+
575
+ model_metadata = read_metadata_from_megadetector_model_file(model_path)
121
576
 
577
+ #: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
578
+ #: aspect ratio".
579
+ if model_metadata is not None and 'image_size' in model_metadata:
580
+ self.default_image_size = model_metadata['image_size']
581
+ if verbose:
582
+ print('Loaded image size {} from model metadata'.format(self.default_image_size))
583
+ else:
584
+ self.default_image_size = 1280
585
+
586
+ #: Either a string ('cpu','cuda:0') or a torch.device()
122
587
  self.device = 'cpu'
588
+
589
+ #: Have we already printed a warning about using a non-standard image size?
590
+ #:
591
+ #: :meta private:
592
+ self.printed_image_size_warning = False
593
+
594
+ #: If this is False, we assume the underlying model is producing class indices in the
595
+ #: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
596
+ #: MD classes (1,2,3) before generating output. If this is True, we use whatever
597
+ #: indices the model provides
598
+ self.use_model_native_classes = use_model_native_classes
599
+
600
+ #: This allows us to maintain backwards compatibility across a set of changes to the
601
+ #: way this class does inference. Currently should start with either "default" or
602
+ #: "classic".
603
+ self.compatibility_mode = compatibility_mode
604
+
605
+ #: Stride size passed to YOLOv5's letterbox() function
606
+ self.letterbox_stride = 32
607
+
608
+ if 'classic' in self.compatibility_mode:
609
+ self.letterbox_stride = 64
610
+
611
+ #: Use half-precision inference... fixed by the model, generally don't mess with this
612
+ self.half_precision = False
613
+
614
+ if preprocess_only:
615
+ return
616
+
123
617
  if not force_cpu:
124
618
  if torch.cuda.is_available():
125
619
  self.device = torch.device('cuda:0')
@@ -129,41 +623,53 @@ class PTDetector:
129
623
  except AttributeError:
130
624
  pass
131
625
  try:
132
- self.model = PTDetector._load_model(model_path, self.device)
626
+ self.model = PTDetector._load_model(model_path,
627
+ device=self.device,
628
+ compatibility_mode=self.compatibility_mode)
629
+
133
630
  except Exception as e:
134
631
  # In a very esoteric scenario where an old version of YOLOv5 is used to run
135
632
  # newer models, we run into an issue because the "Model" class became
136
633
  # "DetectionModel". New YOLOv5 code handles this case by just setting them
137
- # to be the same, so doing that via monkey-patch doesn't seem *that* rude.
634
+ # to be the same, so doing that externally doesn't seem *that* rude.
138
635
  if "Can't get attribute 'DetectionModel'" in str(e):
139
636
  print('Forward-compatibility issue detected, patching')
140
637
  from models import yolo
141
- yolo.DetectionModel = yolo.Model
142
- self.model = PTDetector._load_model(model_path, self.device)
638
+ yolo.DetectionModel = yolo.Model
639
+ self.model = PTDetector._load_model(model_path,
640
+ device=self.device,
641
+ compatibility_mode=self.compatibility_mode,
642
+ verbose=verbose)
143
643
  else:
144
644
  raise
145
645
  if (self.device != 'cpu'):
146
- print('Sending model to GPU')
646
+ if verbose:
647
+ print('Sending model to GPU')
147
648
  self.model.to(self.device)
148
-
149
- self.printed_image_size_warning = False
150
- self.use_model_native_classes = use_model_native_classes
151
-
649
+
152
650
 
153
651
  @staticmethod
154
- def _load_model(model_pt_path, device):
652
+ def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
155
653
 
654
+ if verbose:
655
+ print(f'Using PyTorch version {torch.__version__}')
656
+
156
657
  # There are two very slightly different ways to load the model, (1) using the
157
658
  # map_location=device parameter to torch.load and (2) calling .to(device) after
158
659
  # loading the model. The former is what we did for a zillion years, but is not
159
660
  # supported on Apple silicon at of 2029.09. Switching to the latter causes
160
661
  # very slight changes to the output, which always make me nervous, so I'm not
161
662
  # doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
162
- use_map_location = (device != 'mps')
663
+ if 'classic' in compatibility_mode:
664
+ use_map_location = (device != 'mps')
665
+ else:
666
+ use_map_location = False
163
667
 
164
668
  if use_map_location:
165
669
  try:
166
670
  checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
671
+ # For a transitional period, we want to support torch 1.1x, where the weights_only
672
+ # parameter doesn't exist
167
673
  except Exception as e:
168
674
  if "'weights_only' is an invalid keyword" in str(e):
169
675
  checkpoint = torch.load(model_pt_path, map_location=device)
@@ -172,6 +678,8 @@ class PTDetector:
172
678
  else:
173
679
  try:
174
680
  checkpoint = torch.load(model_pt_path, weights_only=False)
681
+ # For a transitional period, we want to support torch 1.1x, where the weights_only
682
+ # parameter doesn't exist
175
683
  except Exception as e:
176
684
  if "'weights_only' is an invalid keyword" in str(e):
177
685
  checkpoint = torch.load(model_pt_path)
@@ -185,26 +693,31 @@ class PTDetector:
185
693
  if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
186
694
  m.recompute_scale_factor = None
187
695
 
188
- if use_map_location:
189
- model = checkpoint['model'].float().fuse().eval()
696
+ if use_map_location:
697
+ model = checkpoint['model'].float().fuse().eval()
190
698
  else:
191
- model = checkpoint['model'].float().fuse().eval().to(device)
699
+ model = checkpoint['model'].float().fuse().eval().to(device)
192
700
 
193
701
  return model
194
702
 
703
+ # ...def _load_model(...)
704
+
705
+
195
706
  def generate_detections_one_image(self,
196
707
  img_original,
197
708
  image_id='unknown',
198
709
  detection_threshold=0.00001,
199
710
  image_size=None,
200
711
  skip_image_resizing=False,
201
- augment=False):
712
+ augment=False,
713
+ preprocess_only=False,
714
+ verbose=False):
202
715
  """
203
716
  Applies the detector to an image.
204
717
 
205
718
  Args:
206
719
  img_original (Image): the PIL Image object (or numpy array) on which we should run the
207
- detector, with EXIF rotation already handled.
720
+ detector, with EXIF rotation already handled
208
721
  image_id (str, optional): a path to identify the image; will be in the "file" field
209
722
  of the output object
210
723
  detection_threshold (float, optional): only detections above this confidence threshold
@@ -212,8 +725,11 @@ class PTDetector:
212
725
  image_size (tuple, optional): image size to use for inference, only mess with this if
213
726
  (a) you're using a model other than MegaDetector or (b) you know what you're getting into
214
727
  skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
215
- external resizing)
728
+ external resizing), only mess with this if (a) you're using a model other than MegaDetector
729
+ or (b) you know what you're getting into
216
730
  augment (bool, optional): enable (implementation-specific) image augmentation
731
+ preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
732
+ verbose (bool, optional): enable additional debug output
217
733
 
218
734
  Returns:
219
735
  dict: a dictionary with the following fields:
@@ -227,111 +743,275 @@ class PTDetector:
227
743
  detections = []
228
744
  max_conf = 0.0
229
745
 
746
+ if preprocess_only:
747
+ assert 'classic' in self.compatibility_mode, \
748
+ 'Standalone preprocessing only supported in "classic" mode'
749
+ assert not skip_image_resizing, \
750
+ 'skip_image_resizing and preprocess_only are exclusive'
751
+
230
752
  if detection_threshold is None:
231
753
 
232
754
  detection_threshold = 0
233
755
 
234
756
  try:
235
757
 
236
- if not isinstance(img_original,np.ndarray):
237
- img_original = np.asarray(img_original)
238
-
239
- # Padded resize
240
- target_size = PTDetector.IMAGE_SIZE
241
-
242
- # Image size can be an int (which translates to a square target size) or (h,w)
243
- if image_size is not None:
244
-
245
- assert isinstance(image_size,int) or (len(image_size)==2)
758
+ # If the caller wants us to skip all the resizing operations...
759
+ if skip_image_resizing:
246
760
 
247
- if not self.printed_image_size_warning:
248
- print('Warning: using user-supplied image size {}'.format(image_size))
249
- self.printed_image_size_warning = True
250
-
251
- target_size = image_size
761
+ if isinstance(img_original,dict):
762
+ image_info = img_original
763
+ img = image_info['img_processed']
764
+ scaling_shape = image_info['scaling_shape']
765
+ letterbox_pad = image_info['letterbox_pad']
766
+ letterbox_ratio = image_info['letterbox_ratio']
767
+ img_original = image_info['img_original']
768
+ img_original_pil = image_info['img_original_pil']
769
+ else:
770
+ img = img_original
252
771
 
253
772
  else:
254
773
 
255
- self.printed_image_size_warning = False
774
+ img_original_pil = None
775
+ # If we were given a PIL image
776
+
777
+ if not isinstance(img_original,np.ndarray):
778
+ img_original_pil = img_original
779
+ img_original = np.asarray(img_original)
780
+
781
+ # PIL images are RGB already
782
+ # img_original = img_original[:, :, ::-1]
256
783
 
257
- # ...if the caller has specified an image size
784
+ # Save the original shape for scaling boxes later
785
+ scaling_shape = img_original.shape
786
+
787
+ # If the caller is requesting a specific target size...
788
+ if image_size is not None:
789
+
790
+ assert isinstance(image_size,int)
791
+
792
+ if not self.printed_image_size_warning:
793
+ print('Using user-supplied image size {}'.format(image_size))
794
+ self.printed_image_size_warning = True
795
+
796
+ # Otherwise resize to self.default_image_size
797
+ else:
798
+
799
+ image_size = self.default_image_size
800
+ self.printed_image_size_warning = False
801
+
802
+ # ...if the caller has specified an image size
803
+
804
+ # In "classic mode", we only do the letterboxing resize, we don't do an
805
+ # additional initial resizing operation
806
+ if 'classic' in self.compatibility_mode:
807
+
808
+ resize_ratio = 1.0
809
+
810
+ # Resize the image so the long side matches the target image size. This is not
811
+ # letterboxing (i.e., padding) yet, just resizing.
812
+ else:
813
+
814
+ use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
815
+
816
+ h,w = img_original.shape[:2]
817
+ resize_ratio = image_size / max(h,w)
818
+
819
+ # Only resize if we have to
820
+ if resize_ratio != 1:
821
+
822
+ # Match what yolov5 does: use linear interpolation for upsizing;
823
+ # area interpolation for downsizing
824
+ if resize_ratio > 1:
825
+ interpolation_method = cv2.INTER_LINEAR
826
+ else:
827
+ interpolation_method = cv2.INTER_AREA
828
+
829
+ if use_ceil_for_resize:
830
+ target_w = math.ceil(w * resize_ratio)
831
+ target_h = math.ceil(h * resize_ratio)
832
+ else:
833
+ target_w = int(w * resize_ratio)
834
+ target_h = int(h * resize_ratio)
835
+
836
+ img_original = cv2.resize(
837
+ img_original, (target_w, target_h),
838
+ interpolation=interpolation_method)
839
+
840
+ if 'classic' in self.compatibility_mode:
841
+
842
+ letterbox_auto = True
843
+ letterbox_scaleup = True
844
+ target_shape = image_size
845
+
846
+ else:
847
+
848
+ letterbox_auto = False
849
+ letterbox_scaleup = False
850
+
851
+ # The padding to apply as a fraction of the stride size
852
+ pad = 0.5
853
+
854
+ model_stride = int(self.model.stride.max())
855
+
856
+ max_dimension = max(img_original.shape)
857
+ normalized_shape = [img_original.shape[0] / max_dimension,
858
+ img_original.shape[1] / max_dimension]
859
+ target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
860
+ pad).astype(int) * model_stride
861
+
862
+ # Now we letterbox, which is just padding, since we've already resized.
863
+ img,letterbox_ratio,letterbox_pad = letterbox(img_original,
864
+ new_shape=target_shape,
865
+ stride=self.letterbox_stride,
866
+ auto=letterbox_auto,
867
+ scaleFill=False,
868
+ scaleup=letterbox_scaleup)
258
869
 
259
- if skip_image_resizing:
260
- img = img_original
261
- else:
262
- letterbox_result = letterbox(img_original,
263
- new_shape=target_size,
264
- stride=PTDetector.STRIDE,
265
- auto=True)
266
- img = letterbox_result[0]
267
-
268
- # HWC to CHW; PIL Image is RGB already
269
- img = img.transpose((2, 0, 1))
870
+ if preprocess_only:
871
+
872
+ assert 'file' in result
873
+ result['img_processed'] = img
874
+ result['img_original'] = img_original
875
+ result['img_original_pil'] = img_original_pil
876
+ result['target_shape'] = target_shape
877
+ result['scaling_shape'] = scaling_shape
878
+ result['letterbox_ratio'] = letterbox_ratio
879
+ result['letterbox_pad'] = letterbox_pad
880
+ return result
881
+
882
+ # ...are we doing resizing here, or were images already resized?
883
+
884
+ # Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
885
+ # so we don't need to mess with the color channels.
886
+ #
887
+ # TODO, this could be moved into the preprocessing loop
888
+
889
+ img = img.transpose((2, 0, 1)) # [::-1]
270
890
  img = np.ascontiguousarray(img)
271
891
  img = torch.from_numpy(img)
272
892
  img = img.to(self.device)
273
- img = img.float()
893
+ img = img.half() if self.half_precision else img.float()
274
894
  img /= 255
275
895
 
276
896
  # In practice this is always true
277
897
  if len(img.shape) == 3:
278
898
  img = torch.unsqueeze(img, 0)
279
899
 
900
+ # Run the model
280
901
  pred = self.model(img,augment=augment)[0]
281
902
 
282
- # NMS
903
+ if 'classic' in self.compatibility_mode:
904
+ nms_conf_thres = detection_threshold
905
+ nms_iou_thres = 0.45
906
+ nms_agnostic = False
907
+ nms_multi_label = False
908
+ else:
909
+ nms_conf_thres = detection_threshold # 0.01
910
+ nms_iou_thres = 0.6
911
+ nms_agnostic = False
912
+ nms_multi_label = True
913
+
914
+ # As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
915
+ #
916
+ # Send predictions back to the CPU for NMS.
283
917
  if self.device == 'mps':
284
- # As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
285
- #
286
- # Send predictions back to the CPU for NMS.
287
- pred = non_max_suppression(prediction=pred.cpu(), conf_thres=detection_threshold)
288
- else:
289
- pred = non_max_suppression(prediction=pred, conf_thres=detection_threshold)
918
+ pred_nms = pred.cpu()
919
+ else:
920
+ pred_nms = pred
921
+
922
+ # NMS
923
+ pred = non_max_suppression(prediction=pred_nms,
924
+ conf_thres=nms_conf_thres,
925
+ iou_thres=nms_iou_thres,
926
+ agnostic=nms_agnostic,
927
+ multi_label=nms_multi_label)
290
928
 
291
- # format detections/bounding boxes
292
- #
293
- # normalization gain whwh
294
- gn = torch.tensor(img_original.shape)[[1, 0, 1, 0]]
929
+ # In practice this is [w,h,w,h] of the original image
930
+ gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
295
931
 
932
+ if 'classic' in self.compatibility_mode:
933
+
934
+ ratio = None
935
+ ratio_pad = None
936
+
937
+ else:
938
+
939
+ # letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
940
+ #
941
+ # ratio is a 2-tuple specifying the scaling that was applied to each dimension.
942
+ #
943
+ # The scale_boxes function expects a 2-tuple with these things combined.
944
+ ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
945
+ ratio_pad = (ratio, letterbox_pad)
946
+
296
947
  # This is a loop over detection batches, which will always be length 1 in our case,
297
948
  # since we're not doing batch inference.
949
+ #
950
+ # det = pred[0]
951
+ #
952
+ # det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
953
+ # in descending order by confidence.
954
+ #
955
+ # Columns are:
956
+ #
957
+ # x0,y0,x1,y1,confidence,class
958
+ #
959
+ # At this point, these are *non*-normalized values, referring to the size at which we
960
+ # ran inference (img.shape).
298
961
  for det in pred:
299
962
 
300
- if len(det):
963
+ if len(det) == 0:
964
+ continue
965
+
966
+ # Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
967
+ if 'classic' in self.compatibility_mode:
301
968
 
302
- # Rescale boxes from img_size to im0 size
303
969
  det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
970
+
971
+ else:
972
+ # After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
973
+ # original pixel dimension of the image, followed by the class and confidence
974
+ det[:, :4] = scale_coords(img.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
304
975
 
305
- for *xyxy, conf, cls in reversed(det):
306
-
307
- # normalized center-x, center-y, width and height
308
- xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
976
+ # Loop over detections
977
+ for *xyxy, conf, cls in reversed(det):
978
+
979
+ if conf < detection_threshold:
980
+ continue
981
+
982
+ # Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
983
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
309
984
 
310
- api_box = ct_utils.convert_yolo_to_xywh(xywh)
985
+ # Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
986
+ # left/top/w/h (i.e., MD format)
987
+ api_box = ct_utils.convert_yolo_to_xywh(xywh)
311
988
 
989
+ if 'classic' in self.compatibility_mode:
990
+ api_box = ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
312
991
  conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
992
+ else:
993
+ api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
994
+ conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
995
+
996
+ if not self.use_model_native_classes:
997
+ # The MegaDetector output format's categories start at 1, but all YOLO-based
998
+ # MD models have category numbers starting at 0.
999
+ cls = int(cls.tolist()) + 1
1000
+ if cls not in (1, 2, 3):
1001
+ raise KeyError(f'{cls} is not a valid class.')
1002
+ else:
1003
+ cls = int(cls.tolist())
313
1004
 
314
- if not self.use_model_native_classes:
315
- # MegaDetector output format's categories start at 1, but the MD
316
- # model's categories start at 0.
317
- cls = int(cls.tolist()) + 1
318
- if cls not in (1, 2, 3):
319
- raise KeyError(f'{cls} is not a valid class.')
320
- else:
321
- cls = int(cls.tolist())
322
-
323
- detections.append({
324
- 'category': str(cls),
325
- 'conf': conf,
326
- 'bbox': ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
327
- })
328
- max_conf = max(max_conf, conf)
329
-
330
- # ...for each detection in this batch
331
-
332
- # ...if this is a non-empty batch
333
-
334
- # ...for each detection batch
1005
+ detections.append({
1006
+ 'category': str(cls),
1007
+ 'conf': conf,
1008
+ 'bbox': api_box
1009
+ })
1010
+ max_conf = max(max_conf, conf)
1011
+
1012
+ # ...for each detection in this batch
1013
+
1014
+ # ...for each detection batch (always one iteration)
335
1015
 
336
1016
  # ...try
337
1017
 
@@ -339,7 +1019,8 @@ class PTDetector:
339
1019
 
340
1020
  result['failure'] = FAILURE_INFER
341
1021
  print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
342
- traceback.print_exc(e)
1022
+ # traceback.print_exc(e)
1023
+ print(traceback.format_exc())
343
1024
 
344
1025
  result['max_detection_conf'] = max_conf
345
1026
  result['detections'] = detections
@@ -361,7 +1042,7 @@ if __name__ == '__main__':
361
1042
 
362
1043
  #%%
363
1044
 
364
- import os
1045
+ import os #noqa
365
1046
  from megadetector.visualization import visualization_utils as vis_utils
366
1047
 
367
1048
  model_file = os.environ['MDV5A']