megadetector 5.0.28__py3-none-any.whl → 10.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (197) hide show
  1. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +2 -2
  2. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +1 -1
  3. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +1 -1
  4. megadetector/classification/aggregate_classifier_probs.py +3 -3
  5. megadetector/classification/analyze_failed_images.py +5 -5
  6. megadetector/classification/cache_batchapi_outputs.py +5 -5
  7. megadetector/classification/create_classification_dataset.py +11 -12
  8. megadetector/classification/crop_detections.py +10 -10
  9. megadetector/classification/csv_to_json.py +8 -8
  10. megadetector/classification/detect_and_crop.py +13 -15
  11. megadetector/classification/efficientnet/model.py +8 -8
  12. megadetector/classification/efficientnet/utils.py +6 -5
  13. megadetector/classification/evaluate_model.py +7 -7
  14. megadetector/classification/identify_mislabeled_candidates.py +6 -6
  15. megadetector/classification/json_to_azcopy_list.py +1 -1
  16. megadetector/classification/json_validator.py +29 -32
  17. megadetector/classification/map_classification_categories.py +9 -9
  18. megadetector/classification/merge_classification_detection_output.py +12 -9
  19. megadetector/classification/prepare_classification_script.py +19 -19
  20. megadetector/classification/prepare_classification_script_mc.py +26 -26
  21. megadetector/classification/run_classifier.py +4 -4
  22. megadetector/classification/save_mislabeled.py +6 -6
  23. megadetector/classification/train_classifier.py +1 -1
  24. megadetector/classification/train_classifier_tf.py +9 -9
  25. megadetector/classification/train_utils.py +10 -10
  26. megadetector/data_management/annotations/annotation_constants.py +1 -2
  27. megadetector/data_management/camtrap_dp_to_coco.py +79 -46
  28. megadetector/data_management/cct_json_utils.py +103 -103
  29. megadetector/data_management/cct_to_md.py +49 -49
  30. megadetector/data_management/cct_to_wi.py +33 -33
  31. megadetector/data_management/coco_to_labelme.py +75 -75
  32. megadetector/data_management/coco_to_yolo.py +210 -193
  33. megadetector/data_management/databases/add_width_and_height_to_db.py +86 -12
  34. megadetector/data_management/databases/combine_coco_camera_traps_files.py +40 -40
  35. megadetector/data_management/databases/integrity_check_json_db.py +228 -200
  36. megadetector/data_management/databases/subset_json_db.py +33 -33
  37. megadetector/data_management/generate_crops_from_cct.py +88 -39
  38. megadetector/data_management/get_image_sizes.py +54 -49
  39. megadetector/data_management/labelme_to_coco.py +133 -125
  40. megadetector/data_management/labelme_to_yolo.py +159 -73
  41. megadetector/data_management/lila/create_lila_blank_set.py +81 -83
  42. megadetector/data_management/lila/create_lila_test_set.py +32 -31
  43. megadetector/data_management/lila/create_links_to_md_results_files.py +18 -18
  44. megadetector/data_management/lila/download_lila_subset.py +21 -24
  45. megadetector/data_management/lila/generate_lila_per_image_labels.py +365 -107
  46. megadetector/data_management/lila/get_lila_annotation_counts.py +35 -33
  47. megadetector/data_management/lila/get_lila_image_counts.py +22 -22
  48. megadetector/data_management/lila/lila_common.py +73 -70
  49. megadetector/data_management/lila/test_lila_metadata_urls.py +28 -19
  50. megadetector/data_management/mewc_to_md.py +344 -340
  51. megadetector/data_management/ocr_tools.py +262 -255
  52. megadetector/data_management/read_exif.py +249 -227
  53. megadetector/data_management/remap_coco_categories.py +90 -28
  54. megadetector/data_management/remove_exif.py +81 -21
  55. megadetector/data_management/rename_images.py +187 -187
  56. megadetector/data_management/resize_coco_dataset.py +588 -120
  57. megadetector/data_management/speciesnet_to_md.py +41 -41
  58. megadetector/data_management/wi_download_csv_to_coco.py +55 -55
  59. megadetector/data_management/yolo_output_to_md_output.py +248 -122
  60. megadetector/data_management/yolo_to_coco.py +333 -191
  61. megadetector/detection/change_detection.py +832 -0
  62. megadetector/detection/process_video.py +340 -337
  63. megadetector/detection/pytorch_detector.py +358 -278
  64. megadetector/detection/run_detector.py +399 -186
  65. megadetector/detection/run_detector_batch.py +404 -377
  66. megadetector/detection/run_inference_with_yolov5_val.py +340 -327
  67. megadetector/detection/run_tiled_inference.py +257 -249
  68. megadetector/detection/tf_detector.py +24 -24
  69. megadetector/detection/video_utils.py +332 -295
  70. megadetector/postprocessing/add_max_conf.py +19 -11
  71. megadetector/postprocessing/categorize_detections_by_size.py +45 -45
  72. megadetector/postprocessing/classification_postprocessing.py +468 -433
  73. megadetector/postprocessing/combine_batch_outputs.py +23 -23
  74. megadetector/postprocessing/compare_batch_results.py +590 -525
  75. megadetector/postprocessing/convert_output_format.py +106 -102
  76. megadetector/postprocessing/create_crop_folder.py +347 -147
  77. megadetector/postprocessing/detector_calibration.py +173 -168
  78. megadetector/postprocessing/generate_csv_report.py +508 -499
  79. megadetector/postprocessing/load_api_results.py +48 -27
  80. megadetector/postprocessing/md_to_coco.py +133 -102
  81. megadetector/postprocessing/md_to_labelme.py +107 -90
  82. megadetector/postprocessing/md_to_wi.py +40 -40
  83. megadetector/postprocessing/merge_detections.py +92 -114
  84. megadetector/postprocessing/postprocess_batch_results.py +319 -301
  85. megadetector/postprocessing/remap_detection_categories.py +91 -38
  86. megadetector/postprocessing/render_detection_confusion_matrix.py +214 -205
  87. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +57 -57
  88. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +27 -28
  89. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +704 -679
  90. megadetector/postprocessing/separate_detections_into_folders.py +226 -211
  91. megadetector/postprocessing/subset_json_detector_output.py +265 -262
  92. megadetector/postprocessing/top_folders_to_bottom.py +45 -45
  93. megadetector/postprocessing/validate_batch_results.py +70 -70
  94. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +52 -52
  95. megadetector/taxonomy_mapping/map_new_lila_datasets.py +18 -19
  96. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +54 -33
  97. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +67 -67
  98. megadetector/taxonomy_mapping/retrieve_sample_image.py +16 -16
  99. megadetector/taxonomy_mapping/simple_image_download.py +8 -8
  100. megadetector/taxonomy_mapping/species_lookup.py +156 -74
  101. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +14 -14
  102. megadetector/taxonomy_mapping/taxonomy_graph.py +10 -10
  103. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +13 -13
  104. megadetector/utils/ct_utils.py +1049 -211
  105. megadetector/utils/directory_listing.py +21 -77
  106. megadetector/utils/gpu_test.py +22 -22
  107. megadetector/utils/md_tests.py +632 -529
  108. megadetector/utils/path_utils.py +1520 -431
  109. megadetector/utils/process_utils.py +41 -41
  110. megadetector/utils/split_locations_into_train_val.py +62 -62
  111. megadetector/utils/string_utils.py +148 -27
  112. megadetector/utils/url_utils.py +489 -176
  113. megadetector/utils/wi_utils.py +2658 -2526
  114. megadetector/utils/write_html_image_list.py +137 -137
  115. megadetector/visualization/plot_utils.py +34 -30
  116. megadetector/visualization/render_images_with_thumbnails.py +39 -74
  117. megadetector/visualization/visualization_utils.py +487 -435
  118. megadetector/visualization/visualize_db.py +232 -198
  119. megadetector/visualization/visualize_detector_output.py +82 -76
  120. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/METADATA +5 -2
  121. megadetector-10.0.0.dist-info/RECORD +139 -0
  122. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/WHEEL +1 -1
  123. megadetector/api/batch_processing/api_core/__init__.py +0 -0
  124. megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
  125. megadetector/api/batch_processing/api_core/batch_service/score.py +0 -439
  126. megadetector/api/batch_processing/api_core/server.py +0 -294
  127. megadetector/api/batch_processing/api_core/server_api_config.py +0 -97
  128. megadetector/api/batch_processing/api_core/server_app_config.py +0 -55
  129. megadetector/api/batch_processing/api_core/server_batch_job_manager.py +0 -220
  130. megadetector/api/batch_processing/api_core/server_job_status_table.py +0 -149
  131. megadetector/api/batch_processing/api_core/server_orchestration.py +0 -360
  132. megadetector/api/batch_processing/api_core/server_utils.py +0 -88
  133. megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
  134. megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
  135. megadetector/api/batch_processing/api_support/__init__.py +0 -0
  136. megadetector/api/batch_processing/api_support/summarize_daily_activity.py +0 -152
  137. megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
  138. megadetector/api/synchronous/__init__.py +0 -0
  139. megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  140. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +0 -151
  141. megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -263
  142. megadetector/api/synchronous/api_core/animal_detection_api/config.py +0 -35
  143. megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
  144. megadetector/api/synchronous/api_core/tests/load_test.py +0 -110
  145. megadetector/data_management/importers/add_nacti_sizes.py +0 -52
  146. megadetector/data_management/importers/add_timestamps_to_icct.py +0 -79
  147. megadetector/data_management/importers/animl_results_to_md_results.py +0 -158
  148. megadetector/data_management/importers/auckland_doc_test_to_json.py +0 -373
  149. megadetector/data_management/importers/auckland_doc_to_json.py +0 -201
  150. megadetector/data_management/importers/awc_to_json.py +0 -191
  151. megadetector/data_management/importers/bellevue_to_json.py +0 -272
  152. megadetector/data_management/importers/cacophony-thermal-importer.py +0 -793
  153. megadetector/data_management/importers/carrizo_shrubfree_2018.py +0 -269
  154. megadetector/data_management/importers/carrizo_trail_cam_2017.py +0 -289
  155. megadetector/data_management/importers/cct_field_adjustments.py +0 -58
  156. megadetector/data_management/importers/channel_islands_to_cct.py +0 -913
  157. megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
  158. megadetector/data_management/importers/eMammal/eMammal_helpers.py +0 -249
  159. megadetector/data_management/importers/eMammal/make_eMammal_json.py +0 -223
  160. megadetector/data_management/importers/ena24_to_json.py +0 -276
  161. megadetector/data_management/importers/filenames_to_json.py +0 -386
  162. megadetector/data_management/importers/helena_to_cct.py +0 -283
  163. megadetector/data_management/importers/idaho-camera-traps.py +0 -1407
  164. megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
  165. megadetector/data_management/importers/import_desert_lion_conservation_camera_traps.py +0 -387
  166. megadetector/data_management/importers/jb_csv_to_json.py +0 -150
  167. megadetector/data_management/importers/mcgill_to_json.py +0 -250
  168. megadetector/data_management/importers/missouri_to_json.py +0 -490
  169. megadetector/data_management/importers/nacti_fieldname_adjustments.py +0 -79
  170. megadetector/data_management/importers/noaa_seals_2019.py +0 -181
  171. megadetector/data_management/importers/osu-small-animals-to-json.py +0 -364
  172. megadetector/data_management/importers/pc_to_json.py +0 -365
  173. megadetector/data_management/importers/plot_wni_giraffes.py +0 -123
  174. megadetector/data_management/importers/prepare_zsl_imerit.py +0 -131
  175. megadetector/data_management/importers/raic_csv_to_md_results.py +0 -416
  176. megadetector/data_management/importers/rspb_to_json.py +0 -356
  177. megadetector/data_management/importers/save_the_elephants_survey_A.py +0 -320
  178. megadetector/data_management/importers/save_the_elephants_survey_B.py +0 -329
  179. megadetector/data_management/importers/snapshot_safari_importer.py +0 -758
  180. megadetector/data_management/importers/snapshot_serengeti_lila.py +0 -1067
  181. megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
  182. megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
  183. megadetector/data_management/importers/sulross_get_exif.py +0 -65
  184. megadetector/data_management/importers/timelapse_csv_set_to_json.py +0 -490
  185. megadetector/data_management/importers/ubc_to_json.py +0 -399
  186. megadetector/data_management/importers/umn_to_json.py +0 -507
  187. megadetector/data_management/importers/wellington_to_json.py +0 -263
  188. megadetector/data_management/importers/wi_to_json.py +0 -442
  189. megadetector/data_management/importers/zamba_results_to_md_results.py +0 -180
  190. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +0 -101
  191. megadetector/data_management/lila/add_locations_to_nacti.py +0 -151
  192. megadetector/utils/azure_utils.py +0 -178
  193. megadetector/utils/sas_blob_utils.py +0 -509
  194. megadetector-5.0.28.dist-info/RECORD +0 -209
  195. /megadetector/{api/batch_processing/__init__.py → __init__.py} +0 -0
  196. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/licenses/LICENSE +0 -0
  197. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import shutil
