megadetector 5.0.28__py3-none-any.whl → 10.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (197) hide show
  1. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +2 -2
  2. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +1 -1
  3. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +1 -1
  4. megadetector/classification/aggregate_classifier_probs.py +3 -3
  5. megadetector/classification/analyze_failed_images.py +5 -5
  6. megadetector/classification/cache_batchapi_outputs.py +5 -5
  7. megadetector/classification/create_classification_dataset.py +11 -12
  8. megadetector/classification/crop_detections.py +10 -10
  9. megadetector/classification/csv_to_json.py +8 -8
  10. megadetector/classification/detect_and_crop.py +13 -15
  11. megadetector/classification/efficientnet/model.py +8 -8
  12. megadetector/classification/efficientnet/utils.py +6 -5
  13. megadetector/classification/evaluate_model.py +7 -7
  14. megadetector/classification/identify_mislabeled_candidates.py +6 -6
  15. megadetector/classification/json_to_azcopy_list.py +1 -1
  16. megadetector/classification/json_validator.py +29 -32
  17. megadetector/classification/map_classification_categories.py +9 -9
  18. megadetector/classification/merge_classification_detection_output.py +12 -9
  19. megadetector/classification/prepare_classification_script.py +19 -19
  20. megadetector/classification/prepare_classification_script_mc.py +26 -26
  21. megadetector/classification/run_classifier.py +4 -4
  22. megadetector/classification/save_mislabeled.py +6 -6
  23. megadetector/classification/train_classifier.py +1 -1
  24. megadetector/classification/train_classifier_tf.py +9 -9
  25. megadetector/classification/train_utils.py +10 -10
  26. megadetector/data_management/annotations/annotation_constants.py +1 -2
  27. megadetector/data_management/camtrap_dp_to_coco.py +79 -46
  28. megadetector/data_management/cct_json_utils.py +103 -103
  29. megadetector/data_management/cct_to_md.py +49 -49
  30. megadetector/data_management/cct_to_wi.py +33 -33
  31. megadetector/data_management/coco_to_labelme.py +75 -75
  32. megadetector/data_management/coco_to_yolo.py +210 -193
  33. megadetector/data_management/databases/add_width_and_height_to_db.py +86 -12
  34. megadetector/data_management/databases/combine_coco_camera_traps_files.py +40 -40
  35. megadetector/data_management/databases/integrity_check_json_db.py +228 -200
  36. megadetector/data_management/databases/subset_json_db.py +33 -33
  37. megadetector/data_management/generate_crops_from_cct.py +88 -39
  38. megadetector/data_management/get_image_sizes.py +54 -49
  39. megadetector/data_management/labelme_to_coco.py +133 -125
  40. megadetector/data_management/labelme_to_yolo.py +159 -73
  41. megadetector/data_management/lila/create_lila_blank_set.py +81 -83
  42. megadetector/data_management/lila/create_lila_test_set.py +32 -31
  43. megadetector/data_management/lila/create_links_to_md_results_files.py +18 -18
  44. megadetector/data_management/lila/download_lila_subset.py +21 -24
  45. megadetector/data_management/lila/generate_lila_per_image_labels.py +365 -107
  46. megadetector/data_management/lila/get_lila_annotation_counts.py +35 -33
  47. megadetector/data_management/lila/get_lila_image_counts.py +22 -22
  48. megadetector/data_management/lila/lila_common.py +73 -70
  49. megadetector/data_management/lila/test_lila_metadata_urls.py +28 -19
  50. megadetector/data_management/mewc_to_md.py +344 -340
  51. megadetector/data_management/ocr_tools.py +262 -255
  52. megadetector/data_management/read_exif.py +249 -227
  53. megadetector/data_management/remap_coco_categories.py +90 -28
  54. megadetector/data_management/remove_exif.py +81 -21
  55. megadetector/data_management/rename_images.py +187 -187
  56. megadetector/data_management/resize_coco_dataset.py +588 -120
  57. megadetector/data_management/speciesnet_to_md.py +41 -41
  58. megadetector/data_management/wi_download_csv_to_coco.py +55 -55
  59. megadetector/data_management/yolo_output_to_md_output.py +248 -122
  60. megadetector/data_management/yolo_to_coco.py +333 -191
  61. megadetector/detection/change_detection.py +832 -0
  62. megadetector/detection/process_video.py +340 -337
  63. megadetector/detection/pytorch_detector.py +358 -278
  64. megadetector/detection/run_detector.py +399 -186
  65. megadetector/detection/run_detector_batch.py +404 -377
  66. megadetector/detection/run_inference_with_yolov5_val.py +340 -327
  67. megadetector/detection/run_tiled_inference.py +257 -249
  68. megadetector/detection/tf_detector.py +24 -24
  69. megadetector/detection/video_utils.py +332 -295
  70. megadetector/postprocessing/add_max_conf.py +19 -11
  71. megadetector/postprocessing/categorize_detections_by_size.py +45 -45
  72. megadetector/postprocessing/classification_postprocessing.py +468 -433
  73. megadetector/postprocessing/combine_batch_outputs.py +23 -23
  74. megadetector/postprocessing/compare_batch_results.py +590 -525
  75. megadetector/postprocessing/convert_output_format.py +106 -102
  76. megadetector/postprocessing/create_crop_folder.py +347 -147
  77. megadetector/postprocessing/detector_calibration.py +173 -168
  78. megadetector/postprocessing/generate_csv_report.py +508 -499
  79. megadetector/postprocessing/load_api_results.py +48 -27
  80. megadetector/postprocessing/md_to_coco.py +133 -102
  81. megadetector/postprocessing/md_to_labelme.py +107 -90
  82. megadetector/postprocessing/md_to_wi.py +40 -40
  83. megadetector/postprocessing/merge_detections.py +92 -114
  84. megadetector/postprocessing/postprocess_batch_results.py +319 -301
  85. megadetector/postprocessing/remap_detection_categories.py +91 -38
  86. megadetector/postprocessing/render_detection_confusion_matrix.py +214 -205
  87. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +57 -57
  88. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +27 -28
  89. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +704 -679
  90. megadetector/postprocessing/separate_detections_into_folders.py +226 -211
  91. megadetector/postprocessing/subset_json_detector_output.py +265 -262
  92. megadetector/postprocessing/top_folders_to_bottom.py +45 -45
  93. megadetector/postprocessing/validate_batch_results.py +70 -70
  94. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +52 -52
  95. megadetector/taxonomy_mapping/map_new_lila_datasets.py +18 -19
  96. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +54 -33
  97. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +67 -67
  98. megadetector/taxonomy_mapping/retrieve_sample_image.py +16 -16
  99. megadetector/taxonomy_mapping/simple_image_download.py +8 -8
  100. megadetector/taxonomy_mapping/species_lookup.py +156 -74
  101. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +14 -14
  102. megadetector/taxonomy_mapping/taxonomy_graph.py +10 -10
  103. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +13 -13
  104. megadetector/utils/ct_utils.py +1049 -211
  105. megadetector/utils/directory_listing.py +21 -77
  106. megadetector/utils/gpu_test.py +22 -22
  107. megadetector/utils/md_tests.py +632 -529
  108. megadetector/utils/path_utils.py +1520 -431
  109. megadetector/utils/process_utils.py +41 -41
  110. megadetector/utils/split_locations_into_train_val.py +62 -62
  111. megadetector/utils/string_utils.py +148 -27
  112. megadetector/utils/url_utils.py +489 -176
  113. megadetector/utils/wi_utils.py +2658 -2526
  114. megadetector/utils/write_html_image_list.py +137 -137
  115. megadetector/visualization/plot_utils.py +34 -30
  116. megadetector/visualization/render_images_with_thumbnails.py +39 -74
  117. megadetector/visualization/visualization_utils.py +487 -435
  118. megadetector/visualization/visualize_db.py +232 -198
  119. megadetector/visualization/visualize_detector_output.py +82 -76
  120. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/METADATA +5 -2
  121. megadetector-10.0.0.dist-info/RECORD +139 -0
  122. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/WHEEL +1 -1
  123. megadetector/api/batch_processing/api_core/__init__.py +0 -0
  124. megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
  125. megadetector/api/batch_processing/api_core/batch_service/score.py +0 -439
  126. megadetector/api/batch_processing/api_core/server.py +0 -294
  127. megadetector/api/batch_processing/api_core/server_api_config.py +0 -97
  128. megadetector/api/batch_processing/api_core/server_app_config.py +0 -55
  129. megadetector/api/batch_processing/api_core/server_batch_job_manager.py +0 -220
  130. megadetector/api/batch_processing/api_core/server_job_status_table.py +0 -149
  131. megadetector/api/batch_processing/api_core/server_orchestration.py +0 -360
  132. megadetector/api/batch_processing/api_core/server_utils.py +0 -88
  133. megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
  134. megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
  135. megadetector/api/batch_processing/api_support/__init__.py +0 -0
  136. megadetector/api/batch_processing/api_support/summarize_daily_activity.py +0 -152
  137. megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
  138. megadetector/api/synchronous/__init__.py +0 -0
  139. megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  140. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +0 -151
  141. megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -263
  142. megadetector/api/synchronous/api_core/animal_detection_api/config.py +0 -35
  143. megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
  144. megadetector/api/synchronous/api_core/tests/load_test.py +0 -110
  145. megadetector/data_management/importers/add_nacti_sizes.py +0 -52
  146. megadetector/data_management/importers/add_timestamps_to_icct.py +0 -79
  147. megadetector/data_management/importers/animl_results_to_md_results.py +0 -158
  148. megadetector/data_management/importers/auckland_doc_test_to_json.py +0 -373
  149. megadetector/data_management/importers/auckland_doc_to_json.py +0 -201
  150. megadetector/data_management/importers/awc_to_json.py +0 -191
  151. megadetector/data_management/importers/bellevue_to_json.py +0 -272
  152. megadetector/data_management/importers/cacophony-thermal-importer.py +0 -793
  153. megadetector/data_management/importers/carrizo_shrubfree_2018.py +0 -269
  154. megadetector/data_management/importers/carrizo_trail_cam_2017.py +0 -289
  155. megadetector/data_management/importers/cct_field_adjustments.py +0 -58
  156. megadetector/data_management/importers/channel_islands_to_cct.py +0 -913
  157. megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
  158. megadetector/data_management/importers/eMammal/eMammal_helpers.py +0 -249
  159. megadetector/data_management/importers/eMammal/make_eMammal_json.py +0 -223
  160. megadetector/data_management/importers/ena24_to_json.py +0 -276
  161. megadetector/data_management/importers/filenames_to_json.py +0 -386
  162. megadetector/data_management/importers/helena_to_cct.py +0 -283
  163. megadetector/data_management/importers/idaho-camera-traps.py +0 -1407
  164. megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
  165. megadetector/data_management/importers/import_desert_lion_conservation_camera_traps.py +0 -387
  166. megadetector/data_management/importers/jb_csv_to_json.py +0 -150
  167. megadetector/data_management/importers/mcgill_to_json.py +0 -250
  168. megadetector/data_management/importers/missouri_to_json.py +0 -490
  169. megadetector/data_management/importers/nacti_fieldname_adjustments.py +0 -79
  170. megadetector/data_management/importers/noaa_seals_2019.py +0 -181
  171. megadetector/data_management/importers/osu-small-animals-to-json.py +0 -364
  172. megadetector/data_management/importers/pc_to_json.py +0 -365
  173. megadetector/data_management/importers/plot_wni_giraffes.py +0 -123
  174. megadetector/data_management/importers/prepare_zsl_imerit.py +0 -131
  175. megadetector/data_management/importers/raic_csv_to_md_results.py +0 -416
  176. megadetector/data_management/importers/rspb_to_json.py +0 -356
  177. megadetector/data_management/importers/save_the_elephants_survey_A.py +0 -320
  178. megadetector/data_management/importers/save_the_elephants_survey_B.py +0 -329
  179. megadetector/data_management/importers/snapshot_safari_importer.py +0 -758
  180. megadetector/data_management/importers/snapshot_serengeti_lila.py +0 -1067
  181. megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
  182. megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
  183. megadetector/data_management/importers/sulross_get_exif.py +0 -65
  184. megadetector/data_management/importers/timelapse_csv_set_to_json.py +0 -490
  185. megadetector/data_management/importers/ubc_to_json.py +0 -399
  186. megadetector/data_management/importers/umn_to_json.py +0 -507
  187. megadetector/data_management/importers/wellington_to_json.py +0 -263
  188. megadetector/data_management/importers/wi_to_json.py +0 -442
  189. megadetector/data_management/importers/zamba_results_to_md_results.py +0 -180
  190. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +0 -101
  191. megadetector/data_management/lila/add_locations_to_nacti.py +0 -151
  192. megadetector/utils/azure_utils.py +0 -178
  193. megadetector/utils/sas_blob_utils.py +0 -509
  194. megadetector-5.0.28.dist-info/RECORD +0 -209
  195. /megadetector/{api/batch_processing/__init__.py → __init__.py} +0 -0
  196. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/licenses/LICENSE +0 -0
  197. {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/top_level.txt +0 -0
@@ -2,15 +2,15 @@
2
2
 
3
3
  run_detector.py
4
4
 
5
- Module to run an animal detection model on images. The main function in this script also renders
5
+ Module to run an animal detection model on images. The main function in this script also renders
6
6
  the predicted bounding boxes on images and saves the resulting images (with bounding boxes).
7
7
 
8
8
  **This script is not a good way to process lots of images**. It does not produce a useful
9
9
  output format, and it does not facilitate checkpointing the results so if it crashes you
10
- would have to start from scratch. **If you want to run a detector on lots of images, you should
10
+ would have to start from scratch. **If you want to run a detector on lots of images, you should
11
11
  check out run_detector_batch.py**.
12
12
 
13
- That said, this script (run_detector.py) is a good way to test our detector on a handful of images
13
+ That said, this script (run_detector.py) is a good way to test our detector on a handful of images
14
14
  and get super-satisfying, graphical results.
15
15
 
16
16
  If you would like to *not* use the GPU on the machine, set the environment
@@ -32,6 +32,7 @@ import time
32
32
  import json
33
33
  import warnings
34
34
  import tempfile
35
+ import zipfile
35
36
 
36
37
  import humanfriendly
37
38
  from tqdm import tqdm
@@ -40,6 +41,7 @@ from megadetector.utils import path_utils as path_utils
40
41
  from megadetector.visualization import visualization_utils as vis_utils
41
42
  from megadetector.utils.url_utils import download_url
42
43
  from megadetector.utils.ct_utils import parse_kvp_list
44
+ from megadetector.utils.path_utils import compute_file_hash
43
45
 
44
46
  # ignoring all "PIL cannot read EXIF metainfo for the images" warnings
45
47
  warnings.filterwarnings('ignore', '(Possibly )?corrupt EXIF data', UserWarning)
@@ -81,29 +83,66 @@ USE_MODEL_NATIVE_CLASSES = False
81
83
  #
82
84
  # Order matters here.
83
85
  model_string_to_model_version = {
86
+
87
+ # Specific model versions that might be expressed in a variety of ways
84
88
  'mdv2':'v2.0.0',
85
89
  'mdv3':'v3.0.0',
86
90
  'mdv4':'v4.1.0',
87
- 'mdv5a':'v5a.0.0',
88
- 'mdv5b':'v5b.0.0',
91
+ 'mdv5a':'v5a.0.1',
92
+ 'mdv5b':'v5b.0.1',
93
+
89
94
  'v2':'v2.0.0',
90
95
  'v3':'v3.0.0',
91
96
  'v4':'v4.1.0',
92
97
  'v4.1':'v4.1.0',
93
- 'v5a.0.0':'v5a.0.0',
94
- 'v5b.0.0':'v5b.0.0',
98
+ 'v5a.0.0':'v5a.0.1',
99
+ 'v5b.0.0':'v5b.0.1',
100
+
101
+ 'md1000-redwood':'v1000.0.0-redwood',
102
+ 'md1000-cedar':'v1000.0.0-cedar',
103
+ 'md1000-larch':'v1000.0.0-larch',
104
+ 'md1000-sorrel':'v1000.0.0-sorrel',
105
+ 'md1000-spruce':'v1000.0.0-spruce',
106
+
107
+ 'mdv1000-redwood':'v1000.0.0-redwood',
108
+ 'mdv1000-cedar':'v1000.0.0-cedar',
109
+ 'mdv1000-larch':'v1000.0.0-larch',
110
+ 'mdv1000-sorrel':'v1000.0.0-sorrel',
111
+ 'mdv1000-spruce':'v1000.0.0-spruce',
112
+
113
+ 'v1000-redwood':'v1000.0.0-redwood',
114
+ 'v1000-cedar':'v1000.0.0-cedar',
115
+ 'v1000-larch':'v1000.0.0-larch',
116
+ 'v1000-sorrel':'v1000.0.0-sorrel',
117
+ 'v1000-spruce':'v1000.0.0-spruce',
118
+
119
+ # Arguably less specific model versions
95
120
  'redwood':'v1000.0.0-redwood',
96
121
  'spruce':'v1000.0.0-spruce',
97
122
  'cedar':'v1000.0.0-cedar',
98
123
  'larch':'v1000.0.0-larch',
99
- 'default':'v5a.0.0',
100
- 'default-model':'v5a.0.0',
101
- 'megadetector':'v5a.0.0'
124
+
125
+ # Opinionated defaults
126
+ 'mdv5':'v5a.0.1',
127
+ 'md5':'v5a.0.1',
128
+ 'mdv1000':'v1000.0.0-redwood',
129
+ 'md1000':'v1000.0.0-redwood',
130
+ 'default':'v5a.0.1',
131
+ 'megadetector':'v5a.0.1',
102
132
  }
103
133
 
104
- model_url_base = 'http://localhost:8181/'
134
+ # python -m http.server 8181
135
+ model_url_base = 'https://github.com/agentmorris/MegaDetector/releases/download/v1000.0/'
105
136
  assert model_url_base.endswith('/')
106
137
 
138
+ if os.environ.get('MD_MODEL_URL_BASE') is not None:
139
+ model_url_base = os.environ['MD_MODEL_URL_BASE']
140
+ print('Model URL base provided via environment variable: {}'.format(
141
+ model_url_base
142
+ ))
143
+ if not model_url_base.endswith('/'):
144
+ model_url_base += '/'
145
+
107
146
  # Maps canonical model version numbers to metadata
108
147
  known_models = {
109
148
  'v2.0.0':
@@ -137,7 +176,8 @@ known_models = {
137
176
  'conservative_detection_threshold':0.05,
138
177
  'image_size':1280,
139
178
  'model_type':'yolov5',
140
- 'normalized_typical_inference_speed':1.0
179
+ 'normalized_typical_inference_speed':1.0,
180
+ 'md5':'ec1d7603ec8cf642d6e0cd008ba2be8c'
141
181
  },
142
182
  'v5b.0.0':
143
183
  {
@@ -146,29 +186,58 @@ known_models = {
146
186
  'conservative_detection_threshold':0.05,
147
187
  'image_size':1280,
148
188
  'model_type':'yolov5',
149
- 'normalized_typical_inference_speed':1.0
189
+ 'normalized_typical_inference_speed':1.0,
190
+ 'md5':'bc235e73f53c5c95e66ea0d1b2cbf542'
191
+ },
192
+ 'v5a.0.1':
193
+ {
194
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5a.0.1.pt',
195
+ 'typical_detection_threshold':0.2,
196
+ 'conservative_detection_threshold':0.05,
197
+ 'image_size':1280,
198
+ 'model_type':'yolov5',
199
+ 'normalized_typical_inference_speed':1.0,
200
+ 'md5':'60f8e7ec1308554df258ed1f4040bc4f'
201
+ },
202
+ 'v5b.0.1':
203
+ {
204
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.1.pt',
205
+ 'typical_detection_threshold':0.2,
206
+ 'conservative_detection_threshold':0.05,
207
+ 'image_size':1280,
208
+ 'model_type':'yolov5',
209
+ 'normalized_typical_inference_speed':1.0,
210
+ 'md5':'f17ed6fedfac2e403606a08c89984905'
150
211
  },
151
-
152
- # Fake values for testing
153
212
  'v1000.0.0-redwood':
154
213
  {
155
- 'normalized_typical_inference_speed':2.0,
156
- 'url':model_url_base + 'md_v1000.0.0-redwood.pt'
214
+ 'url':model_url_base + 'md_v1000.0.0-redwood.pt',
215
+ 'normalized_typical_inference_speed':1.0,
216
+ 'md5':'74474b3aec9cf1a990da38b37ddf9197'
157
217
  },
158
218
  'v1000.0.0-spruce':
159
219
  {
160
- 'normalized_typical_inference_speed':3.0,
161
- 'url':model_url_base + 'md_v1000.0.0-spruce.pt'
220
+ 'url':model_url_base + 'md_v1000.0.0-spruce.pt',
221
+ 'normalized_typical_inference_speed':12.7,
222
+ 'md5':'1c9d1d2b3ba54931881471fdd508e6f2'
162
223
  },
163
224
  'v1000.0.0-larch':
164
225
  {
165
- 'normalized_typical_inference_speed':4.0,
166
- 'url':model_url_base + 'md_v1000.0.0-larch.pt'
226
+ 'url':model_url_base + 'md_v1000.0.0-larch.pt',
227
+ 'normalized_typical_inference_speed':2.4,
228
+ 'md5':'cab94ebd190c2278e12fb70ffd548b6d'
167
229
  },
168
230
  'v1000.0.0-cedar':
169
231
  {
170
- 'normalized_typical_inference_speed':5.0,
171
- 'url':model_url_base + 'md_v1000.0.0-cedar.pt'
232
+ 'url':model_url_base + 'md_v1000.0.0-cedar.pt',
233
+ 'normalized_typical_inference_speed':2.0,
234
+ 'md5':'3d6472c9b95ba687b59ebe255f7c576b'
235
+ },
236
+ 'v1000.0.0-sorrel':
237
+ {
238
+ 'url':model_url_base + 'md_v1000.0.0-sorrel.pt',
239
+ 'normalized_typical_inference_speed':7.0,
240
+ 'md5':'4339a2c8af7a381f18ded7ac2a4df03e'
172
241
  }
173
242
  }
174
243
 
@@ -180,7 +249,7 @@ DEFAULT_BOX_EXPANSION = 0
180
249
  DEFAULT_LABEL_FONT_SIZE = 16
181
250
  DETECTION_FILENAME_INSERT = '_detections'
182
251
 
183
- # Approximate inference speeds (in images per second) for MDv5 based on
252
+ # Approximate inference speeds (in images per second) for MDv5 based on
184
253
  # benchmarks, only used for reporting very coarse expectations about inference time.
185
254
  device_token_to_mdv5_inference_speed = {
186
255
  '4090':17.6,
@@ -192,9 +261,9 @@ device_token_to_mdv5_inference_speed = {
192
261
  # is around 3.5x faster than MDv4.
193
262
  'V100':2.79*3.5,
194
263
  '2080':2.3*3.5,
195
- '2060':1.6*3.5
264
+ '2060':1.6*3.5
196
265
  }
197
-
266
+
198
267
 
199
268
  #%% Utility functions
200
269
 
@@ -202,15 +271,15 @@ def get_detector_metadata_from_version_string(detector_version):
202
271
  """
203
272
  Given a MegaDetector version string (e.g. "v4.1.0"), returns the metadata for
204
273
  the model. Used for writing standard defaults to batch output files.
205
-
274
+
206
275
  Args:
207
276
  detector_version (str): a detection version string, e.g. "v4.1.0", which you
208
277
  can extract from a filename using get_detector_version_from_filename()
209
-
278
+
210
279
  Returns:
211
280
  dict: metadata for this model, suitable for writing to a MD output file
212
281
  """
213
-
282
+
214
283
  if detector_version not in known_models:
215
284
  print('Warning: no metadata for unknown detector version {}'.format(detector_version))
216
285
  default_detector_metadata = {
@@ -229,31 +298,32 @@ def get_detector_version_from_filename(detector_filename,
229
298
  accept_first_match=True,
230
299
  verbose=False):
231
300
  r"""
232
- Gets the canonical version number string of a detector from the model filename.
233
-
301
+ Gets the canonical version number string of a detector from the model filename.
302
+
234
303
  [detector_filename] will almost always end with one of the following:
235
-
304
+
236
305
  * megadetector_v2.pb
237
306
  * megadetector_v3.pb
238
- * megadetector_v4.1 (not produed by run_detector_batch.py, only found in output files from the deprecated Azure Batch API)
307
+ * megadetector_v4.1 (not produced by run_detector_batch.py, only found in output files from
308
+ the deprecated Azure Batch API)
239
309
  * md_v4.1.0.pb
240
310
  * md_v5a.0.0.pt
241
311
  * md_v5b.0.0.pt
242
-
243
- This function identifies the version number as "v2.0.0", "v3.0.0", "v4.1.0",
244
- "v4.1.0", "v5a.0.0", and "v5b.0.0", respectively. See known_models for the list
312
+
313
+ This function identifies the version number as "v2.0.0", "v3.0.0", "v4.1.0",
314
+ "v4.1.0", "v5a.0.0", and "v5b.0.0", respectively. See known_models for the list
245
315
  of valid version numbers.
246
-
316
+
247
317
  Args:
248
318
  detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
249
- accept_first_match (bool, optional): if multiple candidates match the filename, choose the
319
+ accept_first_match (bool, optional): if multiple candidates match the filename, choose the
250
320
  first one, otherwise returns the string "multiple"
251
321
  verbose (bool, optional): enable additional debug output
252
-
322
+
253
323
  Returns:
254
324
  str: a detector version string, e.g. "v5a.0.0", or "multiple" if I'm confused
255
325
  """
256
-
326
+
257
327
  fn = os.path.basename(detector_filename).lower()
258
328
  matches = []
259
329
  for s in model_string_to_model_version.keys():
@@ -268,117 +338,119 @@ def get_detector_version_from_filename(detector_filename,
268
338
  if verbose:
269
339
  print('Warning: multiple MegaDetector versions for model file {}:'.format(detector_filename))
270
340
  for s in matches:
271
- print(s)
341
+ print(s)
272
342
  return 'multiple'
273
343
  else:
274
344
  return model_string_to_model_version[matches[0]]
275
-
345
+
276
346
 
277
347
  def get_detector_version_from_model_file(detector_filename,verbose=False):
278
348
  """
279
- Gets the canonical detection version from a model file, preferably by reading it
349
+ Gets the canonical detection version from a model file, preferably by reading it
280
350
  from the file itself, otherwise based on the filename.
281
-
351
+
282
352
  Args:
283
- detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
353
+ detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
284
354
  verbose (bool, optional): enable additional debug output
285
-
355
+
286
356
  Returns:
287
357
  str: a canonical detector version string, e.g. "v5a.0.0", or "unknown"
288
358
  """
289
-
359
+
290
360
  # Try to extract a version string from the filename
291
361
  version_string_based_on_filename = get_detector_version_from_filename(
292
362
  detector_filename, verbose=verbose)
293
363
  if version_string_based_on_filename == 'unknown':
294
364
  version_string_based_on_filename = None
295
-
296
- # Try to extract a version string from the file itself; currently this is only
365
+
366
+ # Try to extract a version string from the file itself; currently this is only
297
367
  # a thing for PyTorch models
298
-
368
+
299
369
  version_string_based_on_model_file = None
300
-
370
+
301
371
  if detector_filename.endswith('.pt') or detector_filename.endswith('.zip'):
302
-
372
+
303
373
  from megadetector.detection.pytorch_detector import \
304
374
  read_metadata_from_megadetector_model_file
305
375
  metadata = read_metadata_from_megadetector_model_file(detector_filename,verbose=verbose)
306
-
376
+
307
377
  if metadata is not None and isinstance(metadata,dict):
308
-
378
+
309
379
  if 'metadata_format_version' not in metadata or \
310
380
  not isinstance(metadata['metadata_format_version'],float):
311
-
381
+
312
382
  print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
313
383
  'but it doesn\'t have a valid format version number')
314
-
384
+
315
385
  elif 'model_version_string' not in metadata or \
316
386
  not isinstance(metadata['model_version_string'],str):
317
-
387
+
318
388
  print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
319
389
  'but it doesn\'t have a format model version string')
320
-
390
+
321
391
  else:
322
-
392
+
323
393
  version_string_based_on_model_file = metadata['model_version_string']
324
-
394
+
325
395
  if version_string_based_on_model_file not in known_models:
326
- print('Warning: unknown model version {} specified in file {}'.format(
327
- version_string_based_on_model_file,detector_filename))
328
-
396
+ print('Warning: unknown model version:\n\n{}\n\n...specified in file:\n\n{}'.format(
397
+ version_string_based_on_model_file,os.path.basename(detector_filename)))
398
+
329
399
  # ...if there's metadata in this file
330
-
400
+
331
401
  # ...if this looks like a PyTorch file
332
-
402
+
333
403
  # If we got versions strings from the filename *and* the model file...
334
404
  if (version_string_based_on_filename is not None) and \
335
405
  (version_string_based_on_model_file is not None):
336
406
 
337
407
  if version_string_based_on_filename != version_string_based_on_model_file:
338
- print('Warning: model version string in file {} is {}, but the filename implies {}'.format(
339
- detector_filename,
408
+ print(
409
+ 'Warning: model version string in file:' + \
410
+ '\n\n{}\n\n...is:\n\n{}\n\n...but the filename implies:\n\n{}'.format(
411
+ os.path.basename(detector_filename),
340
412
  version_string_based_on_model_file,
341
413
  version_string_based_on_filename))
342
-
414
+
343
415
  return version_string_based_on_model_file
344
-
416
+
345
417
  # If we got version string from neither the filename nor the model file...
346
418
  if (version_string_based_on_filename is None) and \
347
419
  (version_string_based_on_model_file is None):
348
-
420
+
349
421
  print('Warning: could not determine model version string for model file {}'.format(
350
422
  detector_filename))
351
423
  return None
352
-
424
+
353
425
  elif version_string_based_on_filename is not None:
354
-
426
+
355
427
  return version_string_based_on_filename
356
-
428
+
357
429
  else:
358
-
430
+
359
431
  assert version_string_based_on_model_file is not None
360
432
  return version_string_based_on_model_file
361
-
433
+
362
434
  # ...def get_detector_version_from_model_file(...)
363
435
 
364
-
436
+
365
437
  def estimate_md_images_per_second(model_file, device_name=None):
366
438
  r"""
367
- Estimates how fast MegaDetector will run on a particular device, based on benchmarks.
368
- Defaults to querying the current device. Returns None if no data is available for the current
369
- card/model. Estimates only available for a small handful of GPUs. Uses an absurdly simple
439
+ Estimates how fast MegaDetector will run on a particular device, based on benchmarks.
440
+ Defaults to querying the current device. Returns None if no data is available for the current
441
+ card/model. Estimates only available for a small handful of GPUs. Uses an absurdly simple
370
442
  lookup approach, e.g. if the string "4090" appears in the device name, congratulations,
371
443
  you have an RTX 4090.
372
-
444
+
373
445
  Args:
374
446
  model_file (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
375
447
  device_name (str, optional): device name, e.g. blah-blah-4090-blah-blah
376
-
448
+
377
449
  Returns:
378
450
  float: the approximate number of images this model version can process on this
379
451
  device per second
380
452
  """
381
-
453
+
382
454
  if device_name is None:
383
455
  try:
384
456
  import torch
@@ -386,51 +458,51 @@ def estimate_md_images_per_second(model_file, device_name=None):
386
458
  except Exception as e:
387
459
  print('Error querying device name: {}'.format(e))
388
460
  return None
389
-
461
+
390
462
  # About how fast is this model compared to MDv5?
391
463
  model_version = get_detector_version_from_model_file(model_file)
392
-
464
+
393
465
  if model_version not in known_models.keys():
394
466
  print('Could not estimate inference speed: error determining model version for model file {}'.format(
395
467
  model_file))
396
468
  return None
397
-
469
+
398
470
  model_info = known_models[model_version]
399
-
471
+
400
472
  if 'normalized_typical_inference_speed' not in model_info or \
401
473
  model_info['normalized_typical_inference_speed'] is None:
402
474
  print('No speed ratio available for model type {}'.format(model_version))
403
475
  return None
404
-
476
+
405
477
  normalized_inference_speed = model_info['normalized_typical_inference_speed']
406
-
478
+
407
479
  # About how fast would MDv5 run on this device?
408
480
  mdv5_inference_speed = None
409
481
  for device_token in device_token_to_mdv5_inference_speed.keys():
410
482
  if device_token in device_name:
411
483
  mdv5_inference_speed = device_token_to_mdv5_inference_speed[device_token]
412
484
  break
413
-
485
+
414
486
  if mdv5_inference_speed is None:
415
487
  print('No baseline speed estimate available for device {}'.format(device_name))
416
488
  return None
417
-
489
+
418
490
  return normalized_inference_speed * mdv5_inference_speed
419
-
420
-
491
+
492
+
421
493
  def get_typical_confidence_threshold_from_results(results):
422
494
  """
423
495
  Given the .json data loaded from a MD results file, returns a typical confidence
424
496
  threshold based on the detector version.
425
-
497
+
426
498
  Args:
427
- results (dict or str): a dict of MD results, as it would be loaded from a MD results .json
499
+ results (dict or str): a dict of MD results, as it would be loaded from a MD results .json
428
500
  file, or a .json filename
429
-
501
+
430
502
  Returns:
431
503
  float: a sensible default threshold for this model
432
504
  """
433
-
505
+
434
506
  # Load results if necessary
435
507
  if isinstance(results,str):
436
508
  with open(results,'r') as f:
@@ -450,31 +522,31 @@ def get_typical_confidence_threshold_from_results(results):
450
522
  detector_metadata = get_detector_metadata_from_version_string(detector_version)
451
523
  default_threshold = detector_metadata['typical_detection_threshold']
452
524
 
453
- return default_threshold
525
+ return default_threshold
526
+
454
527
 
455
-
456
528
  def is_gpu_available(model_file):
457
529
  r"""
458
530
  Determines whether a GPU is available, importing PyTorch or TF depending on the extension
459
- of model_file. Does not actually load model_file, just uses that to determine how to check
531
+ of model_file. Does not actually load model_file, just uses that to determine how to check
460
532
  for GPU availability (PT vs. TF).
461
-
533
+
462
534
  Args:
463
535
  model_file (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
464
-
536
+
465
537
  Returns:
466
538
  bool: whether a GPU is available
467
539
  """
468
-
540
+
469
541
  if model_file.endswith('.pb'):
470
542
  import tensorflow.compat.v1 as tf
471
543
  gpu_available = tf.test.is_gpu_available()
472
544
  print('TensorFlow version:', tf.__version__)
473
- print('tf.test.is_gpu_available:', gpu_available)
545
+ print('tf.test.is_gpu_available:', gpu_available)
474
546
  return gpu_available
475
547
  if not model_file.endswith('.pt'):
476
548
  print('Warning: could not determine environment from model file name, assuming PyTorch')
477
-
549
+
478
550
  import torch
479
551
  gpu_available = torch.cuda.is_available()
480
552
  print('PyTorch reports {} available CUDA devices'.format(torch.cuda.device_count()))
@@ -487,16 +559,16 @@ def is_gpu_available(model_file):
487
559
  except AttributeError:
488
560
  pass
489
561
  return gpu_available
490
-
491
562
 
492
- def load_detector(model_file,
493
- force_cpu=False,
494
- force_model_download=False,
563
+
564
+ def load_detector(model_file,
565
+ force_cpu=False,
566
+ force_model_download=False,
495
567
  detector_options=None,
496
568
  verbose=False):
497
569
  r"""
498
570
  Loads a TF or PT detector, depending on the extension of model_file.
499
-
571
+
500
572
  Args:
501
573
  model_file (str): model filename (e.g. c:/x/z/md_v5a.0.0.pt) or known model
502
574
  name (e.g. "MDV5A")
@@ -505,21 +577,21 @@ def load_detector(model_file,
505
577
  force_model_download (bool, optional): force downloading the model file if
506
578
  a named model (e.g. "MDV5A") is supplied, even if the local file already
507
579
  exists
508
- detector_options (dict, optional): key/value pairs that are interpreted differently
580
+ detector_options (dict, optional): key/value pairs that are interpreted differently
509
581
  by different detectors
510
582
  verbose (bool, optional): enable additional debug output
511
-
583
+
512
584
  Returns:
513
585
  object: loaded detector object
514
586
  """
515
-
587
+
516
588
  # Possibly automatically download the model
517
- model_file = try_download_known_detector(model_file,
589
+ model_file = try_download_known_detector(model_file,
518
590
  force_download=force_model_download)
519
-
591
+
520
592
  if verbose:
521
593
  print('GPU available: {}'.format(is_gpu_available(model_file)))
522
-
594
+
523
595
  start_time = time.time()
524
596
 
525
597
  if model_file.endswith('.pb'):
@@ -531,9 +603,9 @@ def load_detector(model_file,
531
603
  detector = TFDetector(model_file, detector_options)
532
604
 
533
605
  elif model_file.endswith('.pt'):
534
-
606
+
535
607
  from megadetector.detection.pytorch_detector import PTDetector
536
-
608
+
537
609
  # Prepare options specific to the PTDetector class
538
610
  if detector_options is None:
539
611
  detector_options = {}
@@ -545,16 +617,16 @@ def load_detector(model_file,
545
617
  detector_options['force_cpu'] = force_cpu
546
618
  detector_options['use_model_native_classes'] = USE_MODEL_NATIVE_CLASSES
547
619
  detector = PTDetector(model_file, detector_options, verbose=verbose)
548
-
620
+
549
621
  else:
550
-
622
+
551
623
  raise ValueError('Unrecognized model format: {}'.format(model_file))
552
-
624
+
553
625
  elapsed = time.time() - start_time
554
-
626
+
555
627
  if verbose:
556
628
  print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
557
-
629
+
558
630
  return detector
559
631
 
560
632
  # ...def load_detector(...)
@@ -562,21 +634,22 @@ def load_detector(model_file,
562
634
 
563
635
  #%% Main function
564
636
 
565
- def load_and_run_detector(model_file,
637
+ def load_and_run_detector(model_file,
566
638
  image_file_names,
567
639
  output_dir,
568
640
  render_confidence_threshold=DEFAULT_RENDERING_CONFIDENCE_THRESHOLD,
569
- crop_images=False,
570
- box_thickness=DEFAULT_BOX_THICKNESS,
641
+ crop_images=False,
642
+ box_thickness=DEFAULT_BOX_THICKNESS,
571
643
  box_expansion=DEFAULT_BOX_EXPANSION,
572
644
  image_size=None,
573
645
  label_font_size=DEFAULT_LABEL_FONT_SIZE,
574
646
  augment=False,
575
647
  force_model_download=False,
576
- detector_options=None):
648
+ detector_options=None,
649
+ verbose=False):
577
650
  r"""
578
651
  Loads and runs a detector on target images, and visualizes the results.
579
-
652
+
580
653
  Args:
581
654
  model_file (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt, or a known model
582
655
  string, e.g. "MDV5A"
@@ -592,23 +665,28 @@ def load_and_run_detector(model_file,
592
665
  if (a) you're using a model other than MegaDetector or (b) you know what you're
593
666
  doing
594
667
  label_font_size (float, optional): font size to use for displaying class names
595
- and confidence values in the rendered images
668
+ and confidence values in the rendered images
596
669
  augment (bool, optional): enable (implementation-specific) image augmentation
597
670
  force_model_download (bool, optional): force downloading the model file if
598
671
  a named model (e.g. "MDV5A") is supplied, even if the local file already
599
672
  exists
600
- detector_options (dict, optional): key/value pairs that are interpreted differently
673
+ detector_options (dict, optional): key/value pairs that are interpreted differently
601
674
  by different detectors
675
+ verbose (bool, optional): enable additional debug output
602
676
  """
603
-
677
+
604
678
  if len(image_file_names) == 0:
605
679
  print('Warning: no files available')
606
680
  return
607
681
 
608
682
  # Possibly automatically download the model
609
- model_file = try_download_known_detector(model_file, force_download=force_model_download)
683
+ model_file = try_download_known_detector(model_file,
684
+ force_download=force_model_download,
685
+ verbose=verbose)
610
686
 
611
- detector = load_detector(model_file, detector_options=detector_options)
687
+ detector = load_detector(model_file,
688
+ detector_options=detector_options,
689
+ verbose=verbose)
612
690
 
613
691
  detection_results = []
614
692
  time_load = []
@@ -649,7 +727,7 @@ def load_and_run_detector(model_file,
649
727
 
650
728
  Returns: output file path
651
729
  """
652
-
730
+
653
731
  fn = os.path.basename(fn).lower()
654
732
  name, ext = os.path.splitext(fn)
655
733
  if crop_index >= 0:
@@ -665,7 +743,7 @@ def load_and_run_detector(model_file,
665
743
  return fn
666
744
 
667
745
  # ...def input_file_to_detection_file()
668
-
746
+
669
747
  for im_file in tqdm(image_file_names):
670
748
 
671
749
  try:
@@ -689,7 +767,7 @@ def load_and_run_detector(model_file,
689
767
  start_time = time.time()
690
768
 
691
769
  result = detector.generate_detections_one_image(
692
- image,
770
+ image,
693
771
  im_file,
694
772
  detection_threshold=DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
695
773
  image_size=image_size,
@@ -700,7 +778,8 @@ def load_and_run_detector(model_file,
700
778
  time_infer.append(elapsed)
701
779
 
702
780
  except Exception as e:
703
- print('An error occurred while running the detector on image {}: {}'.format(im_file, str(e)))
781
+ print('An error occurred while running the detector on image {}: {}'.format(
782
+ im_file, str(e)))
704
783
  continue
705
784
 
706
785
  try:
@@ -749,24 +828,73 @@ def load_and_run_detector(model_file,
749
828
  # ...def load_and_run_detector()
750
829
 
751
830
 
831
+ def _validate_zip_file(file_path, file_description='file'):
832
+ """
833
+ Validates that a .pt file is a valid zip file.
834
+
835
+ Args:
836
+ file_path (str): path to the file to validate
837
+ file_description (str): descriptive string for error messages
838
+
839
+ Returns:
840
+ bool: True if valid, False otherwise
841
+ """
842
+ try:
843
+ with zipfile.ZipFile(file_path, 'r') as zipf:
844
+ zipf.testzip()
845
+ return True
846
+ except (zipfile.BadZipFile, zipfile.LargeZipFile) as e:
847
+ print('{} {} appears to be corrupted (bad zip): {}'.format(
848
+ file_description.capitalize(), file_path, str(e)))
849
+ return False
850
+ except Exception as e:
851
+ print('Error validating {}: {}'.format(file_description, str(e)))
852
+ return False
853
+
854
+
855
+ def _validate_md5_hash(file_path, expected_hash, file_description='file'):
856
+ """
857
+ Validates that a file has the expected MD5 hash.
858
+
859
+ Args:
860
+ file_path (str): path to the file to validate
861
+ expected_hash (str): expected MD5 hash
862
+ file_description (str): descriptive string for error messages
863
+
864
+ Returns:
865
+ bool: True if hash matches, False otherwise
866
+ """
867
+ try:
868
+ actual_hash = compute_file_hash(file_path, algorithm='md5').lower()
869
+ expected_hash = expected_hash.lower()
870
+ if actual_hash != expected_hash:
871
+ print('{} {} has incorrect hash. Expected: {}, Actual: {}'.format(
872
+ file_description.capitalize(), file_path, expected_hash, actual_hash))
873
+ return False
874
+ return True
875
+ except Exception as e:
876
+ print('Error computing hash for {}: {}'.format(file_description, str(e)))
877
+ return False
878
+
879
+
752
880
  def _download_model(model_name,force_download=False):
753
881
  """
754
882
  Downloads one of the known models to local temp space if it hasn't already been downloaded.
755
-
883
+
756
884
  Args:
757
885
  model_name (str): a known model string, e.g. "MDV5A". Returns None if this string is not
758
886
  a known model name.
759
- force_download (bool, optional): whether to download the model even if the local target
887
+ force_download (bool, optional): whether to download the model even if the local target
760
888
  file already exists
761
889
  """
762
-
763
- model_tempdir = os.path.join(tempfile.gettempdir(), 'megadetector_models')
890
+
891
+ model_tempdir = os.path.join(tempfile.gettempdir(), 'megadetector_models')
764
892
  os.makedirs(model_tempdir,exist_ok=True)
765
-
893
+
766
894
  # This is a lazy fix to an issue... if multiple users run this script, the
767
895
  # "megadetector_models" folder is owned by the first person who creates it, and others
768
896
  # can't write to it. I could create uniquely-named folders, but I philosophically prefer
769
- # to put all the individual UUID-named folders within a larger folder, so as to be a
897
+ # to put all the individual UUID-named folders within a larger folder, so as to be a
770
898
  # good tempdir citizen. So, the lazy fix is to make this world-writable.
771
899
  try:
772
900
  os.chmod(model_tempdir,0o777)
@@ -775,46 +903,125 @@ def _download_model(model_name,force_download=False):
775
903
  if model_name.lower() not in known_models:
776
904
  print('Unrecognized downloadable model {}'.format(model_name))
777
905
  return None
778
- url = known_models[model_name.lower()]['url']
906
+
907
+ model_info = known_models[model_name.lower()]
908
+ url = model_info['url']
779
909
  destination_filename = os.path.join(model_tempdir,url.split('/')[-1])
780
- local_file = download_url(url, destination_filename=destination_filename, progress_updater=None,
781
- force_download=force_download, verbose=True)
910
+
911
+ # Check whether the file already exists, in which case we want to validate it
912
+ if os.path.exists(destination_filename) and not force_download:
913
+
914
+ # Only validate .pt files, not .pb files
915
+ if destination_filename.endswith('.pt'):
916
+
917
+ is_valid = True
918
+
919
+ # Check whether the file is a valid zip file (.pt files are zip files in disguise)
920
+ if not _validate_zip_file(destination_filename,
921
+ 'existing model file'):
922
+ is_valid = False
923
+
924
+ # Check MD5 hash if available
925
+ if is_valid and \
926
+ ('md5' in model_info) and \
927
+ (model_info['md5'] is not None) and \
928
+ (len(model_info['md5'].strip()) > 0):
929
+
930
+ if not _validate_md5_hash(destination_filename, model_info['md5'],
931
+ 'existing model file'):
932
+ is_valid = False
933
+
934
+ # If validation failed, delete the corrupted file and re-download
935
+ if not is_valid:
936
+ print('Deleting corrupted model file and re-downloading: {}'.format(
937
+ destination_filename))
938
+ try:
939
+ os.remove(destination_filename)
940
+ # This should be a no-op at this point, but it can't hurt
941
+ force_download = True
942
+ except Exception as e:
943
+ print('Warning: failed to delete corrupted file {}: {}'.format(
944
+ destination_filename, str(e)))
945
+ # Continue with download attempt anyway, setting force_download to True
946
+ force_download = True
947
+ else:
948
+ print('Model {} already exists and is valid at {}'.format(
949
+ model_name, destination_filename))
950
+ return destination_filename
951
+
952
+ # Download the model
953
+ try:
954
+ local_file = download_url(url,
955
+ destination_filename=destination_filename,
956
+ progress_updater=None,
957
+ force_download=force_download,
958
+ verbose=True)
959
+ except Exception as e:
960
+ print('Error downloading model {} from {}: {}'.format(model_name, url, str(e)))
961
+ raise
962
+
963
+ # Validate the downloaded file if it's a .pt file
964
+ if local_file and local_file.endswith('.pt'):
965
+
966
+ # Check if the downloaded file is a valid zip file
967
+ if not _validate_zip_file(local_file, "downloaded model file"):
968
+ # Clean up the corrupted download
969
+ try:
970
+ os.remove(local_file)
971
+ except Exception:
972
+ pass
973
+ return None
974
+
975
+ # Check MD5 hash if available
976
+ if ('md5' in model_info) and \
977
+ (model_info['md5'] is not None) and \
978
+ (len(model_info['md5'].strip()) > 0):
979
+
980
+ if not _validate_md5_hash(local_file, model_info['md5'], "downloaded model file"):
981
+ # Clean up the corrupted download
982
+ try:
983
+ os.remove(local_file)
984
+ except Exception:
985
+ pass
986
+ return None
987
+
782
988
  print('Model {} available at {}'.format(model_name,local_file))
783
989
  return local_file
784
990
 
991
+ # ...def _download_model(...)
785
992
 
786
993
  def try_download_known_detector(detector_file,force_download=False,verbose=False):
787
994
  """
788
995
  Checks whether detector_file is really the name of a known model, in which case we will
789
996
  either read the actual filename from the corresponding environment variable or download
790
997
  (if necessary) to local temp space. Otherwise just returns the input string.
791
-
998
+
792
999
  Args:
793
1000
  detector_file (str): a known model string (e.g. "MDV5A"), or any other string (in which
794
1001
  case this function is a no-op)
795
- force_download (bool, optional): whether to download the model even if the local target
1002
+ force_download (bool, optional): whether to download the model even if the local target
796
1003
  file already exists
797
1004
  verbose (bool, optional): enable additional debug output
798
-
1005
+
799
1006
  Returns:
800
1007
  str: the local filename to which the model was downloaded, or the same string that
801
1008
  was passed in, if it's not recognized as a well-known model name
802
1009
  """
803
-
1010
+
804
1011
  model_string = detector_file.lower()
805
-
806
- # If this is a short model string (e.g. "MDV5A"), convert to a canonical version
1012
+
1013
+ # If this is a short model string (e.g. "MDV5A"), convert to a canonical version
807
1014
  # string (e.g. "v5a.0.0")
808
1015
  if model_string in model_string_to_model_version:
809
-
1016
+
810
1017
  if verbose:
811
1018
  print('Converting short string {} to canonical version string {}'.format(
812
1019
  model_string,
813
1020
  model_string_to_model_version[model_string]))
814
1021
  model_string = model_string_to_model_version[model_string]
815
-
1022
+
816
1023
  if model_string in known_models:
817
-
1024
+
818
1025
  if detector_file in os.environ:
819
1026
  fn = os.environ[detector_file]
820
1027
  print('Reading MD location from environment variable {}: {}'.format(
@@ -822,25 +1029,25 @@ def try_download_known_detector(detector_file,force_download=False,verbose=False
822
1029
  detector_file = fn
823
1030
  else:
824
1031
  detector_file = _download_model(model_string,force_download=force_download)
825
-
1032
+
826
1033
  return detector_file
827
-
828
-
829
-
1034
+
1035
+
1036
+
830
1037
 
831
1038
  #%% Command-line driver
832
1039
 
833
- def main():
1040
+ def main(): # noqa
834
1041
 
835
1042
  parser = argparse.ArgumentParser(
836
1043
  description='Module to run an animal detection model on images')
837
-
1044
+
838
1045
  parser.add_argument(
839
1046
  'detector_file',
840
1047
  help='Path detector model file (.pb or .pt). Can also be MDV4, MDV5A, or MDV5B to request automatic download.')
841
-
1048
+
842
1049
  # Must specify either an image file or a directory
843
- group = parser.add_mutually_exclusive_group(required=True)
1050
+ group = parser.add_mutually_exclusive_group(required=True)
844
1051
  group.add_argument(
845
1052
  '--image_file',
846
1053
  type=str,
@@ -851,98 +1058,103 @@ def main():
851
1058
  type=str,
852
1059
  default=None,
853
1060
  help='Directory to search for images, with optional recursion by adding --recursive')
854
-
1061
+
855
1062
  parser.add_argument(
856
1063
  '--recursive',
857
1064
  action='store_true',
858
1065
  help='Recurse into directories, only meaningful if using --image_dir')
859
-
1066
+
860
1067
  parser.add_argument(
861
1068
  '--output_dir',
862
1069
  type=str,
863
1070
  default=None,
864
1071
  help='Directory for output images (defaults to same as input)')
865
-
1072
+
866
1073
  parser.add_argument(
867
1074
  '--image_size',
868
1075
  type=int,
869
1076
  default=None,
870
1077
  help=('Force image resizing to a (square) integer size (not recommended to change this)'))
871
-
1078
+
872
1079
  parser.add_argument(
873
1080
  '--threshold',
874
1081
  type=float,
875
1082
  default=DEFAULT_RENDERING_CONFIDENCE_THRESHOLD,
876
- help=('Confidence threshold between 0 and 1.0; only render' +
1083
+ help=('Confidence threshold between 0 and 1.0; only render' +
877
1084
  ' boxes above this confidence (defaults to {})'.format(
878
1085
  DEFAULT_RENDERING_CONFIDENCE_THRESHOLD)))
879
-
1086
+
880
1087
  parser.add_argument(
881
1088
  '--crop',
882
1089
  default=False,
883
1090
  action='store_true',
884
1091
  help=('If set, produces separate output images for each crop, '
885
1092
  'rather than adding bounding boxes to the original image'))
886
-
1093
+
887
1094
  parser.add_argument(
888
1095
  '--augment',
889
1096
  default=False,
890
1097
  action='store_true',
891
1098
  help=('Enable image augmentation'))
892
-
1099
+
893
1100
  parser.add_argument(
894
1101
  '--box_thickness',
895
1102
  type=int,
896
1103
  default=DEFAULT_BOX_THICKNESS,
897
1104
  help=('Line width (in pixels) for box rendering (defaults to {})'.format(
898
1105
  DEFAULT_BOX_THICKNESS)))
899
-
1106
+
900
1107
  parser.add_argument(
901
1108
  '--box_expansion',
902
1109
  type=int,
903
1110
  default=DEFAULT_BOX_EXPANSION,
904
1111
  help=('Number of pixels to expand boxes by (defaults to {})'.format(
905
1112
  DEFAULT_BOX_EXPANSION)))
906
-
1113
+
907
1114
  parser.add_argument(
908
1115
  '--label_font_size',
909
1116
  type=int,
910
1117
  default=DEFAULT_LABEL_FONT_SIZE,
911
1118
  help=('Label font size (defaults to {})'.format(
912
1119
  DEFAULT_LABEL_FONT_SIZE)))
913
-
1120
+
914
1121
  parser.add_argument(
915
1122
  '--process_likely_output_images',
916
1123
  action='store_true',
917
1124
  help=('By default, we skip images that end in {}, because they probably came from this script. '\
918
1125
  .format(DETECTION_FILENAME_INSERT) + \
919
1126
  'This option disables that behavior.'))
920
-
1127
+
921
1128
  parser.add_argument(
922
1129
  '--force_model_download',
923
1130
  action='store_true',
924
1131
  help=('If a named model (e.g. "MDV5A") is supplied, force a download of that model even if the ' +\
925
1132
  'local file already exists.'))
926
1133
 
1134
+ parser.add_argument(
1135
+ '--verbose',
1136
+ action='store_true',
1137
+ help=('Enable additional debug output'))
1138
+
927
1139
  parser.add_argument(
928
1140
  '--detector_options',
929
1141
  nargs='*',
930
1142
  metavar='KEY=VALUE',
931
1143
  default='',
932
1144
  help='Detector-specific options, as a space-separated list of key-value pairs')
933
-
1145
+
934
1146
  if len(sys.argv[1:]) == 0:
935
1147
  parser.print_help()
936
1148
  parser.exit()
937
1149
 
938
1150
  args = parser.parse_args()
939
1151
  detector_options = parse_kvp_list(args.detector_options)
940
-
941
- # If the specified detector file is really the name of a known model, find
1152
+
1153
+ # If the specified detector file is really the name of a known model, find
942
1154
  # (and possibly download) that model
943
1155
  args.detector_file = try_download_known_detector(args.detector_file,
944
1156
  force_download=args.force_model_download)
945
-
1157
+
946
1158
  assert os.path.exists(args.detector_file), 'detector file {} does not exist'.format(
947
1159
  args.detector_file)
948
1160
  assert 0.0 < args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1'
@@ -961,7 +1173,7 @@ def main():
961
1173
  else:
962
1174
  image_file_names_valid.append(fn)
963
1175
  image_file_names = image_file_names_valid
964
-
1176
+
965
1177
  print('Running detector on {} images...'.format(len(image_file_names)))
966
1178
 
967
1179
  if args.output_dir:
@@ -972,20 +1184,21 @@ def main():
972
1184
  else:
973
1185
  # but for a single image, args.image_dir is also None
974
1186
  args.output_dir = os.path.dirname(args.image_file)
975
-
1187
+
976
1188
  load_and_run_detector(model_file=args.detector_file,
977
1189
  image_file_names=image_file_names,
978
1190
  output_dir=args.output_dir,
979
1191
  render_confidence_threshold=args.threshold,
980
1192
  box_thickness=args.box_thickness,
981
- box_expansion=args.box_expansion,
1193
+ box_expansion=args.box_expansion,
982
1194
  crop_images=args.crop,
983
1195
  image_size=args.image_size,
984
1196
  label_font_size=args.label_font_size,
985
1197
  augment=args.augment,
986
1198
  # If --force_model_download was specified, we already handled it
987
1199
  force_model_download=False,
988
- detector_options=detector_options)
1200
+ detector_options=detector_options,
1201
+ verbose=args.verbose)
989
1202
 
990
1203
  if __name__ == '__main__':
991
1204
  main()
@@ -998,19 +1211,19 @@ if False:
998
1211
  pass
999
1212
 
1000
1213
  #%% Test model download
1001
-
1214
+
1002
1215
  r"""
1003
1216
  cd i:\models\all_models_in_the_wild
1004
1217
  i:
1005
1218
  python -m http.server 8181
1006
1219
  """
1007
-
1220
+
1008
1221
  model_name = 'redwood'
1009
1222
  try_download_known_detector(model_name,force_download=True,verbose=True)
1010
-
1223
+
1011
1224
 
1012
1225
  #%% Load and run detector
1013
-
1226
+
1014
1227
  model_file = r'c:\temp\models\md_v4.1.0.pb'
1015
1228
  image_file_names = path_utils.find_images(r'c:\temp\demo_images\ssverymini')
1016
1229
  output_dir = r'c:\temp\demo_images\ssverymini'