megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1267 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
run_detector.py
|
|
4
|
+
|
|
5
|
+
Module to run an animal detection model on images. The main function in this script also renders
|
|
6
|
+
the predicted bounding boxes on images and saves the resulting images (with bounding boxes).
|
|
7
|
+
|
|
8
|
+
**This script is not a good way to process lots of images**. It does not produce a useful
|
|
9
|
+
output format, and it does not facilitate checkpointing the results so if it crashes you
|
|
10
|
+
would have to start from scratch. **If you want to run a detector on lots of images, you should
|
|
11
|
+
check out run_detector_batch.py**.
|
|
12
|
+
|
|
13
|
+
That said, this script (run_detector.py) is a good way to test our detector on a handful of images
|
|
14
|
+
and get super-satisfying, graphical results.
|
|
15
|
+
|
|
16
|
+
If you would like to *not* use the GPU, set the environment variable CUDA_VISIBLE_DEVICES to "-1".
|
|
17
|
+
|
|
18
|
+
This script will only consider detections with > 0.005 confidence at all times.
|
|
19
|
+
The threshold you provide is only for rendering the results. If you need to
|
|
20
|
+
see lower-confidence detections, you can change DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
#%% Constants, imports, environment
|
|
25
|
+
|
|
26
|
+
import argparse
|
|
27
|
+
import os
|
|
28
|
+
import statistics
|
|
29
|
+
import sys
|
|
30
|
+
import time
|
|
31
|
+
import json
|
|
32
|
+
import warnings
|
|
33
|
+
import tempfile
|
|
34
|
+
import zipfile
|
|
35
|
+
|
|
36
|
+
import humanfriendly
|
|
37
|
+
from tqdm import tqdm
|
|
38
|
+
|
|
39
|
+
from megadetector.utils import path_utils as path_utils
|
|
40
|
+
from megadetector.visualization import visualization_utils as vis_utils
|
|
41
|
+
from megadetector.utils.url_utils import download_url
|
|
42
|
+
from megadetector.utils.ct_utils import parse_kvp_list
|
|
43
|
+
from megadetector.utils.path_utils import compute_file_hash
|
|
44
|
+
|
|
45
|
+
# ignoring all "PIL cannot read EXIF metainfo for the images" warnings
|
|
46
|
+
warnings.filterwarnings('ignore', '(Possibly )?corrupt EXIF data', UserWarning)
|
|
47
|
+
|
|
48
|
+
# Metadata Warning, tag 256 had too many entries: 42, expected 1
|
|
49
|
+
warnings.filterwarnings('ignore', 'Metadata warning', UserWarning)
|
|
50
|
+
|
|
51
|
+
# Numpy FutureWarnings from tensorflow import
|
|
52
|
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
|
53
|
+
|
|
54
|
+
# String constants used for consistent reporting of processing errors
|
|
55
|
+
FAILURE_INFER = 'inference failure'
|
|
56
|
+
FAILURE_IMAGE_OPEN = 'image access failure'
|
|
57
|
+
|
|
58
|
+
# Number of decimal places to round to for confidence and bbox coordinates
|
|
59
|
+
CONF_DIGITS = 3
|
|
60
|
+
COORD_DIGITS = 4
|
|
61
|
+
|
|
62
|
+
# Label mapping for MegaDetector
|
|
63
|
+
DEFAULT_DETECTOR_LABEL_MAP = {
|
|
64
|
+
'1': 'animal',
|
|
65
|
+
'2': 'person',
|
|
66
|
+
'3': 'vehicle' # available in megadetector v4+
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
# Should we allow classes that don't look anything like the MegaDetector classes?
|
|
70
|
+
#
|
|
71
|
+
# This flag needs to get set if you want to, for example, run an off-the-shelf
|
|
72
|
+
# YOLO model with this package.
|
|
73
|
+
#
|
|
74
|
+
# By default, we error if we see unfamiliar classes.
|
|
75
|
+
#
|
|
76
|
+
# TODO: the use of a global variable to manage this was fine when this was really
|
|
77
|
+
# experimental, but this is really sloppy now that we actually use this code for
|
|
78
|
+
# models other than MegaDetector.
|
|
79
|
+
USE_MODEL_NATIVE_CLASSES = False
|
|
80
|
+
|
|
81
|
+
# Detection threshold to recommend to callers when all other mechanisms for choosing
|
|
82
|
+
# a model-specific threshold fail
|
|
83
|
+
fallback_detection_threshold = 0.2
|
|
84
|
+
|
|
85
|
+
# Maps a variety of strings that might occur in filenames to canonical version numbers.
|
|
86
|
+
#
|
|
87
|
+
# Order matters here.
|
|
88
|
+
model_string_to_model_version = {
|
|
89
|
+
|
|
90
|
+
# Specific model versions that might be expressed in a variety of ways
|
|
91
|
+
'mdv2':'v2.0.0',
|
|
92
|
+
'mdv3':'v3.0.0',
|
|
93
|
+
'mdv4':'v4.1.0',
|
|
94
|
+
'mdv5a':'v5a.0.1',
|
|
95
|
+
'mdv5b':'v5b.0.1',
|
|
96
|
+
|
|
97
|
+
'v2':'v2.0.0',
|
|
98
|
+
'v3':'v3.0.0',
|
|
99
|
+
'v4':'v4.1.0',
|
|
100
|
+
'v4.1':'v4.1.0',
|
|
101
|
+
'v5a.0.0':'v5a.0.1',
|
|
102
|
+
'v5b.0.0':'v5b.0.1',
|
|
103
|
+
'v5a.0.1':'v5a.0.1',
|
|
104
|
+
'v5b.0.1':'v5b.0.1',
|
|
105
|
+
|
|
106
|
+
'md1000-redwood':'v1000.0.0-redwood',
|
|
107
|
+
'md1000-cedar':'v1000.0.0-cedar',
|
|
108
|
+
'md1000-larch':'v1000.0.0-larch',
|
|
109
|
+
'md1000-sorrel':'v1000.0.0-sorrel',
|
|
110
|
+
'md1000-spruce':'v1000.0.0-spruce',
|
|
111
|
+
|
|
112
|
+
'mdv1000-redwood':'v1000.0.0-redwood',
|
|
113
|
+
'mdv1000-cedar':'v1000.0.0-cedar',
|
|
114
|
+
'mdv1000-larch':'v1000.0.0-larch',
|
|
115
|
+
'mdv1000-sorrel':'v1000.0.0-sorrel',
|
|
116
|
+
'mdv1000-spruce':'v1000.0.0-spruce',
|
|
117
|
+
|
|
118
|
+
'v1000-redwood':'v1000.0.0-redwood',
|
|
119
|
+
'v1000-cedar':'v1000.0.0-cedar',
|
|
120
|
+
'v1000-larch':'v1000.0.0-larch',
|
|
121
|
+
'v1000-sorrel':'v1000.0.0-sorrel',
|
|
122
|
+
'v1000-spruce':'v1000.0.0-spruce',
|
|
123
|
+
|
|
124
|
+
# Arguably less specific model versions
|
|
125
|
+
'redwood':'v1000.0.0-redwood',
|
|
126
|
+
'spruce':'v1000.0.0-spruce',
|
|
127
|
+
'cedar':'v1000.0.0-cedar',
|
|
128
|
+
'larch':'v1000.0.0-larch',
|
|
129
|
+
|
|
130
|
+
# Opinionated defaults
|
|
131
|
+
'mdv5':'v5a.0.1',
|
|
132
|
+
'md5':'v5a.0.1',
|
|
133
|
+
'mdv1000':'v1000.0.0-redwood',
|
|
134
|
+
'md1000':'v1000.0.0-redwood',
|
|
135
|
+
'default':'v5a.0.1',
|
|
136
|
+
'megadetector':'v5a.0.1',
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
# python -m http.server 8181
|
|
140
|
+
model_url_base = 'https://github.com/agentmorris/MegaDetector/releases/download/v1000.0/'
|
|
141
|
+
assert model_url_base.endswith('/')
|
|
142
|
+
|
|
143
|
+
if os.environ.get('MD_MODEL_URL_BASE') is not None:
|
|
144
|
+
model_url_base = os.environ['MD_MODEL_URL_BASE']
|
|
145
|
+
print('Model URL base provided via environment variable: {}'.format(
|
|
146
|
+
model_url_base
|
|
147
|
+
))
|
|
148
|
+
if not model_url_base.endswith('/'):
|
|
149
|
+
model_url_base += '/'
|
|
150
|
+
|
|
151
|
+
# Maps canonical model version numbers to metadata
|
|
152
|
+
known_models = {
|
|
153
|
+
'v2.0.0':
|
|
154
|
+
{
|
|
155
|
+
'url':'https://lila.science/public/models/megadetector/megadetector_v2.pb',
|
|
156
|
+
'typical_detection_threshold':0.8,
|
|
157
|
+
'conservative_detection_threshold':0.3,
|
|
158
|
+
'model_type':'tf',
|
|
159
|
+
'normalized_typical_inference_speed':1.0/3.5
|
|
160
|
+
},
|
|
161
|
+
'v3.0.0':
|
|
162
|
+
{
|
|
163
|
+
'url':'https://lila.science/public/models/megadetector/megadetector_v3.pb',
|
|
164
|
+
'typical_detection_threshold':0.8,
|
|
165
|
+
'conservative_detection_threshold':0.3,
|
|
166
|
+
'model_type':'tf',
|
|
167
|
+
'normalized_typical_inference_speed':1.0/3.5
|
|
168
|
+
},
|
|
169
|
+
'v4.1.0':
|
|
170
|
+
{
|
|
171
|
+
'url':'https://github.com/agentmorris/MegaDetector/releases/download/v4.1/md_v4.1.0.pb',
|
|
172
|
+
'typical_detection_threshold':0.8,
|
|
173
|
+
'conservative_detection_threshold':0.3,
|
|
174
|
+
'model_type':'tf',
|
|
175
|
+
'normalized_typical_inference_speed':1.0/3.5
|
|
176
|
+
},
|
|
177
|
+
'v5a.0.0':
|
|
178
|
+
{
|
|
179
|
+
'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5a.0.0.pt',
|
|
180
|
+
'typical_detection_threshold':0.2,
|
|
181
|
+
'conservative_detection_threshold':0.05,
|
|
182
|
+
'image_size':1280,
|
|
183
|
+
'model_type':'yolov5',
|
|
184
|
+
'normalized_typical_inference_speed':1.0,
|
|
185
|
+
'md5':'ec1d7603ec8cf642d6e0cd008ba2be8c'
|
|
186
|
+
},
|
|
187
|
+
'v5b.0.0':
|
|
188
|
+
{
|
|
189
|
+
'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.0.pt',
|
|
190
|
+
'typical_detection_threshold':0.2,
|
|
191
|
+
'conservative_detection_threshold':0.05,
|
|
192
|
+
'image_size':1280,
|
|
193
|
+
'model_type':'yolov5',
|
|
194
|
+
'normalized_typical_inference_speed':1.0,
|
|
195
|
+
'md5':'bc235e73f53c5c95e66ea0d1b2cbf542'
|
|
196
|
+
},
|
|
197
|
+
'v5a.0.1':
|
|
198
|
+
{
|
|
199
|
+
'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5a.0.1.pt',
|
|
200
|
+
'typical_detection_threshold':0.2,
|
|
201
|
+
'conservative_detection_threshold':0.05,
|
|
202
|
+
'image_size':1280,
|
|
203
|
+
'model_type':'yolov5',
|
|
204
|
+
'normalized_typical_inference_speed':1.0,
|
|
205
|
+
'md5':'60f8e7ec1308554df258ed1f4040bc4f'
|
|
206
|
+
},
|
|
207
|
+
'v5b.0.1':
|
|
208
|
+
{
|
|
209
|
+
'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.1.pt',
|
|
210
|
+
'typical_detection_threshold':0.2,
|
|
211
|
+
'conservative_detection_threshold':0.05,
|
|
212
|
+
'image_size':1280,
|
|
213
|
+
'model_type':'yolov5',
|
|
214
|
+
'normalized_typical_inference_speed':1.0,
|
|
215
|
+
'md5':'f17ed6fedfac2e403606a08c89984905'
|
|
216
|
+
},
|
|
217
|
+
'v1000.0.0-redwood':
|
|
218
|
+
{
|
|
219
|
+
'url':model_url_base + 'md_v1000.0.0-redwood.pt',
|
|
220
|
+
'normalized_typical_inference_speed':1.0,
|
|
221
|
+
'md5':'74474b3aec9cf1a990da38b37ddf9197',
|
|
222
|
+
'typical_detection_threshold':0.3
|
|
223
|
+
},
|
|
224
|
+
'v1000.0.0-spruce':
|
|
225
|
+
{
|
|
226
|
+
'url':model_url_base + 'md_v1000.0.0-spruce.pt',
|
|
227
|
+
'normalized_typical_inference_speed':12.7,
|
|
228
|
+
'md5':'1c9d1d2b3ba54931881471fdd508e6f2'
|
|
229
|
+
},
|
|
230
|
+
'v1000.0.0-larch':
|
|
231
|
+
{
|
|
232
|
+
'url':model_url_base + 'md_v1000.0.0-larch.pt',
|
|
233
|
+
'normalized_typical_inference_speed':2.4,
|
|
234
|
+
'md5':'cab94ebd190c2278e12fb70ffd548b6d'
|
|
235
|
+
},
|
|
236
|
+
'v1000.0.0-cedar':
|
|
237
|
+
{
|
|
238
|
+
'url':model_url_base + 'md_v1000.0.0-cedar.pt',
|
|
239
|
+
'normalized_typical_inference_speed':2.0,
|
|
240
|
+
'md5':'3d6472c9b95ba687b59ebe255f7c576b'
|
|
241
|
+
},
|
|
242
|
+
'v1000.0.0-sorrel':
|
|
243
|
+
{
|
|
244
|
+
'url':model_url_base + 'md_v1000.0.0-sorrel.pt',
|
|
245
|
+
'normalized_typical_inference_speed':7.0,
|
|
246
|
+
'md5':'4339a2c8af7a381f18ded7ac2a4df03e'
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
DEFAULT_RENDERING_CONFIDENCE_THRESHOLD = known_models['v5a.0.0']['typical_detection_threshold']
|
|
251
|
+
DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD = 0.005
|
|
252
|
+
|
|
253
|
+
DEFAULT_BOX_THICKNESS = 4
|
|
254
|
+
DEFAULT_BOX_EXPANSION = 0
|
|
255
|
+
DEFAULT_LABEL_FONT_SIZE = 16
|
|
256
|
+
DETECTION_FILENAME_INSERT = '_detections'
|
|
257
|
+
|
|
258
|
+
# Approximate inference speeds (in images per second) for MDv5 based on
|
|
259
|
+
# benchmarks, only used for reporting very coarse expectations about inference time.
|
|
260
|
+
device_token_to_mdv5_inference_speed = {
|
|
261
|
+
'4090':17.6,
|
|
262
|
+
'3090':11.4,
|
|
263
|
+
'3080':9.5,
|
|
264
|
+
'3050':4.2,
|
|
265
|
+
'P2000':2.1,
|
|
266
|
+
# These are written this way because they're MDv4 benchmarks, and MDv5
|
|
267
|
+
# is around 3.5x faster than MDv4.
|
|
268
|
+
'V100':2.79*3.5,
|
|
269
|
+
'2080':2.3*3.5,
|
|
270
|
+
'2060':1.6*3.5
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
#%% Utility functions
|
|
275
|
+
|
|
276
|
+
def get_detector_metadata_from_version_string(detector_version):
|
|
277
|
+
"""
|
|
278
|
+
Given a MegaDetector version string (e.g. "v4.1.0"), returns the metadata for
|
|
279
|
+
the model. Used for writing standard defaults to batch output files.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
detector_version (str): a detection version string, e.g. "v4.1.0", which you
|
|
283
|
+
can extract from a filename using get_detector_version_from_filename()
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
dict: metadata for this model, suitable for writing to a MD output file
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
if detector_version not in known_models:
|
|
290
|
+
print('Warning: no metadata for unknown detector version {}'.format(detector_version))
|
|
291
|
+
default_detector_metadata = {
|
|
292
|
+
'megadetector_version':'unknown',
|
|
293
|
+
'typical_detection_threshold':0.2,
|
|
294
|
+
'conservative_detection_threshold':0.1
|
|
295
|
+
}
|
|
296
|
+
return default_detector_metadata
|
|
297
|
+
else:
|
|
298
|
+
to_return = known_models[detector_version]
|
|
299
|
+
to_return['megadetector_version'] = detector_version
|
|
300
|
+
return to_return
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def get_detector_version_from_filename(detector_filename,
|
|
304
|
+
accept_first_match=True,
|
|
305
|
+
verbose=False):
|
|
306
|
+
r"""
|
|
307
|
+
Gets the canonical version number string of a detector from the model filename.
|
|
308
|
+
|
|
309
|
+
[detector_filename] will almost always end with one of the following:
|
|
310
|
+
|
|
311
|
+
* megadetector_v2.pb
|
|
312
|
+
* megadetector_v3.pb
|
|
313
|
+
* megadetector_v4.1 (not produced by run_detector_batch.py, only found in output files from
|
|
314
|
+
the deprecated Azure Batch API)
|
|
315
|
+
* md_v4.1.0.pb
|
|
316
|
+
* md_v5a.0.0.pt
|
|
317
|
+
* md_v5b.0.0.pt
|
|
318
|
+
|
|
319
|
+
This function identifies the version number as "v2.0.0", "v3.0.0", "v4.1.0",
|
|
320
|
+
"v4.1.0", "v5a.0.0", and "v5b.0.0", respectively. See known_models for the list
|
|
321
|
+
of valid version numbers.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
|
|
325
|
+
accept_first_match (bool, optional): if multiple candidates match the filename, choose the
|
|
326
|
+
first one, otherwise returns the string "multiple"
|
|
327
|
+
verbose (bool, optional): enable additional debug output
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
str: a detector version string, e.g. "v5a.0.0", or "multiple" if I'm confused
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
fn = os.path.basename(detector_filename).lower()
|
|
334
|
+
matches = []
|
|
335
|
+
for s in model_string_to_model_version.keys():
|
|
336
|
+
if s in fn:
|
|
337
|
+
matches.append(s)
|
|
338
|
+
if len(matches) == 0:
|
|
339
|
+
return 'unknown'
|
|
340
|
+
elif len(matches) > 1:
|
|
341
|
+
if accept_first_match:
|
|
342
|
+
return model_string_to_model_version[matches[0]]
|
|
343
|
+
else:
|
|
344
|
+
if verbose:
|
|
345
|
+
print('Warning: multiple MegaDetector versions for model file {}:'.format(detector_filename))
|
|
346
|
+
for s in matches:
|
|
347
|
+
print(s)
|
|
348
|
+
return 'multiple'
|
|
349
|
+
else:
|
|
350
|
+
return model_string_to_model_version[matches[0]]
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def get_detector_version_from_model_file(detector_filename,verbose=False):
|
|
354
|
+
"""
|
|
355
|
+
Gets the canonical detection version from a model file, preferably by reading it
|
|
356
|
+
from the file itself, otherwise based on the filename.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
|
|
360
|
+
verbose (bool, optional): enable additional debug output
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
str: a canonical detector version string, e.g. "v5a.0.0", or "unknown"
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
# Try to extract a version string from the filename
|
|
367
|
+
version_string_based_on_filename = get_detector_version_from_filename(
|
|
368
|
+
detector_filename, verbose=verbose)
|
|
369
|
+
if version_string_based_on_filename == 'unknown':
|
|
370
|
+
version_string_based_on_filename = None
|
|
371
|
+
|
|
372
|
+
# Try to extract a version string from the file itself; currently this is only
|
|
373
|
+
# a thing for PyTorch models
|
|
374
|
+
|
|
375
|
+
version_string_based_on_model_file = None
|
|
376
|
+
|
|
377
|
+
if detector_filename.endswith('.pt') or detector_filename.endswith('.zip'):
|
|
378
|
+
|
|
379
|
+
from megadetector.detection.pytorch_detector import \
|
|
380
|
+
read_metadata_from_megadetector_model_file
|
|
381
|
+
metadata = read_metadata_from_megadetector_model_file(detector_filename,verbose=verbose)
|
|
382
|
+
|
|
383
|
+
if metadata is not None and isinstance(metadata,dict):
|
|
384
|
+
|
|
385
|
+
if 'metadata_format_version' not in metadata or \
|
|
386
|
+
not isinstance(metadata['metadata_format_version'],float):
|
|
387
|
+
|
|
388
|
+
print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
|
|
389
|
+
'but it doesn\'t have a valid format version number')
|
|
390
|
+
|
|
391
|
+
elif 'model_version_string' not in metadata or \
|
|
392
|
+
not isinstance(metadata['model_version_string'],str):
|
|
393
|
+
|
|
394
|
+
print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
|
|
395
|
+
'but it doesn\'t have a format model version string')
|
|
396
|
+
|
|
397
|
+
else:
|
|
398
|
+
|
|
399
|
+
version_string_based_on_model_file = metadata['model_version_string']
|
|
400
|
+
|
|
401
|
+
if version_string_based_on_model_file not in known_models:
|
|
402
|
+
print('Warning: unknown model version:\n\n{}\n\n...specified in file:\n\n{}'.format(
|
|
403
|
+
version_string_based_on_model_file,os.path.basename(detector_filename)))
|
|
404
|
+
|
|
405
|
+
# ...if there's metadata in this file
|
|
406
|
+
|
|
407
|
+
# ...if this looks like a PyTorch file
|
|
408
|
+
|
|
409
|
+
# If we got versions strings from the filename *and* the model file...
|
|
410
|
+
if (version_string_based_on_filename is not None) and \
|
|
411
|
+
(version_string_based_on_model_file is not None):
|
|
412
|
+
|
|
413
|
+
if version_string_based_on_filename != version_string_based_on_model_file:
|
|
414
|
+
# This is a one-off special case where models were re-released with different filenames
|
|
415
|
+
if (version_string_based_on_filename in ('v5a.0.1','v5b.0.1')) and \
|
|
416
|
+
(version_string_based_on_model_file in ('v5a.0.0','v5b.0.0')):
|
|
417
|
+
pass
|
|
418
|
+
else:
|
|
419
|
+
print(
|
|
420
|
+
'Warning: model version string in file:' + \
|
|
421
|
+
'\n\n{}\n\n...is:\n\n{}\n\n...but the filename implies:\n\n{}'.format(
|
|
422
|
+
os.path.basename(detector_filename),
|
|
423
|
+
version_string_based_on_model_file,
|
|
424
|
+
version_string_based_on_filename))
|
|
425
|
+
|
|
426
|
+
return version_string_based_on_model_file
|
|
427
|
+
|
|
428
|
+
# If we got version string from neither the filename nor the model file...
|
|
429
|
+
if (version_string_based_on_filename is None) and \
|
|
430
|
+
(version_string_based_on_model_file is None):
|
|
431
|
+
|
|
432
|
+
print('Warning: could not determine model version string for model file {}'.format(
|
|
433
|
+
detector_filename))
|
|
434
|
+
return None
|
|
435
|
+
|
|
436
|
+
elif version_string_based_on_filename is not None:
|
|
437
|
+
|
|
438
|
+
return version_string_based_on_filename
|
|
439
|
+
|
|
440
|
+
else:
|
|
441
|
+
|
|
442
|
+
assert version_string_based_on_model_file is not None
|
|
443
|
+
return version_string_based_on_model_file
|
|
444
|
+
|
|
445
|
+
# ...def get_detector_version_from_model_file(...)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def estimate_md_images_per_second(model_file, device_name=None):
|
|
449
|
+
r"""
|
|
450
|
+
Estimates how fast MegaDetector will run on a particular device, based on benchmarks.
|
|
451
|
+
Defaults to querying the current device. Returns None if no data is available for the current
|
|
452
|
+
card/model. Estimates only available for a small handful of GPUs. Uses an absurdly simple
|
|
453
|
+
lookup approach, e.g. if the string "4090" appears in the device name, congratulations,
|
|
454
|
+
you have an RTX 4090.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
model_file (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
|
|
458
|
+
device_name (str, optional): device name, e.g. blah-blah-4090-blah-blah
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
float: the approximate number of images this model version can process on this
|
|
462
|
+
device per second
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
if device_name is None:
|
|
466
|
+
try:
|
|
467
|
+
import torch
|
|
468
|
+
device_name = torch.cuda.get_device_name()
|
|
469
|
+
except Exception as e:
|
|
470
|
+
print('Error querying device name: {}'.format(e))
|
|
471
|
+
return None
|
|
472
|
+
|
|
473
|
+
# About how fast is this model compared to MDv5?
|
|
474
|
+
model_version = get_detector_version_from_model_file(model_file)
|
|
475
|
+
|
|
476
|
+
if model_version not in known_models.keys():
|
|
477
|
+
print('Could not estimate inference speed: error determining model version for model file {}'.format(
|
|
478
|
+
model_file))
|
|
479
|
+
return None
|
|
480
|
+
|
|
481
|
+
model_info = known_models[model_version]
|
|
482
|
+
|
|
483
|
+
if 'normalized_typical_inference_speed' not in model_info or \
|
|
484
|
+
model_info['normalized_typical_inference_speed'] is None:
|
|
485
|
+
print('No speed ratio available for model type {}'.format(model_version))
|
|
486
|
+
return None
|
|
487
|
+
|
|
488
|
+
normalized_inference_speed = model_info['normalized_typical_inference_speed']
|
|
489
|
+
|
|
490
|
+
# About how fast would MDv5 run on this device?
|
|
491
|
+
mdv5_inference_speed = None
|
|
492
|
+
for device_token in device_token_to_mdv5_inference_speed.keys():
|
|
493
|
+
if device_token in device_name:
|
|
494
|
+
mdv5_inference_speed = device_token_to_mdv5_inference_speed[device_token]
|
|
495
|
+
break
|
|
496
|
+
|
|
497
|
+
if mdv5_inference_speed is None:
|
|
498
|
+
print('No baseline speed estimate available for device {}'.format(device_name))
|
|
499
|
+
return None
|
|
500
|
+
|
|
501
|
+
return normalized_inference_speed * mdv5_inference_speed
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def get_typical_confidence_threshold_from_results(results):
|
|
505
|
+
"""
|
|
506
|
+
Given the .json data loaded from a MD results file, returns a typical confidence
|
|
507
|
+
threshold based on the detector version.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
results (dict or str): a dict of MD results, as it would be loaded from a MD results .json
|
|
511
|
+
file, or a .json filename
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
float: a sensible default threshold for this model
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
# Load results if necessary
|
|
518
|
+
if isinstance(results,str):
|
|
519
|
+
with open(results,'r') as f:
|
|
520
|
+
results = json.load(f)
|
|
521
|
+
|
|
522
|
+
default_threshold = None
|
|
523
|
+
|
|
524
|
+
# Best case: the .json file tells us the default threshold
|
|
525
|
+
if 'detector_metadata' in results['info'] and \
|
|
526
|
+
'typical_detection_threshold' in results['info']['detector_metadata']:
|
|
527
|
+
default_threshold = results['info']['detector_metadata']['typical_detection_threshold']
|
|
528
|
+
# Worst case: we don't even know what detector this is
|
|
529
|
+
elif ('detector' not in results['info']) or (results['info']['detector'] is None):
|
|
530
|
+
print('Warning: detector version not available in results file, using MDv5 defaults')
|
|
531
|
+
detector_metadata = get_detector_metadata_from_version_string('v5a.0.0')
|
|
532
|
+
default_threshold = detector_metadata['typical_detection_threshold']
|
|
533
|
+
# We know what detector this is, but it doesn't have a default threshold
|
|
534
|
+
# in the .json file
|
|
535
|
+
else:
|
|
536
|
+
print('Warning: detector metadata not available in results file, inferring from MD version')
|
|
537
|
+
try:
|
|
538
|
+
detector_filename = results['info']['detector']
|
|
539
|
+
detector_version = get_detector_version_from_filename(detector_filename)
|
|
540
|
+
detector_metadata = get_detector_metadata_from_version_string(detector_version)
|
|
541
|
+
if 'typical_detection_threshold' in detector_metadata:
|
|
542
|
+
default_threshold = detector_metadata['typical_detection_threshold']
|
|
543
|
+
except Exception:
|
|
544
|
+
pass
|
|
545
|
+
|
|
546
|
+
if default_threshold is None:
|
|
547
|
+
print('Could not determine threshold, using fallback threshold of {}'.format(
|
|
548
|
+
fallback_detection_threshold))
|
|
549
|
+
default_threshold = fallback_detection_threshold
|
|
550
|
+
|
|
551
|
+
return default_threshold
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def is_gpu_available(model_file):
|
|
555
|
+
r"""
|
|
556
|
+
Determines whether a GPU is available, importing PyTorch or TF depending on the extension
|
|
557
|
+
of model_file. Does not actually load model_file, just uses that to determine how to check
|
|
558
|
+
for GPU availability (PT vs. TF).
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
model_file (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
bool: whether a GPU is available
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
if model_file.endswith('.pb'):
|
|
568
|
+
import tensorflow.compat.v1 as tf # type: ignore
|
|
569
|
+
gpu_available = tf.test.is_gpu_available()
|
|
570
|
+
print('TensorFlow version:', tf.__version__)
|
|
571
|
+
print('tf.test.is_gpu_available:', gpu_available)
|
|
572
|
+
return gpu_available
|
|
573
|
+
if not model_file.endswith('.pt'):
|
|
574
|
+
print('Warning: could not determine environment from model file name, assuming PyTorch')
|
|
575
|
+
|
|
576
|
+
import torch
|
|
577
|
+
gpu_available = torch.cuda.is_available()
|
|
578
|
+
print('PyTorch reports {} available CUDA devices'.format(torch.cuda.device_count()))
|
|
579
|
+
if not gpu_available:
|
|
580
|
+
try:
|
|
581
|
+
# mps backend only available in torch >= 1.12.0
|
|
582
|
+
if torch.backends.mps.is_built and torch.backends.mps.is_available():
|
|
583
|
+
gpu_available = True
|
|
584
|
+
print('PyTorch reports Metal Performance Shaders are available')
|
|
585
|
+
except AttributeError:
|
|
586
|
+
pass
|
|
587
|
+
return gpu_available
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def load_detector(model_file,
|
|
591
|
+
force_cpu=False,
|
|
592
|
+
force_model_download=False,
|
|
593
|
+
detector_options=None,
|
|
594
|
+
verbose=False):
|
|
595
|
+
r"""
|
|
596
|
+
Loads a TF or PT detector, depending on the extension of model_file.
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
model_file (str): model filename (e.g. c:/x/z/md_v5a.0.0.pt) or known model
|
|
600
|
+
name (e.g. "MDV5A")
|
|
601
|
+
force_cpu (bool, optional): force the model to run on the CPU even if a GPU
|
|
602
|
+
is available
|
|
603
|
+
force_model_download (bool, optional): force downloading the model file if
|
|
604
|
+
a named model (e.g. "MDV5A") is supplied, even if the local file already
|
|
605
|
+
exists
|
|
606
|
+
detector_options (dict, optional): key/value pairs that are interpreted differently
|
|
607
|
+
by different detectors
|
|
608
|
+
verbose (bool, optional): enable additional debug output
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
object: loaded detector object
|
|
612
|
+
"""
|
|
613
|
+
|
|
614
|
+
# Possibly automatically download the model
|
|
615
|
+
model_file = try_download_known_detector(model_file,
|
|
616
|
+
force_download=force_model_download)
|
|
617
|
+
|
|
618
|
+
print('GPU available: {}'.format(is_gpu_available(model_file)))
|
|
619
|
+
|
|
620
|
+
start_time = time.time()
|
|
621
|
+
|
|
622
|
+
if model_file.endswith('.pb'):
|
|
623
|
+
|
|
624
|
+
from megadetector.detection.tf_detector import TFDetector
|
|
625
|
+
if force_cpu:
|
|
626
|
+
raise ValueError('force_cpu is not currently supported for TF detectors, ' + \
|
|
627
|
+
'use CUDA_VISIBLE_DEVICES=-1 instead')
|
|
628
|
+
detector = TFDetector(model_file, detector_options)
|
|
629
|
+
|
|
630
|
+
elif model_file.endswith('.pt'):
|
|
631
|
+
|
|
632
|
+
from megadetector.detection.pytorch_detector import PTDetector
|
|
633
|
+
|
|
634
|
+
# Prepare options specific to the PTDetector class
|
|
635
|
+
if detector_options is None:
|
|
636
|
+
detector_options = {}
|
|
637
|
+
if 'force_cpu' in detector_options:
|
|
638
|
+
if force_cpu != detector_options['force_cpu']:
|
|
639
|
+
print('Warning: over-riding force_cpu parameter ({}) based on detector_options ({})'.format(
|
|
640
|
+
force_cpu,detector_options['force_cpu']))
|
|
641
|
+
else:
|
|
642
|
+
detector_options['force_cpu'] = force_cpu
|
|
643
|
+
detector_options['use_model_native_classes'] = USE_MODEL_NATIVE_CLASSES
|
|
644
|
+
detector = PTDetector(model_file, detector_options, verbose=verbose)
|
|
645
|
+
|
|
646
|
+
else:
|
|
647
|
+
|
|
648
|
+
raise ValueError('Unrecognized model format: {}'.format(model_file))
|
|
649
|
+
|
|
650
|
+
elapsed = time.time() - start_time
|
|
651
|
+
|
|
652
|
+
if verbose:
|
|
653
|
+
print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
|
|
654
|
+
|
|
655
|
+
return detector
|
|
656
|
+
|
|
657
|
+
# ...def load_detector(...)
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
#%% Main function
|
|
661
|
+
|
|
662
|
+
def load_and_run_detector(model_file,
|
|
663
|
+
image_file_names,
|
|
664
|
+
output_dir,
|
|
665
|
+
render_confidence_threshold=DEFAULT_RENDERING_CONFIDENCE_THRESHOLD,
|
|
666
|
+
crop_images=False,
|
|
667
|
+
box_thickness=DEFAULT_BOX_THICKNESS,
|
|
668
|
+
box_expansion=DEFAULT_BOX_EXPANSION,
|
|
669
|
+
image_size=None,
|
|
670
|
+
label_font_size=DEFAULT_LABEL_FONT_SIZE,
|
|
671
|
+
augment=False,
|
|
672
|
+
force_model_download=False,
|
|
673
|
+
detector_options=None,
|
|
674
|
+
verbose=False):
|
|
675
|
+
r"""
|
|
676
|
+
Loads and runs a detector on target images, and visualizes the results.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
model_file (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt, or a known model
|
|
680
|
+
string, e.g. "MDV5A"
|
|
681
|
+
image_file_names (list): list of absolute paths to process
|
|
682
|
+
output_dir (str): folder to write visualized images to
|
|
683
|
+
render_confidence_threshold (float, optional): only render boxes for detections
|
|
684
|
+
above this threshold
|
|
685
|
+
crop_images (bool, optional): whether to crop detected objects to individual images
|
|
686
|
+
(default is to render images with boxes, rather than cropping)
|
|
687
|
+
box_thickness (float, optional): thickness in pixels for box rendering
|
|
688
|
+
box_expansion (float, optional): box expansion in pixels
|
|
689
|
+
image_size (tuple, optional): image size to use for inference, only mess with this
|
|
690
|
+
if (a) you're using a model other than MegaDetector or (b) you know what you're
|
|
691
|
+
doing
|
|
692
|
+
label_font_size (float, optional): font size to use for displaying class names
|
|
693
|
+
and confidence values in the rendered images
|
|
694
|
+
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
695
|
+
force_model_download (bool, optional): force downloading the model file if
|
|
696
|
+
a named model (e.g. "MDV5A") is supplied, even if the local file already
|
|
697
|
+
exists
|
|
698
|
+
detector_options (dict, optional): key/value pairs that are interpreted differently
|
|
699
|
+
by different detectors
|
|
700
|
+
verbose (bool, optional): enable additional debug output
|
|
701
|
+
"""
|
|
702
|
+
|
|
703
|
+
if len(image_file_names) == 0:
|
|
704
|
+
print('Warning: no files available')
|
|
705
|
+
return
|
|
706
|
+
|
|
707
|
+
# Possibly automatically download the model
|
|
708
|
+
model_file = try_download_known_detector(model_file,
|
|
709
|
+
force_download=force_model_download,
|
|
710
|
+
verbose=verbose)
|
|
711
|
+
|
|
712
|
+
detector = load_detector(model_file,
|
|
713
|
+
detector_options=detector_options,
|
|
714
|
+
verbose=verbose)
|
|
715
|
+
|
|
716
|
+
detection_results = []
|
|
717
|
+
time_load = []
|
|
718
|
+
time_infer = []
|
|
719
|
+
|
|
720
|
+
# Dictionary mapping output file names to a collision-avoidance count.
|
|
721
|
+
#
|
|
722
|
+
# Since we'll be writing a bunch of files to the same folder, we rename
|
|
723
|
+
# as necessary to avoid collisions.
|
|
724
|
+
output_filename_collision_counts = {}
|
|
725
|
+
|
|
726
|
+
def input_file_to_detection_file(fn, crop_index=-1):
|
|
727
|
+
"""
|
|
728
|
+
Creates unique file names for output files.
|
|
729
|
+
|
|
730
|
+
This function does 3 things:
|
|
731
|
+
1) If the --crop flag is used, then each input image may produce several output
|
|
732
|
+
crops. For example, if foo.jpg has 3 detections, then this function should
|
|
733
|
+
get called 3 times, with crop_index taking on 0, 1, then 2. Each time, this
|
|
734
|
+
function appends crop_index to the filename, resulting in
|
|
735
|
+
foo_crop00_detections.jpg
|
|
736
|
+
foo_crop01_detections.jpg
|
|
737
|
+
foo_crop02_detections.jpg
|
|
738
|
+
|
|
739
|
+
2) If the --recursive flag is used, then the same file (base)name may appear
|
|
740
|
+
multiple times. However, we output into a single flat folder. To avoid
|
|
741
|
+
filename collisions, we prepend an integer prefix to duplicate filenames:
|
|
742
|
+
foo_crop00_detections.jpg
|
|
743
|
+
0000_foo_crop00_detections.jpg
|
|
744
|
+
0001_foo_crop00_detections.jpg
|
|
745
|
+
|
|
746
|
+
3) Prepends the output directory:
|
|
747
|
+
out_dir/foo_crop00_detections.jpg
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
fn: str, filename
|
|
751
|
+
crop_index: int, crop number
|
|
752
|
+
|
|
753
|
+
Returns: output file path
|
|
754
|
+
"""
|
|
755
|
+
|
|
756
|
+
fn = os.path.basename(fn).lower()
|
|
757
|
+
name, ext = os.path.splitext(fn)
|
|
758
|
+
if crop_index >= 0:
|
|
759
|
+
name += '_crop{:0>2d}'.format(crop_index)
|
|
760
|
+
fn = '{}{}{}'.format(name, DETECTION_FILENAME_INSERT, '.jpg')
|
|
761
|
+
if fn in output_filename_collision_counts:
|
|
762
|
+
n_collisions = output_filename_collision_counts[fn]
|
|
763
|
+
fn_original = fn
|
|
764
|
+
fn = '{:0>4d}'.format(n_collisions) + '_' + fn
|
|
765
|
+
output_filename_collision_counts[fn_original] += 1
|
|
766
|
+
else:
|
|
767
|
+
output_filename_collision_counts[fn] = 0
|
|
768
|
+
fn = os.path.join(output_dir, fn)
|
|
769
|
+
return fn
|
|
770
|
+
|
|
771
|
+
# ...def input_file_to_detection_file()
|
|
772
|
+
|
|
773
|
+
for im_file in tqdm(image_file_names):
|
|
774
|
+
|
|
775
|
+
try:
|
|
776
|
+
start_time = time.time()
|
|
777
|
+
|
|
778
|
+
image = vis_utils.load_image(im_file)
|
|
779
|
+
|
|
780
|
+
elapsed = time.time() - start_time
|
|
781
|
+
time_load.append(elapsed)
|
|
782
|
+
|
|
783
|
+
except Exception as e:
|
|
784
|
+
print('Image {} cannot be loaded, error: {}'.format(im_file, str(e)))
|
|
785
|
+
result = {
|
|
786
|
+
'file': im_file,
|
|
787
|
+
'failure': FAILURE_IMAGE_OPEN
|
|
788
|
+
}
|
|
789
|
+
detection_results.append(result)
|
|
790
|
+
continue
|
|
791
|
+
|
|
792
|
+
try:
|
|
793
|
+
start_time = time.time()
|
|
794
|
+
|
|
795
|
+
result = detector.generate_detections_one_image(
|
|
796
|
+
image,
|
|
797
|
+
im_file,
|
|
798
|
+
detection_threshold=DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
|
|
799
|
+
image_size=image_size,
|
|
800
|
+
augment=augment)
|
|
801
|
+
detection_results.append(result)
|
|
802
|
+
|
|
803
|
+
elapsed = time.time() - start_time
|
|
804
|
+
time_infer.append(elapsed)
|
|
805
|
+
|
|
806
|
+
except Exception as e:
|
|
807
|
+
print('An error occurred while running the detector on image {}: {}'.format(
|
|
808
|
+
im_file, str(e)))
|
|
809
|
+
continue
|
|
810
|
+
|
|
811
|
+
try:
|
|
812
|
+
if crop_images:
|
|
813
|
+
|
|
814
|
+
images_cropped = vis_utils.crop_image(result['detections'], image,
|
|
815
|
+
confidence_threshold=render_confidence_threshold,
|
|
816
|
+
expansion=box_expansion)
|
|
817
|
+
|
|
818
|
+
for i_crop, cropped_image in enumerate(images_cropped):
|
|
819
|
+
output_full_path = input_file_to_detection_file(im_file, i_crop)
|
|
820
|
+
cropped_image.save(output_full_path)
|
|
821
|
+
|
|
822
|
+
else:
|
|
823
|
+
|
|
824
|
+
# Image is modified in place
|
|
825
|
+
vis_utils.render_detection_bounding_boxes(result['detections'], image,
|
|
826
|
+
label_map=DEFAULT_DETECTOR_LABEL_MAP,
|
|
827
|
+
confidence_threshold=render_confidence_threshold,
|
|
828
|
+
thickness=box_thickness, expansion=box_expansion,
|
|
829
|
+
label_font_size=label_font_size,
|
|
830
|
+
box_sort_order='confidence')
|
|
831
|
+
output_full_path = input_file_to_detection_file(im_file)
|
|
832
|
+
image.save(output_full_path)
|
|
833
|
+
|
|
834
|
+
except Exception as e:
|
|
835
|
+
print('Visualizing results on the image {} failed. Exception: {}'.format(im_file, e))
|
|
836
|
+
continue
|
|
837
|
+
|
|
838
|
+
# ...for each image
|
|
839
|
+
|
|
840
|
+
ave_time_load = statistics.mean(time_load)
|
|
841
|
+
ave_time_infer = statistics.mean(time_infer)
|
|
842
|
+
if len(time_load) > 1 and len(time_infer) > 1:
|
|
843
|
+
std_dev_time_load = humanfriendly.format_timespan(statistics.stdev(time_load))
|
|
844
|
+
std_dev_time_infer = humanfriendly.format_timespan(statistics.stdev(time_infer))
|
|
845
|
+
else:
|
|
846
|
+
std_dev_time_load = 'not available (<=1 image processed)'
|
|
847
|
+
std_dev_time_infer = 'not available (<=1 image processed)'
|
|
848
|
+
print('On average, for each image,')
|
|
849
|
+
print('- loading took {}, std dev is {}'.format(humanfriendly.format_timespan(ave_time_load),
|
|
850
|
+
std_dev_time_load))
|
|
851
|
+
print('- inference took {}, std dev is {}'.format(humanfriendly.format_timespan(ave_time_infer),
|
|
852
|
+
std_dev_time_infer))
|
|
853
|
+
|
|
854
|
+
# ...def load_and_run_detector()
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
def _validate_zip_file(file_path, file_description='file'):
|
|
858
|
+
"""
|
|
859
|
+
Validates that a .pt file is a valid zip file.
|
|
860
|
+
|
|
861
|
+
Args:
|
|
862
|
+
file_path (str): path to the file to validate
|
|
863
|
+
file_description (str): descriptive string for error messages
|
|
864
|
+
|
|
865
|
+
Returns:
|
|
866
|
+
bool: True if valid, False otherwise
|
|
867
|
+
"""
|
|
868
|
+
try:
|
|
869
|
+
with zipfile.ZipFile(file_path, 'r') as zipf:
|
|
870
|
+
corrupt_file = zipf.testzip()
|
|
871
|
+
if corrupt_file is not None:
|
|
872
|
+
print('{} {} contains at least one corrupt file: {}'.format(
|
|
873
|
+
file_description.capitalize(), file_path, corrupt_file))
|
|
874
|
+
return False
|
|
875
|
+
return True
|
|
876
|
+
except (zipfile.BadZipFile, zipfile.LargeZipFile) as e:
|
|
877
|
+
print('{} {} appears to be corrupted (bad zip): {}'.format(
|
|
878
|
+
file_description.capitalize(), file_path, str(e)))
|
|
879
|
+
return False
|
|
880
|
+
except Exception as e:
|
|
881
|
+
print('Error validating {}: {}'.format(file_description, str(e)))
|
|
882
|
+
return False
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
def _validate_md5_hash(file_path, expected_hash, file_description='file'):
|
|
886
|
+
"""
|
|
887
|
+
Validates that a file has the expected MD5 hash.
|
|
888
|
+
|
|
889
|
+
Args:
|
|
890
|
+
file_path (str): path to the file to validate
|
|
891
|
+
expected_hash (str): expected MD5 hash
|
|
892
|
+
file_description (str): descriptive string for error messages
|
|
893
|
+
|
|
894
|
+
Returns:
|
|
895
|
+
bool: True if hash matches, False otherwise
|
|
896
|
+
"""
|
|
897
|
+
try:
|
|
898
|
+
actual_hash = compute_file_hash(file_path, algorithm='md5').lower()
|
|
899
|
+
expected_hash = expected_hash.lower()
|
|
900
|
+
if actual_hash != expected_hash:
|
|
901
|
+
print('{} {} has incorrect hash. Expected: {}, Actual: {}'.format(
|
|
902
|
+
file_description.capitalize(), file_path, expected_hash, actual_hash))
|
|
903
|
+
return False
|
|
904
|
+
return True
|
|
905
|
+
except Exception as e:
|
|
906
|
+
print('Error computing hash for {}: {}'.format(file_description, str(e)))
|
|
907
|
+
return False
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def _download_model(model_name,force_download=False):
|
|
911
|
+
"""
|
|
912
|
+
Downloads one of the known models to local temp space if it hasn't already been downloaded.
|
|
913
|
+
|
|
914
|
+
Args:
|
|
915
|
+
model_name (str): a known model string, e.g. "MDV5A". Returns None if this string is not
|
|
916
|
+
a known model name.
|
|
917
|
+
force_download (bool, optional): whether to download the model even if the local target
|
|
918
|
+
file already exists
|
|
919
|
+
"""
|
|
920
|
+
|
|
921
|
+
model_tempdir = os.path.join(tempfile.gettempdir(), 'megadetector_models')
|
|
922
|
+
os.makedirs(model_tempdir,exist_ok=True)
|
|
923
|
+
|
|
924
|
+
# This is a lazy fix to an issue... if multiple users run this script, the
|
|
925
|
+
# "megadetector_models" folder is owned by the first person who creates it, and others
|
|
926
|
+
# can't write to it. I could create uniquely-named folders, but I philosophically prefer
|
|
927
|
+
# to put all the individual UUID-named folders within a larger folder, so as to be a
|
|
928
|
+
# good tempdir citizen. So, the lazy fix is to make this world-writable.
|
|
929
|
+
try:
|
|
930
|
+
os.chmod(model_tempdir,0o777)
|
|
931
|
+
except Exception:
|
|
932
|
+
pass
|
|
933
|
+
if model_name.lower() not in known_models:
|
|
934
|
+
print('Unrecognized downloadable model {}'.format(model_name))
|
|
935
|
+
return None
|
|
936
|
+
|
|
937
|
+
model_info = known_models[model_name.lower()]
|
|
938
|
+
url = model_info['url']
|
|
939
|
+
destination_filename = os.path.join(model_tempdir,url.split('/')[-1])
|
|
940
|
+
|
|
941
|
+
# Check whether the file already exists, in which case we want to validate it
|
|
942
|
+
if os.path.exists(destination_filename) and not force_download:
|
|
943
|
+
|
|
944
|
+
# Only validate .pt files, not .pb files
|
|
945
|
+
if destination_filename.endswith('.pt'):
|
|
946
|
+
|
|
947
|
+
is_valid = True
|
|
948
|
+
|
|
949
|
+
# Check whether the file is a valid zip file (.pt files are zip files in disguise)
|
|
950
|
+
if not _validate_zip_file(destination_filename,
|
|
951
|
+
'existing model file'):
|
|
952
|
+
is_valid = False
|
|
953
|
+
|
|
954
|
+
# Check MD5 hash if available
|
|
955
|
+
if is_valid and \
|
|
956
|
+
('md5' in model_info) and \
|
|
957
|
+
(model_info['md5'] is not None) and \
|
|
958
|
+
(len(model_info['md5'].strip()) > 0):
|
|
959
|
+
|
|
960
|
+
if not _validate_md5_hash(destination_filename, model_info['md5'],
|
|
961
|
+
'existing model file'):
|
|
962
|
+
is_valid = False
|
|
963
|
+
|
|
964
|
+
# If validation failed, delete the corrupted file and re-download
|
|
965
|
+
if not is_valid:
|
|
966
|
+
print('Deleting corrupted model file and re-downloading: {}'.format(
|
|
967
|
+
destination_filename))
|
|
968
|
+
try:
|
|
969
|
+
os.remove(destination_filename)
|
|
970
|
+
# This should be a no-op at this point, but it can't hurt
|
|
971
|
+
force_download = True
|
|
972
|
+
except Exception as e:
|
|
973
|
+
print('Warning: failed to delete corrupted file {}: {}'.format(
|
|
974
|
+
destination_filename, str(e)))
|
|
975
|
+
# Continue with download attempt anyway, setting force_download to True
|
|
976
|
+
force_download = True
|
|
977
|
+
else:
|
|
978
|
+
print('Model {} already exists and is valid at {}'.format(
|
|
979
|
+
model_name, destination_filename))
|
|
980
|
+
return destination_filename
|
|
981
|
+
|
|
982
|
+
# Download the model
|
|
983
|
+
try:
|
|
984
|
+
local_file = download_url(url,
|
|
985
|
+
destination_filename=destination_filename,
|
|
986
|
+
progress_updater=None,
|
|
987
|
+
force_download=force_download,
|
|
988
|
+
verbose=True)
|
|
989
|
+
except Exception as e:
|
|
990
|
+
print('Error downloading model {} from {}: {}'.format(model_name, url, str(e)))
|
|
991
|
+
raise
|
|
992
|
+
|
|
993
|
+
# Validate the downloaded file if it's a .pt file
|
|
994
|
+
if local_file and local_file.endswith('.pt'):
|
|
995
|
+
|
|
996
|
+
# Check if the downloaded file is a valid zip file
|
|
997
|
+
if not _validate_zip_file(local_file, "downloaded model file"):
|
|
998
|
+
# Clean up the corrupted download
|
|
999
|
+
try:
|
|
1000
|
+
os.remove(local_file)
|
|
1001
|
+
except Exception:
|
|
1002
|
+
pass
|
|
1003
|
+
return None
|
|
1004
|
+
|
|
1005
|
+
# Check MD5 hash if available
|
|
1006
|
+
if ('md5' in model_info) and \
|
|
1007
|
+
(model_info['md5'] is not None) and \
|
|
1008
|
+
(len(model_info['md5'].strip()) > 0):
|
|
1009
|
+
|
|
1010
|
+
if not _validate_md5_hash(local_file, model_info['md5'], "downloaded model file"):
|
|
1011
|
+
# Clean up the corrupted download
|
|
1012
|
+
try:
|
|
1013
|
+
os.remove(local_file)
|
|
1014
|
+
except Exception:
|
|
1015
|
+
pass
|
|
1016
|
+
return None
|
|
1017
|
+
|
|
1018
|
+
print('Model {} available at {}'.format(model_name,local_file))
|
|
1019
|
+
return local_file
|
|
1020
|
+
|
|
1021
|
+
# ...def _download_model(...)
|
|
1022
|
+
|
|
1023
|
+
def try_download_known_detector(detector_file,force_download=False,verbose=False):
|
|
1024
|
+
"""
|
|
1025
|
+
Checks whether detector_file is really the name of a known model, in which case we will
|
|
1026
|
+
either read the actual filename from the corresponding environment variable or download
|
|
1027
|
+
(if necessary) to local temp space. Otherwise just returns the input string.
|
|
1028
|
+
|
|
1029
|
+
Args:
|
|
1030
|
+
detector_file (str): a known model string (e.g. "MDV5A"), or any other string (in which
|
|
1031
|
+
case this function is a no-op)
|
|
1032
|
+
force_download (bool, optional): whether to download the model even if the local target
|
|
1033
|
+
file already exists
|
|
1034
|
+
verbose (bool, optional): enable additional debug output
|
|
1035
|
+
|
|
1036
|
+
Returns:
|
|
1037
|
+
str: the local filename to which the model was downloaded, or the same string that
|
|
1038
|
+
was passed in, if it's not recognized as a well-known model name
|
|
1039
|
+
"""
|
|
1040
|
+
|
|
1041
|
+
model_string = detector_file.lower()
|
|
1042
|
+
|
|
1043
|
+
# If this is a short model string (e.g. "MDV5A"), convert to a canonical version
|
|
1044
|
+
# string (e.g. "v5a.0.0")
|
|
1045
|
+
if model_string in model_string_to_model_version:
|
|
1046
|
+
|
|
1047
|
+
if verbose:
|
|
1048
|
+
print('Converting short string {} to canonical version string {}'.format(
|
|
1049
|
+
model_string,
|
|
1050
|
+
model_string_to_model_version[model_string]))
|
|
1051
|
+
model_string = model_string_to_model_version[model_string]
|
|
1052
|
+
|
|
1053
|
+
if model_string in known_models:
|
|
1054
|
+
|
|
1055
|
+
if detector_file in os.environ:
|
|
1056
|
+
fn = os.environ[detector_file]
|
|
1057
|
+
print('Reading MD location from environment variable {}: {}'.format(
|
|
1058
|
+
detector_file,fn))
|
|
1059
|
+
detector_file = fn
|
|
1060
|
+
else:
|
|
1061
|
+
detector_file = _download_model(model_string,force_download=force_download)
|
|
1062
|
+
|
|
1063
|
+
return detector_file
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
|
|
1068
|
+
#%% Command-line driver
|
|
1069
|
+
|
|
1070
|
+
def main(): # noqa
|
|
1071
|
+
|
|
1072
|
+
parser = argparse.ArgumentParser(
|
|
1073
|
+
description='Module to run an animal detection model on images')
|
|
1074
|
+
|
|
1075
|
+
parser.add_argument(
|
|
1076
|
+
'detector_file',
|
|
1077
|
+
help='Path detector model file (.pb or .pt). Can also be MDV4, MDV5A, or MDV5B to request automatic download.')
|
|
1078
|
+
|
|
1079
|
+
# Must specify either an image file or a directory
|
|
1080
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
1081
|
+
group.add_argument(
|
|
1082
|
+
'--image_file',
|
|
1083
|
+
type=str,
|
|
1084
|
+
default=None,
|
|
1085
|
+
help='Single file to process, mutually exclusive with --image_dir')
|
|
1086
|
+
group.add_argument(
|
|
1087
|
+
'--image_dir',
|
|
1088
|
+
type=str,
|
|
1089
|
+
default=None,
|
|
1090
|
+
help='Directory to search for images, with optional recursion by adding --recursive')
|
|
1091
|
+
|
|
1092
|
+
parser.add_argument(
|
|
1093
|
+
'--recursive',
|
|
1094
|
+
action='store_true',
|
|
1095
|
+
help='Recurse into directories, only meaningful if using --image_dir')
|
|
1096
|
+
|
|
1097
|
+
parser.add_argument(
|
|
1098
|
+
'--output_dir',
|
|
1099
|
+
type=str,
|
|
1100
|
+
default=None,
|
|
1101
|
+
help='Directory for output images (defaults to same as input)')
|
|
1102
|
+
|
|
1103
|
+
parser.add_argument(
|
|
1104
|
+
'--image_size',
|
|
1105
|
+
type=int,
|
|
1106
|
+
default=None,
|
|
1107
|
+
help=('Force image resizing to a (square) integer size (not recommended to change this)'))
|
|
1108
|
+
|
|
1109
|
+
parser.add_argument(
|
|
1110
|
+
'--threshold',
|
|
1111
|
+
type=float,
|
|
1112
|
+
default=DEFAULT_RENDERING_CONFIDENCE_THRESHOLD,
|
|
1113
|
+
help=('Confidence threshold between 0 and 1.0; only render' +
|
|
1114
|
+
' boxes above this confidence (defaults to {})'.format(
|
|
1115
|
+
DEFAULT_RENDERING_CONFIDENCE_THRESHOLD)))
|
|
1116
|
+
|
|
1117
|
+
parser.add_argument(
|
|
1118
|
+
'--crop',
|
|
1119
|
+
default=False,
|
|
1120
|
+
action='store_true',
|
|
1121
|
+
help=('If set, produces separate output images for each crop, '
|
|
1122
|
+
'rather than adding bounding boxes to the original image'))
|
|
1123
|
+
|
|
1124
|
+
parser.add_argument(
|
|
1125
|
+
'--augment',
|
|
1126
|
+
default=False,
|
|
1127
|
+
action='store_true',
|
|
1128
|
+
help=('Enable image augmentation'))
|
|
1129
|
+
|
|
1130
|
+
parser.add_argument(
|
|
1131
|
+
'--box_thickness',
|
|
1132
|
+
type=int,
|
|
1133
|
+
default=DEFAULT_BOX_THICKNESS,
|
|
1134
|
+
help=('Line width (in pixels) for box rendering (defaults to {})'.format(
|
|
1135
|
+
DEFAULT_BOX_THICKNESS)))
|
|
1136
|
+
|
|
1137
|
+
parser.add_argument(
|
|
1138
|
+
'--box_expansion',
|
|
1139
|
+
type=int,
|
|
1140
|
+
default=DEFAULT_BOX_EXPANSION,
|
|
1141
|
+
help=('Number of pixels to expand boxes by (defaults to {})'.format(
|
|
1142
|
+
DEFAULT_BOX_EXPANSION)))
|
|
1143
|
+
|
|
1144
|
+
parser.add_argument(
|
|
1145
|
+
'--label_font_size',
|
|
1146
|
+
type=int,
|
|
1147
|
+
default=DEFAULT_LABEL_FONT_SIZE,
|
|
1148
|
+
help=('Label font size (defaults to {})'.format(
|
|
1149
|
+
DEFAULT_LABEL_FONT_SIZE)))
|
|
1150
|
+
|
|
1151
|
+
parser.add_argument(
|
|
1152
|
+
'--process_likely_output_images',
|
|
1153
|
+
action='store_true',
|
|
1154
|
+
help=('By default, we skip images that end in {}, because they probably came from this script. '\
|
|
1155
|
+
.format(DETECTION_FILENAME_INSERT) + \
|
|
1156
|
+
'This option disables that behavior.'))
|
|
1157
|
+
|
|
1158
|
+
parser.add_argument(
|
|
1159
|
+
'--force_model_download',
|
|
1160
|
+
action='store_true',
|
|
1161
|
+
help=('If a named model (e.g. "MDV5A") is supplied, force a download of that model even if the ' +\
|
|
1162
|
+
'local file already exists.'))
|
|
1163
|
+
|
|
1164
|
+
parser.add_argument(
|
|
1165
|
+
'--verbose',
|
|
1166
|
+
action='store_true',
|
|
1167
|
+
help=('Enable additional debug output'))
|
|
1168
|
+
|
|
1169
|
+
parser.add_argument(
|
|
1170
|
+
'--detector_options',
|
|
1171
|
+
nargs='*',
|
|
1172
|
+
metavar='KEY=VALUE',
|
|
1173
|
+
default='',
|
|
1174
|
+
help='Detector-specific options, as a space-separated list of key-value pairs')
|
|
1175
|
+
|
|
1176
|
+
if len(sys.argv[1:]) == 0:
|
|
1177
|
+
parser.print_help()
|
|
1178
|
+
parser.exit()
|
|
1179
|
+
|
|
1180
|
+
args = parser.parse_args()
|
|
1181
|
+
detector_options = parse_kvp_list(args.detector_options)
|
|
1182
|
+
|
|
1183
|
+
# If the specified detector file is really the name of a known model, find
|
|
1184
|
+
# (and possibly download) that model
|
|
1185
|
+
args.detector_file = try_download_known_detector(args.detector_file,
|
|
1186
|
+
force_download=args.force_model_download)
|
|
1187
|
+
|
|
1188
|
+
assert os.path.exists(args.detector_file), 'detector file {} does not exist'.format(
|
|
1189
|
+
args.detector_file)
|
|
1190
|
+
assert 0.0 < args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1'
|
|
1191
|
+
|
|
1192
|
+
if args.image_file:
|
|
1193
|
+
image_file_names = [args.image_file]
|
|
1194
|
+
else:
|
|
1195
|
+
image_file_names = path_utils.find_images(args.image_dir, args.recursive)
|
|
1196
|
+
|
|
1197
|
+
# Optionally skip images that were probably generated by this script
|
|
1198
|
+
if not args.process_likely_output_images:
|
|
1199
|
+
image_file_names_valid = []
|
|
1200
|
+
for fn in image_file_names:
|
|
1201
|
+
if os.path.splitext(fn)[0].endswith(DETECTION_FILENAME_INSERT):
|
|
1202
|
+
print('Skipping likely output image {}'.format(fn))
|
|
1203
|
+
else:
|
|
1204
|
+
image_file_names_valid.append(fn)
|
|
1205
|
+
image_file_names = image_file_names_valid
|
|
1206
|
+
|
|
1207
|
+
print('Running detector on {} images...'.format(len(image_file_names)))
|
|
1208
|
+
|
|
1209
|
+
if args.output_dir:
|
|
1210
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
1211
|
+
else:
|
|
1212
|
+
if args.image_dir:
|
|
1213
|
+
args.output_dir = args.image_dir
|
|
1214
|
+
else:
|
|
1215
|
+
# but for a single image, args.image_dir is also None
|
|
1216
|
+
args.output_dir = os.path.dirname(args.image_file)
|
|
1217
|
+
|
|
1218
|
+
load_and_run_detector(model_file=args.detector_file,
|
|
1219
|
+
image_file_names=image_file_names,
|
|
1220
|
+
output_dir=args.output_dir,
|
|
1221
|
+
render_confidence_threshold=args.threshold,
|
|
1222
|
+
box_thickness=args.box_thickness,
|
|
1223
|
+
box_expansion=args.box_expansion,
|
|
1224
|
+
crop_images=args.crop,
|
|
1225
|
+
image_size=args.image_size,
|
|
1226
|
+
label_font_size=args.label_font_size,
|
|
1227
|
+
augment=args.augment,
|
|
1228
|
+
# If --force_model_download was specified, we already handled it
|
|
1229
|
+
force_model_download=False,
|
|
1230
|
+
detector_options=detector_options,
|
|
1231
|
+
verbose=args.verbose)
|
|
1232
|
+
|
|
1233
|
+
if __name__ == '__main__':
|
|
1234
|
+
main()
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
#%% Interactive driver(s)
|
|
1238
|
+
|
|
1239
|
+
if False:
|
|
1240
|
+
|
|
1241
|
+
pass
|
|
1242
|
+
|
|
1243
|
+
#%% Test model download
|
|
1244
|
+
|
|
1245
|
+
r"""
|
|
1246
|
+
cd i:\models\all_models_in_the_wild
|
|
1247
|
+
i:
|
|
1248
|
+
python -m http.server 8181
|
|
1249
|
+
"""
|
|
1250
|
+
|
|
1251
|
+
model_name = 'redwood'
|
|
1252
|
+
try_download_known_detector(model_name,force_download=True,verbose=True)
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
#%% Load and run detector
|
|
1256
|
+
|
|
1257
|
+
model_file = r'c:\temp\models\md_v4.1.0.pb'
|
|
1258
|
+
image_file_names = path_utils.find_images(r'c:\temp\demo_images\ssverymini')
|
|
1259
|
+
output_dir = r'c:\temp\demo_images\ssverymini'
|
|
1260
|
+
render_confidence_threshold = 0.8
|
|
1261
|
+
crop_images = True
|
|
1262
|
+
|
|
1263
|
+
load_and_run_detector(model_file=model_file,
|
|
1264
|
+
image_file_names=image_file_names,
|
|
1265
|
+
output_dir=output_dir,
|
|
1266
|
+
render_confidence_threshold=render_confidence_threshold,
|
|
1267
|
+
crop_images=crop_images)
|