17
17
  import traceback
18
18
  import uuid
19
19
  import json
20
+ import inspect
20
21
 
21
22
  import cv2
22
23
  import torch
@@ -54,7 +55,7 @@ def _get_model_type_for_model(model_file,
54
55
  verbose=False):
55
56
  """
56
57
  Determine the model type (i.e., the inference library we need to use) for a .pt file.
57
-
58
+
58
59
  Args:
59
60
  model_file (str): the model file to read
60
61
  prefer_model_type_source (str, optional): how should we handle the (very unlikely)
@@ -64,28 +65,28 @@ def _get_model_type_for_model(model_file,
64
65
  default_model_type (str, optional): return value for the case where we can't find
65
66
  appropriate metadata in the file or in the global table.
66
67
  verbose (bool, optional): enable additional debug output
67
-
68
+
68
69
  Returns:
69
70
  str: the model type indicated for this model
70
71
  """
71
-
72
+
72
73
  model_info = read_metadata_from_megadetector_model_file(model_file)
73
-
74
+
74
75
  # Check whether the model file itself specified a model type
75
76
  model_type_from_model_file_metadata = None
76
-
77
- if model_info is not None and 'model_type' in model_info:
77
+
78
+ if model_info is not None and 'model_type' in model_info:
78
79
  model_type_from_model_file_metadata = model_info['model_type']
79
80
  if verbose:
80
81
  print('Parsed model type {} from model {}'.format(
81
82
  model_type_from_model_file_metadata,
82
83
  model_file))
83
-
84
+
84
85
  model_type_from_model_version = None
85
-
86
+
86
87
  # Check whether this is a known model version with a specific model type
87
88
  model_version_from_file = get_detector_version_from_model_file(model_file)
