megadetector 10.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +701 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +563 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +192 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +665 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +984 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2172 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1604 -0
  81. megadetector/detection/run_tiled_inference.py +1044 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1943 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2140 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +231 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2872 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1766 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1973 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +498 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.15.dist-info/METADATA +115 -0
  144. megadetector-10.0.15.dist-info/RECORD +147 -0
  145. megadetector-10.0.15.dist-info/WHEEL +5 -0
  146. megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1451 @@
1
+ """
2
+
3
+ pytorch_detector.py
4
+
5
+ Module to run YOLO-based MegaDetector models.
6
+
7
+ """
8
+
9
+ #%% Imports and constants
10
+
11
+ import os
12
+ import sys
13
+ import math
14
+ import zipfile
15
+ import tempfile
16
+ import shutil
17
+ import uuid
18
+ import json
19
+ import inspect
20
+
21
+ import cv2
22
+ import torch
23
+ import numpy as np
24
+
25
+ from megadetector.detection.run_detector import \
26
+ CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, FAILURE_IMAGE_OPEN, \
27
+ get_detector_version_from_model_file, \
28
+ known_models
29
+ from megadetector.utils.ct_utils import parse_bool_string
30
+ from megadetector.utils.ct_utils import is_running_in_gha
31
+ from megadetector.utils import ct_utils
32
+ import torchvision
33
+
34
+ # We support a few ways of accessing the YOLOv5 dependencies:
35
+ #
36
+ # * The standard configuration as of 2023.09 expects that the YOLOv5 repo is checked
37
+ # out and on the PYTHONPATH (import utils)
38
+ #
39
+ # * Supported but non-default (used for PyPI packaging):
40
+ #
41
+ # pip install ultralytics-yolov5
42
+ #
43
+ # * Works, but not supported:
44
+ #
45
+ # pip install yolov5
46
+ #
47
+ # * Unfinished:
48
+ #
49
+ # pip install ultralytics
50
+
51
+ yolo_model_type_imported = None
52
+
53
+ def _get_model_type_for_model(model_file,
54
+ prefer_model_type_source='table',
55
+ default_model_type='yolov5',
56
+ verbose=False):
57
+ """
58
+ Determine the model type (i.e., the inference library we need to use) for a .pt file.
59
+
60
+ Args:
61
+ model_file (str): the model file to read
62
+ prefer_model_type_source (str, optional): how should we handle the (very unlikely)
63
+ case where the metadata in the file indicates one model type, but the global model
64
+ type table says something else. Should be "table" (trust the table) or "file"
65
+ (trust the file).
66
+ default_model_type (str, optional): return value for the case where we can't find
67
+ appropriate metadata in the file or in the global table.
68
+ verbose (bool, optional): enable additional debug output
69
+
70
+ Returns:
71
+ str: the model type indicated for this model
72
+ """
73
+
74
+ model_info = read_metadata_from_megadetector_model_file(model_file)
75
+
76
+ # Check whether the model file itself specified a model type
77
+ model_type_from_model_file_metadata = None
78
+
79
+ if model_info is not None and 'model_type' in model_info:
80
+ model_type_from_model_file_metadata = model_info['model_type']
81
+ if verbose:
82
+ print('Parsed model type {} from model {}'.format(
83
+ model_type_from_model_file_metadata,
84
+ model_file))
85
+
86
+ model_type_from_model_version = None
87
+
88
+ # Check whether this is a known model version with a specific model type
89
+ model_version_from_file = get_detector_version_from_model_file(model_file)
90
+
91
+ if model_version_from_file is not None and model_version_from_file in known_models:
92
+ model_info = known_models[model_version_from_file]
93
+ if 'model_type' in model_info:
94
+ model_type_from_model_version = model_info['model_type']
95
+ if verbose:
96
+ print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
97
+ else:
98
+ model_type_from_model_version = None
99
+
100
+ if model_type_from_model_file_metadata is None and \
101
+ model_type_from_model_version is None:
102
+ if verbose:
103
+ print('Could not determine model type for {}, assuming {}'.format(
104
+ model_file,default_model_type))
105
+ model_type = default_model_type
106
+
107
+ elif model_type_from_model_file_metadata is not None and \
108
+ model_type_from_model_version is not None:
109
+ if model_type_from_model_version == model_type_from_model_file_metadata:
110
+ model_type = model_type_from_model_file_metadata
111
+ else:
112
+ print('Warning: model type from model version is {}, from file metadata is {}'.format(
113
+ model_type_from_model_version,model_type_from_model_file_metadata))
114
+ if prefer_model_type_source == 'table':
115
+ model_type = model_type_from_model_version
116
+ else:
117
+ model_type = model_type_from_model_file_metadata
118
+
119
+ elif model_type_from_model_file_metadata is not None:
120
+
121
+ model_type = model_type_from_model_file_metadata
122
+
123
+ elif model_type_from_model_version is not None:
124
+
125
+ model_type = model_type_from_model_version
126
+
127
+ return model_type
128
+
129
+ # ...def _get_model_type_for_model(...)
130
+
131
+
132
+ def _initialize_yolo_imports_for_model(model_file,
133
+ prefer_model_type_source='table',
134
+ default_model_type='yolov5',
135
+ detector_options=None,
136
+ verbose=False):
137
+ """
138
+ Initialize the appropriate YOLO imports for a model file.
139
+
140
+ Args:
141
+ model_file (str): The model file for which we're loading support
142
+ prefer_model_type_source (str, optional): how should we handle the (very unlikely)
143
+ case where the metadata in the file indicates one model type, but the global model
144
+ type table says something else. Should be "table" (trust the table) or "file"
145
+ (trust the file).
146
+ default_model_type (str, optional): return value for the case where we can't find
147
+ appropriate metadata in the file or in the global table.
148
+ detector_options (dict, optional): dictionary of detector options that mean
149
+ different things to different models
150
+ verbose (bool, optional): enable additional debug output
151
+
152
+ Returns:
153
+ str: the model type for which we initialized support
154
+ """
155
+
156
+ global yolo_model_type_imported
157
+
158
+ if detector_options is not None and 'model_type' in detector_options:
159
+ model_type = detector_options['model_type']
160
+ print('Model type {} provided in detector options'.format(model_type))
161
+ else:
162
+ model_type = _get_model_type_for_model(model_file,
163
+ prefer_model_type_source=prefer_model_type_source,
164
+ default_model_type=default_model_type)
165
+
166
+ if yolo_model_type_imported is not None:
167
+ if model_type == yolo_model_type_imported:
168
+ print('Bypassing imports for model type {}'.format(model_type))
169
+ return
170
+ else:
171
+ print('Previously set up imports for model type {}, re-importing as {}'.format(
172
+ yolo_model_type_imported,model_type))
173
+
174
+ _initialize_yolo_imports(model_type,verbose=verbose)
175
+
176
+ return model_type
177
+
178
+
179
+ def _clean_yolo_imports(verbose=False, aggressive_cleanup=False):
180
+ """
181
+ Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
182
+ of another YOLO library version. The reason we jump through all these hoops, rather than
183
+ just, e.g., handling different libraries in different modules, is that we need to make sure
184
+ *pickle* sees the right version of modules during module loading, including modules we don't
185
+ load directly (i.e., every module loaded within a YOLO library), and the only way I know to
186
+ do that is to remove all the "wrong" versions from sys.modules and sys.path.
187
+
188
+ Args:
189
+ verbose (bool, optional): enable additional debug output
190
+ aggressive_cleanup (bool, optional): err on the side of removing modules,
191
+ at least by ignoring whether they are/aren't in a site-packages folder.
192
+ By default, only modules in a folder that includes "site-packages" will
193
+ be considered for unloading.
194
+ """
195
+
196
+ modules_to_delete = []
197
+
198
+ for module_name in sys.modules.keys():
199
+
200
+ module = sys.modules[module_name]
201
+ if not hasattr(module,'__file__') or (module.__file__ is None):
202
+ continue
203
+ try:
204
+ module_file = module.__file__.replace('\\','/')
205
+ if not aggressive_cleanup:
206
+ if 'site-packages' not in module_file:
207
+ continue
208
+ tokens = module_file.split('/')
209
+
210
+ # For local path imports, a module filename that should be unloaded might
211
+ # look like:
212
+ #
213
+ # c:/git/yolov9/models/common.py
214
+ #
215
+ # For pip imports, a module filename that should be unloaded might look like:
216
+ #
217
+ # c:/users/user/miniforge3/envs/megadetector/lib/site-packages/yolov9/utils/__init__.py
218
+ first_token_to_check = len(tokens) - 4
219
+ for i_token,token in enumerate(tokens):
220
+ if i_token < first_token_to_check:
221
+ continue
222
+ # Don't remove anything based on the environment name, which
223
+ # always follows "envs" in the path
224
+ if (i_token > 1) and (tokens[i_token-1] == 'envs'):
225
+ continue
226
+ if ('yolov5' in token) or ('yolov9' in token) or ('ultralytics' in token):
227
+ if verbose:
228
+ print('Module {} ({}) looks deletable'.format(module_name,module_file))
229
+ modules_to_delete.append(module_name)
230
+ break
231
+ except Exception as e:
232
+ if verbose:
233
+ print('Exception during module review: {}'.format(str(e)))
234
+ pass
235
+
236
+ # ...for each module in the global namespace
237
+
238
+ for module_name in modules_to_delete:
239
+
240
+ if module_name in sys.modules.keys():
241
+ if verbose:
242
+ try:
243
+ module = sys.modules[module_name]
244
+ module_file = module.__file__.replace('\\','/')
245
+ print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
246
+ except Exception:
247
+ pass
248
+ del sys.modules[module_name]
249
+
250
+ # ...for each module we want to remove from the global namespace
251
+
252
+ paths_to_delete = []
253
+
254
+ for p in sys.path:
255
+ if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
256
+ print('clean_yolo_imports: removing {} from path'.format(p))
257
+ paths_to_delete.append(p)
258
+
259
+ for p in paths_to_delete:
260
+ sys.path.remove(p)
261
+
262
+ # ...def _clean_yolo_imports(...)
263
+
264
+
265
+ def _initialize_yolo_imports(model_type='yolov5',
266
+ allow_fallback_import=True,
267
+ force_reimport=False,
268
+ verbose=False):
269
+ """
270
+ Imports required functions from one or more yolo libraries (yolov5, yolov9,
271
+ ultralytics, targeting support for [model_type]).
272
+
273
+ Args:
274
+ model_type (str): The model type for which we're loading support
275
+ allow_fallback_import (bool, optional): If we can't import from the package for
276
+ which we're trying to load support, fall back to "import utils". This is
277
+ typically used when the right support library is on the current PYTHONPATH.
278
+ force_reimport (bool, optional): import the appropriate libraries even if the
279
+ requested model type matches the current initialization state
280
+ verbose (bool, optional): include additional debug output
281
+
282
+ Returns:
283
+ str: the model type for which we initialized support
284
+ """
285
+
286
+ # When running in pytest, the megadetector 'utils' module is put in the global
287
+ # namespace, which creates conflicts with yolov5; remove it from the global
288
+ # namespsace.
289
+ if ('PYTEST_CURRENT_TEST' in os.environ):
290
+ print('*** pytest detected ***')
291
+ if ('utils' in sys.modules):
292
+ utils_module = sys.modules['utils']
293
+ if hasattr(utils_module, '__file__') and 'megadetector' in str(utils_module.__file__):
294
+ print(f"Removing conflicting utils module: {utils_module.__file__}")
295
+ sys.modules.pop('utils', None)
296
+ # Also remove any submodules
297
+ to_remove = [name for name in sys.modules if name.startswith('utils.')]
298
+ for name in to_remove:
299
+ sys.modules.pop(name, None)
300
+
301
+ global yolo_model_type_imported
302
+
303
+ if model_type is None:
304
+ model_type = 'yolov5'
305
+
306
+ # The point of this function is to make the appropriate version
307
+ # of the following functions available at module scope
308
+ global non_max_suppression
309
+ global xyxy2xywh
310
+ global letterbox
311
+ global scale_coords
312
+
313
+ if yolo_model_type_imported is not None:
314
+ if (yolo_model_type_imported == model_type) and (not force_reimport):
315
+ print('Bypassing imports for YOLO model type {}'.format(model_type))
316
+ return
317
+ else:
318
+ _clean_yolo_imports()
319
+
320
+ try_yolov5_import = (model_type == 'yolov5')
321
+ try_yolov9_import = (model_type == 'yolov9')
322
+ try_ultralytics_import = (model_type == 'ultralytics')
323
+
324
+ utils_imported = False
325
+
326
+ # First try importing from the yolov5 package; this is how the pip
327
+ # package finds YOLOv5 utilities.
328
+ if try_yolov5_import and not utils_imported:
329
+
330
+ try:
331
+ # from yolov5.utils.general import non_max_suppression # type: ignore
332
+ from yolov5.utils.general import xyxy2xywh # noqa
333
+ from yolov5.utils.augmentations import letterbox # noqa
334
+ try:
335
+ from yolov5.utils.general import scale_boxes as scale_coords
336
+ except Exception:
337
+ from yolov5.utils.general import scale_coords
338
+ utils_imported = True
339
+ if verbose:
340
+ print('Imported utils from YOLOv5 package')
341
+
342
+ except Exception as e: # noqa
343
+ # print('yolov5 module import failed: {}'.format(e))
344
+ # print(traceback.format_exc())
345
+ pass
346
+
347
+ # Next try importing from the yolov9 package
348
+ if try_yolov9_import and not utils_imported:
349
+
350
+ try:
351
+
352
+ # from yolov9.utils.general import non_max_suppression # noqa
353
+ from yolov9.utils.general import xyxy2xywh # noqa
354
+ from yolov9.utils.augmentations import letterbox # noqa
355
+ from yolov9.utils.general import scale_boxes as scale_coords # noqa
356
+ utils_imported = True
357
+ if verbose:
358
+ print('Imported utils from YOLOv9 package')
359
+
360
+ except Exception as e: # noqa
361
+
362
+ # print('yolov9 module import failed: {}'.format(e))
363
+ # print(traceback.format_exc())
364
+ pass
365
+
366
+ # If we haven't succeeded yet, import from the ultralytics package
367
+ if try_ultralytics_import and not utils_imported:
368
+
369
+ try:
370
+
371
+ import ultralytics # type: ignore # noqa
372
+
373
+ except Exception:
374
+
375
+ print('It looks like you are trying to run a model that requires the ultralytics package, '
376
+ 'but the ultralytics package is not installed. For licensing reasons, this '
377
+ 'is not installed by default with the MegaDetector Python package. Run '
378
+ '"pip install ultralytics" to install it, and try again.')
379
+ raise
380
+
381
+ try:
382
+
383
+ # The non_max_suppression() function moved from the ops module to the nms module
384
+ # in mid-2025
385
+ try:
386
+ from ultralytics.utils.ops import non_max_suppression # type: ignore # noqa
387
+ except Exception:
388
+ from ultralytics.utils.nms import non_max_suppression # type: ignore # noqa
389
+ from ultralytics.utils.ops import xyxy2xywh # type: ignore # noqa
390
+
391
+ # In the ultralytics package, scale_boxes and scale_coords both exist;
392
+ # we want scale_boxes.
393
+ #
394
+ # from ultralytics.utils.ops import scale_coords # noqa
395
+ from ultralytics.utils.ops import scale_boxes as scale_coords # type: ignore # noqa
396
+ from ultralytics.data.augment import LetterBox # type: ignore # noqa
397
+
398
+ # letterbox() became a LetterBox class in the ultralytics package. Create a
399
+ # backwards-compatible letterbox function wrapper that wraps the class up.
400
+ def letterbox(img,new_shape,auto=False,scaleFill=False, #noqa
401
+ scaleup=True,center=True,stride=32):
402
+
403
+ # Ultralytics changed the "scaleFill" parameter to "scale_fill", we want to support
404
+ # both conventions.
405
+ use_old_scalefill_arg = False
406
+ try:
407
+ sig = inspect.signature(LetterBox.__init__)
408
+ if 'scaleFill' in sig.parameters:
409
+ use_old_scalefill_arg = True
410
+ except Exception:
411
+ pass
412
+
413
+ if use_old_scalefill_arg:
414
+ if verbose:
415
+ print('Using old scaleFill calling convention')
416
+ letterbox_transformer = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,
417
+ scaleup=scaleup,center=center,stride=stride)
418
+ else:
419
+ letterbox_transformer = LetterBox(new_shape,auto=auto,scale_fill=scaleFill,
420
+ scaleup=scaleup,center=center,stride=stride)
421
+
422
+ letterbox_result = letterbox_transformer(image=img)
423
+
424
+ if isinstance(new_shape,int):
425
+ new_shape = [new_shape,new_shape]
426
+
427
+ # The letterboxing is done, we just need to reverse-engineer what it did
428
+ shape = img.shape[:2]
429
+
430
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
431
+ if not scaleup:
432
+ r = min(r, 1.0)
433
+ ratio = r, r
434
+
435
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
436
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
437
+ if auto:
438
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride)
439
+ elif scaleFill:
440
+ dw, dh = 0.0, 0.0
441
+ new_unpad = (new_shape[1], new_shape[0])
442
+ ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
443
+
444
+ dw /= 2
445
+ dh /= 2
446
+ pad = (dw,dh)
447
+
448
+ return [letterbox_result,ratio,pad]
449
+
450
+ utils_imported = True
451
+ if verbose:
452
+ print('Imported utils from ultralytics package')
453
+
454
+ except Exception as e:
455
+
456
+ print('Ultralytics module import failed: {}'.format(str(e)))
457
+ pass
458
+
459
+ # If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
460
+ if (not utils_imported) and allow_fallback_import:
461
+
462
+ try:
463
+
464
+ # import pre- and post-processing functions from the YOLOv5 repo
465
+ # from utils.general import non_max_suppression # type: ignore
466
+ from utils.general import xyxy2xywh # type: ignore
467
+ from utils.augmentations import letterbox # type: ignore
468
+
469
+ # scale_coords() is scale_boxes() in some YOLOv5 versions
470
+ try:
471
+ from utils.general import scale_coords # type: ignore
472
+ except ImportError:
473
+ from utils.general import scale_boxes as scale_coords # type: ignore
474
+ utils_imported = True
475
+ imported_file = sys.modules[scale_coords.__module__].__file__
476
+ if verbose:
477
+ print('Imported utils from {}'.format(imported_file))
478
+
479
+ except ModuleNotFoundError as e:
480
+
481
+ raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
482
+
483
+ assert utils_imported, 'YOLO utils import error'
484
+
485
+ yolo_model_type_imported = model_type
486
+ if verbose:
487
+ print('Prepared YOLO imports for model type {}'.format(model_type))
488
+
489
+ return model_type
490
+
491
+ # ...def _initialize_yolo_imports(...)
492
+
493
+
494
+ #%% NMS
495
+
496
+ def nms(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
497
+ """
498
+ Non-maximum suppression (a wrapper around torchvision.ops.nms())
499
+
500
+ Args:
501
+ prediction (torch.Tensor): Model predictions with shape [batch_size, num_anchors, num_classes + 5]
502
+ Format: [x_center, y_center, width, height, objectness, class1_conf, class2_conf, ...]
503
+ Coordinates are normalized to input image size.
504
+ conf_thres (float): Confidence threshold for filtering detections
505
+ iou_thres (float): IoU threshold for NMS
506
+ max_det (int): Maximum number of detections per image
507
+
508
+ Returns:
509
+ list: List of tensors, one per image in batch. Each tensor has shape [N, 6] where:
510
+ - N is the number of detections for that image
511
+ - Columns are [x1, y1, x2, y2, confidence, class_id]
512
+ - Coordinates are in absolute pixels relative to input image size
513
+ - class_id is the integer class index (0-based)
514
+ """
515
+
516
+ batch_size = prediction.shape[0]
517
+ num_classes = prediction.shape[2] - 5 # noqa
518
+ output = []
519
+
520
+ # Process each image in the batch
521
+ for img_idx in range(batch_size):
522
+
523
+ x = prediction[img_idx] # Shape: [num_anchors, num_classes + 5]
524
+
525
+ # Filter by objectness confidence
526
+ obj_conf = x[:, 4]
527
+ valid_detections = obj_conf > conf_thres
528
+ x = x[valid_detections]
529
+
530
+ if x.shape[0] == 0:
531
+ # No detections for this image
532
+ output.append(torch.zeros((0, 6), device=prediction.device))
533
+ continue
534
+
535
+ # Convert box coordinates from [x_center, y_center, w, h] to [x1, y1, x2, y2]
536
+ box = x[:, :4].clone()
537
+ box[:, 0] = x[:, 0] - x[:, 2] / 2.0 # x1 = center_x - width/2
538
+ box[:, 1] = x[:, 1] - x[:, 3] / 2.0 # y1 = center_y - height/2
539
+ box[:, 2] = x[:, 0] + x[:, 2] / 2.0 # x2 = center_x + width/2
540
+ box[:, 3] = x[:, 1] + x[:, 3] / 2.0 # y2 = center_y + height/2
541
+
542
+ # Get class predictions: multiply objectness by class probabilities
543
+ class_conf = x[:, 5:] * x[:, 4:5] # shape: [N, num_classes]
544
+
545
+ # For each detection, take the class with highest confidence (single-label)
546
+ best_class_conf, best_class_idx = class_conf.max(1, keepdim=True)
547
+
548
+ # Filter by class confidence threshold
549
+ conf_mask = best_class_conf.view(-1) > conf_thres
550
+ if conf_mask.sum() == 0:
551
+ # No detections pass confidence threshold
552
+ output.append(torch.zeros((0, 6), device=prediction.device))
553
+ continue
554
+
555
+ box = box[conf_mask]
556
+ best_class_conf = best_class_conf[conf_mask]
557
+ best_class_idx = best_class_idx[conf_mask]
558
+
559
+ # Prepare for NMS: group detections by class
560
+ unique_classes = best_class_idx.unique()
561
+ final_detections = []
562
+
563
+ for class_id in unique_classes:
564
+
565
+ class_mask = (best_class_idx == class_id).view(-1)
566
+ class_boxes = box[class_mask]
567
+ class_scores = best_class_conf[class_mask].view(-1)
568
+
569
+ if class_boxes.shape[0] == 0:
570
+ continue
571
+
572
+ # Apply NMS for this class
573
+ keep_indices = torchvision.ops.nms(class_boxes, class_scores, iou_thres)
574
+
575
+ if len(keep_indices) > 0:
576
+ kept_boxes = class_boxes[keep_indices]
577
+ kept_scores = class_scores[keep_indices]
578
+ kept_classes = torch.full((len(keep_indices), 1), class_id.item(),
579
+ device=prediction.device, dtype=torch.float)
580
+
581
+ # Combine: [x1, y1, x2, y2, conf, class]
582
+ class_detections = torch.cat([kept_boxes, kept_scores.unsqueeze(1), kept_classes], 1)
583
+ final_detections.append(class_detections)
584
+
585
+ # ...for each category
586
+
587
+ if final_detections:
588
+
589
+ # Combine all classes and sort by confidence
590
+ all_detections = torch.cat(final_detections, 0)
591
+ conf_sort_indices = all_detections[:, 4].argsort(descending=True)
592
+ all_detections = all_detections[conf_sort_indices]
593
+
594
+ # Limit to max_det
595
+ if all_detections.shape[0] > max_det:
596
+ all_detections = all_detections[:max_det]
597
+
598
+ output.append(all_detections)
599
+ else:
600
+ output.append(torch.zeros((0, 6), device=prediction.device))
601
+
602
+ # ...for each image in the batch
603
+
604
+ return output
605
+
606
+ # ...def nms(...)
607
+
608
+
609
+ #%% Model metadata functions
610
+
611
+ def add_metadata_to_megadetector_model_file(model_file_in,
612
+ model_file_out,
613
+ metadata,
614
+ destination_path='megadetector_info.json'):
615
+ """
616
+ Adds a .json file to the specified MegaDetector model file containing metadata used
617
+ by this module. Always over-writes the output file.
618
+
619
+ Args:
620
+ model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
621
+ model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
622
+ May be the same as model_file_in.
623
+ metadata (dict): The metadata dict to add to the output model file
624
+ destination_path (str, optional): The relative path within the main folder of the
625
+ model archive where we should write the metadata. This is not relative to the root
626
+ of the archive, it's relative to the one and only folder at the root of the archive
627
+ (this is a PyTorch convention).
628
+ """
629
+
630
+ tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
631
+ os.makedirs(tmp_base,exist_ok=True)
632
+ metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
633
+ metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
634
+
635
+ with open(metadata_tmp_file_abs,'w') as f:
636
+ json.dump(metadata,f,indent=1)
637
+
638
+ # Copy the input file to the output file
639
+ shutil.copyfile(model_file_in,model_file_out)
640
+
641
+ # Write metadata to the output file
642
+ with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
643
+
644
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
645
+ # it in the one and only folder.
646
+ names = zipf.namelist()
647
+ root_folders = set()
648
+ for name in names:
649
+ root_folder = name.split('/')[0]
650
+ root_folders.add(root_folder)
651
+ assert len(root_folders) == 1,\
652
+ 'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
653
+ root_folder = next(iter(root_folders))
654
+
655
+ zipf.write(metadata_tmp_file_abs,
656
+ root_folder + '/' + destination_path,
657
+ compresslevel=9,
658
+ compress_type=zipfile.ZIP_DEFLATED)
659
+
660
+ try:
661
+ os.remove(metadata_tmp_file_abs)
662
+ except Exception as e:
663
+ print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
664
+
665
+ # ...def add_metadata_to_megadetector_model_file(...)
666
+
667
+
668
+ def read_metadata_from_megadetector_model_file(model_file,
669
+ relative_path='megadetector_info.json',
670
+ verbose=False):
671
+ """
672
+ Reads custom MegaDetector metadata from a modified MegaDetector model file.
673
+
674
+ Args:
675
+ model_file (str): The model filename to read, typically .pt (.zip is also sensible)
676
+ relative_path (str, optional): The relative path within the main folder of the model
677
+ archive from which we should read the metadata. This is not relative to the root
678
+ of the archive, it's relative to the one and only folder at the root of the archive
679
+ (this is a PyTorch convention).
680
+ verbose (str, optional): enable additional debug output
681
+
682
+ Returns:
683
+ object: whatever we read from the metadata file, always a dict in practice. Returns
684
+ None if we failed to read the specified metadata file.
685
+ """
686
+
687
+ with zipfile.ZipFile(model_file,'r') as zipf:
688
+
689
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
690
+ # it in the one and only folder.
691
+ names = zipf.namelist()
692
+ root_folders = set()
693
+ for name in names:
694
+ root_folder = name.split('/')[0]
695
+ root_folders.add(root_folder)
696
+ if len(root_folders) != 1:
697
+ print('Warning: this archive does not have exactly one folder at the top level; ' + \
698
+ 'are you sure it\'s a Torch model file?')
699
+ return None
700
+ root_folder = next(iter(root_folders))
701
+
702
+ metadata_file = root_folder + '/' + relative_path
703
+ if metadata_file not in names:
704
+ # This is the case for MDv5a and MDv5b
705
+ if verbose:
706
+ print('Warning: could not find metadata file {} in zip archive {}'.format(
707
+ metadata_file,os.path.basename(model_file)))
708
+ return None
709
+
710
+ try:
711
+ path = zipfile.Path(zipf,metadata_file)
712
+ contents = path.read_text()
713
+ d = json.loads(contents)
714
+ except Exception as e:
715
+ print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
716
+ return None
717
+
718
+ return d
719
+
720
+ # ...with zipfile.Zipfile(...)
721
+
722
+ # ...def read_metadata_from_megadetector_model_file(...)
723
+
724
+
725
+ #%% Inference classes
726
+
727
+ default_compatibility_mode = 'classic'
728
+
729
+ # This is a useful hack when I want to verify that my test driver (md_tests.py) is
730
+ # correctly forcing a specific compatibility mode (I use "classic-test" in that case)
731
+ require_non_default_compatibility_mode = False
732
+
733
+ class PTDetector:
734
+ """
735
+ Class that runs a PyTorch-based MegaDetector model. Also used as a preprocessor
736
+ for images that will later be run through an instance of PTDetector.
737
+ """
738
+
739
+ def __init__(self, model_path, detector_options=None, verbose=False):
740
+ """
741
+ PTDetector constructor. If detector_options['preprocess_only'] exists and is
742
+ True, this instance is being used as a preprocessor, so we don't load model weights.
743
+ """
744
+
745
+ if verbose:
746
+ print('Initializing PTDetector (verbose)')
747
+
748
+ # Set up the import environment for this model, unloading previous
749
+ # YOLO library versions if necessary.
750
+ _initialize_yolo_imports_for_model(model_path,
751
+ detector_options=detector_options,
752
+ verbose=verbose)
753
+
754
+ # Parse options specific to this detector family
755
+ force_cpu = False
756
+ use_model_native_classes = False
757
+ compatibility_mode = default_compatibility_mode
758
+
759
+ if detector_options is not None:
760
+
761
+ if 'force_cpu' in detector_options:
762
+ force_cpu = parse_bool_string(detector_options['force_cpu'])
763
+ if 'use_model_native_classes' in detector_options:
764
+ use_model_native_classes = parse_bool_string(detector_options['use_model_native_classes'])
765
+ if 'compatibility_mode' in detector_options:
766
+ if detector_options['compatibility_mode'] is None:
767
+ compatibility_mode = default_compatibility_mode
768
+ else:
769
+ compatibility_mode = detector_options['compatibility_mode']
770
+
771
+ # This is a global option used only during testing, to make sure I'm hitting
772
+ # the cases where we are not using "classic" preprocessing.
773
+ if require_non_default_compatibility_mode:
774
+
775
+ print('### DEBUG: requiring non-default compatibility mode ###')
776
+ assert compatibility_mode != 'classic'
777
+ assert compatibility_mode != 'default'
778
+
779
+ preprocess_only = False
780
+ if (detector_options is not None) and \
781
+ ('preprocess_only' in detector_options) and \
782
+ (detector_options['preprocess_only']):
783
+ preprocess_only = True
784
+
785
+ if verbose or (not preprocess_only):
786
+ print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
787
+
788
+ self.model_metadata = read_metadata_from_megadetector_model_file(model_path)
789
+
790
+ #: Image size passed to the letterbox() function; 1280 means "1280 on the long side,
791
+ #: preserving aspect ratio".
792
+ if self.model_metadata is not None and 'image_size' in self.model_metadata:
793
+ self.default_image_size = self.model_metadata['image_size']
794
+ print('Loaded image size {} from model metadata'.format(self.default_image_size))
795
+ else:
796
+ # This is not the default for most YOLO models, but most of the time, if someone
797
+ # is loading a model here that does not have metadata, it's MDv5[ab].0.0
798
+ print('No image size available in model metadata, defaulting to 1280')
799
+ self.default_image_size = 1280
800
+
801
+ #: Either a string ('cpu','cuda:0') or a torch.device()
802
+ self.device = 'cpu'
803
+
804
+ #: Have we already printed a warning about using a non-standard image size?
805
+ #:
806
+ #: :meta private:
807
+ self.printed_image_size_warning = False
808
+
809
+ #: If this is False, we assume the underlying model is producing class indices in the
810
+ #: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
811
+ #: MD classes (1,2,3) before generating output. If this is True, we use whatever
812
+ #: indices the model provides
813
+ self.use_model_native_classes = use_model_native_classes
814
+
815
+ #: This allows us to maintain backwards compatibility across a set of changes to the
816
+ #: way this class does inference. Currently should start with either "default" or
817
+ #: "classic".
818
+ self.compatibility_mode = compatibility_mode
819
+
820
+ #: Stride size passed to the YOLO letterbox() function
821
+ self.letterbox_stride = 32
822
+
823
+ # This is a convenient heuristic to determine the stride size without actually loading
824
+ # the model: the only models in the YOLO family with a stride size of 64 are the
825
+ # YOLOv5*6 and YOLOv5*6u models, which are 1280px models.
826
+ #
827
+ # See:
828
+ #
829
+ # github.com/ultralytics/ultralytics/issues/21544
830
+ #
831
+ # Note to self, though, if I decide later to require loading the model on preprocessing
832
+ # workers so I can more reliably choose a stride, this is the right way to determine the
833
+ # stride:
834
+ #
835
+ # self.letterbox_stride = int(self.model.stride.max())
836
+ if self.default_image_size == 1280:
837
+ self.letterbox_stride = 64
838
+
839
+ print('Using model stride: {}'.format(self.letterbox_stride))
840
+
841
+ #: Use half-precision inference... fixed by the model, generally don't mess with this
842
+ self.half_precision = False
843
+
844
+ if preprocess_only:
845
+ return
846
+
847
+ if not force_cpu:
848
+ if torch.cuda.is_available():
849
+ self.device = torch.device('cuda:0')
850
+ try:
851
+ if torch.backends.mps.is_built and torch.backends.mps.is_available():
852
+ # MPS inference fails on GitHub runners as of 2025.08. This is
853
+ # independent of model size. So, we disable MPS when running in GHA.
854
+ if is_running_in_gha():
855
+ print('GitHub actions detected, bypassing MPS backend')
856
+ else:
857
+ print('Using MPS device')
858
+ self.device = 'mps'
859
+ except AttributeError:
860
+ pass
861
+
862
+ # AddaxAI depends on this printout, don't remove it
863
+ print('PTDetector using device {}'.format(str(self.device).lower()))
864
+
865
+ try:
866
+ self.model = PTDetector._load_model(model_path,
867
+ device=self.device,
868
+ compatibility_mode=self.compatibility_mode)
869
+
870
+ except Exception as e:
871
+ # In a very esoteric scenario where an old version of YOLOv5 is used to run
872
+ # newer models, we run into an issue because the "Model" class became
873
+ # "DetectionModel". New YOLOv5 code handles this case by just setting them
874
+ # to be the same, so doing that externally doesn't seem *that* rude.
875
+ if "Can't get attribute 'DetectionModel'" in str(e):
876
+ print('Forward-compatibility issue detected, patching')
877
+ from models import yolo # type: ignore
878
+ yolo.DetectionModel = yolo.Model
879
+ self.model = PTDetector._load_model(model_path,
880
+ device=self.device,
881
+ compatibility_mode=self.compatibility_mode,
882
+ verbose=verbose)
883
+ else:
884
+ raise
885
+ if (self.device != 'cpu'):
886
+ if verbose:
887
+ print('Sending model to GPU')
888
+ self.model.to(self.device)
889
+
890
+
891
+ @staticmethod
892
+ def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
893
+
894
+ if verbose:
895
+ print(f'Using PyTorch version {torch.__version__}')
896
+
897
+ # I get quirky errors when loading YOLOv5 models on MPS hardware using
898
+ # map_location, but this is the recommended method, so I'm using it everywhere
899
+ # other than MPS devices.
900
+ use_map_location = (device != 'mps')
901
+
902
+ if use_map_location:
903
+ try:
904
+ checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
905
+ # For a transitional period, we want to support torch 1.1x, where the weights_only
906
+ # parameter doesn't exist
907
+ except Exception as e:
908
+ if "'weights_only' is an invalid keyword" in str(e):
909
+ checkpoint = torch.load(model_pt_path, map_location=device)
910
+ else:
911
+ raise
912
+ else:
913
+ try:
914
+ checkpoint = torch.load(model_pt_path, weights_only=False)
915
+ # For a transitional period, we want to support torch 1.1x, where the weights_only
916
+ # parameter doesn't exist
917
+ except Exception as e:
918
+ if "'weights_only' is an invalid keyword" in str(e):
919
+ checkpoint = torch.load(model_pt_path)
920
+ else:
921
+ raise
922
+
923
+ # Compatibility fix that allows us to load older YOLOv5 models with
924
+ # newer versions of YOLOv5/PT
925
+ for m in checkpoint['model'].modules():
926
+ t = type(m)
927
+ if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
928
+ m.recompute_scale_factor = None
929
+
930
+ # Calling .to(device) should no longer be necessary now that we're using map_location=device
931
+ # model = checkpoint['model'].float().fuse().eval().to(device)
932
+ model = checkpoint['model'].float().fuse().eval()
933
+
934
+ return model
935
+
936
+ # ...def _load_model(...)
937
+
938
+
939
+ def preprocess_image(self,
940
+ img_original,
941
+ image_id='unknown',
942
+ image_size=None,
943
+ verbose=False):
944
+ """
945
+ Prepare an image for detection, including scaling and letterboxing.
946
+
947
+ Args:
948
+ img_original (Image or np.array): the image on which we should run the detector, with
949
+ EXIF rotation already handled
950
+ image_id (str, optional): a path to identify the image; will be in the "file" field
951
+ of the output object
952
+ detection_threshold (float, optional): only detections above this confidence threshold
953
+ will be included in the return value
954
+ image_size (int, optional): image size (long side) to use for inference, or None to
955
+ use the default size specified at the time the model was loaded
956
+ verbose (bool, optional): enable additional debug output
957
+
958
+ Returns:
959
+ dict: dict with fields:
960
+ - file (filename)
961
+ - img (the preprocessed np.array)
962
+ - img_original (the input image before preprocessing, as an np.array)
963
+ - img_original_pil (the input image before preprocessing, as a PIL Image)
964
+ - target_shape (the 2D shape to which the image was resized during preprocessing)
965
+ - scaling_shape (the 2D original size, for normalizing coordinates later)
966
+ - letterbox_ratio (letterbox parameter used for normalizing coordinates later)
967
+ - letterbox_pad (letterbox parameter used for normalizing coordinates later)
968
+ """
969
+
970
+ # Prepare return dict
971
+ result = {'file': image_id }
972
+
973
+ # Store the PIL version of the original image, the caller may want to use
974
+ # it for metadata extraction later.
975
+ img_original_pil = None
976
+
977
+ # If we were given a PIL image, rather than a numpy array
978
+ if not isinstance(img_original,np.ndarray):
979
+ img_original_pil = img_original
980
+ img_original = np.asarray(img_original)
981
+
982
+ # PIL images are RGB already
983
+ # img_original = img_original[:, :, ::-1]
984
+
985
+ # Save the original shape for scaling boxes later
986
+ scaling_shape = img_original.shape
987
+
988
+ # If the caller is requesting a specific target size...
989
+ if image_size is not None:
990
+
991
+ assert isinstance(image_size,int)
992
+
993
+ if not self.printed_image_size_warning:
994
+ print('Using user-supplied image size {}'.format(image_size))
995
+ self.printed_image_size_warning = True
996
+
997
+ # Otherwise resize to self.default_image_size
998
+ else:
999
+
1000
+ image_size = self.default_image_size
1001
+ self.printed_image_size_warning = False
1002
+
1003
+ # ...if the caller has specified an image size
1004
+
1005
+ # In "classic mode", we only do the letterboxing resize, we don't do an
1006
+ # additional initial resizing operation
1007
+ if 'classic' in self.compatibility_mode:
1008
+
1009
+ resize_ratio = 1.0
1010
+
1011
+ # Resize the image so the long side matches the target image size. This is not
1012
+ # letterboxing (i.e., padding) yet, just resizing.
1013
+ else:
1014
+
1015
+ use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
1016
+
1017
+ h,w = img_original.shape[:2]
1018
+ resize_ratio = image_size / max(h,w)
1019
+
1020
+ # Only resize if we have to
1021
+ if resize_ratio != 1:
1022
+
1023
+ # Match what yolov5 does: use linear interpolation for upsizing;
1024
+ # area interpolation for downsizing
1025
+ if resize_ratio > 1:
1026
+ interpolation_method = cv2.INTER_LINEAR
1027
+ else:
1028
+ interpolation_method = cv2.INTER_AREA
1029
+
1030
+ if use_ceil_for_resize:
1031
+ target_w = math.ceil(w * resize_ratio)
1032
+ target_h = math.ceil(h * resize_ratio)
1033
+ else:
1034
+ target_w = int(w * resize_ratio)
1035
+ target_h = int(h * resize_ratio)
1036
+
1037
+ img_original = cv2.resize(
1038
+ img_original, (target_w, target_h),
1039
+ interpolation=interpolation_method)
1040
+
1041
+ if 'classic' in self.compatibility_mode:
1042
+
1043
+ letterbox_auto = True
1044
+ letterbox_scaleup = True
1045
+ target_shape = image_size
1046
+
1047
+ else:
1048
+
1049
+ letterbox_auto = False
1050
+ letterbox_scaleup = False
1051
+
1052
+ # The padding to apply as a fraction of the stride size
1053
+ pad = 0.5
1054
+
1055
+ # Resize to a multiple of the model stride
1056
+ #
1057
+ # This is how we would determine the stride if we knew the model had been loaded:
1058
+ #
1059
+ # model_stride = int(self.model.stride.max())
1060
+ #
1061
+ # ...but because we do this on preprocessing workers now, we try to avoid loading the model
1062
+ # just for preprocessing, and we assume the stride was determined at the time the PTDetector
1063
+ # object was created.
1064
+ try:
1065
+ model_stride = int(self.model.stride.max())
1066
+ if model_stride != self.letterbox_stride:
1067
+ print('*** Warning: model stride is {}, stride at construction time was {} ***'.format(
1068
+ model_stride,self.letterbox_stride
1069
+ ))
1070
+ except Exception:
1071
+ pass
1072
+
1073
+ model_stride = self.letterbox_stride
1074
+ max_dimension = max(img_original.shape)
1075
+ normalized_shape = [img_original.shape[0] / max_dimension,
1076
+ img_original.shape[1] / max_dimension]
1077
+ target_shape = np.ceil(((np.array(normalized_shape) * image_size) / model_stride) + \
1078
+ pad).astype(int) * model_stride
1079
+
1080
+ # Now we letterbox, which is just padding, since we've already resized
1081
+ img,letterbox_ratio,letterbox_pad = letterbox(img_original,
1082
+ new_shape=target_shape,
1083
+ stride=self.letterbox_stride,
1084
+ auto=letterbox_auto,
1085
+ scaleFill=False,
1086
+ scaleup=letterbox_scaleup)
1087
+
1088
+ result['img_processed'] = img
1089
+ result['img_original'] = img_original
1090
+ result['img_original_pil'] = img_original_pil
1091
+ result['target_shape'] = target_shape
1092
+ result['scaling_shape'] = scaling_shape
1093
+ result['letterbox_ratio'] = letterbox_ratio
1094
+ result['letterbox_pad'] = letterbox_pad
1095
+ return result
1096
+
1097
+ # ...def preprocess_image(...)
1098
+
1099
+
1100
+ def generate_detections_one_batch(self,
1101
+ img_original,
1102
+ image_id=None,
1103
+ detection_threshold=0.00001,
1104
+ image_size=None,
1105
+ augment=False,
1106
+ verbose=False):
1107
+ """
1108
+ Run a detector on a batch of images.
1109
+
1110
+ Args:
1111
+ img_original (list): list of images (Image, np.array, or dict) on which we should run the detector, with
1112
+ EXIF rotation already handled, or dicts representing preprocessed images with associated
1113
+ letterbox parameters
1114
+ image_id (list or None): list of paths to identify the images; will be in the "file" field
1115
+ of the output objects. Will be ignored when img_original contains preprocessed dicts.
1116
+ detection_threshold (float, optional): only detections above this confidence threshold
1117
+ will be included in the return value
1118
+ image_size (int, optional): image size (long side) to use for inference, or None to
1119
+ use the default size specified at the time the model was loaded
1120
+ augment (bool, optional): enable (implementation-specific) image augmentation
1121
+ verbose (bool, optional): enable additional debug output
1122
+
1123
+ Returns:
1124
+ list: a list of dictionaries, each with the following fields:
1125
+ - 'file' (filename, always present)
1126
+ - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
1127
+ - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
1128
+ - 'failure' (a failure string, or None if everything went fine)
1129
+ """
1130
+
1131
+ # Validate inputs
1132
+ if not isinstance(img_original, list):
1133
+ raise ValueError('img_original must be a list for batch processing')
1134
+
1135
+ if len(img_original) == 0:
1136
+ return []
1137
+
1138
+ # Check input consistency
1139
+ if isinstance(img_original[0], dict):
1140
+ # All items in img_original should be preprocessed dicts
1141
+ for i, img in enumerate(img_original):
1142
+ if not isinstance(img, dict):
1143
+ raise ValueError(f'Mixed input types in batch: item {i} is not a dict, but item 0 is a dict')
1144
+ else:
1145
+ # All items in img_original should be PIL/numpy images, and image_id should be a list of strings
1146
+ if image_id is None:
1147
+ raise ValueError('image_id must be a list when img_original contains PIL/numpy images')
1148
+ if not isinstance(image_id, list):
1149
+ raise ValueError('image_id must be a list for batch processing')
1150
+ if len(image_id) != len(img_original):
1151
+ raise ValueError(
1152
+ 'Length mismatch: img_original has {} items, image_id has {} items'.format(
1153
+ len(img_original),len(image_id)))
1154
+ for i_img, img in enumerate(img_original):
1155
+ if isinstance(img, dict):
1156
+ raise ValueError(
1157
+ 'Mixed input types in batch: item {} is a dict, but item 0 is not a dict'.format(
1158
+ i_img))
1159
+
1160
+ if detection_threshold is None:
1161
+ detection_threshold = 0.0
1162
+
1163
+ batch_size = len(img_original)
1164
+ results = [None] * batch_size
1165
+
1166
+ # Preprocess all images, handling failures
1167
+ preprocessed_images = []
1168
+ preprocessing_failed_indices = set()
1169
+
1170
+ for i_img, img in enumerate(img_original):
1171
+
1172
+ try:
1173
+ if isinstance(img, dict):
1174
+ # Already preprocessed
1175
+ image_info = img
1176
+ current_image_id = image_info['file']
1177
+ else:
1178
+ # Need to preprocess
1179
+ current_image_id = image_id[i_img]
1180
+ image_info = self.preprocess_image(
1181
+ img_original=img,
1182
+ image_id=current_image_id,
1183
+ image_size=image_size,
1184
+ verbose=verbose)
1185
+
1186
+ preprocessed_images.append((i_img, image_info, current_image_id))
1187
+
1188
+ except Exception as e:
1189
+ print('Warning: preprocessing failed for image {}: {}'.format(
1190
+ image_id[i_img] if image_id else f'index_{i_img}', str(e)))
1191
+
1192
+ preprocessing_failed_indices.add(i_img)
1193
+ current_image_id = image_id[i_img] if image_id else f'index_{i_img}'
1194
+ results[i_img] = {
1195
+ 'file': current_image_id,
1196
+ 'detections': None,
1197
+ 'failure': FAILURE_IMAGE_OPEN
1198
+ }
1199
+
1200
+ # ...for each image in this batch
1201
+
1202
+ # Group preprocessed images by actual processed image shape for batching
1203
+ shape_groups = {}
1204
+ for original_idx, image_info, current_image_id in preprocessed_images:
1205
+ # Use the actual processed image shape for grouping, not target_shape
1206
+ actual_shape = tuple(image_info['img_processed'].shape)
1207
+ if actual_shape not in shape_groups:
1208
+ shape_groups[actual_shape] = []
1209
+ shape_groups[actual_shape].append((original_idx, image_info, current_image_id))
1210
+
1211
+ # Process each shape group as a batch
1212
+ for target_shape, group_items in shape_groups.items():
1213
+
1214
+ try:
1215
+ self._process_batch_group(group_items, results, detection_threshold, augment, verbose)
1216
+ except Exception as e:
1217
+ # If inference fails for the entire batch, mark all images in this batch as failed
1218
+ print('Warning: batch inference failed for shape {}: {}'.format(target_shape, str(e)))
1219
+
1220
+ for original_idx, image_info, current_image_id in group_items:
1221
+ results[original_idx] = {
1222
+ 'file': current_image_id,
1223
+ 'detections': None,
1224
+ 'failure': FAILURE_INFER
1225
+ }
1226
+
1227
+ # ...for each shape group
1228
+ return results
1229
+
1230
+ # ...def generate_detections_one_batch(...)
1231
+
1232
+
1233
+ def _process_batch_group(self, group_items, results, detection_threshold, augment, verbose):
1234
+ """
1235
+ Process a group of images with the same target shape as a single batch.
1236
+
1237
+ Args:
1238
+ group_items (list): List of (original_idx, image_info, current_image_id) tuples
1239
+ results (list): Results list to populate (modified in place)
1240
+ detection_threshold (float): Detection confidence threshold
1241
+ augment (bool): Enable augmentation
1242
+ verbose (bool): Enable verbose output
1243
+
1244
+ Returns:
1245
+ list of dict: list of dictionaries the same length as group_items, with fields 'file',
1246
+ 'detections', 'max_detection_conf'.
1247
+ """
1248
+
1249
+ if len(group_items) == 0:
1250
+ return
1251
+
1252
+ # Extract batch data
1253
+ batch_images = []
1254
+ batch_metadata = []
1255
+
1256
+ # For each image in this batch...
1257
+ for original_idx, image_info, current_image_id in group_items:
1258
+
1259
+ img = image_info['img_processed']
1260
+
1261
+ # Convert HWC to CHW and prepare tensor
1262
+ img_tensor = img.transpose((2, 0, 1))
1263
+ img_tensor = np.ascontiguousarray(img_tensor)
1264
+ img_tensor = torch.from_numpy(img_tensor)
1265
+ batch_images.append(img_tensor)
1266
+
1267
+ metadata = {
1268
+ 'original_idx': original_idx,
1269
+ 'current_image_id': current_image_id,
1270
+ 'scaling_shape': image_info['scaling_shape'],
1271
+ 'letterbox_pad': image_info['letterbox_pad'],
1272
+ 'img_original': image_info['img_original']
1273
+ }
1274
+ batch_metadata.append(metadata)
1275
+
1276
+ # ...for each image in this batch
1277
+
1278
+ # Stack images into a batch tensor
1279
+ batch_tensor = torch.stack(batch_images)
1280
+
1281
+ batch_tensor = batch_tensor.float()
1282
+ batch_tensor /= 255.0
1283
+
1284
+ batch_tensor = batch_tensor.to(self.device)
1285
+ if self.half_precision:
1286
+ batch_tensor = batch_tensor.half()
1287
+
1288
+ # Run the model on the batch
1289
+ pred = self.model(batch_tensor, augment=augment)[0]
1290
+
1291
+ # Configure NMS parameters
1292
+ if 'classic' in self.compatibility_mode:
1293
+ nms_iou_thres = 0.45
1294
+ else:
1295
+ nms_iou_thres = 0.6
1296
+
1297
+ use_library_nms = False
1298
+
1299
+ # Model output format changed in recent ultralytics packages, and the nms implementation
1300
+ # in this module hasn't been updated to handle that format yet.
1301
+ if (yolo_model_type_imported is not None) and (yolo_model_type_imported == 'ultralytics'):
1302
+ use_library_nms = True
1303
+
1304
+ if use_library_nms:
1305
+ pred = non_max_suppression(prediction=pred,
1306
+ conf_thres=detection_threshold,
1307
+ iou_thres=nms_iou_thres,
1308
+ agnostic=False,
1309
+ multi_label=False)
1310
+ else:
1311
+ pred = nms(prediction=pred,
1312
+ conf_thres=detection_threshold,
1313
+ iou_thres=nms_iou_thres)
1314
+
1315
+ assert isinstance(pred, list)
1316
+ assert len(pred) == len(batch_metadata), \
1317
+ 'Mismatch between prediction length {} and batch size {}'.format(
1318
+ len(pred),len(batch_metadata))
1319
+
1320
+ # Process each image's detections
1321
+ for i_image, det in enumerate(pred):
1322
+
1323
+ metadata = batch_metadata[i_image]
1324
+ original_idx = metadata['original_idx']
1325
+ current_image_id = metadata['current_image_id']
1326
+ scaling_shape = metadata['scaling_shape']
1327
+ letterbox_pad = metadata['letterbox_pad']
1328
+ img_original = metadata['img_original']
1329
+
1330
+ detections = []
1331
+ max_conf = 0.0
1332
+
1333
+ if len(det) > 0:
1334
+
1335
+ # Prepare scaling parameters
1336
+ gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
1337
+
1338
+ if 'classic' in self.compatibility_mode:
1339
+ ratio = None
1340
+ ratio_pad = None
1341
+ else:
1342
+ ratio = (img_original.shape[0]/scaling_shape[0],
1343
+ img_original.shape[1]/scaling_shape[1])
1344
+ ratio_pad = (ratio, letterbox_pad)
1345
+
1346
+ # Rescale boxes
1347
+ if 'classic' in self.compatibility_mode:
1348
+ det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], img_original.shape).round()
1349
+ else:
1350
+ det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
1351
+
1352
+ # Process each detection
1353
+ for *xyxy, conf, cls in reversed(det):
1354
+ if conf < detection_threshold:
1355
+ continue
1356
+
1357
+ # Convert to YOLO format then to MD format
1358
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
1359
+ api_box = ct_utils.convert_yolo_to_xywh(xywh)
1360
+
1361
+ if 'classic' in self.compatibility_mode:
1362
+ api_box = ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
1363
+ conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
1364
+ else:
1365
+ api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
1366
+ conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
1367
+
1368
+ if not self.use_model_native_classes:
1369
+ cls = int(cls.tolist()) + 1
1370
+ if cls not in (1, 2, 3):
1371
+ raise KeyError(f'{cls} is not a valid class.')
1372
+ else:
1373
+ cls = int(cls.tolist())
1374
+
1375
+ detections.append({
1376
+ 'category': str(cls),
1377
+ 'conf': conf,
1378
+ 'bbox': api_box
1379
+ })
1380
+ max_conf = max(max_conf, conf)
1381
+
1382
+ # ...for each detection
1383
+
1384
+ # ...if there are > 0 detections
1385
+
1386
+ # Store result for this image
1387
+ results[original_idx] = {
1388
+ 'file': current_image_id,
1389
+ 'detections': detections,
1390
+ 'max_detection_conf': max_conf
1391
+ }
1392
+
1393
+ # ...for each image
1394
+
1395
+ # ...def _process_batch_group(...)
1396
+
1397
+ def generate_detections_one_image(self,
1398
+ img_original,
1399
+ image_id='unknown',
1400
+ detection_threshold=0.00001,
1401
+ image_size=None,
1402
+ augment=False,
1403
+ verbose=False):
1404
+ """
1405
+ Run a detector on an image (wrapper around batch function).
1406
+
1407
+ Args:
1408
+ img_original (Image, np.array, or dict): the image on which we should run the detector, with
1409
+ EXIF rotation already handled, or a dict representing a preprocessed image with associated
1410
+ letterbox parameters
1411
+ image_id (str, optional): a path to identify the image; will be in the "file" field
1412
+ of the output object
1413
+ detection_threshold (float, optional): only detections above this confidence threshold
1414
+ will be included in the return value
1415
+ image_size (int, optional): image size (long side) to use for inference, or None to
1416
+ use the default size specified at the time the model was loaded
1417
+ augment (bool, optional): enable (implementation-specific) image augmentation
1418
+ verbose (bool, optional): enable additional debug output
1419
+
1420
+ Returns:
1421
+ dict: a dictionary with the following fields:
1422
+ - 'file' (filename, always present)
1423
+ - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
1424
+ - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
1425
+ - 'failure' (a failure string, or None if everything went fine)
1426
+ """
1427
+
1428
+ # Prepare batch inputs
1429
+ if isinstance(img_original, dict):
1430
+ batch_results = self.generate_detections_one_batch(
1431
+ img_original=[img_original],
1432
+ image_id=None,
1433
+ detection_threshold=detection_threshold,
1434
+ image_size=image_size,
1435
+ augment=augment,
1436
+ verbose=verbose)
1437
+ else:
1438
+ batch_results = self.generate_detections_one_batch(
1439
+ img_original=[img_original],
1440
+ image_id=[image_id],
1441
+ detection_threshold=detection_threshold,
1442
+ image_size=image_size,
1443
+ augment=augment,
1444
+ verbose=verbose)
1445
+
1446
+ # Return the single result
1447
+ return batch_results[0]
1448
+
1449
+ # ...def generate_detections_one_image(...)
1450
+
1451
+ # ...class PTDetector