megadetector 5.0.11__py3-none-any.whl → 5.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (201) hide show
  1. megadetector/api/__init__.py +0 -0
  2. megadetector/api/batch_processing/__init__.py +0 -0
  3. megadetector/api/batch_processing/api_core/__init__.py +0 -0
  4. megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
  5. megadetector/api/batch_processing/api_core/batch_service/score.py +439 -0
  6. megadetector/api/batch_processing/api_core/server.py +294 -0
  7. megadetector/api/batch_processing/api_core/server_api_config.py +98 -0
  8. megadetector/api/batch_processing/api_core/server_app_config.py +55 -0
  9. megadetector/api/batch_processing/api_core/server_batch_job_manager.py +220 -0
  10. megadetector/api/batch_processing/api_core/server_job_status_table.py +152 -0
  11. megadetector/api/batch_processing/api_core/server_orchestration.py +360 -0
  12. megadetector/api/batch_processing/api_core/server_utils.py +92 -0
  13. megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
  14. megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +46 -0
  15. megadetector/api/batch_processing/api_support/__init__.py +0 -0
  16. megadetector/api/batch_processing/api_support/summarize_daily_activity.py +152 -0
  17. megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
  18. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  19. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  20. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  21. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +126 -0
  22. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  23. megadetector/api/synchronous/__init__.py +0 -0
  24. megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  25. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +152 -0
  26. megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +266 -0
  27. megadetector/api/synchronous/api_core/animal_detection_api/config.py +35 -0
  28. megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
  29. megadetector/api/synchronous/api_core/tests/load_test.py +110 -0
  30. megadetector/classification/__init__.py +0 -0
  31. megadetector/classification/aggregate_classifier_probs.py +108 -0
  32. megadetector/classification/analyze_failed_images.py +227 -0
  33. megadetector/classification/cache_batchapi_outputs.py +198 -0
  34. megadetector/classification/create_classification_dataset.py +627 -0
  35. megadetector/classification/crop_detections.py +516 -0
  36. megadetector/classification/csv_to_json.py +226 -0
  37. megadetector/classification/detect_and_crop.py +855 -0
  38. megadetector/classification/efficientnet/__init__.py +9 -0
  39. megadetector/classification/efficientnet/model.py +415 -0
  40. megadetector/classification/efficientnet/utils.py +610 -0
  41. megadetector/classification/evaluate_model.py +520 -0
  42. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  43. megadetector/classification/json_to_azcopy_list.py +63 -0
  44. megadetector/classification/json_validator.py +699 -0
  45. megadetector/classification/map_classification_categories.py +276 -0
  46. megadetector/classification/merge_classification_detection_output.py +506 -0
  47. megadetector/classification/prepare_classification_script.py +194 -0
  48. megadetector/classification/prepare_classification_script_mc.py +228 -0
  49. megadetector/classification/run_classifier.py +287 -0
  50. megadetector/classification/save_mislabeled.py +110 -0
  51. megadetector/classification/train_classifier.py +827 -0
  52. megadetector/classification/train_classifier_tf.py +725 -0
  53. megadetector/classification/train_utils.py +323 -0
  54. megadetector/data_management/__init__.py +0 -0
  55. megadetector/data_management/annotations/__init__.py +0 -0
  56. megadetector/data_management/annotations/annotation_constants.py +34 -0
  57. megadetector/data_management/camtrap_dp_to_coco.py +239 -0
  58. megadetector/data_management/cct_json_utils.py +395 -0
  59. megadetector/data_management/cct_to_md.py +176 -0
  60. megadetector/data_management/cct_to_wi.py +289 -0
  61. megadetector/data_management/coco_to_labelme.py +272 -0
  62. megadetector/data_management/coco_to_yolo.py +662 -0
  63. megadetector/data_management/databases/__init__.py +0 -0
  64. megadetector/data_management/databases/add_width_and_height_to_db.py +33 -0
  65. megadetector/data_management/databases/combine_coco_camera_traps_files.py +206 -0
  66. megadetector/data_management/databases/integrity_check_json_db.py +477 -0
  67. megadetector/data_management/databases/subset_json_db.py +115 -0
  68. megadetector/data_management/generate_crops_from_cct.py +149 -0
  69. megadetector/data_management/get_image_sizes.py +189 -0
  70. megadetector/data_management/importers/add_nacti_sizes.py +52 -0
  71. megadetector/data_management/importers/add_timestamps_to_icct.py +79 -0
  72. megadetector/data_management/importers/animl_results_to_md_results.py +158 -0
  73. megadetector/data_management/importers/auckland_doc_test_to_json.py +373 -0
  74. megadetector/data_management/importers/auckland_doc_to_json.py +201 -0
  75. megadetector/data_management/importers/awc_to_json.py +191 -0
  76. megadetector/data_management/importers/bellevue_to_json.py +273 -0
  77. megadetector/data_management/importers/cacophony-thermal-importer.py +796 -0
  78. megadetector/data_management/importers/carrizo_shrubfree_2018.py +269 -0
  79. megadetector/data_management/importers/carrizo_trail_cam_2017.py +289 -0
  80. megadetector/data_management/importers/cct_field_adjustments.py +58 -0
  81. megadetector/data_management/importers/channel_islands_to_cct.py +913 -0
  82. megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +180 -0
  83. megadetector/data_management/importers/eMammal/eMammal_helpers.py +249 -0
  84. megadetector/data_management/importers/eMammal/make_eMammal_json.py +223 -0
  85. megadetector/data_management/importers/ena24_to_json.py +276 -0
  86. megadetector/data_management/importers/filenames_to_json.py +386 -0
  87. megadetector/data_management/importers/helena_to_cct.py +283 -0
  88. megadetector/data_management/importers/idaho-camera-traps.py +1407 -0
  89. megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +294 -0
  90. megadetector/data_management/importers/jb_csv_to_json.py +150 -0
  91. megadetector/data_management/importers/mcgill_to_json.py +250 -0
  92. megadetector/data_management/importers/missouri_to_json.py +490 -0
  93. megadetector/data_management/importers/nacti_fieldname_adjustments.py +79 -0
  94. megadetector/data_management/importers/noaa_seals_2019.py +181 -0
  95. megadetector/data_management/importers/pc_to_json.py +365 -0
  96. megadetector/data_management/importers/plot_wni_giraffes.py +123 -0
  97. megadetector/data_management/importers/prepare-noaa-fish-data-for-lila.py +359 -0
  98. megadetector/data_management/importers/prepare_zsl_imerit.py +131 -0
  99. megadetector/data_management/importers/rspb_to_json.py +356 -0
  100. megadetector/data_management/importers/save_the_elephants_survey_A.py +320 -0
  101. megadetector/data_management/importers/save_the_elephants_survey_B.py +329 -0
  102. megadetector/data_management/importers/snapshot_safari_importer.py +758 -0
  103. megadetector/data_management/importers/snapshot_safari_importer_reprise.py +665 -0
  104. megadetector/data_management/importers/snapshot_serengeti_lila.py +1067 -0
  105. megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +150 -0
  106. megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +153 -0
  107. megadetector/data_management/importers/sulross_get_exif.py +65 -0
  108. megadetector/data_management/importers/timelapse_csv_set_to_json.py +490 -0
  109. megadetector/data_management/importers/ubc_to_json.py +399 -0
  110. megadetector/data_management/importers/umn_to_json.py +507 -0
  111. megadetector/data_management/importers/wellington_to_json.py +263 -0
  112. megadetector/data_management/importers/wi_to_json.py +442 -0
  113. megadetector/data_management/importers/zamba_results_to_md_results.py +181 -0
  114. megadetector/data_management/labelme_to_coco.py +547 -0
  115. megadetector/data_management/labelme_to_yolo.py +272 -0
  116. megadetector/data_management/lila/__init__.py +0 -0
  117. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +97 -0
  118. megadetector/data_management/lila/add_locations_to_nacti.py +147 -0
  119. megadetector/data_management/lila/create_lila_blank_set.py +558 -0
  120. megadetector/data_management/lila/create_lila_test_set.py +152 -0
  121. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  122. megadetector/data_management/lila/download_lila_subset.py +178 -0
  123. megadetector/data_management/lila/generate_lila_per_image_labels.py +516 -0
  124. megadetector/data_management/lila/get_lila_annotation_counts.py +170 -0
  125. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  126. megadetector/data_management/lila/lila_common.py +300 -0
  127. megadetector/data_management/lila/test_lila_metadata_urls.py +132 -0
  128. megadetector/data_management/ocr_tools.py +874 -0
  129. megadetector/data_management/read_exif.py +681 -0
  130. megadetector/data_management/remap_coco_categories.py +84 -0
  131. megadetector/data_management/remove_exif.py +66 -0
  132. megadetector/data_management/resize_coco_dataset.py +189 -0
  133. megadetector/data_management/wi_download_csv_to_coco.py +246 -0
  134. megadetector/data_management/yolo_output_to_md_output.py +441 -0
  135. megadetector/data_management/yolo_to_coco.py +676 -0
  136. megadetector/detection/__init__.py +0 -0
  137. megadetector/detection/detector_training/__init__.py +0 -0
  138. megadetector/detection/detector_training/model_main_tf2.py +114 -0
  139. megadetector/detection/process_video.py +702 -0
  140. megadetector/detection/pytorch_detector.py +341 -0
  141. megadetector/detection/run_detector.py +779 -0
  142. megadetector/detection/run_detector_batch.py +1219 -0
  143. megadetector/detection/run_inference_with_yolov5_val.py +917 -0
  144. megadetector/detection/run_tiled_inference.py +934 -0
  145. megadetector/detection/tf_detector.py +189 -0
  146. megadetector/detection/video_utils.py +606 -0
  147. megadetector/postprocessing/__init__.py +0 -0
  148. megadetector/postprocessing/add_max_conf.py +64 -0
  149. megadetector/postprocessing/categorize_detections_by_size.py +163 -0
  150. megadetector/postprocessing/combine_api_outputs.py +249 -0
  151. megadetector/postprocessing/compare_batch_results.py +958 -0
  152. megadetector/postprocessing/convert_output_format.py +396 -0
  153. megadetector/postprocessing/load_api_results.py +195 -0
  154. megadetector/postprocessing/md_to_coco.py +310 -0
  155. megadetector/postprocessing/md_to_labelme.py +330 -0
  156. megadetector/postprocessing/merge_detections.py +401 -0
  157. megadetector/postprocessing/postprocess_batch_results.py +1902 -0
  158. megadetector/postprocessing/remap_detection_categories.py +170 -0
  159. megadetector/postprocessing/render_detection_confusion_matrix.py +660 -0
  160. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +211 -0
  161. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +83 -0
  162. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1631 -0
  163. megadetector/postprocessing/separate_detections_into_folders.py +730 -0
  164. megadetector/postprocessing/subset_json_detector_output.py +696 -0
  165. megadetector/postprocessing/top_folders_to_bottom.py +223 -0
  166. megadetector/taxonomy_mapping/__init__.py +0 -0
  167. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  168. megadetector/taxonomy_mapping/map_new_lila_datasets.py +150 -0
  169. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +142 -0
  170. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +590 -0
  171. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  172. megadetector/taxonomy_mapping/simple_image_download.py +219 -0
  173. megadetector/taxonomy_mapping/species_lookup.py +834 -0
  174. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  175. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  176. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  177. megadetector/utils/__init__.py +0 -0
  178. megadetector/utils/azure_utils.py +178 -0
  179. megadetector/utils/ct_utils.py +612 -0
  180. megadetector/utils/directory_listing.py +246 -0
  181. megadetector/utils/md_tests.py +968 -0
  182. megadetector/utils/path_utils.py +1044 -0
  183. megadetector/utils/process_utils.py +157 -0
  184. megadetector/utils/sas_blob_utils.py +509 -0
  185. megadetector/utils/split_locations_into_train_val.py +228 -0
  186. megadetector/utils/string_utils.py +92 -0
  187. megadetector/utils/url_utils.py +323 -0
  188. megadetector/utils/write_html_image_list.py +225 -0
  189. megadetector/visualization/__init__.py +0 -0
  190. megadetector/visualization/plot_utils.py +293 -0
  191. megadetector/visualization/render_images_with_thumbnails.py +275 -0
  192. megadetector/visualization/visualization_utils.py +1536 -0
  193. megadetector/visualization/visualize_db.py +550 -0
  194. megadetector/visualization/visualize_detector_output.py +405 -0
  195. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/METADATA +1 -1
  196. megadetector-5.0.12.dist-info/RECORD +199 -0
  197. megadetector-5.0.12.dist-info/top_level.txt +1 -0
  198. megadetector-5.0.11.dist-info/RECORD +0 -5
  199. megadetector-5.0.11.dist-info/top_level.txt +0 -1
  200. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/LICENSE +0 -0
  201. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/WHEEL +0 -0