88
-
89
+
89
90
  if model_version_from_file is not None and model_version_from_file in known_models:
90
91
  model_info = known_models[model_version_from_file]
91
92
  if 'model_type' in model_info:
@@ -93,15 +94,15 @@ def _get_model_type_for_model(model_file,
93
94
  if verbose:
94
95
  print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
95
96
  else:
96
- model_type_from_model_version = None
97
-
97
+ model_type_from_model_version = None
98
+
98
99
  if model_type_from_model_file_metadata is None and \
99
100
  model_type_from_model_version is None:
100
101
  if verbose:
101
102
  print('Could not determine model type for {}, assuming {}'.format(
102
103
  model_file,default_model_type))
103
104
  model_type = default_model_type
104
-
105
+
105
106
  elif model_type_from_model_file_metadata is not None and \
106
107
  model_type_from_model_version is not None:
107
108
  if model_type_from_model_version == model_type_from_model_file_metadata:
@@ -113,15 +114,15 @@ def _get_model_type_for_model(model_file,
113
114
  model_type = model_type_from_model_file_metadata
114
115
  else:
115
116
  model_type = model_type_from_model_version
116
-
117
+
117
118
  elif model_type_from_model_file_metadata is not None:
118
-
119
+
119
120
  model_type = model_type_from_model_file_metadata
120
-
121
+
121
122
  elif model_type_from_model_version is not None:
122
-
123
+
123
124
  model_type = model_type_from_model_version
124
-
125
+
125
126
  return model_type
126
127
 
127
128
  # ...def _get_model_type_for_model(...)
@@ -134,7 +135,7 @@ def _initialize_yolo_imports_for_model(model_file,
134
135
  verbose=False):
135
136
  """
136
137
  Initialize the appropriate YOLO imports for a model file.
137
-
138
+
138
139
  Args:
139
140
  model_file (str): The model file for which we're loading support
140
141
  prefer_model_type_source (str, optional): how should we handle the (very unlikely)
@@ -142,18 +143,17 @@ def _initialize_yolo_imports_for_model(model_file,
142
143
  type table says something else. Should be "table" (trust the table) or "file"
143
144
  (trust the file).
144
145
  default_model_type (str, optional): return value for the case where we can't find
145
- appropriate metadata in the file or in the global table.
146
+ appropriate metadata in the file or in the global table.
146
147
  detector_options (dict, optional): dictionary of detector options that mean
147
148
  different things to different models
148
149
  verbose (bool, optional): enable additional debug output
149
-
150
+
150
151
  Returns:
151
152
  str: the model type for which we initialized support
152
153
  """
153
-
154
-
154
+
155
155
  global yolo_model_type_imported
156
-
156
+
157
157
  if detector_options is not None and 'model_type' in detector_options:
158
158
  model_type = detector_options['model_type']
159
159
  print('Model type {} provided in detector options'.format(model_type))
@@ -161,7 +161,7 @@ def _initialize_yolo_imports_for_model(model_file,
161
161
  model_type = _get_model_type_for_model(model_file,
162
162
  prefer_model_type_source=prefer_model_type_source,
163
163
  default_model_type=default_model_type)
164
-
164
+
165
165
  if yolo_model_type_imported is not None:
166
166
  if model_type == yolo_model_type_imported:
167
167
  print('Bypassing imports for model type {}'.format(model_type))
@@ -169,54 +169,92 @@ def _initialize_yolo_imports_for_model(model_file,
169
169
  else:
170
170
  print('Previously set up imports for model type {}, re-importing as {}'.format(
171
171
  yolo_model_type_imported,model_type))
172
-
172
+
173
173
  _initialize_yolo_imports(model_type,verbose=verbose)
174
-
174
+
175
175
  return model_type
176
176
 
177
177
 
178
- def _clean_yolo_imports(verbose=False):
178
+ def _clean_yolo_imports(verbose=False,aggressive_cleanup=False):
179
179
  """
180
180
  Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
181
- of another YOLO library version. The reason we jump through all these hoops, rather than
181
+ of another YOLO library version. The reason we jump through all these hoops, rather than
182
182
  just, e.g., handling different libraries in different modules, is that we need to make sure
183
183
  *pickle* sees the right version of modules during module loading, including modules we don't
184
- load directly (i.e., every module loaded within a YOLO library), and the only way I know to
184
+ load directly (i.e., every module loaded within a YOLO library), and the only way I know to
185
185
  do that is to remove all the "wrong" versions from sys.modules and sys.path.
186
-
186
+
187
187
  Args:
188
188
  verbose (bool, optional): enable additional debug output
189
+ aggressive_cleanup (bool, optional): err on the side of removing modules,
190
+ at least by ignoring whether they are/aren't in a site-packages folder.
191
+ By default, only modules in a folder that includes "site-packages" will
192
+ be considered for unloading.
189
193
  """
190
-
194
+
191
195
  modules_to_delete = []
192
- for module_name in sys.modules.keys():
196
+
197
+ for module_name in sys.modules.keys():
198
+
193
199
  module = sys.modules[module_name]
200
+ if not hasattr(module,'__file__') or (module.__file__ is None):
201
+ continue
194
202
  try:
195
203
  module_file = module.__file__.replace('\\','/')
196
- if 'site-packages' not in module_file:
197
- continue
198
- tokens = module_file.split('/')[-4:]
199
- for token in tokens:
200
- if 'yolov5' in token or 'yolov9' in token or 'ultralytics' in token:
204
+ if not aggressive_cleanup:
205
+ if 'site-packages' not in module_file:
206
+ continue
207
+ tokens = module_file.split('/')
208
+
209
+ # For local path imports, a module filename that should be unloaded might
210
+ # look like:
211
+ #
212
+ # c:/git/yolov9/models/common.py
213
+ #
214
+ # For pip imports, a module filename that should be unloaded might look like:
215
+ #
216
+ # c:/users/user/miniforge3/envs/megadetector/lib/site-packages/yolov9/utils/__init__.py
217
+ first_token_to_check = len(tokens) - 4
218
+ for i_token,token in enumerate(tokens):
219
+ if i_token < first_token_to_check:
220
+ continue
221
+ # Don't remove anything based on the environment name, which
222
+ # always follows "envs" in the path
223
+ if (i_token > 1) and (tokens[i_token-1] == 'envs'):
224
+ continue
225
+ if ('yolov5' in token) or ('yolov9' in token) or ('ultralytics' in token):
226
+ if verbose:
227
+ print('Module {} ({}) looks deletable'.format(module_name,module_file))
201
228
  modules_to_delete.append(module_name)
202
- break
203
- except Exception:
229
+ break
230
+ except Exception as e:
231
+ if verbose:
232
+ print('Exception during module review: {}'.format(str(e)))
204
233
  pass
205
-
234
+
235
+ # ...for each module in the global namespace
236
+
206
237
  for module_name in modules_to_delete:
238
+
207
239
  if module_name in sys.modules.keys():
208
- module_file = module.__file__.replace('\\','/')
209
240
  if verbose:
210
- print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
241
+ try:
242
+ module = sys.modules[module_name]
243
+ module_file = module.__file__.replace('\\','/')
244
+ print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
245
+ except Exception:
246
+ pass
211
247
  del sys.modules[module_name]
212
-
248
+
249
+ # ...for each module we want to remove from the global namespace
250
+
213
251
  paths_to_delete = []
214
-
252
+
215
253
  for p in sys.path:
216
254
  if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
217
255
  print('clean_yolo_imports: removing {} from path'.format(p))
218
256
  paths_to_delete.append(p)
219
-
257
+
220
258
  for p in paths_to_delete:
221
259
  sys.path.remove(p)
222
260
 
@@ -228,52 +266,67 @@ def _initialize_yolo_imports(model_type='yolov5',
228
266
  force_reimport=False,
229
267
  verbose=False):
230
268
  """
231
- Imports required functions from one or more yolo libraries (yolov5, yolov9,
269
+ Imports required functions from one or more yolo libraries (yolov5, yolov9,
232
270
  ultralytics, targeting support for [model_type]).
233
-
271
+
234
272
  Args:
235
273
  model_type (str): The model type for which we're loading support
236
- allow_fallback_import (bool, optional): If we can't import from the package for
237
- which we're trying to load support, fall back to "import utils". This is
274
+ allow_fallback_import (bool, optional): If we can't import from the package for
275
+ which we're trying to load support, fall back to "import utils". This is
238
276
  typically used when the right support library is on the current PYTHONPATH.
239
- force_reimport (bool, optional): import the appropriate libraries even if the
277
+ force_reimport (bool, optional): import the appropriate libraries even if the
240
278
  requested model type matches the current initialization state
241
279
  verbose (bool, optional): include additional debug output
242
-
280
+
243
281
  Returns:
244
282
  str: the model type for which we initialized support
245
283
  """
246
-
284
+
285
+ # When running in pytest, the megadetector 'utils' module is put in the global
286
+ # namespace, which creates conflicts with yolov5; remove it from the global
287
+ # namespsace.
288
+ if ('PYTEST_CURRENT_TEST' in os.environ):
289
+ print('*** pytest detected ***')
290
+ if ('utils' in sys.modules):
291
+ utils_module = sys.modules['utils']
292
+ if hasattr(utils_module, '__file__') and 'megadetector' in str(utils_module.__file__):
293
+ print(f"Removing conflicting utils module: {utils_module.__file__}")
294
+ sys.modules.pop('utils', None)
295
+ # Also remove any submodules
296
+ to_remove = [name for name in sys.modules if name.startswith('utils.')]
297
+ for name in to_remove:
298
+ sys.modules.pop(name, None)
299
+
247
300
  global yolo_model_type_imported
248
-
301
+
249
302
  if model_type is None:
250
303
  model_type = 'yolov5'
251
-
304
+
252
305
  # The point of this function is to make the appropriate version
253
306
  # of the following functions available at module scope
254
307
  global non_max_suppression
255
308
  global xyxy2xywh
256
309
  global letterbox
257
310
  global scale_coords
258
-
311
+
259
312
  if yolo_model_type_imported is not None:
260
- if yolo_model_type_imported == model_type:
313
+ if (yolo_model_type_imported == model_type) and (not force_reimport):
261
314
  print('Bypassing imports for YOLO model type {}'.format(model_type))
262
315
  return
263
316
  else:
264
317
  _clean_yolo_imports()
265
-
318
+
266
319
  try_yolov5_import = (model_type == 'yolov5')
267
320
  try_yolov9_import = (model_type == 'yolov9')
268
321
  try_ultralytics_import = (model_type == 'ultralytics')
269
-
322
+
270
323
  utils_imported = False
271
-
324
+
272
325
  # First try importing from the yolov5 package; this is how the pip
273
326
  # package finds YOLOv5 utilities.
274
327
  if try_yolov5_import and not utils_imported:
275
-
276
- try:
328
+
329
+ try:
277
330
  from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
278
331
  from yolov5.utils.augmentations import letterbox # noqa
279
332
  try:
@@ -283,109 +336,127 @@ def _initialize_yolo_imports(model_type='yolov5',
283
336
  utils_imported = True
284
337
  if verbose:
285
338
  print('Imported utils from YOLOv5 package')
286
-
287
- except Exception as e: # noqa
288
-
339
+
340
+ except Exception as e: # noqa
289
341
  # print('yolov5 module import failed: {}'.format(e))
290
- # print(traceback.format_exc())
342
+ # print(traceback.format_exc())
291
343
  pass
292
-
344
+
293
345
  # Next try importing from the yolov9 package
294
346
  if try_yolov9_import and not utils_imported:
295
-
347
+
296
348
  try:
297
-
349
+
298
350
  from yolov9.utils.general import non_max_suppression, xyxy2xywh # noqa
299
351
  from yolov9.utils.augmentations import letterbox # noqa
300
352
  from yolov9.utils.general import scale_boxes as scale_coords # noqa
301
353
  utils_imported = True
302
354
  if verbose:
303
355
  print('Imported utils from YOLOv9 package')
304
-
356
+
305
357
  except Exception as e: # noqa
306
-
358
+
307
359
  # print('yolov9 module import failed: {}'.format(e))
308
360
  # print(traceback.format_exc())
309
361
  pass
310
-
311
- # If we haven't succeeded yet, import from the ultralytics package
362
+
363
+ # If we haven't succeeded yet, import from the ultralytics package
312
364
  if try_ultralytics_import and not utils_imported:
313
-
365
+
314
366
  try:
315
-
316
- import ultralytics # noqa
317
-
367
+
368
+ import ultralytics # type: ignore # noqa
369
+
318
370
  except Exception:
319
-
371
+
320
372
  print('It looks like you are trying to run a model that requires the ultralytics package, '
321
373
  'but the ultralytics package is not installed, but . For licensing reasons, this '
322
374
  'is not installed by default with the MegaDetector Python package. Run '
323
375
  '"pip install ultralytics" to install it, and try again.')
324
376
  raise
325
-
377
+
326
378
  try:
327
-
328
- from ultralytics.utils.ops import non_max_suppression # noqa
329
- from ultralytics.utils.ops import xyxy2xywh # noqa
330
-
379
+
380
+ from ultralytics.utils.ops import non_max_suppression # type: ignore # noqa
381
+ from ultralytics.utils.ops import xyxy2xywh # type: ignore # noqa
382
+
331
383
  # In the ultralytics package, scale_boxes and scale_coords both exist;
332
384
  # we want scale_boxes.
333
- #
385
+ #
334
386
  # from ultralytics.utils.ops import scale_coords # noqa
335
- from ultralytics.utils.ops import scale_boxes as scale_coords # noqa
336
- from ultralytics.data.augment import LetterBox
337
-
338
- # letterbox() became a LetterBox class in the ultralytics package. Create a
387
+ from ultralytics.utils.ops import scale_boxes as scale_coords # type: ignore # noqa
388
+ from ultralytics.data.augment import LetterBox # type: ignore # noqa
389
+
390
+ # letterbox() became a LetterBox class in the ultralytics package. Create a
339
391
  # backwards-compatible letterbox function wrapper that wraps the class up.
340
- def letterbox(img,new_shape,auto=False,scaleFill=False,scaleup=True,center=True,stride=32): # noqa
341
-
342
- L = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,scaleup=scaleup,center=center,stride=stride)
343
- letterbox_result = L(image=img)
344
-
392
+ def letterbox(img,new_shape,auto=False,scaleFill=False, #noqa
393
+ scaleup=True,center=True,stride=32):
394
+
395
+ # Ultralytics changed the "scaleFill" parameter to "scale_fill", we want to support
396
+ # both conventions.
397
+ use_old_scalefill_arg = False
398
+ try:
399
+ sig = inspect.signature(LetterBox.__init__)
400
+ if 'scaleFill' in sig.parameters:
401
+ use_old_scalefill_arg = True
402
+ except Exception:
403
+ pass
404
+
405
+ if use_old_scalefill_arg:
406
+ if verbose:
407
+ print('Using old scaleFill calling convention')
408
+ letterbox_transformer = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,
409
+ scaleup=scaleup,center=center,stride=stride)
410
+ else:
411
+ letterbox_transformer = LetterBox(new_shape,auto=auto,scale_fill=scaleFill,
412
+ scaleup=scaleup,center=center,stride=stride)
413
+
414
+ letterbox_result = letterbox_transformer(image=img)
415
+
345
416
  if isinstance(new_shape,int):
346
417
  new_shape = [new_shape,new_shape]
347
-
418
+
348
419
  # The letterboxing is done, we just need to reverse-engineer what it did
349
420
  shape = img.shape[:2]
350
-
421
+
351
422
  r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
352
423
  if not scaleup:
353
424
  r = min(r, 1.0)
354
425
  ratio = r, r
355
-
426
+
356
427
  new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
357
428
  dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
358
429
  if auto:
359
430
  dw, dh = np.mod(dw, stride), np.mod(dh, stride)
360
- elif scaleFill:
431
+ elif scaleFill:
361
432
  dw, dh = 0.0, 0.0
362
433
  new_unpad = (new_shape[1], new_shape[0])
363
434
  ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
364
-
435
+
365
436
  dw /= 2
366
437
  dh /= 2
367
438
  pad = (dw,dh)
368
-
439
+
369
440
  return [letterbox_result,ratio,pad]
370
-
441
+
371
442
  utils_imported = True
372
443
  if verbose:
373
444
  print('Imported utils from ultralytics package')
374
-
445
+
375
446
  except Exception:
376
-
447
+
377
448
  # print('Ultralytics module import failed')
378
449
  pass
379
-
450
+
380
451
  # If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
381
452
  if (not utils_imported) and allow_fallback_import:
382
-
453
+
383
454
  try:
384
-
455
+
385
456
  # import pre- and post-processing functions from the YOLOv5 repo
386
457
  from utils.general import non_max_suppression, xyxy2xywh # noqa
387
458
  from utils.augmentations import letterbox # noqa
388
-
459
+
389
460
  # scale_coords() is scale_boxes() in some YOLOv5 versions
390
461
  try:
391
462
  from utils.general import scale_coords # noqa
@@ -395,58 +466,58 @@ def _initialize_yolo_imports(model_type='yolov5',
395
466
  imported_file = sys.modules[scale_coords.__module__].__file__
396
467
  if verbose:
397
468
  print('Imported utils from {}'.format(imported_file))
398
-
469
+
399
470
  except ModuleNotFoundError as e:
400
-
471
+
401
472
  raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
402
-
473
+
403
474
  assert utils_imported, 'YOLO utils import error'
404
-
475
+
405
476
  yolo_model_type_imported = model_type
406
477
  if verbose:
407
478
  print('Prepared YOLO imports for model type {}'.format(model_type))
408
-
479
+
409
480
  return model_type
410
481
 
411
482
  # ...def _initialize_yolo_imports(...)
412
-
483
+
413
484
 
414
485
  #%% Model metadata functions
415
486
 
416
- def add_metadata_to_megadetector_model_file(model_file_in,
487
+ def add_metadata_to_megadetector_model_file(model_file_in,
417
488
  model_file_out,
418
- metadata,
489
+ metadata,
419
490
  destination_path='megadetector_info.json'):
420
491
  """
421
- Adds a .json file to the specified MegaDetector model file containing metadata used
492
+ Adds a .json file to the specified MegaDetector model file containing metadata used
422
493
  by this module. Always over-writes the output file.
423
-
494
+
424
495
  Args:
425
496
  model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
426
497
  model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
427
498
  May be the same as model_file_in.
428
499
  metadata (dict): The metadata dict to add to the output model file
429
- destination_path (str, optional): The relative path within the main folder of the
430
- model archive where we should write the metadata. This is not relative to the root
500
+ destination_path (str, optional): The relative path within the main folder of the
501
+ model archive where we should write the metadata. This is not relative to the root
431
502
  of the archive, it's relative to the one and only folder at the root of the archive
432
- (this is a PyTorch convention).
503
+ (this is a PyTorch convention).
433
504
  """
434
-
505
+
435
506
  tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
436
507
  os.makedirs(tmp_base,exist_ok=True)
437
508
  metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
438
- metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
509
+ metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
439
510
 
440
511
  with open(metadata_tmp_file_abs,'w') as f:
441
512
  json.dump(metadata,f,indent=1)
442
-
513
+
443
514
  # Copy the input file to the output file
444
515
  shutil.copyfile(model_file_in,model_file_out)
445
516
 
446
517
  # Write metadata to the output file
447
518
  with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
448
-
449
- # Torch doesn't like anything in the root folder of the zipfile, so we put
519
+
520
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
450
521
  # it in the one and only folder.
451
522
  names = zipf.namelist()
452
523
  root_folders = set()
@@ -456,9 +527,9 @@ def add_metadata_to_megadetector_model_file(model_file_in,
456
527
  assert len(root_folders) == 1,\
457
528
  'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
458
529
  root_folder = next(iter(root_folders))
459
-
460
- zipf.write(metadata_tmp_file_abs,
461
- root_folder + '/' + destination_path,
530
+
531
+ zipf.write(metadata_tmp_file_abs,
532
+ root_folder + '/' + destination_path,
462
533
  compresslevel=9,
463
534
  compress_type=zipfile.ZIP_DEFLATED)
464
535
 
@@ -466,7 +537,7 @@ def add_metadata_to_megadetector_model_file(model_file_in,
466
537
  os.remove(metadata_tmp_file_abs)
467
538
  except Exception as e:
468
539
  print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
469
-
540
+
470
541
  # ...def add_metadata_to_megadetector_model_file(...)
471
542
 
472
543
 
@@ -475,22 +546,23 @@ def read_metadata_from_megadetector_model_file(model_file,
475
546
  verbose=False):
476
547
  """
477
548
  Reads custom MegaDetector metadata from a modified MegaDetector model file.
478
-
549
+
479
550
  Args:
480
551
  model_file (str): The model filename to read, typically .pt (.zip is also sensible)
481
- relative_path (str, optional): The relative path within the main folder of the model
482
- archive from which we should read the metadata. This is not relative to the root
552
+ relative_path (str, optional): The relative path within the main folder of the model
553
+ archive from which we should read the metadata. This is not relative to the root
483
554
  of the archive, it's relative to the one and only folder at the root of the archive
484
- (this is a PyTorch convention).
485
-
555
+ (this is a PyTorch convention).
556
+ verbose (str, optional): enable additional debug output
557
+
486
558
  Returns:
487
- object: Whatever we read from the metadata file, always a dict in practice. Returns
488
- None if we failed to read the specified metadata file.
559
+ object: whatever we read from the metadata file, always a dict in practice. Returns
560
+ None if we failed to read the specified metadata file.
489
561
  """
490
-
562
+
491
563
  with zipfile.ZipFile(model_file,'r') as zipf:
492
-
493
- # Torch doesn't like anything in the root folder of the zipfile, so we put
564
+
565
+ # Torch doesn't like anything in the root folder of the zipfile, so we put
494
566
  # it in the one and only folder.
495
567
  names = zipf.namelist()
496
568
  root_folders = set()
@@ -498,17 +570,19 @@ def read_metadata_from_megadetector_model_file(model_file,
498
570
  root_folder = name.split('/')[0]
499
571
  root_folders.add(root_folder)
500
572
  if len(root_folders) != 1:
501
- print('Warning: this archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?')
573
+ print('Warning: this archive does not have exactly one folder at the top level; ' + \
574
+ 'are you sure it\'s a Torch model file?')
502
575
  return None
503
576
  root_folder = next(iter(root_folders))
504
-
577
+
505
578
  metadata_file = root_folder + '/' + relative_path
506
579
  if metadata_file not in names:
507
580
  # This is the case for MDv5a and MDv5b
508
581
  if verbose:
509
- print('Warning: could not find metadata file {} in zip archive'.format(metadata_file))
582
+ print('Warning: could not find metadata file {} in zip archive {}'.format(
583
+ metadata_file,os.path.basename(model_file)))
510
584
  return None
511
-
585
+
512
586
  try:
513
587
  path = zipfile.Path(zipf,metadata_file)
514
588
  contents = path.read_text()
@@ -516,9 +590,9 @@ def read_metadata_from_megadetector_model_file(model_file,
516
590
  except Exception as e:
517
591
  print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
518
592
  return None
519
-
593
+
520
594
  return d
521
-
595
+
522
596
  # ...def read_metadata_from_megadetector_model_file(...)
523
597
 
524
598
 
@@ -526,27 +600,33 @@ def read_metadata_from_megadetector_model_file(model_file,
526
600
 
527
601
  default_compatibility_mode = 'classic'
528
602
 
529
- # This is a useful hack when I want to verify that my test driver (md_tests.py) is
603
+ # This is a useful hack when I want to verify that my test driver (md_tests.py) is
530
604
  # correctly forcing a specific compatibility mode (I use "classic-test" in that case)
531
605
  require_non_default_compatibility_mode = False
532
606
 
533
607
  class PTDetector:
534
-
608
+ """
609
+ Class that runs a PyTorch-based MegaDetector model.
610
+ """
611
+
535
612
  def __init__(self, model_path, detector_options=None, verbose=False):
536
-
613
+
614
+ if verbose:
615
+ print('Initializing PTDetector (verbose)')
616
+
537
617
  # Set up the import environment for this model, unloading previous
538
618
  # YOLO library versions if necessary.
539
619
  _initialize_yolo_imports_for_model(model_path,
540
620
  detector_options=detector_options,
541
621
  verbose=verbose)
542
-
622
+
543
623
  # Parse options specific to this detector family
544
624
  force_cpu = False
545
- use_model_native_classes = False
625
+ use_model_native_classes = False
546
626
  compatibility_mode = default_compatibility_mode
547
-
627
+
548
628
  if detector_options is not None:
549
-
629
+
550
630
  if 'force_cpu' in detector_options:
551
631
  force_cpu = parse_bool_string(detector_options['force_cpu'])
552
632
  if 'use_model_native_classes' in detector_options:
@@ -554,66 +634,66 @@ class PTDetector:
554
634
  if 'compatibility_mode' in detector_options:
555
635
  if detector_options['compatibility_mode'] is None:
556
636
  compatibility_mode = default_compatibility_mode
557
- else:
558
- compatibility_mode = detector_options['compatibility_mode']
559
-
637
+ else:
638
+ compatibility_mode = detector_options['compatibility_mode']
639
+
560
640
  if require_non_default_compatibility_mode:
561
-
641
+
562
642
  print('### DEBUG: requiring non-default compatibility mode ###')
563
643
  assert compatibility_mode != 'classic'
564
644
  assert compatibility_mode != 'default'
565
-
645
+
566
646
  preprocess_only = False
567
647
  if (detector_options is not None) and \
568
648
  ('preprocess_only' in detector_options) and \
569
649
  (detector_options['preprocess_only']):
570
650
  preprocess_only = True
571
-
651
+
572
652
  if verbose or (not preprocess_only):
573
653
  print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
574
-
654
+
575
655
  model_metadata = read_metadata_from_megadetector_model_file(model_path)
576
-
577
- #: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
656
+
657
+ #: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
578
658
  #: aspect ratio".
579
659
  if model_metadata is not None and 'image_size' in model_metadata:
580
660
  self.default_image_size = model_metadata['image_size']
581
- if verbose:
582
- print('Loaded image size {} from model metadata'.format(self.default_image_size))
661
+ print('Loaded image size {} from model metadata'.format(self.default_image_size))
583
662
  else:
663
+ print('No image size available in model metadata, defaulting to 1280')
584
664
  self.default_image_size = 1280
585
-
665
+
586
666
  #: Either a string ('cpu','cuda:0') or a torch.device()
587
667
  self.device = 'cpu'
588
-
668
+
589
669
  #: Have we already printed a warning about using a non-standard image size?
590
670
  #:
591
671
  #: :meta private:
592
672
  self.printed_image_size_warning = False
593
-
673
+
594
674
  #: If this is False, we assume the underlying model is producing class indices in the
595
675
  #: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
596
- #: MD classes (1,2,3) before generating output. If this is True, we use whatever
676
+ #: MD classes (1,2,3) before generating output. If this is True, we use whatever
597
677
  #: indices the model provides
598
- self.use_model_native_classes = use_model_native_classes
599
-
678
+ self.use_model_native_classes = use_model_native_classes
679
+
600
680
  #: This allows us to maintain backwards compatibility across a set of changes to the
601
- #: way this class does inference. Currently should start with either "default" or
681
+ #: way this class does inference. Currently should start with either "default" or
602
682
  #: "classic".
603
683
  self.compatibility_mode = compatibility_mode
604
-
684
+
605
685
  #: Stride size passed to YOLOv5's letterbox() function
606
686
  self.letterbox_stride = 32
607
-
687
+
608
688
  if 'classic' in self.compatibility_mode:
609
689
  self.letterbox_stride = 64
610
-
690
+
611
691
  #: Use half-precision inference... fixed by the model, generally don't mess with this
612
692
  self.half_precision = False
613
-
693
+
614
694
  if preprocess_only:
615
695
  return
616
-
696
+
617
697
  if not force_cpu:
618
698
  if torch.cuda.is_available():
619
699
  self.device = torch.device('cuda:0')
@@ -623,10 +703,10 @@ class PTDetector:
623
703
  except AttributeError:
624
704
  pass
625
705
  try:
626
- self.model = PTDetector._load_model(model_path,
627
- device=self.device,
706
+ self.model = PTDetector._load_model(model_path,
707
+ device=self.device,
628
708
  compatibility_mode=self.compatibility_mode)
629
-
709
+
630
710
  except Exception as e:
631
711
  # In a very esoteric scenario where an old version of YOLOv5 is used to run
632
712
  # newer models, we run into an issue because the "Model" class became
@@ -636,21 +716,21 @@ class PTDetector:
636
716
  print('Forward-compatibility issue detected, patching')
637
717
  from models import yolo
638
718
  yolo.DetectionModel = yolo.Model
639
- self.model = PTDetector._load_model(model_path,
719
+ self.model = PTDetector._load_model(model_path,
640
720
  device=self.device,
641
721
  compatibility_mode=self.compatibility_mode,
642
- verbose=verbose)
722
+ verbose=verbose)
643
723
  else:
644
724
  raise
645
725
  if (self.device != 'cpu'):
646
726
  if verbose:
647
727
  print('Sending model to GPU')
648
728
  self.model.to(self.device)
649
-
729
+
650
730
 
651
731
  @staticmethod
652
732
  def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
653
-
733
+
654
734
  if verbose:
655
735
  print(f'Using PyTorch version {torch.__version__}')
656
736
 
@@ -661,12 +741,12 @@ class PTDetector:
661
741
  # very slight changes to the output, which always make me nervous, so I'm not
662
742
  # doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
663
743
  if 'classic' in compatibility_mode:
664
- use_map_location = (device != 'mps')
744
+ use_map_location = (device != 'mps')
665
745
  else:
666
746
  use_map_location = False
667
-
747
+
668
748
  if use_map_location:
669
- try:
749
+ try:
670
750
  checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
671
751
  # For a transitional period, we want to support torch 1.1x, where the weights_only
672
752
  # parameter doesn't exist
@@ -682,31 +762,31 @@ class PTDetector:
682
762
  # parameter doesn't exist
683
763
  except Exception as e:
684
764
  if "'weights_only' is an invalid keyword" in str(e):
685
- checkpoint = torch.load(model_pt_path)
765
+ checkpoint = torch.load(model_pt_path)
686
766
  else:
687
767
  raise
688
-
689
- # Compatibility fix that allows us to load older YOLOv5 models with
768
+
769
+ # Compatibility fix that allows us to load older YOLOv5 models with
690
770
  # newer versions of YOLOv5/PT
691
771
  for m in checkpoint['model'].modules():
692
772
  t = type(m)
693
773
  if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
694
774
  m.recompute_scale_factor = None
695
-
696
- if use_map_location:
697
- model = checkpoint['model'].float().fuse().eval()
775
+
776
+ if use_map_location:
777
+ model = checkpoint['model'].float().fuse().eval()
698
778
  else:
699
- model = checkpoint['model'].float().fuse().eval().to(device)
700
-
779
+ model = checkpoint['model'].float().fuse().eval().to(device)
780
+
701
781
  return model
702
782
 
703
783
  # ...def _load_model(...)
704
-
705
784
 
706
- def generate_detections_one_image(self,
707
- img_original,
708
- image_id='unknown',
709
- detection_threshold=0.00001,
785
+
786
+ def generate_detections_one_image(self,
787
+ img_original,
788
+ image_id='unknown',
789
+ detection_threshold=0.00001,
710
790
  image_size=None,
711
791
  skip_image_resizing=False,
712
792
  augment=False,
@@ -716,16 +796,16 @@ class PTDetector:
716
796
  Applies the detector to an image.
717
797
 
718
798
  Args:
719
- img_original (Image): the PIL Image object (or numpy array) on which we should run the
799
+ img_original (Image): the PIL Image object (or numpy array) on which we should run the
720
800
  detector, with EXIF rotation already handled
721
- image_id (str, optional): a path to identify the image; will be in the "file" field
801
+ image_id (str, optional): a path to identify the image; will be in the "file" field
722
802
  of the output object
723
- detection_threshold (float, optional): only detections above this confidence threshold
803
+ detection_threshold (float, optional): only detections above this confidence threshold
724
804
  will be included in the return value
725
- image_size (tuple, optional): image size to use for inference, only mess with this if
805
+ image_size (int, optional): image size to use for inference, only mess with this if
726
806
  (a) you're using a model other than MegaDetector or (b) you know what you're getting into
727
- skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
728
- external resizing), only mess with this if (a) you're using a model other than MegaDetector
807
+ skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
808
+ external resizing), only mess with this if (a) you're using a model other than MegaDetector
729
809
  or (b) you know what you're getting into
730
810
  augment (bool, optional): enable (implementation-specific) image augmentation
731
811
  preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
@@ -747,17 +827,17 @@ class PTDetector:
747
827
  assert 'classic' in self.compatibility_mode, \
748
828
  'Standalone preprocessing only supported in "classic" mode'
749
829
  assert not skip_image_resizing, \
750
- 'skip_image_resizing and preprocess_only are exclusive'
751
-
830
+ 'skip_image_resizing and preprocess_only are exclusive'
831
+
752
832
  if detection_threshold is None:
753
-
833
+
754
834
  detection_threshold = 0
755
-
835
+
756
836
  try:
757
-
837
+
758
838
  # If the caller wants us to skip all the resizing operations...
759
839
  if skip_image_resizing:
760
-
840
+
761
841
  if isinstance(img_original,dict):
762
842
  image_info = img_original
763
843
  img = image_info['img_processed']
@@ -768,107 +848,107 @@ class PTDetector:
768
848
  img_original_pil = image_info['img_original_pil']
769
849
  else:
770
850
  img = img_original
771
-
851
+
772
852
  else:
773
-
853
+
774
854
  img_original_pil = None
775
855
  # If we were given a PIL image
776
-
856
+
777
857
  if not isinstance(img_original,np.ndarray):
778
- img_original_pil = img_original
858
+ img_original_pil = img_original
779
859
  img_original = np.asarray(img_original)
780
-
860
+
781
861
  # PIL images are RGB already
782
862
  # img_original = img_original[:, :, ::-1]
783
-
863
+
784
864
  # Save the original shape for scaling boxes later
785
865
  scaling_shape = img_original.shape
786
-
866
+
787
867
  # If the caller is requesting a specific target size...
788
868
  if image_size is not None:
789
-
869
+
790
870
  assert isinstance(image_size,int)
791
-
871
+
792
872
  if not self.printed_image_size_warning:
793
873
  print('Using user-supplied image size {}'.format(image_size))
794
- self.printed_image_size_warning = True
795
-
874
+ self.printed_image_size_warning = True
875
+
796
876
  # Otherwise resize to self.default_image_size
797
877
  else:
798
-
878
+
799
879
  image_size = self.default_image_size
800
880
  self.printed_image_size_warning = False
801
-
881
+
802
882
  # ...if the caller has specified an image size
803
-
883
+
804
884
  # In "classic mode", we only do the letterboxing resize, we don't do an
805
885
  # additional initial resizing operation
806
886
  if 'classic' in self.compatibility_mode:
807
-
887
+
808
888
  resize_ratio = 1.0
809
-
810
- # Resize the image so the long side matches the target image size. This is not
889
+
890
+ # Resize the image so the long side matches the target image size. This is not
811
891
  # letterboxing (i.e., padding) yet, just resizing.
812
892
  else:
813
-
893
+
814
894
  use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
815
-
895
+
816
896
  h,w = img_original.shape[:2]
817
897
  resize_ratio = image_size / max(h,w)
818
-
898
+
819
899
  # Only resize if we have to
820
900
  if resize_ratio != 1:
821
-
822
- # Match what yolov5 does: use linear interpolation for upsizing;
901
+
902
+ # Match what yolov5 does: use linear interpolation for upsizing;
823
903
  # area interpolation for downsizing
824
904
  if resize_ratio > 1:
825
905
  interpolation_method = cv2.INTER_LINEAR
826
906
  else:
827
- interpolation_method = cv2.INTER_AREA
828
-
907
+ interpolation_method = cv2.INTER_AREA
908
+
829
909
  if use_ceil_for_resize:
830
910
  target_w = math.ceil(w * resize_ratio)
831
911
  target_h = math.ceil(h * resize_ratio)
832
912
  else:
833
913
  target_w = int(w * resize_ratio)
834
914
  target_h = int(h * resize_ratio)
835
-
915
+
836
916
  img_original = cv2.resize(
837
917
  img_original, (target_w, target_h),
838
918
  interpolation=interpolation_method)
839
919
 
840
920
  if 'classic' in self.compatibility_mode:
841
-
921
+
842
922
  letterbox_auto = True
843
923
  letterbox_scaleup = True
844
924
  target_shape = image_size
845
-
925
+
846
926
  else:
847
-
927
+
848
928
  letterbox_auto = False
849
929
  letterbox_scaleup = False
850
-
930
+
851
931
  # The padding to apply as a fraction of the stride size
852
932
  pad = 0.5
853
-
933
+
854
934
  model_stride = int(self.model.stride.max())
855
-
935
+
856
936
  max_dimension = max(img_original.shape)
857
937
  normalized_shape = [img_original.shape[0] / max_dimension,
858
938
  img_original.shape[1] / max_dimension]
859
939
  target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
860
940
  pad).astype(int) * model_stride
861
-
941
+
862
942
  # Now we letterbox, which is just padding, since we've already resized.
863
- img,letterbox_ratio,letterbox_pad = letterbox(img_original,
943
+ img,letterbox_ratio,letterbox_pad = letterbox(img_original,
864
944
  new_shape=target_shape,
865
- stride=self.letterbox_stride,
945
+ stride=self.letterbox_stride,
866
946
  auto=letterbox_auto,
867
947
  scaleFill=False,
868
948
  scaleup=letterbox_scaleup)
869
-
949
+
870
950
  if preprocess_only:
871
-
951
+
872
952
  assert 'file' in result
873
953
  result['img_processed'] = img
874
954
  result['img_original'] = img_original
@@ -878,14 +958,14 @@ class PTDetector:
878
958
  result['letterbox_ratio'] = letterbox_ratio
879
959
  result['letterbox_pad'] = letterbox_pad
880
960
  return result
881
-
961
+
882
962
  # ...are we doing resizing here, or were images already resized?
883
-
963
+
884
964
  # Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
885
965
  # so we don't need to mess with the color channels.
886
966
  #
887
967
  # TODO, this could be moved into the preprocessing loop
888
-
968
+
889
969
  img = img.transpose((2, 0, 1)) # [::-1]
890
970
  img = np.ascontiguousarray(img)
891
971
  img = torch.from_numpy(img)
@@ -893,8 +973,8 @@ class PTDetector:
893
973
  img = img.half() if self.half_precision else img.float()
894
974
  img /= 255
895
975
 
896
- # In practice this is always true
897
- if len(img.shape) == 3:
976
+ # In practice this is always true
977
+ if len(img.shape) == 3:
898
978
  img = torch.unsqueeze(img, 0)
899
979
 
900
980
  # Run the model
@@ -908,19 +988,19 @@ class PTDetector:
908
988
  else:
909
989
  nms_conf_thres = detection_threshold # 0.01
910
990
  nms_iou_thres = 0.6
911
- nms_agnostic = False
991
+ nms_agnostic = False
912
992
  nms_multi_label = True
913
-
993
+
914
994
  # As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
915
995
  #
916
- # Send predictions back to the CPU for NMS.
996
+ # Send predictions back to the CPU for NMS.
917
997
  if self.device == 'mps':
918
998
  pred_nms = pred.cpu()
919
999
  else:
920
1000
  pred_nms = pred
921
-
1001
+
922
1002
  # NMS
923
- pred = non_max_suppression(prediction=pred_nms,
1003
+ pred = non_max_suppression(prediction=pred_nms,
924
1004
  conf_thres=nms_conf_thres,
925
1005
  iou_thres=nms_iou_thres,
926
1006
  agnostic=nms_agnostic,
@@ -930,26 +1010,26 @@ class PTDetector:
930
1010
  gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
931
1011
 
932
1012
  if 'classic' in self.compatibility_mode:
933
-
1013
+
934
1014
  ratio = None
935
1015
  ratio_pad = None
936
-
1016
+
937
1017
  else:
938
-
1018
+
939
1019
  # letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
940
1020
  #
941
1021
  # ratio is a 2-tuple specifying the scaling that was applied to each dimension.
942
1022
  #
943
1023
  # The scale_boxes function expects a 2-tuple with these things combined.
944
1024
  ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
945
- ratio_pad = (ratio, letterbox_pad)
946
-
1025
+ ratio_pad = (ratio, letterbox_pad)
1026
+
947
1027
  # This is a loop over detection batches, which will always be length 1 in our case,
948
1028
  # since we're not doing batch inference.
949
1029
  #
950
1030
  # det = pred[0]
951
1031
  #
952
- # det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
1032
+ # det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
953
1033
  # in descending order by confidence.
954
1034
  #
955
1035
  # Columns are:
@@ -959,15 +1039,15 @@ class PTDetector:
959
1039
  # At this point, these are *non*-normalized values, referring to the size at which we
960
1040
  # ran inference (img.shape).
961
1041
  for det in pred:
962
-
1042
+
963
1043
  if len(det) == 0:
964
1044
  continue
965
-
1045
+
966
1046
  # Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
967
1047
  if 'classic' in self.compatibility_mode:
968
-
1048
+
969
1049
  det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
970
-
1050
+
971
1051
  else:
972
1052
  # After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
973
1053
  # original pixel dimension of the image, followed by the class and confidence
@@ -975,14 +1055,14 @@ class PTDetector:
975
1055
 
976
1056
  # Loop over detections
977
1057
  for *xyxy, conf, cls in reversed(det):
978
-
1058
+
979
1059
  if conf < detection_threshold:
980
1060
  continue
981
-
1061
+
982
1062
  # Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
983
1063
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
984
1064
 
985
- # Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
1065
+ # Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
986
1066
  # left/top/w/h (i.e., MD format)
987
1067
  api_box = ct_utils.convert_yolo_to_xywh(xywh)
988
1068
 
@@ -991,11 +1071,11 @@ class PTDetector:
991
1071
  conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
992
1072
  else:
993
1073
  api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
994
- conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
995
-
1074
+ conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
1075
+
996
1076
  if not self.use_model_native_classes:
997
- # The MegaDetector output format's categories start at 1, but all YOLO-based
998
- # MD models have category numbers starting at 0.
1077
+ # The MegaDetector output format's categories start at 1, but all YOLO-based
1078
+ # MD models have category numbers starting at 0.
999
1079
  cls = int(cls.tolist()) + 1
1000
1080
  if cls not in (1, 2, 3):
1001
1081
  raise KeyError(f'{cls} is not a valid class.')
@@ -1008,15 +1088,15 @@ class PTDetector:
1008
1088
  'bbox': api_box
1009
1089
  })
1010
1090
  max_conf = max(max_conf, conf)
1011
-
1091
+
1012
1092
  # ...for each detection in this batch
1013
-
1093
+
1014
1094
  # ...for each detection batch (always one iteration)
1015
1095
 
1016
1096
  # ...try
1017
-
1097
+
1018
1098
  except Exception as e:
1019
-
1099
+
1020
1100
  result['failure'] = FAILURE_INFER
1021
1101
  print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
1022
1102
  # traceback.print_exc(e)
@@ -1039,12 +1119,12 @@ class PTDetector:
1039
1119
  if __name__ == '__main__':
1040
1120
 
1041
1121
  pass
1042
-
1122
+
1043
1123
  #%%
1044
-
1124
+
1045
1125
  import os #noqa
1046
1126
  from megadetector.visualization import visualization_utils as vis_utils
1047
-
1127
+
1048
1128
  model_file = os.environ['MDV5A']
1049
1129
  im_file = os.path.expanduser('~/git/MegaDetector/images/nacti.jpg')
1050
1130