megadetector 5.0.22__py3-none-any.whl → 5.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (38) hide show
  1. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +2 -3
  2. megadetector/classification/merge_classification_detection_output.py +2 -2
  3. megadetector/data_management/coco_to_labelme.py +2 -1
  4. megadetector/data_management/databases/integrity_check_json_db.py +15 -14
  5. megadetector/data_management/databases/subset_json_db.py +49 -21
  6. megadetector/data_management/mewc_to_md.py +340 -0
  7. megadetector/data_management/wi_to_md.py +41 -0
  8. megadetector/data_management/yolo_output_to_md_output.py +15 -8
  9. megadetector/detection/process_video.py +24 -7
  10. megadetector/detection/pytorch_detector.py +841 -160
  11. megadetector/detection/run_detector.py +340 -146
  12. megadetector/detection/run_detector_batch.py +306 -70
  13. megadetector/detection/run_inference_with_yolov5_val.py +61 -4
  14. megadetector/detection/tf_detector.py +6 -1
  15. megadetector/postprocessing/{combine_api_outputs.py → combine_batch_outputs.py} +10 -13
  16. megadetector/postprocessing/compare_batch_results.py +68 -6
  17. megadetector/postprocessing/md_to_labelme.py +7 -7
  18. megadetector/postprocessing/md_to_wi.py +40 -0
  19. megadetector/postprocessing/merge_detections.py +1 -1
  20. megadetector/postprocessing/postprocess_batch_results.py +10 -3
  21. megadetector/postprocessing/separate_detections_into_folders.py +32 -4
  22. megadetector/postprocessing/validate_batch_results.py +9 -4
  23. megadetector/utils/ct_utils.py +172 -57
  24. megadetector/utils/gpu_test.py +107 -0
  25. megadetector/utils/md_tests.py +363 -108
  26. megadetector/utils/path_utils.py +9 -2
  27. megadetector/utils/wi_utils.py +1794 -0
  28. megadetector/visualization/visualization_utils.py +82 -16
  29. megadetector/visualization/visualize_db.py +25 -7
  30. megadetector/visualization/visualize_detector_output.py +60 -13
  31. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/LICENSE +0 -0
  32. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/METADATA +129 -143
  33. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/RECORD +35 -33
  34. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/top_level.txt +0 -0
  35. megadetector/detection/detector_training/__init__.py +0 -0
  36. megadetector/detection/detector_training/model_main_tf2.py +0 -114
  37. megadetector/utils/torch_test.py +0 -32
  38. {megadetector-5.0.22.dist-info → megadetector-5.0.24.dist-info}/WHEEL +0 -0
@@ -38,6 +38,7 @@ from tqdm import tqdm
38
38
  from megadetector.utils import path_utils as path_utils
39
39
  from megadetector.visualization import visualization_utils as vis_utils
40
40
  from megadetector.utils.url_utils import download_url
41
+ from megadetector.utils.ct_utils import parse_kvp_list
41
42
 
42
43
  # ignoring all "PIL cannot read EXIF metainfo for the images" warnings
43
44
  warnings.filterwarnings('ignore', '(Possibly )?corrupt EXIF data', UserWarning)
@@ -48,13 +49,9 @@ warnings.filterwarnings('ignore', 'Metadata warning', UserWarning)
48
49
  # Numpy FutureWarnings from tensorflow import
49
50
  warnings.filterwarnings('ignore', category=FutureWarning)
50
51
 
51
- # Useful hack to force CPU inference
52
- # os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
53
-
54
-
55
- # An enumeration of failure reasons
56
- FAILURE_INFER = 'Failure inference'
57
- FAILURE_IMAGE_OPEN = 'Failure image access'
52
+ # String constants used for consistent reporting of processing errors
53
+ FAILURE_INFER = 'inference failure'
54
+ FAILURE_IMAGE_OPEN = 'image access failure'
58
55
 
59
56
  # Number of decimal places to round to for confidence and bbox coordinates
60
57
  CONF_DIGITS = 3
