megadetector 5.0.27__py3-none-any.whl → 5.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (176) hide show
  1. megadetector/api/batch_processing/api_core/batch_service/score.py +4 -5
  2. megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +1 -1
  3. megadetector/api/batch_processing/api_support/summarize_daily_activity.py +1 -1
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +2 -2
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +1 -1
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +1 -1
  7. megadetector/api/synchronous/api_core/tests/load_test.py +2 -3
  8. megadetector/classification/aggregate_classifier_probs.py +3 -3
  9. megadetector/classification/analyze_failed_images.py +5 -5
  10. megadetector/classification/cache_batchapi_outputs.py +5 -5
  11. megadetector/classification/create_classification_dataset.py +11 -12
  12. megadetector/classification/crop_detections.py +10 -10
  13. megadetector/classification/csv_to_json.py +8 -8
  14. megadetector/classification/detect_and_crop.py +13 -15
  15. megadetector/classification/evaluate_model.py +7 -7
  16. megadetector/classification/identify_mislabeled_candidates.py +6 -6
  17. megadetector/classification/json_to_azcopy_list.py +1 -1
  18. megadetector/classification/json_validator.py +29 -32
  19. megadetector/classification/map_classification_categories.py +9 -9
  20. megadetector/classification/merge_classification_detection_output.py +12 -9
  21. megadetector/classification/prepare_classification_script.py +19 -19
  22. megadetector/classification/prepare_classification_script_mc.py +23 -23
  23. megadetector/classification/run_classifier.py +4 -4
  24. megadetector/classification/save_mislabeled.py +6 -6
  25. megadetector/classification/train_classifier.py +1 -1
  26. megadetector/classification/train_classifier_tf.py +9 -9
  27. megadetector/classification/train_utils.py +10 -10
  28. megadetector/data_management/annotations/annotation_constants.py +1 -1
  29. megadetector/data_management/camtrap_dp_to_coco.py +45 -45
  30. megadetector/data_management/cct_json_utils.py +101 -101
  31. megadetector/data_management/cct_to_md.py +49 -49
  32. megadetector/data_management/cct_to_wi.py +33 -33
  33. megadetector/data_management/coco_to_labelme.py +75 -75
  34. megadetector/data_management/coco_to_yolo.py +189 -189
  35. megadetector/data_management/databases/add_width_and_height_to_db.py +3 -2
  36. megadetector/data_management/databases/combine_coco_camera_traps_files.py +38 -38
  37. megadetector/data_management/databases/integrity_check_json_db.py +202 -188
  38. megadetector/data_management/databases/subset_json_db.py +33 -33
  39. megadetector/data_management/generate_crops_from_cct.py +38 -38
  40. megadetector/data_management/get_image_sizes.py +54 -49
  41. megadetector/data_management/labelme_to_coco.py +130 -124
  42. megadetector/data_management/labelme_to_yolo.py +78 -72
  43. megadetector/data_management/lila/create_lila_blank_set.py +81 -83
  44. megadetector/data_management/lila/create_lila_test_set.py +32 -31
  45. megadetector/data_management/lila/create_links_to_md_results_files.py +18 -18
  46. megadetector/data_management/lila/download_lila_subset.py +21 -24
  47. megadetector/data_management/lila/generate_lila_per_image_labels.py +91 -91
  48. megadetector/data_management/lila/get_lila_annotation_counts.py +30 -30
  49. megadetector/data_management/lila/get_lila_image_counts.py +22 -22
  50. megadetector/data_management/lila/lila_common.py +70 -70
  51. megadetector/data_management/lila/test_lila_metadata_urls.py +13 -14
  52. megadetector/data_management/mewc_to_md.py +339 -340
  53. megadetector/data_management/ocr_tools.py +258 -252
  54. megadetector/data_management/read_exif.py +232 -223
  55. megadetector/data_management/remap_coco_categories.py +26 -26
  56. megadetector/data_management/remove_exif.py +31 -20
  57. megadetector/data_management/rename_images.py +187 -187
  58. megadetector/data_management/resize_coco_dataset.py +41 -41
  59. megadetector/data_management/speciesnet_to_md.py +41 -41
  60. megadetector/data_management/wi_download_csv_to_coco.py +55 -55
  61. megadetector/data_management/yolo_output_to_md_output.py +117 -120
  62. megadetector/data_management/yolo_to_coco.py +195 -188
  63. megadetector/detection/change_detection.py +831 -0
  64. megadetector/detection/process_video.py +341 -338
  65. megadetector/detection/pytorch_detector.py +308 -266
  66. megadetector/detection/run_detector.py +186 -166
  67. megadetector/detection/run_detector_batch.py +366 -364
  68. megadetector/detection/run_inference_with_yolov5_val.py +328 -325
  69. megadetector/detection/run_tiled_inference.py +312 -253
  70. megadetector/detection/tf_detector.py +24 -24
  71. megadetector/detection/video_utils.py +291 -283
  72. megadetector/postprocessing/add_max_conf.py +15 -11
  73. megadetector/postprocessing/categorize_detections_by_size.py +44 -44
  74. megadetector/postprocessing/classification_postprocessing.py +808 -311
  75. megadetector/postprocessing/combine_batch_outputs.py +20 -21
  76. megadetector/postprocessing/compare_batch_results.py +528 -517
  77. megadetector/postprocessing/convert_output_format.py +97 -97
  78. megadetector/postprocessing/create_crop_folder.py +220 -147
  79. megadetector/postprocessing/detector_calibration.py +173 -168
  80. megadetector/postprocessing/generate_csv_report.py +508 -0
  81. megadetector/postprocessing/load_api_results.py +25 -22
  82. megadetector/postprocessing/md_to_coco.py +129 -98
  83. megadetector/postprocessing/md_to_labelme.py +89 -83
  84. megadetector/postprocessing/md_to_wi.py +40 -40
  85. megadetector/postprocessing/merge_detections.py +87 -114
  86. megadetector/postprocessing/postprocess_batch_results.py +319 -302
  87. megadetector/postprocessing/remap_detection_categories.py +36 -36
  88. megadetector/postprocessing/render_detection_confusion_matrix.py +205 -199
  89. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +57 -57
  90. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +27 -28
  91. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +702 -677
  92. megadetector/postprocessing/separate_detections_into_folders.py +226 -211
  93. megadetector/postprocessing/subset_json_detector_output.py +265 -262
  94. megadetector/postprocessing/top_folders_to_bottom.py +45 -45
  95. megadetector/postprocessing/validate_batch_results.py +70 -70
  96. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +52 -52
  97. megadetector/taxonomy_mapping/map_new_lila_datasets.py +15 -15
  98. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +14 -14
  99. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +66 -69
  100. megadetector/taxonomy_mapping/retrieve_sample_image.py +16 -16
  101. megadetector/taxonomy_mapping/simple_image_download.py +8 -8
  102. megadetector/taxonomy_mapping/species_lookup.py +33 -33
  103. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +14 -14
  104. megadetector/taxonomy_mapping/taxonomy_graph.py +11 -11
  105. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +13 -13
  106. megadetector/utils/azure_utils.py +22 -22
  107. megadetector/utils/ct_utils.py +1019 -200
  108. megadetector/utils/directory_listing.py +21 -77
  109. megadetector/utils/gpu_test.py +22 -22
  110. megadetector/utils/md_tests.py +541 -518
  111. megadetector/utils/path_utils.py +1511 -406
  112. megadetector/utils/process_utils.py +41 -41
  113. megadetector/utils/sas_blob_utils.py +53 -49
  114. megadetector/utils/split_locations_into_train_val.py +73 -60
  115. megadetector/utils/string_utils.py +147 -26
  116. megadetector/utils/url_utils.py +463 -173
  117. megadetector/utils/wi_utils.py +2629 -2868
  118. megadetector/utils/write_html_image_list.py +137 -137
  119. megadetector/visualization/plot_utils.py +21 -21
  120. megadetector/visualization/render_images_with_thumbnails.py +37 -73
  121. megadetector/visualization/visualization_utils.py +424 -404
  122. megadetector/visualization/visualize_db.py +197 -190
  123. megadetector/visualization/visualize_detector_output.py +126 -98
  124. {megadetector-5.0.27.dist-info → megadetector-5.0.29.dist-info}/METADATA +6 -3
  125. megadetector-5.0.29.dist-info/RECORD +163 -0
  126. {megadetector-5.0.27.dist-info → megadetector-5.0.29.dist-info}/WHEEL +1 -1
  127. megadetector/data_management/importers/add_nacti_sizes.py +0 -52
  128. megadetector/data_management/importers/add_timestamps_to_icct.py +0 -79
  129. megadetector/data_management/importers/animl_results_to_md_results.py +0 -158
  130. megadetector/data_management/importers/auckland_doc_test_to_json.py +0 -373
  131. megadetector/data_management/importers/auckland_doc_to_json.py +0 -201
  132. megadetector/data_management/importers/awc_to_json.py +0 -191
  133. megadetector/data_management/importers/bellevue_to_json.py +0 -272
  134. megadetector/data_management/importers/cacophony-thermal-importer.py +0 -793
  135. megadetector/data_management/importers/carrizo_shrubfree_2018.py +0 -269
  136. megadetector/data_management/importers/carrizo_trail_cam_2017.py +0 -289
  137. megadetector/data_management/importers/cct_field_adjustments.py +0 -58
  138. megadetector/data_management/importers/channel_islands_to_cct.py +0 -913
  139. megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
  140. megadetector/data_management/importers/eMammal/eMammal_helpers.py +0 -249
  141. megadetector/data_management/importers/eMammal/make_eMammal_json.py +0 -223
  142. megadetector/data_management/importers/ena24_to_json.py +0 -276
  143. megadetector/data_management/importers/filenames_to_json.py +0 -386
  144. megadetector/data_management/importers/helena_to_cct.py +0 -283
  145. megadetector/data_management/importers/idaho-camera-traps.py +0 -1407
  146. megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
  147. megadetector/data_management/importers/import_desert_lion_conservation_camera_traps.py +0 -387
  148. megadetector/data_management/importers/jb_csv_to_json.py +0 -150
  149. megadetector/data_management/importers/mcgill_to_json.py +0 -250
  150. megadetector/data_management/importers/missouri_to_json.py +0 -490
  151. megadetector/data_management/importers/nacti_fieldname_adjustments.py +0 -79
  152. megadetector/data_management/importers/noaa_seals_2019.py +0 -181
  153. megadetector/data_management/importers/osu-small-animals-to-json.py +0 -364
  154. megadetector/data_management/importers/pc_to_json.py +0 -365
  155. megadetector/data_management/importers/plot_wni_giraffes.py +0 -123
  156. megadetector/data_management/importers/prepare_zsl_imerit.py +0 -131
  157. megadetector/data_management/importers/raic_csv_to_md_results.py +0 -416
  158. megadetector/data_management/importers/rspb_to_json.py +0 -356
  159. megadetector/data_management/importers/save_the_elephants_survey_A.py +0 -320
  160. megadetector/data_management/importers/save_the_elephants_survey_B.py +0 -329
  161. megadetector/data_management/importers/snapshot_safari_importer.py +0 -758
  162. megadetector/data_management/importers/snapshot_serengeti_lila.py +0 -1067
  163. megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
  164. megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
  165. megadetector/data_management/importers/sulross_get_exif.py +0 -65
  166. megadetector/data_management/importers/timelapse_csv_set_to_json.py +0 -490
  167. megadetector/data_management/importers/ubc_to_json.py +0 -399
  168. megadetector/data_management/importers/umn_to_json.py +0 -507
  169. megadetector/data_management/importers/wellington_to_json.py +0 -263
  170. megadetector/data_management/importers/wi_to_json.py +0 -442
  171. megadetector/data_management/importers/zamba_results_to_md_results.py +0 -180
  172. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +0 -101
  173. megadetector/data_management/lila/add_locations_to_nacti.py +0 -151
  174. megadetector-5.0.27.dist-info/RECORD +0 -208
  175. {megadetector-5.0.27.dist-info → megadetector-5.0.29.dist-info}/licenses/LICENSE +0 -0
  176. {megadetector-5.0.27.dist-info → megadetector-5.0.29.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import shutil
17
17
  import traceback
18
18
  import uuid
19
19
  import json
20
+ import inspect
20
21
 
21
22
  import cv2
22
23
  import torch
@@ -27,6 +28,7 @@ from megadetector.detection.run_detector import \
27
28
  get_detector_version_from_model_file, \
28
29
  known_models
29
30
  from megadetector.utils.ct_utils import parse_bool_string
31
+ from megadetector.utils.ct_utils import to_bool
30
32
  from megadetector.utils import ct_utils
31
33
 
32
34
  # We support a few ways of accessing the YOLOv5 dependencies:
@@ -54,7 +56,7 @@ def _get_model_type_for_model(model_file,
54
56
  verbose=False):
55
57
  """
56
58
  Determine the model type (i.e., the inference library we need to use) for a .pt file.
57
-
59
+
58
60
  Args:
59
61
  model_file (str): the model file to read
60
62
  prefer_model_type_source (str, optional): how should we handle the (very unlikely)
@@ -64,28 +66,28 @@ def _get_model_type_for_model(model_file,
64
66
  default_model_type (str, optional): return value for the case where we can't find
65
67
  appropriate metadata in the file or in the global table.
66
68
  verbose (bool, optional): enable additional debug output
67
-
69
+
68
70
  Returns:
69
71
  str: the model type indicated for this model
70
72
  """
71
-
73
+
72
74
  model_info = read_metadata_from_megadetector_model_file(model_file)
73
-
75
+
74
76
  # Check whether the model file itself specified a model type
75
77
  model_type_from_model_file_metadata = None
76
-
77
- if model_info is not None and 'model_type' in model_info:
78
+
79
+ if model_info is not None and 'model_type' in model_info:
78
80
  model_type_from_model_file_metadata = model_info['model_type']
79
81
  if verbose:
80
82
  print('Parsed model type {} from model {}'.format(
81
83
  model_type_from_model_file_metadata,
82
84
  model_file))
83
-
85
+
84
86
  model_type_from_model_version = None
85
-
87
+
86
88
  # Check whether this is a known model version with a specific model type
87
89
  model_version_from_file = get_detector_version_from_model_file(model_file)
88
-
90
+
89
91
  if model_version_from_file is not None and model_version_from_file in known_models:
90
92
  model_info = known_models[model_version_from_file]
91
93
  if 'model_type' in model_info:
@@ -93,35 +95,35 @@ def _get_model_type_for_model(model_file,
93
95
  if verbose:
94
96
  print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
95
97
  else:
96
- model_type_from_model_version = None
97
-
98
+ model_type_from_model_version = None
99
+
98
100
  if model_type_from_model_file_metadata is None and \
99
101
  model_type_from_model_version is None:
100
102
  if verbose:
101
103
  print('Could not determine model type for {}, assuming {}'.format(
102
104
  model_file,default_model_type))
103
105
  model_type = default_model_type
104
-
106
+
105
107
  elif model_type_from_model_file_metadata is not None and \
106
108
  model_type_from_model_version is not None:
107
109
  if model_type_from_model_version == model_type_from_model_file_metadata:
108
110
  model_type = model_type_from_model_file_metadata
109
111
  else:
110
- print('Waring: model type from model version is {}, from file metadata is {}'.format(
112
+ print('Warning: model type from model version is {}, from file metadata is {}'.format(
111
113
  model_type_from_model_version,model_type_from_model_file_metadata))
112
114
  if prefer_model_type_source == 'table':
113
115
  model_type = model_type_from_model_file_metadata
114
116
  else:
115
117
  model_type = model_type_from_model_version
116
-
118
+
117
119
  elif model_type_from_model_file_metadata is not None:
118
-
120
+
119
121
  model_type = model_type_from_model_file_metadata
120
-
122
+
121
123
  elif model_type_from_model_version is not None:
122
-
124
+
123
125
  model_type = model_type_from_model_version
124
-
126
+
125
127
  return model_type
126
128
 
127
129
  # ...def _get_model_type_for_model(...)
@@ -134,7 +136,7 @@ def _initialize_yolo_imports_for_model(model_file,
134
136
  verbose=False):
135
137
  """
136
138
  Initialize the appropriate YOLO imports for a model file.
137
-
139
+
138
140
  Args:
139
141
  model_file (str): The model file for which we're loading support
140
142
  prefer_model_type_source (str, optional): how should we handle the (very unlikely)
@@ -142,18 +144,17 @@ def _initialize_yolo_imports_for_model(model_file,
142
144
  type table says something else. Should be "table" (trust the table) or "file"
143
145
  (trust the file).
144
146
  default_model_type (str, optional): return value for the case where we can't find
145
- appropriate metadata in the file or in the global table.
147
+ appropriate metadata in the file or in the global table.
146
148
  detector_options (dict, optional): dictionary of detector options that mean
147
149
  different things to different models
148
- verbose (bool, optional): enable additonal debug output
149
-
150
+ verbose (bool, optional): enable additional debug output
151
+
150
152
  Returns:
151
153
  str: the model type for which we initialized support
152
154
  """
153
-
154
-
155
+
155
156
  global yolo_model_type_imported
156
-
157
+
157
158
  if detector_options is not None and 'model_type' in detector_options:
158
159
  model_type = detector_options['model_type']
159
160
  print('Model type {} provided in detector options'.format(model_type))
@@ -161,7 +162,7 @@ def _initialize_yolo_imports_for_model(model_file,
161
162
  model_type = _get_model_type_for_model(model_file,
162
163
  prefer_model_type_source=prefer_model_type_source,
163
164
  default_model_type=default_model_type)
164
-
165
+
165
166
  if yolo_model_type_imported is not None:
166
167
  if model_type == yolo_model_type_imported:
167
168
  print('Bypassing imports for model type {}'.format(model_type))
@@ -169,27 +170,27 @@ def _initialize_yolo_imports_for_model(model_file,
169
170
  else:
170
171
  print('Previously set up imports for model type {}, re-importing as {}'.format(
171
172
  yolo_model_type_imported,model_type))
172
-
173
+
173
174
  _initialize_yolo_imports(model_type,verbose=verbose)
174
-
175
+
175
176
  return model_type
176
177
 
177
178
 
178
179
  def _clean_yolo_imports(verbose=False):
179
180
  """
180
181
  Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
181
- of another YOLO library version. The reason we jump through all these hoops, rather than
182
+ of another YOLO library version. The reason we jump through all these hoops, rather than
182
183
  just, e.g., handling different libraries in different modules, is that we need to make sure
183
184
  *pickle* sees the right version of modules during module loading, including modules we don't
184
- load directly (i.e., every module loaded within a YOLO library), and the only way I know to
185
+ load directly (i.e., every module loaded within a YOLO library), and the only way I know to
185
186
  do that is to remove all the "wrong" versions from sys.modules and sys.path.
186
-
187
+
187
188
  Args:
188
189
  verbose (bool, optional): enable additional debug output
189
190
  """
190
-
191
+
191
192
  modules_to_delete = []
192
- for module_name in sys.modules.keys():
193
+ for module_name in sys.modules.keys():
193
194
  module = sys.modules[module_name]
194
195
  try:
195
196
  module_file = module.__file__.replace('\\','/')
@@ -199,24 +200,24 @@ def _clean_yolo_imports(verbose=False):
199
200
  for token in tokens:
200
201
  if 'yolov5' in token or 'yolov9' in token or 'ultralytics' in token:
201
202
  modules_to_delete.append(module_name)
202
- break
203
+ break
203
204
  except Exception:
204
205
  pass
205
-
206
+
206
207
  for module_name in modules_to_delete:
207
208
  if module_name in sys.modules.keys():
208
209
  module_file = module.__file__.replace('\\','/')
209
210
  if verbose:
210
211
  print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
211
212
  del sys.modules[module_name]
212
-
213
+
213
214
  paths_to_delete = []
214
-
215
+
215
216
  for p in sys.path:
216
217
  if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
217
218
  print('clean_yolo_imports: removing {} from path'.format(p))
218
219
  paths_to_delete.append(p)
219
-
220
+
220
221
  for p in paths_to_delete:
221
222
  sys.path.remove(p)
222
223
 
@@ -226,54 +227,69 @@ def _clean_yolo_imports(verbose=False):
226
227
  def _initialize_yolo_imports(model_type='yolov5',
227
228
  allow_fallback_import=True,
228
229
  force_reimport=False,
229
- verbose=False):
230
+ verbose=True):
230
231
  """
231
- Imports required functions from one or more yolo libraries (yolov5, yolov9,
232
+ Imports required functions from one or more yolo libraries (yolov5, yolov9,
232
233
  ultralytics, targeting support for [model_type]).
233
-
234
+
234
235
  Args:
235
236
  model_type (str): The model type for which we're loading support
236
- allow_fallback_import (bool, optional): If we can't import from the package for
237
- which we're trying to load support, fall back to "import utils". This is
237
+ allow_fallback_import (bool, optional): If we can't import from the package for
238
+ which we're trying to load support, fall back to "import utils". This is
238
239
  typically used when the right support library is on the current PYTHONPATH.
239
- force_reimport (bool, optional): import the appropriate libraries even if the
240
+ force_reimport (bool, optional): import the appropriate libraries even if the
240
241
  requested model type matches the current initialization state
241
- verbose (bool, optional): include additonal debug output
242
-
242
+ verbose (bool, optional): include additional debug output
243
+
243
244
  Returns:
244
245
  str: the model type for which we initialized support
245
246
  """
246
-
247
+
248
+ # When running in pytest, the megadetector 'utils' module is put in the global
249
+ # namespace, which creates conflicts with yolov5; remove it from the global
250
+ # namespsace.
251
+ if ('PYTEST_CURRENT_TEST' in os.environ):
252
+ print('*** pytest detected ***')
253
+ if ('utils' in sys.modules):
254
+ utils_module = sys.modules['utils']
255
+ if hasattr(utils_module, '__file__') and 'megadetector' in str(utils_module.__file__):
256
+ print(f"Removing conflicting utils module: {utils_module.__file__}")
257
+ sys.modules.pop('utils', None)
258
+ # Also remove any submodules
259
+ to_remove = [name for name in sys.modules if name.startswith('utils.')]
260
+ for name in to_remove:
261
+ sys.modules.pop(name, None)
262
+
247
263
  global yolo_model_type_imported
248
-
264
+
249
265
  if model_type is None:
250
266
  model_type = 'yolov5'
251
-
267
+
252
268
  # The point of this function is to make the appropriate version
253
269
  # of the following functions available at module scope
254
270
  global non_max_suppression
255
271
  global xyxy2xywh
256
272
  global letterbox
257
273
  global scale_coords
258
-
274
+
259
275
  if yolo_model_type_imported is not None:
260
- if yolo_model_type_imported == model_type:
276
+ if (yolo_model_type_imported == model_type) and (not force_reimport):
261
277
  print('Bypassing imports for YOLO model type {}'.format(model_type))
262
278
  return
263
279
  else:
264
280
  _clean_yolo_imports()
265
-
281
+
266
282
  try_yolov5_import = (model_type == 'yolov5')
267
283
  try_yolov9_import = (model_type == 'yolov9')
268
284
  try_ultralytics_import = (model_type == 'ultralytics')
269
-
285
+
270
286
  utils_imported = False
271
-
287
+
272
288
  # First try importing from the yolov5 package; this is how the pip
273
289
  # package finds YOLOv5 utilities.
274
290
  if try_yolov5_import and not utils_imported:
275
-
276
- try:
291
+
292
+ try:
277
293
  from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
278
294
  from yolov5.utils.augmentations import letterbox # noqa
279
295
  try:
@@ -283,109 +299,127 @@ def _initialize_yolo_imports(model_type='yolov5',
283
299
  utils_imported = True
284
300
  if verbose:
285
301
  print('Imported utils from YOLOv5 package')
286
-
287
- except Exception as e: # noqa
288
-
302
+
303
+ except Exception as e: # noqa
289
304
  # print('yolov5 module import failed: {}'.format(e))
290
- # print(traceback.format_exc())
305
+ # print(traceback.format_exc())
291
306
  pass
292
-
307
+
293
308
  # Next try importing from the yolov9 package
294
309
  if try_yolov9_import and not utils_imported:
295
-
310
+
296
311
  try:
297
-
312
+
298
313
  from yolov9.utils.general import non_max_suppression, xyxy2xywh # noqa
299
314
  from yolov9.utils.augmentations import letterbox # noqa
300
315
  from yolov9.utils.general import scale_boxes as scale_coords # noqa
301
316
  utils_imported = True
302
317
  if verbose:
303
318
  print('Imported utils from YOLOv9 package')
304
-
319
+
305
320
  except Exception as e: # noqa
306
-
321
+
307
322
  # print('yolov9 module import failed: {}'.format(e))
308
323
  # print(traceback.format_exc())
309
324
  pass
310
-
311
- # If we haven't succeeded yet, import from the ultralytics package
325
+
326
+ # If we haven't succeeded yet, import from the ultralytics package
312
327
  if try_ultralytics_import and not utils_imported:
313
-
328
+
314
329
  try:
315
-
330
+
316
331
  import ultralytics # noqa
317
-
332
+
318
333
  except Exception:
319
-
334
+
320
335
  print('It looks like you are trying to run a model that requires the ultralytics package, '
321
336
  'but the ultralytics package is not installed, but . For licensing reasons, this '
322
337
  'is not installed by default with the MegaDetector Python package. Run '
323
338
  '"pip install ultralytics" to install it, and try again.')
324
339
  raise
325
-
340
+
326
341
  try:
327
-
342
+
328
343
  from ultralytics.utils.ops import non_max_suppression # noqa
329
344
  from ultralytics.utils.ops import xyxy2xywh # noqa
330
-
345
+
331
346
  # In the ultralytics package, scale_boxes and scale_coords both exist;
332
347
  # we want scale_boxes.
333
- #
348
+ #
334
349
  # from ultralytics.utils.ops import scale_coords # noqa
335
350
  from ultralytics.utils.ops import scale_boxes as scale_coords # noqa
336
351
  from ultralytics.data.augment import LetterBox
337
-
338
- # letterbox() became a LetterBox class in the ultralytics package. Create a
352
+
353
+ # letterbox() became a LetterBox class in the ultralytics package. Create a
339
354
  # backwards-compatible letterbox function wrapper that wraps the class up.
340
- def letterbox(img,new_shape,auto=False,scaleFill=False,scaleup=True,center=True,stride=32): # noqa
341
-
342
- L = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,scaleup=scaleup,center=center,stride=stride)
343
- letterbox_result = L(image=img)
344
-
355
+ def letterbox(img,new_shape,auto=False,scaleFill=False, #noqa
356
+ scaleup=True,center=True,stride=32):
357
+
358
+ # Ultralytics changed the "scaleFill" parameter to "scale_fill", we want to support
359
+ # both conventions.
360
+ use_old_scalefill_arg = False
361
+ try:
362
+ sig = inspect.signature(LetterBox.__init__)
363
+ if 'scaleFill' in sig.parameters:
364
+ use_old_scalefill_arg = True
365
+ except Exception:
366
+ pass
367
+
368
+ if use_old_scalefill_arg:
369
+ if verbose:
370
+ print('Using old scaleFill calling convention')
371
+ letterbox_transformer = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,
372
+ scaleup=scaleup,center=center,stride=stride)
373
+ else:
374
+ letterbox_transformer = LetterBox(new_shape,auto=auto,scale_fill=scaleFill,
375
+ scaleup=scaleup,center=center,stride=stride)
376
+
377
+ letterbox_result = letterbox_transformer(image=img)
378
+
345
379
  if isinstance(new_shape,int):
346
380
  new_shape = [new_shape,new_shape]
347
-
381
+
348
382
  # The letterboxing is done, we just need to reverse-engineer what it did
349
383
  shape = img.shape[:2]
350
-
384
+
351
385
  r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
352
386
  if not scaleup:
353
387
  r = min(r, 1.0)
354
388
  ratio = r, r
355
-
389
+
356
390
  new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
357
391
  dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
358
392
  if auto:
359
393
  dw, dh = np.mod(dw, stride), np.mod(dh, stride)
360
- elif scaleFill:
394
+ elif scaleFill:
361
395
  dw, dh = 0.0, 0.0
362
396
  new_unpad = (new_shape[1], new_shape[0])
363
397
  ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
364
-
398
+
365
399
  dw /= 2
366
400
  dh /= 2
367
401
  pad = (dw,dh)
368
-
402
+
369
403
  return [letterbox_result,ratio,pad]
370
-
404
+
371
405
  utils_imported = True
372
406
  if verbose:
373
407
  print('Imported utils from ultralytics package')
374
-
408
+
375
409
  except Exception:
376
-
410
+
377
411
  # print('Ultralytics module import failed')
378
412
  pass
379
-
413
+
380
414
  # If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
381
415
  if (not utils_imported) and allow_fallback_import:
382
-
416
+
383
417
  try:
384
-
418
+
385
419
  # import pre- and post-processing functions from the YOLOv5 repo
386
420
  from utils.general import non_max_suppression, xyxy2xywh # noqa
387
421
  from utils.augmentations import letterbox # noqa
388
-
422
+
389
423
  # scale_coords() is scale_boxes() in some YOLOv5 versions
390
424
  try:
391
425
  from utils.general import scale_coords # noqa
@@ -395,58 +429,58 @@ def _initialize_yolo_imports(model_type='yolov5',
395
429
  imported_file = sys.modules[scale_coords.__module__].__file__
396
430
  if verbose:
397
431
  print('Imported utils from {}'.format(imported_file))
398
-
432
+
399
433
  except ModuleNotFoundError as e:
400
-
434
+
401
435
  raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
402
-
436
+
403
437
  assert utils_imported, 'YOLO utils import error'
404
-
438
+
405
439
  yolo_model_type_imported = model_type
406
440
  if verbose:
407
441
  print('Prepared YOLO imports for model type {}'.format(model_type))
408
-
442
+
409
443
  return model_type
410
444
 
411
445
  # ...def _initialize_yolo_imports(...)
412
-
446
+
413
447
 
414
448
  #%% Model metadata functions
415
449
 
416
- def add_metadata_to_megadetector_model_file(model_file_in,
450
+ def add_metadata_to_megadetector_model_file(model_file_in,
417
451
  model_file_out,
418
- metadata,
452
+ metadata,
419
453
  destination_path='megadetector_info.json'):
420
454
  """
421
- Adds a .json file to the specified MegaDetector model file containing metadata used
455
+ Adds a .json file to the specified MegaDetector model file containing metadata used
422
456
  by this module. Always over-writes the output file.
423
-
457
+
424
458
  Args:
425
459
  model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
426
460
  model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
427
461
  May be the same as model_file_in.
428
462
  metadata (dict): The metadata dict to add to the output model file
429
- destination_path (str, optional): The relative path within the main folder of the
430
- model archive where we should write the metadata. This is not relative to the root
463
+ destination_path (str, optional): The relative path within the main folder of the
464
+ model archive where we should write the metadata. This is not relative to the root
431
465
  of the archive, it's relative to the one and only folder at the root of the archive
432
- (this is a PyTorch convention).
466
+ (this is a PyTorch convention).
433
467
  """
434
-
468
+
435
469
  tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
436
470
  os.makedirs(tmp_base,exist_ok=True)
437
471
  metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
438
- metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
472
+ metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
439
473
 
440
474
  with open(metadata_tmp_file_abs,'w') as f:
441
475
  json.dump(metadata,f,indent=1)
442
-
476
+
443
477
  # Copy the input file to the output file
444
478
  shutil.copyfile(model_file_in,model_file_out)
445
479
 
446
480
  # Write metadata to the output file
447
481
  with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
448
-
449
- # Torch doesn't like anything in the root folder of the zipfile, so we put
482
+
483
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
450
484
  # it in the one and only folder.
451
485
  names = zipf.namelist()
452
486
  root_folders = set()
@@ -456,9 +490,9 @@ def add_metadata_to_megadetector_model_file(model_file_in,
456
490
  assert len(root_folders) == 1,\
457
491
  'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
458
492
  root_folder = next(iter(root_folders))
459
-
460
- zipf.write(metadata_tmp_file_abs,
461
- root_folder + '/' + destination_path,
493
+
494
+ zipf.write(metadata_tmp_file_abs,
495
+ root_folder + '/' + destination_path,
462
496
  compresslevel=9,
463
497
  compress_type=zipfile.ZIP_DEFLATED)
464
498
 
@@ -466,7 +500,7 @@ def add_metadata_to_megadetector_model_file(model_file_in,
466
500
  os.remove(metadata_tmp_file_abs)
467
501
  except Exception as e:
468
502
  print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
469
-
503
+
470
504
  # ...def add_metadata_to_megadetector_model_file(...)
471
505
 
472
506
 
@@ -475,22 +509,22 @@ def read_metadata_from_megadetector_model_file(model_file,
475
509
  verbose=False):
476
510
  """
477
511
  Reads custom MegaDetector metadata from a modified MegaDetector model file.
478
-
512
+
479
513
  Args:
480
514
  model_file (str): The model filename to read, typically .pt (.zip is also sensible)
481
- relative_path (str, optional): The relative path within the main folder of the model
482
- archive from which we should read the metadata. This is not relative to the root
515
+ relative_path (str, optional): The relative path within the main folder of the model
516
+ archive from which we should read the metadata. This is not relative to the root
483
517
  of the archive, it's relative to the one and only folder at the root of the archive
484
- (this is a PyTorch convention).
485
-
518
+ (this is a PyTorch convention).
519
+
486
520
  Returns:
487
- object: Whatever we read from the metadata file, always a dict in practice. Returns
488
- None if we failed to read the specified metadata file.
521
+ object: whatever we read from the metadata file, always a dict in practice. Returns
522
+ None if we failed to read the specified metadata file.
489
523
  """
490
-
524
+
491
525
  with zipfile.ZipFile(model_file,'r') as zipf:
492
-
493
- # Torch doesn't like anything in the root folder of the zipfile, so we put
526
+
527
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
494
528
  # it in the one and only folder.
495
529
  names = zipf.namelist()
496
530
  root_folders = set()
@@ -498,17 +532,19 @@ def read_metadata_from_megadetector_model_file(model_file,
498
532
  root_folder = name.split('/')[0]
499
533
  root_folders.add(root_folder)
500
534
  if len(root_folders) != 1:
501
- print('Warning: this archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?')
535
+ print('Warning: this archive does not have exactly one folder at the top level; ' + \
536
+ 'are you sure it\'s a Torch model file?')
502
537
  return None
503
538
  root_folder = next(iter(root_folders))
504
-
539
+
505
540
  metadata_file = root_folder + '/' + relative_path
506
541
  if metadata_file not in names:
507
542
  # This is the case for MDv5a and MDv5b
508
543
  if verbose:
509
- print('Warning: could not find metadata file {} in zip archive'.format(metadata_file))
544
+ print('Warning: could not find metadata file {} in zip archive {}'.format(
545
+ metadata_file,os.path.basename(model_file)))
510
546
  return None
511
-
547
+
512
548
  try:
513
549
  path = zipfile.Path(zipf,metadata_file)
514
550
  contents = path.read_text()
@@ -516,9 +552,9 @@ def read_metadata_from_megadetector_model_file(model_file,
516
552
  except Exception as e:
517
553
  print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
518
554
  return None
519
-
555
+
520
556
  return d
521
-
557
+
522
558
  # ...def read_metadata_from_megadetector_model_file(...)
523
559
 
524
560
 
@@ -526,27 +562,33 @@ def read_metadata_from_megadetector_model_file(model_file,
526
562
 
527
563
  default_compatibility_mode = 'classic'
528
564
 
529
- # This is a useful hack when I want to verify that my test driver (md_tests.py) is
530
- # correctly forcing a specific compabitility mode (I use "classic-test" in that case)
565
+ # This is a useful hack when I want to verify that my test driver (md_tests.py) is
566
+ # correctly forcing a specific compatibility mode (I use "classic-test" in that case)
531
567
  require_non_default_compatibility_mode = False
532
568
 
533
569
  class PTDetector:
534
-
570
+ """
571
+ Class that runs a PyTorch-based MegaDetector model.
572
+ """
573
+
535
574
  def __init__(self, model_path, detector_options=None, verbose=False):
575
+
576
+ if verbose:
577
+ print('Initializing PTDetector (verbose)')
536
578
 
537
579
  # Set up the import environment for this model, unloading previous
538
580
  # YOLO library versions if necessary.
539
581
  _initialize_yolo_imports_for_model(model_path,
540
582
  detector_options=detector_options,
541
583
  verbose=verbose)
542
-
584
+
543
585
  # Parse options specific to this detector family
544
586
  force_cpu = False
545
- use_model_native_classes = False
587
+ use_model_native_classes = False
546
588
  compatibility_mode = default_compatibility_mode
547
-
589
+
548
590
  if detector_options is not None:
549
-
591
+
550
592
  if 'force_cpu' in detector_options:
551
593
  force_cpu = parse_bool_string(detector_options['force_cpu'])
552
594
  if 'use_model_native_classes' in detector_options:
@@ -554,27 +596,27 @@ class PTDetector:
554
596
  if 'compatibility_mode' in detector_options:
555
597
  if detector_options['compatibility_mode'] is None:
556
598
  compatibility_mode = default_compatibility_mode
557
- else:
558
- compatibility_mode = detector_options['compatibility_mode']
559
-
599
+ else:
600
+ compatibility_mode = detector_options['compatibility_mode']
601
+
560
602
  if require_non_default_compatibility_mode:
561
-
603
+
562
604
  print('### DEBUG: requiring non-default compatibility mode ###')
563
605
  assert compatibility_mode != 'classic'
564
606
  assert compatibility_mode != 'default'
565
-
607
+
566
608
  preprocess_only = False
567
609
  if (detector_options is not None) and \
568
610
  ('preprocess_only' in detector_options) and \
569
611
  (detector_options['preprocess_only']):
570
612
  preprocess_only = True
571
-
613
+
572
614
  if verbose or (not preprocess_only):
573
615
  print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
574
-
616
+
575
617
  model_metadata = read_metadata_from_megadetector_model_file(model_path)
576
-
577
- #: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
618
+
619
+ #: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
578
620
  #: aspect ratio".
579
621
  if model_metadata is not None and 'image_size' in model_metadata:
580
622
  self.default_image_size = model_metadata['image_size']
@@ -582,38 +624,38 @@ class PTDetector:
582
624
  print('Loaded image size {} from model metadata'.format(self.default_image_size))
583
625
  else:
584
626
  self.default_image_size = 1280
585
-
627
+
586
628
  #: Either a string ('cpu','cuda:0') or a torch.device()
587
629
  self.device = 'cpu'
588
-
630
+
589
631
  #: Have we already printed a warning about using a non-standard image size?
590
632
  #:
591
633
  #: :meta private:
592
634
  self.printed_image_size_warning = False
593
-
635
+
594
636
  #: If this is False, we assume the underlying model is producing class indices in the
595
637
  #: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
596
- #: MD classes (1,2,3) before generating output. If this is True, we use whatever
638
+ #: MD classes (1,2,3) before generating output. If this is True, we use whatever
597
639
  #: indices the model provides
598
- self.use_model_native_classes = use_model_native_classes
599
-
640
+ self.use_model_native_classes = use_model_native_classes
641
+
600
642
  #: This allows us to maintain backwards compatibility across a set of changes to the
601
- #: way this class does inference. Currently should start with either "default" or
643
+ #: way this class does inference. Currently should start with either "default" or
602
644
  #: "classic".
603
645
  self.compatibility_mode = compatibility_mode
604
-
646
+
605
647
  #: Stride size passed to YOLOv5's letterbox() function
606
648
  self.letterbox_stride = 32
607
-
649
+
608
650
  if 'classic' in self.compatibility_mode:
609
651
  self.letterbox_stride = 64
610
-
652
+
611
653
  #: Use half-precision inference... fixed by the model, generally don't mess with this
612
654
  self.half_precision = False
613
-
655
+
614
656
  if preprocess_only:
615
657
  return
616
-
658
+
617
659
  if not force_cpu:
618
660
  if torch.cuda.is_available():
619
661
  self.device = torch.device('cuda:0')
@@ -623,10 +665,10 @@ class PTDetector:
623
665
  except AttributeError:
624
666
  pass
625
667
  try:
626
- self.model = PTDetector._load_model(model_path,
627
- device=self.device,
668
+ self.model = PTDetector._load_model(model_path,
669
+ device=self.device,
628
670
  compatibility_mode=self.compatibility_mode)
629
-
671
+
630
672
  except Exception as e:
631
673
  # In a very esoteric scenario where an old version of YOLOv5 is used to run
632
674
  # newer models, we run into an issue because the "Model" class became
@@ -636,21 +678,21 @@ class PTDetector:
636
678
  print('Forward-compatibility issue detected, patching')
637
679
  from models import yolo
638
680
  yolo.DetectionModel = yolo.Model
639
- self.model = PTDetector._load_model(model_path,
681
+ self.model = PTDetector._load_model(model_path,
640
682
  device=self.device,
641
683
  compatibility_mode=self.compatibility_mode,
642
- verbose=verbose)
684
+ verbose=verbose)
643
685
  else:
644
686
  raise
645
687
  if (self.device != 'cpu'):
646
688
  if verbose:
647
689
  print('Sending model to GPU')
648
690
  self.model.to(self.device)
649
-
691
+
650
692
 
651
693
  @staticmethod
652
694
  def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
653
-
695
+
654
696
  if verbose:
655
697
  print(f'Using PyTorch version {torch.__version__}')
656
698
 
@@ -661,12 +703,12 @@ class PTDetector:
661
703
  # very slight changes to the output, which always make me nervous, so I'm not
662
704
  # doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
663
705
  if 'classic' in compatibility_mode:
664
- use_map_location = (device != 'mps')
706
+ use_map_location = (device != 'mps')
665
707
  else:
666
708
  use_map_location = False
667
-
709
+
668
710
  if use_map_location:
669
- try:
711
+ try:
670
712
  checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
671
713
  # For a transitional period, we want to support torch 1.1x, where the weights_only
672
714
  # parameter doesn't exist
@@ -682,31 +724,31 @@ class PTDetector:
682
724
  # parameter doesn't exist
683
725
  except Exception as e:
684
726
  if "'weights_only' is an invalid keyword" in str(e):
685
- checkpoint = torch.load(model_pt_path)
727
+ checkpoint = torch.load(model_pt_path)
686
728
  else:
687
729
  raise
688
-
689
- # Compatibility fix that allows us to load older YOLOv5 models with
730
+
731
+ # Compatibility fix that allows us to load older YOLOv5 models with
690
732
  # newer versions of YOLOv5/PT
691
733
  for m in checkpoint['model'].modules():
692
734
  t = type(m)
693
735
  if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
694
736
  m.recompute_scale_factor = None
695
-
696
- if use_map_location:
697
- model = checkpoint['model'].float().fuse().eval()
737
+
738
+ if use_map_location:
739
+ model = checkpoint['model'].float().fuse().eval()
698
740
  else:
699
- model = checkpoint['model'].float().fuse().eval().to(device)
700
-
741
+ model = checkpoint['model'].float().fuse().eval().to(device)
742
+
701
743
  return model
702
744
 
703
745
  # ...def _load_model(...)
704
-
705
746
 
706
- def generate_detections_one_image(self,
707
- img_original,
708
- image_id='unknown',
709
- detection_threshold=0.00001,
747
+
748
+ def generate_detections_one_image(self,
749
+ img_original,
750
+ image_id='unknown',
751
+ detection_threshold=0.00001,
710
752
  image_size=None,
711
753
  skip_image_resizing=False,
712
754
  augment=False,
@@ -716,16 +758,16 @@ class PTDetector:
716
758
  Applies the detector to an image.
717
759
 
718
760
  Args:
719
- img_original (Image): the PIL Image object (or numpy array) on which we should run the
761
+ img_original (Image): the PIL Image object (or numpy array) on which we should run the
720
762
  detector, with EXIF rotation already handled
721
- image_id (str, optional): a path to identify the image; will be in the "file" field
763
+ image_id (str, optional): a path to identify the image; will be in the "file" field
722
764
  of the output object
723
- detection_threshold (float, optional): only detections above this confidence threshold
765
+ detection_threshold (float, optional): only detections above this confidence threshold
724
766
  will be included in the return value
725
- image_size (tuple, optional): image size to use for inference, only mess with this if
767
+ image_size (tuple, optional): image size to use for inference, only mess with this if
726
768
  (a) you're using a model other than MegaDetector or (b) you know what you're getting into
727
- skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
728
- external resizing), only mess with this if (a) you're using a model other than MegaDetector
769
+ skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
770
+ external resizing), only mess with this if (a) you're using a model other than MegaDetector
729
771
  or (b) you know what you're getting into
730
772
  augment (bool, optional): enable (implementation-specific) image augmentation
731
773
  preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
@@ -747,17 +789,17 @@ class PTDetector:
747
789
  assert 'classic' in self.compatibility_mode, \
748
790
  'Standalone preprocessing only supported in "classic" mode'
749
791
  assert not skip_image_resizing, \
750
- 'skip_image_resizing and preprocess_only are exclusive'
751
-
792
+ 'skip_image_resizing and preprocess_only are exclusive'
793
+
752
794
  if detection_threshold is None:
753
-
795
+
754
796
  detection_threshold = 0
755
-
797
+
756
798
  try:
757
-
799
+
758
800
  # If the caller wants us to skip all the resizing operations...
759
801
  if skip_image_resizing:
760
-
802
+
761
803
  if isinstance(img_original,dict):
762
804
  image_info = img_original
763
805
  img = image_info['img_processed']
@@ -768,107 +810,107 @@ class PTDetector:
768
810
  img_original_pil = image_info['img_original_pil']
769
811
  else:
770
812
  img = img_original
771
-
813
+
772
814
  else:
773
-
815
+
774
816
  img_original_pil = None
775
817
  # If we were given a PIL image
776
-
818
+
777
819
  if not isinstance(img_original,np.ndarray):
778
- img_original_pil = img_original
820
+ img_original_pil = img_original
779
821
  img_original = np.asarray(img_original)
780
-
822
+
781
823
  # PIL images are RGB already
782
824
  # img_original = img_original[:, :, ::-1]
783
-
825
+
784
826
  # Save the original shape for scaling boxes later
785
827
  scaling_shape = img_original.shape
786
-
828
+
787
829
  # If the caller is requesting a specific target size...
788
830
  if image_size is not None:
789
-
831
+
790
832
  assert isinstance(image_size,int)
791
-
833
+
792
834
  if not self.printed_image_size_warning:
793
835
  print('Using user-supplied image size {}'.format(image_size))
794
- self.printed_image_size_warning = True
795
-
836
+ self.printed_image_size_warning = True
837
+
796
838
  # Otherwise resize to self.default_image_size
797
839
  else:
798
-
840
+
799
841
  image_size = self.default_image_size
800
842
  self.printed_image_size_warning = False
801
-
843
+
802
844
  # ...if the caller has specified an image size
803
-
845
+
804
846
  # In "classic mode", we only do the letterboxing resize, we don't do an
805
847
  # additional initial resizing operation
806
848
  if 'classic' in self.compatibility_mode:
807
-
849
+
808
850
  resize_ratio = 1.0
809
-
810
- # Resize the image so the long side matches the target image size. This is not
851
+
852
+ # Resize the image so the long side matches the target image size. This is not
811
853
  # letterboxing (i.e., padding) yet, just resizing.
812
854
  else:
813
-
855
+
814
856
  use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
815
-
857
+
816
858
  h,w = img_original.shape[:2]
817
859
  resize_ratio = image_size / max(h,w)
818
-
860
+
819
861
  # Only resize if we have to
820
862
  if resize_ratio != 1:
821
-
822
- # Match what yolov5 does: use linear interpolation for upsizing;
863
+
864
+ # Match what yolov5 does: use linear interpolation for upsizing;
823
865
  # area interpolation for downsizing
824
866
  if resize_ratio > 1:
825
867
  interpolation_method = cv2.INTER_LINEAR
826
868
  else:
827
- interpolation_method = cv2.INTER_AREA
828
-
869
+ interpolation_method = cv2.INTER_AREA
870
+
829
871
  if use_ceil_for_resize:
830
872
  target_w = math.ceil(w * resize_ratio)
831
873
  target_h = math.ceil(h * resize_ratio)
832
874
  else:
833
875
  target_w = int(w * resize_ratio)
834
876
  target_h = int(h * resize_ratio)
835
-
877
+
836
878
  img_original = cv2.resize(
837
879
  img_original, (target_w, target_h),
838
880
  interpolation=interpolation_method)
839
881
 
840
882
  if 'classic' in self.compatibility_mode:
841
-
883
+
842
884
  letterbox_auto = True
843
885
  letterbox_scaleup = True
844
886
  target_shape = image_size
845
-
887
+
846
888
  else:
847
-
889
+
848
890
  letterbox_auto = False
849
891
  letterbox_scaleup = False
850
-
892
+
851
893
  # The padding to apply as a fraction of the stride size
852
894
  pad = 0.5
853
-
895
+
854
896
  model_stride = int(self.model.stride.max())
855
-
897
+
856
898
  max_dimension = max(img_original.shape)
857
899
  normalized_shape = [img_original.shape[0] / max_dimension,
858
900
  img_original.shape[1] / max_dimension]
859
901
  target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
860
902
  pad).astype(int) * model_stride
861
-
903
+
862
904
  # Now we letterbox, which is just padding, since we've already resized.
863
- img,letterbox_ratio,letterbox_pad = letterbox(img_original,
905
+ img,letterbox_ratio,letterbox_pad = letterbox(img_original,
864
906
  new_shape=target_shape,
865
- stride=self.letterbox_stride,
907
+ stride=self.letterbox_stride,
866
908
  auto=letterbox_auto,
867
909
  scaleFill=False,
868
910
  scaleup=letterbox_scaleup)
869
-
911
+
870
912
  if preprocess_only:
871
-
913
+
872
914
  assert 'file' in result
873
915
  result['img_processed'] = img
874
916
  result['img_original'] = img_original
@@ -878,14 +920,14 @@ class PTDetector:
878
920
  result['letterbox_ratio'] = letterbox_ratio
879
921
  result['letterbox_pad'] = letterbox_pad
880
922
  return result
881
-
923
+
882
924
  # ...are we doing resizing here, or were images already resized?
883
-
925
+
884
926
  # Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
885
927
  # so we don't need to mess with the color channels.
886
928
  #
887
929
  # TODO, this could be moved into the preprocessing loop
888
-
930
+
889
931
  img = img.transpose((2, 0, 1)) # [::-1]
890
932
  img = np.ascontiguousarray(img)
891
933
  img = torch.from_numpy(img)
@@ -893,8 +935,8 @@ class PTDetector:
893
935
  img = img.half() if self.half_precision else img.float()
894
936
  img /= 255
895
937
 
896
- # In practice this is always true
897
- if len(img.shape) == 3:
938
+ # In practice this is always true
939
+ if len(img.shape) == 3:
898
940
  img = torch.unsqueeze(img, 0)
899
941
 
900
942
  # Run the model
@@ -908,19 +950,19 @@ class PTDetector:
908
950
  else:
909
951
  nms_conf_thres = detection_threshold # 0.01
910
952
  nms_iou_thres = 0.6
911
- nms_agnostic = False
953
+ nms_agnostic = False
912
954
  nms_multi_label = True
913
-
955
+
914
956
  # As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
915
957
  #
916
- # Send predictions back to the CPU for NMS.
958
+ # Send predictions back to the CPU for NMS.
917
959
  if self.device == 'mps':
918
960
  pred_nms = pred.cpu()
919
961
  else:
920
962
  pred_nms = pred
921
-
963
+
922
964
  # NMS
923
- pred = non_max_suppression(prediction=pred_nms,
965
+ pred = non_max_suppression(prediction=pred_nms,
924
966
  conf_thres=nms_conf_thres,
925
967
  iou_thres=nms_iou_thres,
926
968
  agnostic=nms_agnostic,
@@ -930,26 +972,26 @@ class PTDetector:
930
972
  gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
931
973
 
932
974
  if 'classic' in self.compatibility_mode:
933
-
975
+
934
976
  ratio = None
935
977
  ratio_pad = None
936
-
978
+
937
979
  else:
938
-
980
+
939
981
  # letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
940
982
  #
941
983
  # ratio is a 2-tuple specifying the scaling that was applied to each dimension.
942
984
  #
943
985
  # The scale_boxes function expects a 2-tuple with these things combined.
944
986
  ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
945
- ratio_pad = (ratio, letterbox_pad)
946
-
987
+ ratio_pad = (ratio, letterbox_pad)
988
+
947
989
  # This is a loop over detection batches, which will always be length 1 in our case,
948
990
  # since we're not doing batch inference.
949
991
  #
950
992
  # det = pred[0]
951
993
  #
952
- # det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
994
+ # det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
953
995
  # in descending order by confidence.
954
996
  #
955
997
  # Columns are:
@@ -959,15 +1001,15 @@ class PTDetector:
959
1001
  # At this point, these are *non*-normalized values, referring to the size at which we
960
1002
  # ran inference (img.shape).
961
1003
  for det in pred:
962
-
1004
+
963
1005
  if len(det) == 0:
964
1006
  continue
965
-
1007
+
966
1008
  # Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
967
1009
  if 'classic' in self.compatibility_mode:
968
-
1010
+
969
1011
  det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
970
-
1012
+
971
1013
  else:
972
1014
  # After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
973
1015
  # original pixel dimension of the image, followed by the class and confidence
@@ -975,14 +1017,14 @@ class PTDetector:
975
1017
 
976
1018
  # Loop over detections
977
1019
  for *xyxy, conf, cls in reversed(det):
978
-
1020
+
979
1021
  if conf < detection_threshold:
980
1022
  continue
981
-
1023
+
982
1024
  # Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
983
1025
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
984
1026
 
985
- # Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
1027
+ # Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
986
1028
  # left/top/w/h (i.e., MD format)
987
1029
  api_box = ct_utils.convert_yolo_to_xywh(xywh)
988
1030
 
@@ -991,11 +1033,11 @@ class PTDetector:
991
1033
  conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
992
1034
  else:
993
1035
  api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
994
- conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
995
-
1036
+ conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
1037
+
996
1038
  if not self.use_model_native_classes:
997
- # The MegaDetector output format's categories start at 1, but all YOLO-based
998
- # MD models have category numbers starting at 0.
1039
+ # The MegaDetector output format's categories start at 1, but all YOLO-based
1040
+ # MD models have category numbers starting at 0.
999
1041
  cls = int(cls.tolist()) + 1
1000
1042
  if cls not in (1, 2, 3):
1001
1043
  raise KeyError(f'{cls} is not a valid class.')
@@ -1008,15 +1050,15 @@ class PTDetector:
1008
1050
  'bbox': api_box
1009
1051
  })
1010
1052
  max_conf = max(max_conf, conf)
1011
-
1053
+
1012
1054
  # ...for each detection in this batch
1013
-
1055
+
1014
1056
  # ...for each detection batch (always one iteration)
1015
1057
 
1016
1058
  # ...try
1017
-
1059
+
1018
1060
  except Exception as e:
1019
-
1061
+
1020
1062
  result['failure'] = FAILURE_INFER
1021
1063
  print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
1022
1064
  # traceback.print_exc(e)
@@ -1039,12 +1081,12 @@ class PTDetector:
1039
1081
  if __name__ == '__main__':
1040
1082
 
1041
1083
  pass
1042
-
1084
+
1043
1085
  #%%
1044
-
1086
+
1045
1087
  import os #noqa
1046
1088
  from megadetector.visualization import visualization_utils as vis_utils
1047
-
1089
+
1048
1090
  model_file = os.environ['MDV5A']
1049
1091
  im_file = os.path.expanduser('~/git/MegaDetector/images/nacti.jpg')
1050
1092