megadetector 5.0.23__py3-none-any.whl → 5.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (42) hide show
  1. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +2 -3
  2. megadetector/classification/merge_classification_detection_output.py +2 -2
  3. megadetector/data_management/coco_to_labelme.py +2 -1
  4. megadetector/data_management/databases/integrity_check_json_db.py +15 -14
  5. megadetector/data_management/databases/subset_json_db.py +49 -21
  6. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +73 -69
  7. megadetector/data_management/lila/add_locations_to_nacti.py +114 -110
  8. megadetector/data_management/mewc_to_md.py +340 -0
  9. megadetector/data_management/speciesnet_to_md.py +41 -0
  10. megadetector/data_management/yolo_output_to_md_output.py +15 -8
  11. megadetector/detection/process_video.py +24 -7
  12. megadetector/detection/pytorch_detector.py +841 -160
  13. megadetector/detection/run_detector.py +341 -146
  14. megadetector/detection/run_detector_batch.py +307 -70
  15. megadetector/detection/run_inference_with_yolov5_val.py +61 -4
  16. megadetector/detection/tf_detector.py +6 -1
  17. megadetector/postprocessing/{combine_api_outputs.py → combine_batch_outputs.py} +10 -13
  18. megadetector/postprocessing/compare_batch_results.py +236 -7
  19. megadetector/postprocessing/create_crop_folder.py +358 -0
  20. megadetector/postprocessing/md_to_labelme.py +7 -7
  21. megadetector/postprocessing/md_to_wi.py +40 -0
  22. megadetector/postprocessing/merge_detections.py +1 -1
  23. megadetector/postprocessing/postprocess_batch_results.py +12 -5
  24. megadetector/postprocessing/separate_detections_into_folders.py +32 -4
  25. megadetector/postprocessing/validate_batch_results.py +9 -4
  26. megadetector/utils/ct_utils.py +236 -45
  27. megadetector/utils/directory_listing.py +3 -3
  28. megadetector/utils/gpu_test.py +125 -0
  29. megadetector/utils/md_tests.py +455 -116
  30. megadetector/utils/path_utils.py +43 -2
  31. megadetector/utils/wi_utils.py +2691 -0
  32. megadetector/visualization/visualization_utils.py +95 -18
  33. megadetector/visualization/visualize_db.py +25 -7
  34. megadetector/visualization/visualize_detector_output.py +60 -13
  35. {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/METADATA +11 -23
  36. {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/RECORD +39 -36
  37. {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/WHEEL +1 -1
  38. megadetector/detection/detector_training/__init__.py +0 -0
  39. megadetector/detection/detector_training/model_main_tf2.py +0 -114
  40. megadetector/utils/torch_test.py +0 -32
  41. {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/LICENSE +0 -0
  42. {megadetector-5.0.23.dist-info → megadetector-5.0.25.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ from tqdm import tqdm
38
38
  from megadetector.utils import path_utils as path_utils
39
39
  from megadetector.visualization import visualization_utils as vis_utils
40
40
  from megadetector.utils.url_utils import download_url
41
+ from megadetector.utils.ct_utils import parse_kvp_list
41
42
 
42
43
  # ignoring all "PIL cannot read EXIF metainfo for the images" warnings
43
44
  warnings.filterwarnings('ignore', '(Possibly )?corrupt EXIF data', UserWarning)
@@ -48,13 +49,9 @@ warnings.filterwarnings('ignore', 'Metadata warning', UserWarning)
48
49
  # Numpy FutureWarnings from tensorflow import
49
50
  warnings.filterwarnings('ignore', category=FutureWarning)
50
51
 
51
- # Useful hack to force CPU inference
52
- # os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
53
-
54
-
55
- # An enumeration of failure reasons
56
- FAILURE_INFER = 'Failure inference'
57
- FAILURE_IMAGE_OPEN = 'Failure image access'
52
+ # String constants used for consistent reporting of processing errors
53
+ FAILURE_INFER = 'inference failure'
54
+ FAILURE_IMAGE_OPEN = 'image access failure'
58
55
 
59
56
  # Number of decimal places to round to for confidence and bbox coordinates
60
57
  CONF_DIGITS = 3
@@ -69,6 +66,9 @@ DEFAULT_DETECTOR_LABEL_MAP = {
69
66
 
70
67
  # Should we allow classes that don't look anything like the MegaDetector classes?
71
68
  #
69
+ # This flag needs to get set if you want to, for example, run an off-the-shelf
70
+ # YOLO model with this package.
71
+ #
72
72
  # By default, we error if we see unfamiliar classes.
73
73
  #
74
74
  # TODO: the use of a global variable to manage this was fine when this was really
@@ -76,33 +76,102 @@ DEFAULT_DETECTOR_LABEL_MAP = {
76
76
  # models other than MegaDetector.
77
77
  USE_MODEL_NATIVE_CLASSES = False
78
78
 
79
- # Each version of the detector is associated with some "typical" values
80
- # that are included in output files, so that downstream applications can
81
- # use them as defaults.
82
- DETECTOR_METADATA = {
79
+ # Maps a variety of strings that might occur in filenames to canonical version numbers.
80
+ #
81
+ # Order matters here.
82
+ model_string_to_model_version = {
83
+ 'mdv2':'v2.0.0',
84
+ 'mdv3':'v3.0.0',
85
+ 'mdv4':'v4.1.0',
86
+ 'mdv5a':'v5a.0.0',
87
+ 'mdv5b':'v5b.0.0',
88
+ 'v2':'v2.0.0',
89
+ 'v3':'v3.0.0',
90
+ 'v4':'v4.1.0',
91
+ 'v4.1':'v4.1.0',
92
+ 'v5a.0.0':'v5a.0.0',
93
+ 'v5b.0.0':'v5b.0.0',
94
+ 'redwood':'v1000.0.0-redwood',
95
+ 'spruce':'v1000.0.0-spruce',
96
+ 'cedar':'v1000.0.0-cedar',
97
+ 'larch':'v1000.0.0-larch',
98
+ 'default':'v5a.0.0',
99
+ 'default-model':'v5a.0.0',
100
+ 'megadetector':'v5a.0.0'
101
+ }
102
+
103
+ model_url_base = 'http://localhost:8181/'
104
+ assert model_url_base.endswith('/')
105
+
106
+ # Maps canonical model version numbers to metadata
107
+ known_models = {
83
108
  'v2.0.0':
84
- {'megadetector_version':'v2.0.0',
85
- 'typical_detection_threshold':0.8,
86
- 'conservative_detection_threshold':0.3},
109
+ {
110
+ 'url':'https://lila.science/public/models/megadetector/megadetector_v2.pb',
111
+ 'typical_detection_threshold':0.8,
112
+ 'conservative_detection_threshold':0.3,
113
+ 'model_type':'tf',
114
+ 'normalized_typical_inference_speed':1.0/3.5
115
+ },
87
116
  'v3.0.0':
88
- {'megadetector_version':'v3.0.0',
89
- 'typical_detection_threshold':0.8,
90
- 'conservative_detection_threshold':0.3},
117
+ {
118
+ 'url':'https://lila.science/public/models/megadetector/megadetector_v3.pb',
119
+ 'typical_detection_threshold':0.8,
120
+ 'conservative_detection_threshold':0.3,
121
+ 'model_type':'tf',
122
+ 'normalized_typical_inference_speed':1.0/3.5
123
+ },
91
124
  'v4.1.0':
92
- {'megadetector_version':'v4.1.0',
93
- 'typical_detection_threshold':0.8,
94
- 'conservative_detection_threshold':0.3},
125
+ {
126
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v4.1/md_v4.1.0.pb',
127
+ 'typical_detection_threshold':0.8,
128
+ 'conservative_detection_threshold':0.3,
129
+ 'model_type':'tf',
130
+ 'normalized_typical_inference_speed':1.0/3.5
131
+ },
95
132
  'v5a.0.0':
96
- {'megadetector_version':'v5a.0.0',
97
- 'typical_detection_threshold':0.2,
98
- 'conservative_detection_threshold':0.05},
133
+ {
134
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5a.0.0.pt',
135
+ 'typical_detection_threshold':0.2,
136
+ 'conservative_detection_threshold':0.05,
137
+ 'image_size':1280,
138
+ 'model_type':'yolov5',
139
+ 'normalized_typical_inference_speed':1.0
140
+ },
99
141
  'v5b.0.0':
100
- {'megadetector_version':'v5b.0.0',
101
- 'typical_detection_threshold':0.2,
102
- 'conservative_detection_threshold':0.05}
142
+ {
143
+ 'url':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.0.pt',
144
+ 'typical_detection_threshold':0.2,
145
+ 'conservative_detection_threshold':0.05,
146
+ 'image_size':1280,
147
+ 'model_type':'yolov5',
148
+ 'normalized_typical_inference_speed':1.0
149
+ },
150
+
151
+ # Fake values for testing
152
+ 'v1000.0.0-redwood':
153
+ {
154
+ 'normalized_typical_inference_speed':2.0,
155
+ 'url':model_url_base + 'md_v1000.0.0-redwood.pt'
156
+ },
157
+ 'v1000.0.0-spruce':
158
+ {
159
+ 'normalized_typical_inference_speed':3.0,
160
+ 'url':model_url_base + 'md_v1000.0.0-spruce.pt'
161
+ },
162
+ 'v1000.0.0-larch':
163
+ {
164
+ 'normalized_typical_inference_speed':4.0,
165
+ 'url':model_url_base + 'md_v1000.0.0-larch.pt'
166
+ },
167
+ 'v1000.0.0-cedar':
168
+ {
169
+ 'normalized_typical_inference_speed':5.0,
170
+ 'url':model_url_base + 'md_v1000.0.0-cedar.pt'
171
+ }
103
172
  }
104
173
 
105
- DEFAULT_RENDERING_CONFIDENCE_THRESHOLD = DETECTOR_METADATA['v5b.0.0']['typical_detection_threshold']
174
+ DEFAULT_RENDERING_CONFIDENCE_THRESHOLD = known_models['v5a.0.0']['typical_detection_threshold']
106
175
  DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD = 0.005
107
176
 
108
177
  DEFAULT_BOX_THICKNESS = 4
@@ -110,29 +179,6 @@ DEFAULT_BOX_EXPANSION = 0
110
179
  DEFAULT_LABEL_FONT_SIZE = 16
111
180
  DETECTION_FILENAME_INSERT = '_detections'
112
181
 
113
- # The model filenames "MDV5A", "MDV5B", and "MDV4" are special; they will trigger an
114
- # automatic model download to the system temp folder, or they will use the paths specified in the
115
- # $MDV4, $MDV5A, or $MDV5B environment variables if they exist.
116
- downloadable_models = {
117
- 'MDV2':'https://lila.science/public/models/megadetector/megadetector_v2.pb',
118
- 'MDV3':'https://lila.science/public/models/megadetector/megadetector_v3.pb',
119
- 'MDV4':'https://github.com/agentmorris/MegaDetector/releases/download/v4.1/md_v4.1.0.pb',
120
- 'MDV5A':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5a.0.0.pt',
121
- 'MDV5B':'https://github.com/agentmorris/MegaDetector/releases/download/v5.0/md_v5b.0.0.pt'
122
- }
123
-
124
- model_string_to_model_version = {
125
- 'v2':'v2.0.0',
126
- 'v3':'v3.0.0',
127
- 'v4.1':'v4.1.0',
128
- 'v5a.0.0':'v5a.0.0',
129
- 'v5b.0.0':'v5b.0.0',
130
- 'mdv5a':'v5a.0.0',
131
- 'mdv5b':'v5b.0.0',
132
- 'mdv4':'v4.1.0',
133
- 'mdv3':'v3.0.0'
134
- }
135
-
136
182
  # Approximate inference speeds (in images per second) for MDv5 based on
137
183
  # benchmarks, only used for reporting very coarse expectations about inference time.
138
184
  device_token_to_mdv5_inference_speed = {
@@ -145,35 +191,12 @@ device_token_to_mdv5_inference_speed = {
145
191
  # is around 3.5x faster than MDv4.
146
192
  'V100':2.79*3.5,
147
193
  '2080':2.3*3.5,
148
- '2060':1.6*3.5
194
+ '2060':1.6*3.5
149
195
  }
150
196
 
151
197
 
152
198
  #%% Utility functions
153
199
 
154
- def convert_to_tf_coords(array):
155
- """
156
- Converts a bounding box from [x1, y1, width, height] to [y1, x1, y2, x2]. This
157
- is mostly not helpful, this function only exists to maintain backwards compatibility
158
- in the synchronous API, which possibly zero people in the world are using.
159
-
160
- Args:
161
- array (list): a bounding box in [x,y,w,h] format
162
-
163
- Returns:
164
- list: a bounding box in [y1,x1,y2,x2] format
165
- """
166
-
167
- x1 = array[0]
168
- y1 = array[1]
169
- width = array[2]
170
- height = array[3]
171
- x2 = x1 + width
172
- y2 = y1 + height
173
-
174
- return [y1, x1, y2, x2]
175
-
176
-
177
200
  def get_detector_metadata_from_version_string(detector_version):
178
201
  """
179
202
  Given a MegaDetector version string (e.g. "v4.1.0"), returns the metadata for
@@ -187,7 +210,7 @@ def get_detector_metadata_from_version_string(detector_version):
187
210
  dict: metadata for this model, suitable for writing to a MD output file
188
211
  """
189
212
 
190
- if detector_version not in DETECTOR_METADATA:
213
+ if detector_version not in known_models:
191
214
  print('Warning: no metadata for unknown detector version {}'.format(detector_version))
192
215
  default_detector_metadata = {
193
216
  'megadetector_version':'unknown',
@@ -196,12 +219,16 @@ def get_detector_metadata_from_version_string(detector_version):
196
219
  }
197
220
  return default_detector_metadata
198
221
  else:
199
- return DETECTOR_METADATA[detector_version]
222
+ to_return = known_models[detector_version]
223
+ to_return['megadetector_version'] = detector_version
224
+ return to_return
200
225
 
201
226
 
202
- def get_detector_version_from_filename(detector_filename,accept_first_match=True):
227
+ def get_detector_version_from_filename(detector_filename,
228
+ accept_first_match=True,
229
+ verbose=False):
203
230
  r"""
204
- Gets the version number component of the detector from the model filename.
231
+ Gets the canonical version number string of a detector from the model filename.
205
232
 
206
233
  [detector_filename] will almost always end with one of the following:
207
234
 
@@ -213,12 +240,14 @@ def get_detector_version_from_filename(detector_filename,accept_first_match=True
213
240
  * md_v5b.0.0.pt
214
241
 
215
242
  This function identifies the version number as "v2.0.0", "v3.0.0", "v4.1.0",
216
- "v4.1.0", "v5a.0.0", and "v5b.0.0", respectively.
243
+ "v4.1.0", "v5a.0.0", and "v5b.0.0", respectively. See known_models for the list
244
+ of valid version numbers.
217
245
 
218
246
  Args:
219
247
  detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
220
248
  accept_first_match (bool, optional): if multiple candidates match the filename, choose the
221
249
  first one, otherwise returns the string "multiple"
250
+ verbose (bool, optional): enable additional debug output
222
251
 
223
252
  Returns:
224
253
  str: a detector version string, e.g. "v5a.0.0", or "multiple" if I'm confused
@@ -230,24 +259,114 @@ def get_detector_version_from_filename(detector_filename,accept_first_match=True
230
259
  if s in fn:
231
260
  matches.append(s)
232
261
  if len(matches) == 0:
233
- print('Warning: could not determine MegaDetector version for model file {}'.format(detector_filename))
234
262
  return 'unknown'
235
263
  elif len(matches) > 1:
236
- print('Warning: multiple MegaDetector versions for model file {}'.format(detector_filename))
237
264
  if accept_first_match:
238
265
  return model_string_to_model_version[matches[0]]
239
266
  else:
267
+ if verbose:
268
+ print('Warning: multiple MegaDetector versions for model file {}:'.format(detector_filename))
269
+ for s in matches:
270
+ print(s)
240
271
  return 'multiple'
241
272
  else:
242
273
  return model_string_to_model_version[matches[0]]
243
274
 
244
275
 
276
+ def get_detector_version_from_model_file(detector_filename,verbose=False):
277
+ """
278
+ Gets the canonical detection version from a model file, preferably by reading it
279
+ from the file itself, otherwise based on the filename.
280
+
281
+ Args:
282
+ detector_filename (str): model filename, e.g. c:/x/z/md_v5a.0.0.pt
283
+ verbose (bool, optional): enable additional debug output
284
+
285
+ Returns:
286
+ str: a canonical detector version string, e.g. "v5a.0.0", or "unknown"
287
+ """
288
+
289
+ # Try to extract a version string from the filename
290
+ version_string_based_on_filename = get_detector_version_from_filename(
291
+ detector_filename, verbose=verbose)
292
+ if version_string_based_on_filename == 'unknown':
293
+ version_string_based_on_filename = None
294
+
295
+ # Try to extract a version string from the file itself; currently this is only
296
+ # a thing for PyTorch models
297
+
298
+ version_string_based_on_model_file = None
299
+
300
+ if detector_filename.endswith('.pt') or detector_filename.endswith('.zip'):
301
+
302
+ from megadetector.detection.pytorch_detector import \
303
+ read_metadata_from_megadetector_model_file
304
+ metadata = read_metadata_from_megadetector_model_file(detector_filename,verbose=verbose)
305
+
306
+ if metadata is not None and isinstance(metadata,dict):
307
+
308
+ if 'metadata_format_version' not in metadata or \
309
+ not isinstance(metadata['metadata_format_version'],float):
310
+
311
+ print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
312
+ 'but it doesn\'t have a valid format version number')
313
+
314
+ elif 'model_version_string' not in metadata or \
315
+ not isinstance(metadata['model_version_string'],str):
316
+
317
+ print(f'Warning: I found a metadata file in detector file {detector_filename}, '+\
318
+ 'but it doesn\'t have a format model version string')
319
+
320
+ else:
321
+
322
+ version_string_based_on_model_file = metadata['model_version_string']
323
+
324
+ if version_string_based_on_model_file not in known_models:
325
+ print('Warning: unknown model version {} specified in file {}'.format(
326
+ version_string_based_on_model_file,detector_filename))
327
+
328
+ # ...if there's metadata in this file
329
+
330
+ # ...if this looks like a PyTorch file
331
+
332
+ # If we got versions strings from the filename *and* the model file...
333
+ if (version_string_based_on_filename is not None) and \
334
+ (version_string_based_on_model_file is not None):
335
+
336
+ if version_string_based_on_filename != version_string_based_on_model_file:
337
+ print('Warning: model version string in file {} is {}, but the filename implies {}'.format(
338
+ detector_filename,
339
+ version_string_based_on_model_file,
340
+ version_string_based_on_filename))
341
+
342
+ return version_string_based_on_model_file
343
+
344
+ # If we got version string from neither the filename nor the model file...
345
+ if (version_string_based_on_filename is None) and \
346
+ (version_string_based_on_model_file is None):
347
+
348
+ print('Warning: could not determine model version string for model file {}'.format(
349
+ detector_filename))
350
+ return None
351
+
352
+ elif version_string_based_on_filename is not None:
353
+
354
+ return version_string_based_on_filename
355
+
356
+ else:
357
+
358
+ assert version_string_based_on_model_file is not None
359
+ return version_string_based_on_model_file
360
+
361
+ # ...def get_detector_version_from_model_file(...)
362
+
363
+
245
364
  def estimate_md_images_per_second(model_file, device_name=None):
246
365
  r"""
247
- Estimates how fast MegaDetector will run, based on benchmarks. Defaults to querying
248
- the current device. Returns None if no data is available for the current card/model.
249
- Estimates only available for a small handful of GPUs. Uses an absurdly simple lookup
250
- approach, e.g. if the string "4090" appears in the device name, congratulations,
366
+ Estimates how fast MegaDetector will run on a particular device, based on benchmarks.
367
+ Defaults to querying the current device. Returns None if no data is available for the current
368
+ card/model. Estimates only available for a small handful of GPUs. Uses an absurdly simple
369
+ lookup approach, e.g. if the string "4090" appears in the device name, congratulations,
251
370
  you have an RTX 4090.
252
371
 
253
372
  Args:
@@ -267,15 +386,24 @@ def estimate_md_images_per_second(model_file, device_name=None):
267
386
  print('Error querying device name: {}'.format(e))
268
387
  return None
269
388
 
270
- model_file = model_file.lower().strip()
271
- if model_file in model_string_to_model_version.values():
272
- model_version = model_file
273
- else:
274
- model_version = get_detector_version_from_filename(model_file)
275
- if model_version not in model_string_to_model_version.values():
276
- print('Error determining model version for model file {}'.format(model_file))
277
- return None
389
+ # About how fast is this model compared to MDv5?
390
+ model_version = get_detector_version_from_model_file(model_file)
391
+
392
+ if model_version not in known_models.keys():
393
+ print('Could not estimate inference speed: error determining model version for model file {}'.format(
394
+ model_file))
395
+ return None
396
+
397
+ model_info = known_models[model_version]
398
+
399
+ if 'normalized_typical_inference_speed' not in model_info or \
400
+ model_info['normalized_typical_inference_speed'] is None:
401
+ print('No speed ratio available for model type {}'.format(model_version))
402
+ return None
403
+
404
+ normalized_inference_speed = model_info['normalized_typical_inference_speed']
278
405
 
406
+ # About how fast would MDv5 run on this device?
279
407
  mdv5_inference_speed = None
280
408
  for device_token in device_token_to_mdv5_inference_speed.keys():
281
409
  if device_token in device_name:
@@ -283,16 +411,11 @@ def estimate_md_images_per_second(model_file, device_name=None):
283
411
  break
284
412
 
285
413
  if mdv5_inference_speed is None:
286
- print('No speed estimate available for {}'.format(device_name))
287
-
288
- if 'v5' in model_version:
289
- return mdv5_inference_speed
290
- elif 'v2' in model_version or 'v3' in model_version or 'v4' in model_version:
291
- return mdv5_inference_speed / 3.5
292
- else:
293
- print('Could not estimate inference speed for model file {}'.format(model_file))
414
+ print('No baseline speed estimate available for device {}'.format(device_name))
294
415
  return None
295
416
 
417
+ return normalized_inference_speed * mdv5_inference_speed
418
+
296
419
 
297
420
  def get_typical_confidence_threshold_from_results(results):
298
421
  """
@@ -342,25 +465,28 @@ def is_gpu_available(model_file):
342
465
  print('TensorFlow version:', tf.__version__)
343
466
  print('tf.test.is_gpu_available:', gpu_available)
344
467
  return gpu_available
345
- elif model_file.endswith('.pt'):
346
- import torch
347
- gpu_available = torch.cuda.is_available()
348
- print('PyTorch reports {} available CUDA devices'.format(torch.cuda.device_count()))
349
- if not gpu_available:
350
- try:
351
- # mps backend only available in torch >= 1.12.0
352
- if torch.backends.mps.is_built and torch.backends.mps.is_available():
353
- gpu_available = True
354
- print('PyTorch reports Metal Performance Shaders are available')
355
- except AttributeError:
356
- pass
357
- return gpu_available
358
- else:
359
- raise ValueError('Model {} does not have a recognized extension and is not a known model name'.\
360
- format(model_file))
361
-
362
-
363
- def load_detector(model_file, force_cpu=False, force_model_download=False):
468
+ if not model_file.endswith('.pt'):
469
+ print('Warning: could not determine environment from model file name, assuming PyTorch')
470
+
471
+ import torch
472
+ gpu_available = torch.cuda.is_available()
473
+ print('PyTorch reports {} available CUDA devices'.format(torch.cuda.device_count()))
474
+ if not gpu_available:
475
+ try:
476
+ # mps backend only available in torch >= 1.12.0
477
+ if torch.backends.mps.is_built and torch.backends.mps.is_available():
478
+ gpu_available = True
479
+ print('PyTorch reports Metal Performance Shaders are available')
480
+ except AttributeError:
481
+ pass
482
+ return gpu_available
483
+
484
+
485
+ def load_detector(model_file,
486
+ force_cpu=False,
487
+ force_model_download=False,
488
+ detector_options=None,
489
+ verbose=False):
364
490
  r"""
365
491
  Loads a TF or PT detector, depending on the extension of model_file.
366
492
 
@@ -372,6 +498,9 @@ def load_detector(model_file, force_cpu=False, force_model_download=False):
372
498
  force_model_download (bool, optional): force downloading the model file if
373
499
  a named model (e.g. "MDV5A") is supplied, even if the local file already
374
500
  exists
501
+ detector_options (dict, optional): key/value pairs that are interpreted differently
502
+ by different detectors
503
+ verbose (bool, optional): enable additional debug output
375
504
 
376
505
  Returns:
377
506
  object: loaded detector object
@@ -381,25 +510,48 @@ def load_detector(model_file, force_cpu=False, force_model_download=False):
381
510
  model_file = try_download_known_detector(model_file,
382
511
  force_download=force_model_download)
383
512
 
384
- print('GPU available: {}'.format(is_gpu_available(model_file)))
513
+ if verbose:
514
+ print('GPU available: {}'.format(is_gpu_available(model_file)))
385
515
 
386
516
  start_time = time.time()
517
+
387
518
  if model_file.endswith('.pb'):
519
+
388
520
  from megadetector.detection.tf_detector import TFDetector
389
521
  if force_cpu:
390
522
  raise ValueError('force_cpu is not currently supported for TF detectors, ' + \
391
523
  'use CUDA_VISIBLE_DEVICES=-1 instead')
392
- detector = TFDetector(model_file)
524
+ detector = TFDetector(model_file, detector_options)
525
+
393
526
  elif model_file.endswith('.pt'):
527
+
394
528
  from megadetector.detection.pytorch_detector import PTDetector
395
- detector = PTDetector(model_file, force_cpu, USE_MODEL_NATIVE_CLASSES)
529
+
530
+ # Prepare options specific to the PTDetector class
531
+ if detector_options is None:
532
+ detector_options = {}
533
+ if 'force_cpu' in detector_options:
534
+ if force_cpu != detector_options['force_cpu']:
535
+ print('Warning: over-riding force_cpu parameter ({}) based on detector_options ({})'.format(
536
+ force_cpu,detector_options['force_cpu']))
537
+ else:
538
+ detector_options['force_cpu'] = force_cpu
539
+ detector_options['use_model_native_classes'] = USE_MODEL_NATIVE_CLASSES
540
+ detector = PTDetector(model_file, detector_options, verbose=verbose)
541
+
396
542
  else:
543
+
397
544
  raise ValueError('Unrecognized model format: {}'.format(model_file))
545
+
398
546
  elapsed = time.time() - start_time
399
- print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
547
+
548
+ if verbose:
549
+ print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
400
550
 
401
551
  return detector
402
552
 
553
+ # ...def load_detector(...)
554
+
403
555
 
404
556
  #%% Main function
405
557
 
@@ -413,8 +565,8 @@ def load_and_run_detector(model_file,
413
565
  image_size=None,
414
566
  label_font_size=DEFAULT_LABEL_FONT_SIZE,
415
567
  augment=False,
416
- force_model_download=False
417
- ):
568
+ force_model_download=False,
569
+ detector_options=None):
418
570
  r"""
419
571
  Loads and runs a detector on target images, and visualizes the results.
420
572
 
@@ -438,6 +590,8 @@ def load_and_run_detector(model_file,
438
590
  force_model_download (bool, optional): force downloading the model file if
439
591
  a named model (e.g. "MDV5A") is supplied, even if the local file already
440
592
  exists
593
+ detector_options (dict, optional): key/value pairs that are interpreted differently
594
+ by different detectors
441
595
  """
442
596
 
443
597
  if len(image_file_names) == 0:
@@ -447,7 +601,7 @@ def load_and_run_detector(model_file,
447
601
  # Possibly automatically download the model
448
602
  model_file = try_download_known_detector(model_file, force_download=force_model_download)
449
603
 
450
- detector = load_detector(model_file)
604
+ detector = load_detector(model_file, detector_options=detector_options)
451
605
 
452
606
  detection_results = []
453
607
  time_load = []
@@ -516,7 +670,7 @@ def load_and_run_detector(model_file,
516
670
  time_load.append(elapsed)
517
671
 
518
672
  except Exception as e:
519
- print('Image {} cannot be loaded. Exception: {}'.format(im_file, e))
673
+ print('Image {} cannot be loaded, error: {}'.format(im_file, str(e)))
520
674
  result = {
521
675
  'file': im_file,
522
676
  'failure': FAILURE_IMAGE_OPEN
@@ -539,7 +693,7 @@ def load_and_run_detector(model_file,
539
693
  time_infer.append(elapsed)
540
694
 
541
695
  except Exception as e:
542
- print('An error occurred while running the detector on image {}. Exception: {}'.format(im_file, e))
696
+ print('An error occurred while running the detector on image {}: {}'.format(im_file, str(e)))
543
697
  continue
544
698
 
545
699
  try:
@@ -560,7 +714,8 @@ def load_and_run_detector(model_file,
560
714
  label_map=DEFAULT_DETECTOR_LABEL_MAP,
561
715
  confidence_threshold=render_confidence_threshold,
562
716
  thickness=box_thickness, expansion=box_expansion,
563
- label_font_size=label_font_size)
717
+ label_font_size=label_font_size,
718
+ box_sort_order='confidence')
564
719
  output_full_path = input_file_to_detection_file(im_file)
565
720
  image.save(output_full_path)
566
721
 
@@ -576,8 +731,8 @@ def load_and_run_detector(model_file,
576
731
  std_dev_time_load = humanfriendly.format_timespan(statistics.stdev(time_load))
577
732
  std_dev_time_infer = humanfriendly.format_timespan(statistics.stdev(time_infer))
578
733
  else:
579
- std_dev_time_load = 'not available'
580
- std_dev_time_infer = 'not available'
734
+ std_dev_time_load = 'not available (<=1 image processed)'
735
+ std_dev_time_infer = 'not available (<=1 image processed)'
581
736
  print('On average, for each image,')
582
737
  print('- loading took {}, std dev is {}'.format(humanfriendly.format_timespan(ave_time_load),
583
738
  std_dev_time_load))
@@ -587,12 +742,13 @@ def load_and_run_detector(model_file,
587
742
  # ...def load_and_run_detector()
588
743
 
589
744
 
590
- def download_model(model_name,force_download=False):
745
+ def _download_model(model_name,force_download=False):
591
746
  """
592
747
  Downloads one of the known models to local temp space if it hasn't already been downloaded.
593
748
 
594
749
  Args:
595
- model_name (str): a known model string, e.g. "MDV5A"
750
+ model_name (str): a known model string, e.g. "MDV5A". Returns None if this string is not
751
+ a known model name.
596
752
  force_download (bool, optional): whether to download the model even if the local target
597
753
  file already exists
598
754
  """
@@ -609,10 +765,10 @@ def download_model(model_name,force_download=False):
609
765
  os.chmod(model_tempdir,0o777)
610
766
  except Exception:
611
767
  pass
612
- if model_name.upper() not in downloadable_models:
768
+ if model_name.lower() not in known_models:
613
769
  print('Unrecognized downloadable model {}'.format(model_name))
614
770
  return None
615
- url = downloadable_models[model_name.upper()]
771
+ url = known_models[model_name.lower()]['url']
616
772
  destination_filename = os.path.join(model_tempdir,url.split('/')[-1])
617
773
  local_file = download_url(url, destination_filename=destination_filename, progress_updater=None,
618
774
  force_download=force_download, verbose=True)
@@ -620,7 +776,7 @@ def download_model(model_name,force_download=False):
620
776
  return local_file
621
777
 
622
778
 
623
- def try_download_known_detector(detector_file,force_download=False):
779
+ def try_download_known_detector(detector_file,force_download=False,verbose=False):
624
780
  """
625
781
  Checks whether detector_file is really the name of a known model, in which case we will
626
782
  either read the actual filename from the corresponding environment variable or download
@@ -631,22 +787,37 @@ def try_download_known_detector(detector_file,force_download=False):
631
787
  case this function is a no-op)
632
788
  force_download (bool, optional): whether to download the model even if the local target
633
789
  file already exists
790
+ verbose (bool, optional): enable additional debug output
634
791
 
635
792
  Returns:
636
793
  str: the local filename to which the model was downloaded, or the same string that
637
794
  was passed in, if it's not recognized as a well-known model name
638
795
  """
639
796
 
640
- if detector_file.upper() in downloadable_models:
797
+ model_string = detector_file.lower()
798
+
799
+ # If this is a short model string (e.g. "MDV5A"), convert to a canonical version
800
+ # string (e.g. "v5a.0.0")
801
+ if model_string in model_string_to_model_version:
802
+
803
+ if verbose:
804
+ print('Converting short string {} to canonical version string {}'.format(
805
+ model_string,
806
+ model_string_to_model_version[model_string]))
807
+ model_string = model_string_to_model_version[model_string]
808
+
809
+ if model_string in known_models:
810
+
641
811
  if detector_file in os.environ:
642
812
  fn = os.environ[detector_file]
643
813
  print('Reading MD location from environment variable {}: {}'.format(
644
814
  detector_file,fn))
645
815
  detector_file = fn
646
816
  else:
647
- print('Downloading model {}'.format(detector_file))
648
- detector_file = download_model(detector_file,force_download=force_download)
817
+ detector_file = _download_model(model_string,force_download=force_download)
818
+
649
819
  return detector_file
820
+
650
821
 
651
822
 
652
823
 
@@ -745,13 +916,21 @@ def main():
745
916
  action='store_true',
746
917
  help=('If a named model (e.g. "MDV5A") is supplied, force a download of that model even if the ' +\
747
918
  'local file already exists.'))
919
+
920
+ parser.add_argument(
921
+ '--detector_options',
922
+ nargs='*',
923
+ metavar='KEY=VALUE',
924
+ default='',
925
+ help='Detector-specific options, as a space-separated list of key-value pairs')
748
926
 
749
927
  if len(sys.argv[1:]) == 0:
750
928
  parser.print_help()
751
929
  parser.exit()
752
930
 
753
931
  args = parser.parse_args()
754
-
932
+ detector_options = parse_kvp_list(args.detector_options)
933
+
755
934
  # If the specified detector file is really the name of a known model, find
756
935
  # (and possibly download) that model
757
936
  args.detector_file = try_download_known_detector(args.detector_file,
@@ -786,7 +965,7 @@ def main():
786
965
  else:
787
966
  # but for a single image, args.image_dir is also None
788
967
  args.output_dir = os.path.dirname(args.image_file)
789
-
968
+
790
969
  load_and_run_detector(model_file=args.detector_file,
791
970
  image_file_names=image_file_names,
792
971
  output_dir=args.output_dir,
@@ -797,18 +976,34 @@ def main():
797
976
  image_size=args.image_size,
798
977
  label_font_size=args.label_font_size,
799
978
  augment=args.augment,
800
- # Don't download the model *again*
801
- force_model_download=False)
979
+ # If --force_model_download was specified, we already handled it
980
+ force_model_download=False,
981
+ detector_options=detector_options)
802
982
 
803
983
  if __name__ == '__main__':
804
984
  main()
805
985
 
806
986
 
807
- #%% Interactive driver
987
+ #%% Interactive driver(s)
808
988
 
809
989
  if False:
810
990
 
811
- #%%
991
+ pass
992
+
993
+ #%% Test model download
994
+
995
+ r"""
996
+ cd i:\models\all_models_in_the_wild
997
+ i:
998
+ python -m http.server 8181
999
+ """
1000
+
1001
+ model_name = 'redwood'
1002
+ try_download_known_detector(model_name,force_download=True,verbose=True)
1003
+
1004
+
1005
+ #%% Load and run detector
1006
+
812
1007
  model_file = r'c:\temp\models\md_v4.1.0.pb'
813
1008
  image_file_names = path_utils.find_images(r'c:\temp\demo_images\ssverymini')
814
1009
  output_dir = r'c:\temp\demo_images\ssverymini'