@@ -0,0 +1,341 @@
1
+ """
2
+
3
+ pytorch_detector.py
4
+
5
+ Module to run MegaDetector v5, a PyTorch YOLOv5 animal detection model.
6
+
7
+ """
8
+
9
+ #%% Imports and constants
10
+
11
+ import sys
12
+ import torch
13
+ import numpy as np
14
+ import traceback
15
+ import builtins
16
+
17
+ from megadetector.detection.run_detector import CONF_DIGITS, COORD_DIGITS, FAILURE_INFER
18
+ from megadetector.utils import ct_utils
19
+
20
+ # We support a few ways of accessing the YOLOv5 dependencies:
21
+ #
22
+ # * The standard configuration as of 2023.09 expects that the YOLOv5 repo is checked
23
+ # out and on the PYTHONPATH (import utils)
24
+ #
25
+ # * Supported but non-default (used for PyPI packaging):
26
+ #
27
+ # pip install ultralytics-yolov5
28
+ #
29
+ # * Works, but not supported:
30
+ #
31
+ # pip install yolov5
32
+ #
33
+ # * Unfinished:
34
+ #
35
+ # pip install ultralytics
36
+ #
37
+ # If try_ultralytics_import is True, we'll try to import all YOLOv5 dependencies from
38
+ # ultralytics.utils and ultralytics.data. But as of 2023.11, this results in a "No
39
+ # module named 'models'" error when running MDv5, and there's no upside to this approach
40
+ # compared to using either of the YOLOv5 PyPI packages, so... punting on this for now.
41
+
42
+ utils_imported = False
43
+ try_yolov5_import = True
44
+
45
+ # See above; this should remain as "False" unless we update the MegaDetector .pt file
46
+ # to use more recent YOLOv5 namespace conventions.
47
+ try_ultralytics_import = False
48
+
49
+ # First try importing from the yolov5 package; this is how the pip
50
+ # package finds YOLOv5 utilities.
51
+ if try_yolov5_import and not utils_imported:
52
+
53
+ try:
54
+ from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
55
+ from yolov5.utils.augmentations import letterbox # noqa
56
+ from yolov5.utils.general import scale_boxes as scale_coords # noqa
57
+ utils_imported = True
58
+ print('Imported YOLOv5 from YOLOv5 package')
59
+ except Exception:
60
+ # print('YOLOv5 module import failed, falling back to path-based import')
61
+ pass
62
+
63
+ # If we haven't succeeded yet, import from the ultralytics package
64
+ if try_ultralytics_import and not utils_imported:
65
+
66
+ try:
67
+ from ultralytics.utils.ops import non_max_suppression # noqa
68
+ from ultralytics.utils.ops import xyxy2xywh # noqa
69
+ from ultralytics.utils.ops import scale_coords # noqa
70
+ from ultralytics.data.augment import LetterBox
71
+
72
+ # letterbox() became a LetterBox class in the ultralytics package
73
+ def letterbox(img,new_shape,stride,auto=True): # noqa
74
+ L = LetterBox(new_shape,stride=stride,auto=auto)
75
+ letterbox_result = L(image=img)
76
+ return [letterbox_result]
77
+ utils_imported = True
78
+ print('Imported YOLOv5 from ultralytics package')
79
+ except Exception:
80
+ # print('Ultralytics module import failed, falling back to yolov5 import')
81
+ pass
82
+
83
+ # If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
84
+ if not utils_imported:
85
+
86
+ try:
87
+ # import pre- and post-processing functions from the YOLOv5 repo
88
+ from utils.general import non_max_suppression, xyxy2xywh # noqa
89
+ from utils.augmentations import letterbox # noqa
90
+
91
+ # scale_coords() became scale_boxes() in later YOLOv5 versions
92
+ try:
93
+ from utils.general import scale_coords # noqa
94
+ except ImportError:
95
+ from utils.general import scale_boxes as scale_coords
96
+ utils_imported = True
97
+ print('Imported YOLOv5 as utils.*')
98
+
99
+ except ModuleNotFoundError as e:
100
+ raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
101
+
102
+ assert utils_imported, 'YOLOv5 import error'
103
+
104
+ print(f'Using PyTorch version {torch.__version__}')
105
+
106
+
107
+ #%% Classes
108
+
109
+ class PTDetector:
110
+
111
+ #: Image size passed to YOLOv5's letterbox() function; 1280 means "1280 on the long side, preserving
112
+ #: aspect ratio"
113
+ #:
114
+ #: :meta private:
115
+ IMAGE_SIZE = 1280
116
+
117
+ #: Stride size passed to YOLOv5's letterbox() function
118
+ #:
119
+ #: :meta private:
120
+ STRIDE = 64
121
+
122
+ def __init__(self, model_path, force_cpu=False, use_model_native_classes= False):
123
+
124
+ self.device = 'cpu'
125
+ if not force_cpu:
126
+ if torch.cuda.is_available():
127
+ self.device = torch.device('cuda:0')
128
+ try:
129
+ if torch.backends.mps.is_built and torch.backends.mps.is_available():
130
+ self.device = 'mps'
131
+ except AttributeError:
132
+ pass
133
+ self.model = PTDetector._load_model(model_path, self.device)
134
+ if (self.device != 'cpu'):
135
+ print('Sending model to GPU')
136
+ self.model.to(self.device)
137
+
138
+ self.printed_image_size_warning = False
139
+ self.use_model_native_classes = use_model_native_classes
140
+
141
+
142
+ @staticmethod
143
+ def _load_model(model_pt_path, device):
144
+
145
+ # There are two very slightly different ways to load the model, (1) using the
146
+ # map_location=device parameter to torch.load and (2) calling .to(device) after
147
+ # loading the model. The former is what we did for a zillion years, but is not
148
+ # supported on Apple silicon at of 2029.09. Switching to the latter causes
149
+ # very slight changes to the output, which always make me nervous, so I'm not
150
+ # doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
151
+ use_map_location = (device != 'mps')
152
+
153
+ if use_map_location:
154
+ checkpoint = torch.load(model_pt_path, map_location=device)
155
+ else:
156
+ checkpoint = torch.load(model_pt_path)
157
+
158
+ # Compatibility fix that allows us to load older YOLOv5 models with
159
+ # newer versions of YOLOv5/PT
160
+ for m in checkpoint['model'].modules():
161
+ t = type(m)
162
+ if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
163
+ m.recompute_scale_factor = None
164
+
165
+ if use_map_location:
166
+ model = checkpoint['model'].float().fuse().eval()
167
+ else:
168
+ model = checkpoint['model'].float().fuse().eval().to(device)
169
+
170
+ return model
171
+
172
+ def generate_detections_one_image(self, img_original, image_id='unknown',
173
+ detection_threshold=0.00001, image_size=None,
174
+ skip_image_resizing=False):
175
+ """
176
+ Applies the detector to an image.
177
+
178
+ Args:
179
+ img_original (Image): the PIL Image object with EXIF rotation taken into account
180
+ image_id (str, optional): a path to identify the image; will be in the "file" field
181
+ of the output object
182
+ detection_threshold (float, optional): only detections above this confidence threshold
183
+ will be included in the return value
184
+ image_size (tuple, optional): image size to use for inference, only mess with this
185
+ if (a) you're using a model other than MegaDetector or (b) you know what you're
186
+ doing
187
+ skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on external
188
+ resizing)
189
+
190
+ Returns:
191
+ dict: a dictionary with the following fields:
192
+ - 'file' (filename, always present)
193
+ - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
194
+ - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
195
+ - 'failure' (a failure string, or None if everything went fine)
196
+ """
197
+
198
+ result = {
199
+ 'file': image_id
200
+ }
201
+ detections = []
202
+ max_conf = 0.0
203
+
204
+ try:
205
+
206
+ img_original = np.asarray(img_original)
207
+
208
+ # padded resize
209
+ target_size = PTDetector.IMAGE_SIZE
210
+
211
+ # Image size can be an int (which translates to a square target size) or (h,w)
212
+ if image_size is not None:
213
+
214
+ assert isinstance(image_size,int) or (len(image_size)==2)
215
+
216
+ if not self.printed_image_size_warning:
217
+ print('Warning: using user-supplied image size {}'.format(image_size))
218
+ self.printed_image_size_warning = True
219
+
220
+ target_size = image_size
221
+
222
+ else:
223
+
224
+ self.printed_image_size_warning = False
225
+
226
+ # ...if the caller has specified an image size
227
+
228
+ if skip_image_resizing:
229
+ img = img_original
230
+ else:
231
+ letterbox_result = letterbox(img_original, new_shape=target_size,
232
+ stride=PTDetector.STRIDE, auto=True)
233
+ img = letterbox_result[0]
234
+
235
+ # HWC to CHW; PIL Image is RGB already
236
+ img = img.transpose((2, 0, 1))
237
+ img = np.ascontiguousarray(img)
238
+ img = torch.from_numpy(img)
239
+ img = img.to(self.device)
240
+ img = img.float()
241
+ img /= 255
242
+
243
+ # In practice this is always true
244
+ if len(img.shape) == 3:
245
+ img = torch.unsqueeze(img, 0)
246
+
247
+ pred: list = self.model(img)[0]
248
+
249
+ # NMS
250
+ if self.device == 'mps':
251
+ # As of v1.13.0.dev20220824, nms is not implemented for MPS.
252
+ #
253
+ # Send prediction back to the CPU to fix.
254
+ pred = non_max_suppression(prediction=pred.cpu(), conf_thres=detection_threshold)
255
+ else:
256
+ pred = non_max_suppression(prediction=pred, conf_thres=detection_threshold)
257
+
258
+ # format detections/bounding boxes
259
+ #
260
+ # normalization gain whwh
261
+ gn = torch.tensor(img_original.shape)[[1, 0, 1, 0]]
262
+
263
+ # This is a loop over detection batches, which will always be length 1 in our case,
264
+ # since we're not doing batch inference.
265
+ for det in pred:
266
+
267
+ if len(det):
268
+
269
+ # Rescale boxes from img_size to im0 size
270
+ det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
271
+
272
+ for *xyxy, conf, cls in reversed(det):
273
+
274
+ # normalized center-x, center-y, width and height
275
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
276
+
277
+ api_box = ct_utils.convert_yolo_to_xywh(xywh)
278
+
279
+ conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
280
+
281
+ if not self.use_model_native_classes:
282
+ # MegaDetector output format's categories start at 1, but the MD
283
+ # model's categories start at 0.
284
+ cls = int(cls.tolist()) + 1
285
+ if cls not in (1, 2, 3):
286
+ raise KeyError(f'{cls} is not a valid class.')
287
+ else:
288
+ cls = int(cls.tolist())
289
+
290
+ detections.append({
291
+ 'category': str(cls),
292
+ 'conf': conf,
293
+ 'bbox': ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
294
+ })
295
+ max_conf = max(max_conf, conf)
296
+
297
+ # ...for each detection in this batch
298
+
299
+ # ...if this is a non-empty batch
300
+
301
+ # ...for each detection batch
302
+
303
+ # ...try
304
+
305
+ except Exception as e:
306
+
307
+ result['failure'] = FAILURE_INFER
308
+ print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
309
+ traceback.print_exc(e)
310
+
311
+ result['max_detection_conf'] = max_conf
312
+ result['detections'] = detections
313
+
314
+ return result
315
+
316
+ # ...def generate_detections_one_image(...)
317
+
318
+ # ...class PTDetector
319
+
320
+
321
+ #%% Command-line driver
322
+
323
+ # For testing only... you don't really want to run this module directly.
324
+
325
+ if __name__ == '__main__':
326
+
327
+ pass
328
+
329
+ #%%
330
+
331
+ import os
332
+ from megadetector.visualization import visualization_utils as vis_utils
333
+
334
+ model_file = os.environ['MDV5A']
335
+ im_file = os.path.expanduser('~/git/MegaDetector/images/nacti.jpg')
336
+
337
+ detector = PTDetector(model_file)
338
+ image = vis_utils.load_image(im_file)
339
+
340
+ res = detector.generate_detections_one_image(image, im_file, detection_threshold=0.00001)
341
+ print(res)