@@ -69,6 +66,9 @@ DEFAULT_DETECTOR_LABEL_MAP = {
69
66
 
70
67
  # Should we allow classes that don't look anything like the MegaDetector classes?
71
68
  #
69
+ # This flag needs to get set if you want to, for example, run an off-the-shelf
70
+ # YOLO model with this package.
71
+ #
72
72
  # By default, we error if we see unfamiliar classes.
73
73
  #
74
74
  # TODO: the use of a global variable to manage this was fine when this was really
@@ -76,33 +76,101 @@ DEFAULT_DETECTOR_LABEL_MAP = {
76
76
  # models other than MegaDetector.
77
77
  USE_MODEL_NATIVE_CLASSES = False
78
78
 
79
- # Each version of the detector is associated with some "typical" values
80
- # that are included in output files, so that downstream applications can
81
- # use them as defaults.
82
- DETECTOR_METADATA = {
79
+ # Maps a variety of strings that might occur in filenames to canonical version numbers.
80
+ #
81
+ # Order matters here.
82
+ model_string_to_model_version = {
83
+ 'mdv2':'v2.0.0',
84
+ 'mdv3':'v3.0.0',
85
+ 'mdv4':'v4.1.0',
86
+ 'mdv5a':'v5a.0.0',
87
+ 'mdv5b':'v5b.0.0',
88
+ 'v2':'v2.0.0',
89
+ 'v3':'v3.0.0',
90
+ 'v4':'v4.1.0',
91
+ 'v4.1':'v4.1.0',
92
+ 'v5a.0.0':'v5a.0.0',
93
+ 'v5b.0.0':'v5b.0.0',
94
+ 'redwood':'v1000.0.0-redwood',
95
+ 'spruce':'v1000.0.0-spruce',
96
+ 'cedar':'v1000.0.0-cedar',
97
+ 'larch':'v1000.0.0-larch',
98
+ 'default':'v5a.0.0',
99
+ 'megadetector':'v5a.0.0'
100
+ }
101
+
102
+ model_url_base = 'http://localhost:8181/'
103
+ assert model_url_base.endswith('/')
104
+
105
+ # Maps canonical model version numbers to metadata
106
+ known_models = {
83
107
  'v2.0.0':
84
- {'megadetector_version':'v2.0.0',
85
- 'typical_detection_threshold':0.8,
86
- 'conservative_detection_threshold':0.3},
108
+ {
109
+ 'url':'https://lila.science/public/models/megadetector/megadetector_v2.pb',
110
+ 'typical_detection_threshold':0.8,
111
+ 'conservative_detection_threshold':0.3,
112
+ 'model_type':'tf',
113
+ 'normalized_typical_inference_speed':1.0/3.5
114
+ },
87
115
  'v3.0.0':
88
- {'megadetector_version':'v3.0.0',
89
- 'typical_detection_threshold':0.8,
90
- 'conservative_detection_threshold':0.3},
116
+ {
117
+ 'url':'https://lila.science/public/models/megadetector/megadetector_v3.pb',
118
+ 'typical_detection_threshold':0.8,
119
+ 'conservative_detection_threshold':0.3,
120
+ 'model_type':'tf',
121
+ 'normalized_typical_inference_speed':1.0/3.5
122
+ },
91
123
  'v4.1.0':
92
- {'megadetector_version':'v4.1.0',
93
- 'typical_detection_threshold':0.8,
94
- 'conservative_detection_threshold':0.3},
124
+ {
125
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v4.1/md_v4.1.0.pb',
126
+ 'typical_detection_threshold':0.8,
127
+ 'conservative_detection_threshold':0.3,
128
+ 'model_type':'tf',
129
+ 'normalized_typical_inference_speed':1.0/3.5
130
+ },
95
131
  'v5a.0.0':
96
- {'megadetector_version':'v5a.0.0',
97
- 'typical_detection_threshold':0.2,
98
- 'conservative_detection_threshold':0.05},
132
+ {
133
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5a.0.0.pt',
134
+ 'typical_detection_threshold':0.2,
135
+ 'conservative_detection_threshold':0.05,
136
+ 'image_size':1280,
137
+ 'model_type':'yolov5',
138
+ 'normalized_typical_inference_speed':1.0
139
+ },
99
140
  'v5b.0.0':
100
- {'megadetector_version':'v5b.0.0',
101
- 'typical_detection_threshold':0.2,
102
- 'conservative_detection_threshold':0.05}
141
+ {
142
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.0.pt',
143
+ 'typical_detection_threshold':0.2,
144
+ 'conservative_detection_threshold':0.05,
145
+ 'image_size':1280,
146
+ 'model_type':'yolov5',
147
+ 'normalized_typical_inference_speed':1.0
148
+ },
149
+
150
+ # Fake values for testing
151
+ 'v1000.0.0-redwood':
152
+ {
153
+ 'normalized_typical_inference_speed':2.0,
154
+ 'url':model_url_base + 'md_v1000.0.0-redwood.pt'
155
+ },
156
+ 'v1000.0.0-spruce':
157
+ {
158
+ 'normalized_typical_inference_speed':3.0,
159
+ 'url':model_url_base + 'md_v1000.0.0-spruce.pt'
160
+ },
161
+ 'v1000.0.0-larch':
162
+ {
163
+ 'normalized_typical_inference_speed':4.0,
164
+ 'url':model_url_base + 'md_v1000.0.0-larch.pt'
165
+ },
166
+ 'v1000.0.0-cedar':
167
+ {
168
+ 'normalized_typical_inference_speed':5.0,
169
+ 'url':model_url_base + 'md_v1000.0.0-cedar.pt'
170
+ }
103
171
  }
104
172
 
105
- DEFAULT_RENDERING_CONFIDENCE_THRESHOLD = DETECTOR_METADATA['v5b.0.0']['typical_detection_threshold']
173
+ DEFAULT_RENDERING_CONFIDENCE_THRESHOLD = known_models['v5a.0.0']['typical_detection_threshold']
106
174
  DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD = 0.005
107
175
 
108
176
  DEFAULT_BOX_THICKNESS = 4
@@ -110,29 +178,6 @@ DEFAULT_BOX_EXPANSION = 0
110
178
  DEFAULT_LABEL_FONT_SIZE = 16
111
179
  DETECTION_FILENAME_INSERT = '_detections'
112
180
 
113
- # The model filenames "MDV5A", "MDV5B", and "MDV4" are special; they will trigger an
114
- # automatic model download to the system temp folder, or they will use the paths specified in the
115
- # $MDV4, $MDV5A, or $MDV5B environment variables if they exist.
116
- downloadable_models = {
117
- 'MDV2':'https://lila.science/public/models/megadetector/megadetector_v2.pb',
118
- 'MDV3':'https://lila.science/public/models/megadetector/megadetector_v3.pb',
119
- 'MDV4':'https://github.com/agentmorris/MegaDetector/releases/download/v4.1/md_v4.1.0.pb',
120
- 'MDV5A':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5a.0.0.pt',
121
- 'MDV5B':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.0.pt'
122
- }
123
-
124
- model_string_to_model_version = {
125
- 'v2':'v2.0.0',
126
- 'v3':'v3.0.0',
127
- 'v4.1':'v4.1.0',
128
- 'v5a.0.0':'v5a.0.0',
129
- 'v5b.0.0':'v5b.0.0',
130
- 'mdv5a':'v5a.0.0',
131
- 'mdv5b':'v5b.0.0',
132
- 'mdv4':'v4.1.0',
133
- 'mdv3':'v3.0.0'
134
- }
135
-
136
181
  # Approximate inference speeds (in images per second) for MDv5 based on
137
182
  # benchmarks, only used for reporting very coarse expectations about inference time.
138
183
  device_token_to_mdv5_inference_speed = {
@@ -145,35 +190,12 @@ device_token_to_mdv5_inference_speed = {
145
190
  # is around 3.5x faster than MDv4.
146
191
  'V100':2.79*3.5,
147
192
  '2080':2.3*3.5,
148
- '2060':1.6*3.5
193
+ '2060':1.6*3.5
149
194
  }
150
195
 
151
196
 
152
197
  #%% Utility functions
153
198
 
154
- def convert_to_tf_coords(array):
155
- """
156
- Converts a bounding box from [x1, y1, width, height] to [y1, x1, y2, x2]. This
157
- is mostly not helpful, this function only exists to maintain backwards compatibility
158
- in the synchronous API, which possibly zero people in the world are using.
159
-
160
- Args:
161
- array (list): a bounding box in [x,y,w,h] format
162
-
163
- Returns:
164
- list: a bounding box in [y1,x1,y2,x2] format
165
- """
166
-
167
- x1 = array[0]
168
- y1 = array[1]
169
- width = array[2]
170
- height = array[3]
171
- x2 = x1 + width
172
- y2 = y1 + height
173
-
174
- return [y1, x1, y2, x2]
175
-
176
-
177
199
  def get_detector_metadata_from_version_string(detector_version):
178
200
  """
179
201
  Given a MegaDetector version string (e.g. "v4.1.0"), returns the metadata for
@@ -187,7 +209,7 @@ def get_detector_metadata_from_version_string(detector_version):
187
209
  dict: metadata for this model, suitable for writing to a MD output file
188
210
  """
189
211
 
190
- if detector_version not in DETECTOR_METADATA:
212
+ if detector_version not in known_models:
191
213
  print('Warning: no metadata for unknown detector version {}'.format(detector_version))
192
214
  default_detector_metadata = {
193
215
  'megadetector_version':'unknown',
@@ -196,12 +218,16 @@ def get_detector_metadata_from_version_string(detector_version):
196
218
  }
197
219
  return default_detector_metadata
198
220
  else:
199
- return DETECTOR_METADATA[detector_version]
221
+ to_return = known_models[detector_version]
222
+ to_return['megadetector_version'] = detector_version
223
+ return to_return
200
224
 
201
225
 
202
- def get_detector_version_from_filename(detector_filename,accept_first_match=True):
226
+ def get_detector_version_from_filename(detector_filename,
227
+ accept_first_match=True,
228
+ verbose=False):
203
229
  r"""
204
- Gets the version number component of the detector from the model filename.
230
+ Gets the canonical version number string of a detector from the model filename.
205
231
 
206
232
  [detector_filename] will almost always end with one of the following:
207
233
 
@@ -213,12 +239,14 @@ def get_detector_version_from_filename(detector_filename,accept_first_match=True
213
239
  * md_v5b.0.0.pt
214
240
 
215
241
  This function identifies the version number as "v2.0.0", "v3.0.0", "v4.1.0",
216
- "v4.1.0", "v5a.0.0", and "v5b.0.0", respectively.
242
+ "v4.1.0", "v5a.0.0", and "v5b.0.0", respectively. See known_models for the list
243
+ of valid version numbers.
217
244
 
218
245
  Args:
219
246
  detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
220
247
  accept_first_match (bool, optional): if multiple candidates match the filename, choose the
221
248
  first one, otherwise returns the string "multiple"
249
+ verbose (bool, optional): enable additional debug output
222
250
 
223
251
  Returns:
224
252
  str: a detector version string, e.g. "v5a.0.0", or "multiple" if I'm confused
@@ -230,24 +258,114 @@ def get_detector_version_from_filename(detector_filename,accept_first_match=True
230
258
  if s in fn:
231
259
  matches.append(s)
232
260
  if len(matches) == 0:
233
- print('Warning: could not determine MegaDetector version for model file {}'.format(detector_filename))
234
261
  return 'unknown'
235
262
  elif len(matches) > 1:
236
- print('Warning: multiple MegaDetector versions for model file {}'.format(detector_filename))
237
263
  if accept_first_match:
238
264
  return model_string_to_model_version[matches[0]]
239
265
  else:
266
+ if verbose:
267
+ print('Warning: multiple MegaDetector versions for model file {}:'.format(detector_filename))
268
+ for s in matches:
269
+ print(s)
240
270
  return 'multiple'
241
271
  else:
242
272
  return model_string_to_model_version[matches[0]]
243
273
 
244
274
 
275
+ def get_detector_version_from_model_file(detector_filename,verbose=False):
276
+ """
277
+ Gets the canonical detection version from a model file, preferably by reading it
278
+ from the file itself, otherwise based on the filename.
279
+
280
+ Args:
281
+ detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
282
+ verbose (bool, optional): enable additional debug output
283
+
284
+ Returns:
285
+ str: a canonical detector version string, e.g. "v5a.0.0", or "unknown"
286
+ """
287
+
288
+ # Try to extract a version string from the filename
289
+ version_string_based_on_filename = get_detector_version_from_filename(
290
+ detector_filename, verbose=verbose)
291
+ if version_string_based_on_filename == 'unknown':
292
+ version_string_based_on_filename = None
293
+
294
+ # Try to extract a version string from the file itself; currently this is only
295
+ # a thing for PyTorch models
296
+
297
+ version_string_based_on_model_file = None
298
+
299
+ if detector_filename.endswith('.pt') or detector_filename.endswith('.zip'):
300
+
301
+ from megadetector.detection.pytorch_detector import \
302
+ read_metadata_from_megadetector_model_file
303
+ metadata = read_metadata_from_megadetector_model_file(detector_filename,verbose=verbose)
304
+
305
+ if metadata is not None and isinstance(metadata,dict):
306
+
307
+ if 'metadata_format_version' not in metadata or \
308
+ not isinstance(metadata['metadata_format_version'],float):
309
+
310
+ print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
311
+ 'but it doesn\'t have a valid format version number')
312
+
313
+ elif 'model_version_string' not in metadata or \
314
+ not isinstance(metadata['model_version_string'],str):
315
+
316
+ print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
317
+ 'but it doesn\'t have a format model version string')
318
+
319
+ else:
320
+
321
+ version_string_based_on_model_file = metadata['model_version_string']
322
+
323
+ if version_string_based_on_model_file not in known_models:
324
+ print('Warning: unknown model version {} specified in file {}'.format(
325
+ version_string_based_on_model_file,detector_filename))
326
+
327
+ # ...if there's metadata in this file
328
+
329
+ # ...if this looks like a PyTorch file
330
+
331
+ # If we got versions strings from the filename *and* the model file...
332
+ if (version_string_based_on_filename is not None) and \
333
+ (version_string_based_on_model_file is not None):
334
+
335
+ if version_string_based_on_filename != version_string_based_on_model_file:
336
+ print('Warning: model version string in file {} is {}, but the filename implies {}'.format(
337
+ detector_filename,
338
+ version_string_based_on_model_file,
339
+ version_string_based_on_filename))
340
+
341
+ return version_string_based_on_model_file
342
+
343
+ # If we got version string from neither the filename nor the model file...
344
+ if (version_string_based_on_filename is None) and \
345
+ (version_string_based_on_model_file is None):
346
+
347
+ print('Warning: could not determine model version string for model file {}'.format(
348
+ detector_filename))
349
+ return None
350
+
351
+ elif version_string_based_on_filename is not None:
352
+
353
+ return version_string_based_on_filename
354
+
355
+ else:
356
+
357
+ assert version_string_based_on_model_file is not None
358
+ return version_string_based_on_model_file
359
+
360
+ # ...def get_detector_version_from_model_file(...)
361
+
362
+
245
363
  def estimate_md_images_per_second(model_file, device_name=None):
246
364
  r"""
247
- Estimates how fast MegaDetector will run, based on benchmarks. Defaults to querying
248
- the current device. Returns None if no data is available for the current card/model.
249
- Estimates only available for a small handful of GPUs. Uses an absurdly simple lookup
250
- approach, e.g. if the string "4090" appears in the device name, congratulations,
365
+ Estimates how fast MegaDetector will run on a particular device, based on benchmarks.
366
+ Defaults to querying the current device. Returns None if no data is available for the current
367
+ card/model. Estimates only available for a small handful of GPUs. Uses an absurdly simple
368
+ lookup approach, e.g. if the string "4090" appears in the device name, congratulations,
251
369
  you have an RTX 4090.
252
370
 
253
371
  Args:
@@ -267,15 +385,24 @@ def estimate_md_images_per_second(model_file, device_name=None):
267
385
  print('Error querying device name: {}'.format(e))
268
386
  return None
269
387
 
270
- model_file = model_file.lower().strip()
271
- if model_file in model_string_to_model_version.values():
272
- model_version = model_file
273
- else:
274
- model_version = get_detector_version_from_filename(model_file)
275
- if model_version not in model_string_to_model_version.values():
276
- print('Error determining model version for model file {}'.format(model_file))
277
- return None
388
+ # About how fast is this model compared to MDv5?
389
+ model_version = get_detector_version_from_model_file(model_file)
390
+
391
+ if model_version not in known_models.keys():
392
+ print('Could not estimate inference speed: error determining model version for model file {}'.format(
393
+ model_file))
394
+ return None
395
+
396
+ model_info = known_models[model_version]
397
+
398
+ if 'normalized_typical_inference_speed' not in model_info or \
399
+ model_info['normalized_typical_inference_speed'] is None:
400
+ print('No speed ratio available for model type {}'.format(model_version))
401
+ return None
402
+
403
+ normalized_inference_speed = model_info['normalized_typical_inference_speed']
278
404
 
405
+ # About how fast would MDv5 run on this device?
279
406
  mdv5_inference_speed = None
280
407
  for device_token in device_token_to_mdv5_inference_speed.keys():
281
408
  if device_token in device_name:
@@ -283,16 +410,11 @@ def estimate_md_images_per_second(model_file, device_name=None):
283
410
  break
284
411
 
285
412
  if mdv5_inference_speed is None:
286
- print('No speed estimate available for {}'.format(device_name))
287
-
288
- if 'v5' in model_version:
289
- return mdv5_inference_speed
290
- elif 'v2' in model_version or 'v3' in model_version or 'v4' in model_version:
291
- return mdv5_inference_speed / 3.5
292
- else:
293
- print('Could not estimate inference speed for model file {}'.format(model_file))
413
+ print('No baseline speed estimate available for device {}'.format(device_name))
294
414
  return None
295
415
 
416
+ return normalized_inference_speed * mdv5_inference_speed
417
+
296
418
 
297
419
  def get_typical_confidence_threshold_from_results(results):
298
420
  """
@@ -342,25 +464,28 @@ def is_gpu_available(model_file):
342
464
  print('TensorFlow version:', tf.__version__)
343
465
  print('tf.test.is_gpu_available:', gpu_available)
344
466
  return gpu_available
345
- elif model_file.endswith('.pt'):
346
- import torch
347
- gpu_available = torch.cuda.is_available()
348
- print('PyTorch reports {} available CUDA devices'.format(torch.cuda.device_count()))
349
- if not gpu_available:
350
- try:
351
- # mps backend only available in torch >= 1.12.0
352
- if torch.backends.mps.is_built and torch.backends.mps.is_available():
353
- gpu_available = True
354
- print('PyTorch reports Metal Performance Shaders are available')
355
- except AttributeError:
356
- pass
357
- return gpu_available
358
- else:
359
- raise ValueError('Model {} does not have a recognized extension and is not a known model name'.\
360
- format(model_file))
361
-
362
-
363
- def load_detector(model_file, force_cpu=False, force_model_download=False):
467
+ if not model_file.endswith('.pt'):
468
+ print('Warning: could not determine environment from model file name, assuming PyTorch')
469
+
470
+ import torch
471
+ gpu_available = torch.cuda.is_available()
472
+ print('PyTorch reports {} available CUDA devices'.format(torch.cuda.device_count()))
473
+ if not gpu_available:
474
+ try:
475
+ # mps backend only available in torch >= 1.12.0
476
+ if torch.backends.mps.is_built and torch.backends.mps.is_available():
477
+ gpu_available = True
478
+ print('PyTorch reports Metal Performance Shaders are available')
479
+ except AttributeError:
480
+ pass
481
+ return gpu_available
482
+
483
+
484
+ def load_detector(model_file,
485
+ force_cpu=False,
486
+ force_model_download=False,
487
+ detector_options=None,
488
+ verbose=False):
364
489
  r"""
365
490
  Loads a TF or PT detector, depending on the extension of model_file.
366
491
 
@@ -372,6 +497,9 @@ def load_detector(model_file, force_cpu=False, force_model_download=False):
372
497
  force_model_download (bool, optional): force downloading the model file if
373
498
  a named model (e.g. "MDV5A") is supplied, even if the local file already
374
499
  exists
500
+ detector_options (dict, optional): key/value pairs that are interpreted differently
501
+ by different detectors
502
+ verbose (bool, optional): enable additional debug output
375
503
 
376
504
  Returns:
377
505
  object: loaded detector object
@@ -381,25 +509,48 @@ def load_detector(model_file, force_cpu=False, force_model_download=False):
381
509
  model_file = try_download_known_detector(model_file,
382
510
  force_download=force_model_download)
383
511
 
384
- print('GPU available: {}'.format(is_gpu_available(model_file)))
512
+ if verbose:
513
+ print('GPU available: {}'.format(is_gpu_available(model_file)))
385
514
 
386
515
  start_time = time.time()
516
+
387
517
  if model_file.endswith('.pb'):
518
+
388
519
  from megadetector.detection.tf_detector import TFDetector
389
520
  if force_cpu:
390
521
  raise ValueError('force_cpu is not currently supported for TF detectors, ' + \
391
522
  'use CUDA_VISIBLE_DEVICES=-1 instead')
392
- detector = TFDetector(model_file)
523
+ detector = TFDetector(model_file, detector_options)
524
+
393
525
  elif model_file.endswith('.pt'):
526
+
394
527
  from megadetector.detection.pytorch_detector import PTDetector
395
- detector = PTDetector(model_file, force_cpu, USE_MODEL_NATIVE_CLASSES)
528
+
529
+ # Prepare options specific to the PTDetector class
530
+ if detector_options is None:
531
+ detector_options = {}
532
+ if 'force_cpu' in detector_options:
533
+ if force_cpu != detector_options['force_cpu']:
534
+ print('Warning: over-riding force_cpu parameter ({}) based on detector_options ({})'.format(
535
+ force_cpu,detector_options['force_cpu']))
536
+ else:
537
+ detector_options['force_cpu'] = force_cpu
538
+ detector_options['use_model_native_classes'] = USE_MODEL_NATIVE_CLASSES
539
+ detector = PTDetector(model_file, detector_options, verbose=verbose)
540
+
396
541
  else:
542
+
397
543
  raise ValueError('Unrecognized model format: {}'.format(model_file))
544
+
398
545
  elapsed = time.time() - start_time
399
- print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
546
+
547
+ if verbose:
548
+ print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
400
549
 
401
550
  return detector
402
551
 
552
+ # ...def load_detector(...)
553
+
403
554
 
404
555
  #%% Main function
405
556
 
@@ -413,8 +564,8 @@ def load_and_run_detector(model_file,
413
564
  image_size=None,
414
565
  label_font_size=DEFAULT_LABEL_FONT_SIZE,
415
566
  augment=False,
416
- force_model_download=False
417
- ):
567
+ force_model_download=False,
568
+ detector_options=None):
418
569
  r"""
419
570
  Loads and runs a detector on target images, and visualizes the results.
420
571
 
@@ -438,6 +589,8 @@ def load_and_run_detector(model_file,
438
589
  force_model_download (bool, optional): force downloading the model file if
439
590
  a named model (e.g. "MDV5A") is supplied, even if the local file already
440
591
  exists
592
+ detector_options (dict, optional): key/value pairs that are interpreted differently
593
+ by different detectors
441
594
  """
442
595
 
443
596
  if len(image_file_names) == 0:
@@ -447,7 +600,7 @@ def load_and_run_detector(model_file,
447
600
  # Possibly automatically download the model
448
601
  model_file = try_download_known_detector(model_file, force_download=force_model_download)
449
602
 
450
- detector = load_detector(model_file)
603
+ detector = load_detector(model_file, detector_options=detector_options)
451
604
 
452
605
  detection_results = []
453
606
  time_load = []
@@ -516,7 +669,7 @@ def load_and_run_detector(model_file,
516
669
  time_load.append(elapsed)
517
670
 
518
671
  except Exception as e:
519
- print('Image {} cannot be loaded. Exception: {}'.format(im_file, e))
672
+ print('Image {} cannot be loaded, error: {}'.format(im_file, str(e)))
520
673
  result = {
521
674
  'file': im_file,
522
675
  'failure': FAILURE_IMAGE_OPEN
@@ -539,7 +692,7 @@ def load_and_run_detector(model_file,
539
692
  time_infer.append(elapsed)
540
693
 
541
694
  except Exception as e:
542
- print('An error occurred while running the detector on image {}. Exception: {}'.format(im_file, e))
695
+ print('An error occurred while running the detector on image {}: {}'.format(im_file, str(e)))
543
696
  continue
544
697
 
545
698
  try:
@@ -560,7 +713,8 @@ def load_and_run_detector(model_file,
560
713
  label_map=DEFAULT_DETECTOR_LABEL_MAP,
561
714
  confidence_threshold=render_confidence_threshold,
562
715
  thickness=box_thickness, expansion=box_expansion,
563
- label_font_size=label_font_size)
716
+ label_font_size=label_font_size,
717
+ box_sort_order='confidence')
564
718
  output_full_path = input_file_to_detection_file(im_file)
565
719
  image.save(output_full_path)
566
720
 
@@ -576,8 +730,8 @@ def load_and_run_detector(model_file,
576
730
  std_dev_time_load = humanfriendly.format_timespan(statistics.stdev(time_load))
577
731
  std_dev_time_infer = humanfriendly.format_timespan(statistics.stdev(time_infer))
578
732
  else:
579
- std_dev_time_load = 'not available'
580
- std_dev_time_infer = 'not available'
733
+ std_dev_time_load = 'not available (<=1 image processed)'
734
+ std_dev_time_infer = 'not available (<=1 image processed)'
581
735
  print('On average, for each image,')
582
736
  print('- loading took {}, std dev is {}'.format(humanfriendly.format_timespan(ave_time_load),
583
737
  std_dev_time_load))
@@ -587,12 +741,13 @@ def load_and_run_detector(model_file,
587
741
  # ...def load_and_run_detector()
588
742
 
589
743
 
590
- def download_model(model_name,force_download=False):
744
+ def _download_model(model_name,force_download=False):
591
745
  """
592
746
  Downloads one of the known models to local temp space if it hasn't already been downloaded.
593
747
 
594
748
  Args:
595
- model_name (str): a known model string, e.g. "MDV5A"
749
+ model_name (str): a known model string, e.g. "MDV5A". Returns None if this string is not
750
+ a known model name.
596
751
  force_download (bool, optional): whether to download the model even if the local target
597
752
  file already exists
598
753
  """
@@ -609,10 +764,10 @@ def download_model(model_name,force_download=False):
609
764
  os.chmod(model_tempdir,0o777)
610
765
  except Exception:
611
766
  pass
612
- if model_name.upper() not in downloadable_models:
767
+ if model_name.lower() not in known_models:
613
768
  print('Unrecognized downloadable model {}'.format(model_name))
614
769
  return None
615
- url = downloadable_models[model_name.upper()]
770
+ url = known_models[model_name.lower()]['url']
616
771
  destination_filename = os.path.join(model_tempdir,url.split('/')[-1])
617
772
  local_file = download_url(url, destination_filename=destination_filename, progress_updater=None,
618
773
  force_download=force_download, verbose=True)
@@ -620,7 +775,7 @@ def download_model(model_name,force_download=False):
620
775
  return local_file
621
776
 
622
777
 
623
- def try_download_known_detector(detector_file,force_download=False):
778
+ def try_download_known_detector(detector_file,force_download=False,verbose=False):
624
779
  """
625
780
  Checks whether detector_file is really the name of a known model, in which case we will
626
781
  either read the actual filename from the corresponding environment variable or download
@@ -631,22 +786,37 @@ def try_download_known_detector(detector_file,force_download=False):
631
786
  case this function is a no-op)
632
787
  force_download (bool, optional): whether to download the model even if the local target
633
788
  file already exists
789
+ verbose (bool, optional): enable additional debug output
634
790
 
635
791
  Returns:
636
792
  str: the local filename to which the model was downloaded, or the same string that
637
793
  was passed in, if it's not recognized as a well-known model name
638
794
  """
639
795
 
640
- if detector_file.upper() in downloadable_models:
796
+ model_string = detector_file.lower()
797
+
798
+ # If this is a short model string (e.g. "MDV5A"), convert to a canonical version
799
+ # string (e.g. "v5a.0.0")
800
+ if model_string in model_string_to_model_version:
801
+
802
+ if verbose:
803
+ print('Converting short string {} to canonical version string {}'.format(
804
+ model_string,
805
+ model_string_to_model_version[model_string]))
806
+ model_string = model_string_to_model_version[model_string]
807
+
808
+ if model_string in known_models:
809
+
641
810
  if detector_file in os.environ:
642
811
  fn = os.environ[detector_file]
643
812
  print('Reading MD location from environment variable {}: {}'.format(
644
813
  detector_file,fn))
645
814
  detector_file = fn
646
815
  else:
647
- print('Downloading model {}'.format(detector_file))
648
- detector_file = download_model(detector_file,force_download=force_download)
816
+ detector_file = _download_model(model_string,force_download=force_download)
817
+
649
818
  return detector_file
819
+
650
820
 
651
821
 
652
822
 
@@ -745,13 +915,21 @@ def main():
745
915
  action='store_true',
746
916
  help=('If a named model (e.g. "MDV5A") is supplied, force a download of that model even if the ' +\
747
917
  'local file already exists.'))
918
+
919
+ parser.add_argument(
920
+ '--detector_options',
921
+ nargs='*',
922
+ metavar='KEY=VALUE',
923
+ default='',
924
+ help='Detector-specific options, as a space-separated list of key-value pairs')
748
925
 
749
926
  if len(sys.argv[1:]) == 0:
750
927
  parser.print_help()
751
928
  parser.exit()
752
929
 
753
930
  args = parser.parse_args()
754
-
931
+ detector_options = parse_kvp_list(args.detector_options)
932
+
755
933
  # If the specified detector file is really the name of a known model, find
756
934
  # (and possibly download) that model
757
935
  args.detector_file = try_download_known_detector(args.detector_file,
@@ -786,7 +964,7 @@ def main():
786
964
  else:
787
965
  # but for a single image, args.image_dir is also None
788
966
  args.output_dir = os.path.dirname(args.image_file)
789
-
967
+
790
968
  load_and_run_detector(model_file=args.detector_file,
791
969
  image_file_names=image_file_names,
792
970
  output_dir=args.output_dir,
@@ -797,18 +975,34 @@ def main():
797
975
  image_size=args.image_size,
798
976
  label_font_size=args.label_font_size,
799
977
  augment=args.augment,
800
- # Don't download the model *again*
801
- force_model_download=False)
978
+ # If --force_model_download was specified, we already handled it
979
+ force_model_download=False,
980
+ detector_options=detector_options)
802
981
 
803
982
  if __name__ == '__main__':
804
983
  main()
805
984
 
806
985
 
807
- #%% Interactive driver
986
+ #%% Interactive driver(s)
808
987
 
809
988
  if False:
810
989
 
811
- #%%
990
+ pass
991
+
992
+ #%% Test model download
993
+
994
+ r"""
995
+ cd i:\models\all_models_in_the_wild
996
+ i:
997
+ python -m http.server 8181
998
+ """
999
+
1000
+ model_name = 'redwood'
1001
+ try_download_known_detector(model_name,force_download=True,verbose=True)
1002
+
1003
+
1004
+ #%% Load and run detector
1005
+
812
1006
  model_file = r'c:\temp\models\md_v4.1.0.pb'
813
1007
  image_file_names = path_utils.find_images(r'c:\temp\demo_images\ssverymini')
814
1008
  output_dir = r'c:\temp\demo_images\ssverymini'