megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (132) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +302 -263
  2. api/batch_processing/data_preparation/manage_video_batch.py +81 -2
  3. api/batch_processing/postprocessing/add_max_conf.py +1 -0
  4. api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
  5. api/batch_processing/postprocessing/compare_batch_results.py +110 -60
  6. api/batch_processing/postprocessing/load_api_results.py +56 -70
  7. api/batch_processing/postprocessing/md_to_coco.py +1 -1
  8. api/batch_processing/postprocessing/md_to_labelme.py +2 -1
  9. api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
  11. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
  12. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  13. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
  14. api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
  15. api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
  16. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
  17. classification/prepare_classification_script.py +191 -191
  18. data_management/coco_to_yolo.py +68 -45
  19. data_management/databases/integrity_check_json_db.py +7 -5
  20. data_management/generate_crops_from_cct.py +3 -3
  21. data_management/get_image_sizes.py +8 -6
  22. data_management/importers/add_timestamps_to_icct.py +79 -0
  23. data_management/importers/animl_results_to_md_results.py +160 -0
  24. data_management/importers/auckland_doc_test_to_json.py +4 -4
  25. data_management/importers/auckland_doc_to_json.py +1 -1
  26. data_management/importers/awc_to_json.py +5 -5
  27. data_management/importers/bellevue_to_json.py +5 -5
  28. data_management/importers/carrizo_shrubfree_2018.py +5 -5
  29. data_management/importers/carrizo_trail_cam_2017.py +5 -5
  30. data_management/importers/cct_field_adjustments.py +2 -3
  31. data_management/importers/channel_islands_to_cct.py +4 -4
  32. data_management/importers/ena24_to_json.py +5 -5
  33. data_management/importers/helena_to_cct.py +10 -10
  34. data_management/importers/idaho-camera-traps.py +12 -12
  35. data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
  36. data_management/importers/jb_csv_to_json.py +4 -4
  37. data_management/importers/missouri_to_json.py +1 -1
  38. data_management/importers/noaa_seals_2019.py +1 -1
  39. data_management/importers/pc_to_json.py +5 -5
  40. data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
  41. data_management/importers/prepare_zsl_imerit.py +5 -5
  42. data_management/importers/rspb_to_json.py +4 -4
  43. data_management/importers/save_the_elephants_survey_A.py +5 -5
  44. data_management/importers/save_the_elephants_survey_B.py +6 -6
  45. data_management/importers/snapshot_safari_importer.py +9 -9
  46. data_management/importers/snapshot_serengeti_lila.py +9 -9
  47. data_management/importers/timelapse_csv_set_to_json.py +5 -7
  48. data_management/importers/ubc_to_json.py +4 -4
  49. data_management/importers/umn_to_json.py +4 -4
  50. data_management/importers/wellington_to_json.py +1 -1
  51. data_management/importers/wi_to_json.py +2 -2
  52. data_management/importers/zamba_results_to_md_results.py +181 -0
  53. data_management/labelme_to_coco.py +35 -7
  54. data_management/labelme_to_yolo.py +229 -0
  55. data_management/lila/add_locations_to_island_camera_traps.py +1 -1
  56. data_management/lila/add_locations_to_nacti.py +147 -0
  57. data_management/lila/create_lila_blank_set.py +474 -0
  58. data_management/lila/create_lila_test_set.py +2 -1
  59. data_management/lila/create_links_to_md_results_files.py +106 -0
  60. data_management/lila/download_lila_subset.py +46 -21
  61. data_management/lila/generate_lila_per_image_labels.py +23 -14
  62. data_management/lila/get_lila_annotation_counts.py +17 -11
  63. data_management/lila/lila_common.py +14 -11
  64. data_management/lila/test_lila_metadata_urls.py +116 -0
  65. data_management/ocr_tools.py +829 -0
  66. data_management/resize_coco_dataset.py +13 -11
  67. data_management/yolo_output_to_md_output.py +84 -12
  68. data_management/yolo_to_coco.py +38 -20
  69. detection/process_video.py +36 -14
  70. detection/pytorch_detector.py +23 -8
  71. detection/run_detector.py +76 -19
  72. detection/run_detector_batch.py +178 -63
  73. detection/run_inference_with_yolov5_val.py +326 -57
  74. detection/run_tiled_inference.py +153 -43
  75. detection/video_utils.py +34 -8
  76. md_utils/ct_utils.py +172 -1
  77. md_utils/md_tests.py +372 -51
  78. md_utils/path_utils.py +167 -39
  79. md_utils/process_utils.py +26 -7
  80. md_utils/split_locations_into_train_val.py +215 -0
  81. md_utils/string_utils.py +10 -0
  82. md_utils/url_utils.py +0 -2
  83. md_utils/write_html_image_list.py +9 -26
  84. md_visualization/plot_utils.py +12 -8
  85. md_visualization/visualization_utils.py +106 -7
  86. md_visualization/visualize_db.py +16 -8
  87. md_visualization/visualize_detector_output.py +208 -97
  88. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
  89. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
  90. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
  91. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
  92. taxonomy_mapping/map_new_lila_datasets.py +43 -39
  93. taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
  94. taxonomy_mapping/preview_lila_taxonomy.py +27 -27
  95. taxonomy_mapping/species_lookup.py +33 -13
  96. taxonomy_mapping/taxonomy_csv_checker.py +7 -5
  97. api/synchronous/api_core/yolov5/detect.py +0 -252
  98. api/synchronous/api_core/yolov5/export.py +0 -607
  99. api/synchronous/api_core/yolov5/hubconf.py +0 -146
  100. api/synchronous/api_core/yolov5/models/__init__.py +0 -0
  101. api/synchronous/api_core/yolov5/models/common.py +0 -738
  102. api/synchronous/api_core/yolov5/models/experimental.py +0 -104
  103. api/synchronous/api_core/yolov5/models/tf.py +0 -574
  104. api/synchronous/api_core/yolov5/models/yolo.py +0 -338
  105. api/synchronous/api_core/yolov5/train.py +0 -670
  106. api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
  107. api/synchronous/api_core/yolov5/utils/activations.py +0 -103
  108. api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
  109. api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
  110. api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
  111. api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
  112. api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
  113. api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
  114. api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
  115. api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
  116. api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
  117. api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
  118. api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
  119. api/synchronous/api_core/yolov5/utils/general.py +0 -1018
  120. api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
  121. api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
  122. api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
  123. api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
  124. api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
  125. api/synchronous/api_core/yolov5/utils/loss.py +0 -234
  126. api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
  127. api/synchronous/api_core/yolov5/utils/plots.py +0 -489
  128. api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
  129. api/synchronous/api_core/yolov5/val.py +0 -394
  130. md_utils/matlab_porting_tools.py +0 -97
  131. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
  132. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
md_utils/path_utils.py CHANGED
@@ -21,7 +21,8 @@ import zipfile
21
21
  from zipfile import ZipFile
22
22
  from datetime import datetime
23
23
  from typing import Container, Iterable, List, Optional, Tuple, Sequence
24
- from multiprocessing.pool import ThreadPool
24
+ from multiprocessing.pool import Pool, ThreadPool
25
+ from functools import partial
25
26
  from tqdm import tqdm
26
27
 
27
28
  IMG_EXTENSIONS = ('.jpg', '.jpeg', '.gif', '.png', '.tif', '.tiff', '.bmp')
@@ -34,30 +35,51 @@ CHAR_LIMIT = 255
34
35
 
35
36
  #%% General path functions
36
37
 
37
- def recursive_file_list(base_dir, convert_slashes=True, return_relative_paths=False):
38
- """
38
+ def recursive_file_list(base_dir, convert_slashes=True,
39
+ return_relative_paths=False, sort_files=True,
40
+ recursive=True):
41
+ r"""
39
42
  Enumerate files (not directories) in [base_dir], optionally converting
40
43
  \ to /
41
44
  """
42
45
 
43
46
  all_files = []
44
47
 
45
- for root, _, filenames in os.walk(base_dir):
46
- for filename in filenames:
47
- full_path = os.path.join(root, filename)
48
- if convert_slashes:
49
- full_path = full_path.replace('\\', '/')
50
- all_files.append(full_path)
51
-
48
+ if recursive:
49
+ for root, _, filenames in os.walk(base_dir):
50
+ for filename in filenames:
51
+ full_path = os.path.join(root, filename)
52
+ all_files.append(full_path)
53
+ else:
54
+ all_files_relative = os.listdir(base_dir)
55
+ all_files = [os.path.join(base_dir,fn) for fn in all_files_relative]
56
+ all_files = [fn for fn in all_files if os.path.isfile(fn)]
57
+
52
58
  if return_relative_paths:
53
59
  all_files = [os.path.relpath(fn,base_dir) for fn in all_files]
60
+
61
+ if convert_slashes:
62
+ all_files = [fn.replace('\\', '/') for fn in all_files]
63
+
64
+ if sort_files:
65
+ all_files = sorted(all_files)
54
66
 
55
- all_files = sorted(all_files)
56
67
  return all_files
57
68
 
58
69
 
59
- def split_path(path: str) -> List[str]:
70
+ def file_list(base_dir, convert_slashes=True, return_relative_paths=False, sort_files=True,
71
+ recursive=False):
72
+ """
73
+ Trivial wrapper for recursive_file_list, which was a poor function name choice at the time,
74
+ it doesn't really make sense to have a "recursive" option in a function called "recursive_file_list".
60
75
  """
76
+
77
+ return recursive_file_list(base_dir,convert_slashes,return_relative_paths,sort_files,
78
+ recursive=recursive)
79
+
80
+
81
+ def split_path(path: str) -> List[str]:
82
+ r"""
61
83
  Splits [path] into all its constituent tokens.
62
84
 
63
85
  Non-recursive version of:
@@ -87,7 +109,7 @@ def split_path(path: str) -> List[str]:
87
109
 
88
110
 
89
111
  def fileparts(path: str) -> Tuple[str, str, str]:
90
- """
112
+ r"""
91
113
  Breaks down a path into the directory path, filename, and extension.
92
114
 
93
115
  Note that the '.' lives with the extension, and separators are removed.
@@ -185,8 +207,9 @@ def safe_create_link(link_exists,link_new):
185
207
  and if it has a different target than link_exists, remove and re-create
186
208
  it.
187
209
 
188
- Errors of link_new already exists but it's not a link.
189
- """
210
+ Errors if link_new already exists but it's not a link.
211
+ """
212
+
190
213
  if os.path.exists(link_new) or os.path.islink(link_new):
191
214
  assert os.path.islink(link_new)
192
215
  if not os.readlink(link_new) == link_exists:
@@ -196,6 +219,23 @@ def safe_create_link(link_exists,link_new):
196
219
  os.symlink(link_exists,link_new)
197
220
 
198
221
 
222
+ def get_file_sizes(base_dir, convert_slashes=True):
223
+ """
224
+ Get sizes recursively for all files in base_dir, returning a dict mapping
225
+ relative filenames to size.
226
+ """
227
+
228
+ relative_filenames = recursive_file_list(base_dir, convert_slashes=convert_slashes,
229
+ return_relative_paths=True)
230
+
231
+ fn_to_size = {}
232
+ for fn_relative in tqdm(relative_filenames):
233
+ fn_abs = os.path.join(base_dir,fn_relative)
234
+ fn_to_size[fn_relative] = os.path.getsize(fn_abs)
235
+
236
+ return fn_to_size
237
+
238
+
199
239
  #%% Image-related path functions
200
240
 
201
241
  def is_image_file(s: str, img_extensions: Container[str] = IMG_EXTENSIONS
@@ -221,10 +261,14 @@ def find_image_strings(strings: Iterable[str]) -> List[str]:
221
261
  return [s for s in strings if is_image_file(s)]
222
262
 
223
263
 
224
- def find_images(dirname: str, recursive: bool = False, return_relative_paths: bool = False) -> List[str]:
264
+ def find_images(dirname: str, recursive: bool = False,
265
+ return_relative_paths: bool = False,
266
+ convert_slashes: bool = False) -> List[str]:
225
267
  """
226
268
  Finds all files in a directory that look like image file names. Returns
227
- absolute paths unless return_relative_paths is set.
269
+ absolute paths unless return_relative_paths is set. Uses the OS-native
270
+ path separator unless convert_slahes is set, in which case will always
271
+ use '/'.
228
272
  """
229
273
 
230
274
  if recursive:
@@ -238,6 +282,10 @@ def find_images(dirname: str, recursive: bool = False, return_relative_paths: bo
238
282
  image_files = [os.path.relpath(fn,dirname) for fn in image_files]
239
283
 
240
284
  image_files = sorted(image_files)
285
+
286
+ if convert_slashes:
287
+ image_files = [fn.replace('\\', '/') for fn in image_files]
288
+
241
289
  return image_files
242
290
 
243
291
 
@@ -245,11 +293,11 @@ def find_images(dirname: str, recursive: bool = False, return_relative_paths: bo
245
293
 
246
294
  def clean_filename(filename: str, allow_list: str = VALID_FILENAME_CHARS,
247
295
  char_limit: int = CHAR_LIMIT, force_lower: bool = False) -> str:
248
- """
296
+ r"""
249
297
  Removes non-ASCII and other invalid filename characters (on any
250
298
  reasonable OS) from a filename, then trims to a maximum length.
251
299
 
252
- Does not allow :\/, use clean_path if you want to preserve those.
300
+ Does not allow :\/ by default, use clean_path if you want to preserve those.
253
301
 
254
302
  Adapted from
255
303
  https://gist.github.com/wassname/1393c4a57cfcbf03641dbc31886123b8
@@ -294,30 +342,71 @@ def flatten_path(pathname: str, separator_chars: str = SEPARATOR_CHARS) -> str:
294
342
 
295
343
  #%% Platform-independent way to open files in their associated application
296
344
 
297
- import sys,subprocess
298
-
299
- def open_file(filename):
300
- if sys.platform == "win32":
301
- os.startfile(filename)
302
- else:
303
- opener = "open" if sys.platform == "darwin" else "xdg-open"
304
- subprocess.call([opener, filename])
345
+ import sys,subprocess,platform,re
305
346
 
347
+ def environment_is_wsl():
348
+ """
349
+ Returns True if we're running in WSL
350
+ """
351
+
352
+ if sys.platform not in ('linux','posix'):
353
+ return False
354
+ platform_string = ' '.join(platform.uname()).lower()
355
+ return 'microsoft' in platform_string and 'wsl' in platform_string
356
+
306
357
 
307
- #%% zipfile management functions
358
+ def wsl_path_to_windows_path(filename):
359
+ """
360
+ Converts a WSL path to a Windows path, or returns None if that's not possible. E.g.
361
+ converts:
362
+
363
+ /mnt/e/a/b/c
364
+
365
+ ...to:
366
+
367
+ e:\a\b\c
368
+ """
369
+
370
+ result = subprocess.run(['wslpath', '-w', filename], text=True, capture_output=True)
371
+ if result.returncode != 0:
372
+ print('Could not convert path {} from WSL to Windows'.format(filename))
373
+ return None
374
+ return result.stdout.strip()
375
+
308
376
 
309
- def unzip_file(input_file, output_folder=None):
377
+ def open_file(filename,attempt_to_open_in_wsl_host=False):
310
378
  """
311
- Unzip a zipfile to the specified output folder, defaulting to the same location as
312
- the input file
379
+ Opens [filename] in the native OS file handler. If attempt_to_open_in_wsl_host
380
+ is True, and we're in WSL, attempts to open [filename] in Windows.
313
381
  """
314
382
 
315
- if output_folder is None:
316
- output_folder = os.path.dirname(input_file)
383
+ if sys.platform == 'win32':
317
384
 
318
- with zipfile.ZipFile(input_file, 'r') as zf:
319
- zf.extractall(output_folder)
385
+ os.startfile(filename)
320
386
 
387
+ elif sys.platform == 'darwin':
388
+
389
+ opener = 'open'
390
+ subprocess.call([opener, filename])
391
+
392
+ elif attempt_to_open_in_wsl_host and environment_is_wsl():
393
+
394
+ windows_path = wsl_path_to_windows_path(filename)
395
+
396
+ # Fall back to xdg-open
397
+ if windows_path is None:
398
+ subprocess.call(['xdg-open', filename])
399
+
400
+ if os.path.isdir(filename):
401
+ subprocess.run(["explorer.exe", windows_path])
402
+ else:
403
+ os.system("cmd.exe /C start %s" % (re.escape(windows_path)))
404
+
405
+ else:
406
+
407
+ opener = 'xdg-open'
408
+ subprocess.call([opener, filename])
409
+
321
410
 
322
411
  #%% File list functions
323
412
 
@@ -393,7 +482,7 @@ def zip_folder(input_folder, output_fn=None, overwrite=False, verbose=False, com
393
482
  relative_filenames = recursive_file_list(input_folder,return_relative_paths=True)
394
483
 
395
484
  with ZipFile(output_fn,'w',zipfile.ZIP_DEFLATED) as zipf:
396
- for input_fn_relative in relative_filenames:
485
+ for input_fn_relative in tqdm(relative_filenames,disable=(not verbose)):
397
486
  input_fn_abs = os.path.join(input_folder,input_fn_relative)
398
487
  zipf.write(input_fn_abs,
399
488
  arcname=input_fn_relative,
@@ -403,14 +492,53 @@ def zip_folder(input_folder, output_fn=None, overwrite=False, verbose=False, com
403
492
  return output_fn
404
493
 
405
494
 
406
- def parallel_zip_files(input_files,max_workers=16):
495
+ def parallel_zip_files(input_files, max_workers=16, use_threads=True):
407
496
  """
408
497
  Zip one or more files to separate output files in parallel, leaving the
409
- original files in place.
498
+ original files in place. Each file is zipped to [filename].zip.
410
499
  """
411
500
 
412
501
  n_workers = min(max_workers,len(input_files))
413
- pool = ThreadPool(n_workers)
502
+
503
+ if use_threads:
504
+ pool = ThreadPool(n_workers)
505
+ else:
506
+ pool = Pool(n_workers)
507
+
414
508
  with tqdm(total=len(input_files)) as pbar:
415
509
  for i,_ in enumerate(pool.imap_unordered(zip_file,input_files)):
416
510
  pbar.update()
511
+
512
+
513
+ def parallel_zip_folders(input_folders, max_workers=16, use_threads=True,
514
+ compresslevel=9, overwrite=False):
515
+ """
516
+ Zip one or more folders to separate output files in parallel, leaving the
517
+ original folders in place. Each folder is zipped to [folder_name].zip.
518
+ """
519
+
520
+ n_workers = min(max_workers,len(input_folders))
521
+
522
+ if use_threads:
523
+ pool = ThreadPool(n_workers)
524
+ else:
525
+ pool = Pool(n_workers)
526
+
527
+ with tqdm(total=len(input_folders)) as pbar:
528
+ for i,_ in enumerate(pool.imap_unordered(
529
+ partial(zip_folder,overwrite=overwrite,compresslevel=compresslevel),
530
+ input_folders)):
531
+ pbar.update()
532
+
533
+
534
+ def unzip_file(input_file, output_folder=None):
535
+ """
536
+ Unzip a zipfile to the specified output folder, defaulting to the same location as
537
+ the input file
538
+ """
539
+
540
+ if output_folder is None:
541
+ output_folder = os.path.dirname(input_file)
542
+
543
+ with zipfile.ZipFile(input_file, 'r') as zf:
544
+ zf.extractall(output_folder)
md_utils/process_utils.py CHANGED
@@ -17,14 +17,28 @@ import subprocess
17
17
 
18
18
  os.environ["PYTHONUNBUFFERED"] = "1"
19
19
 
20
- def execute(cmd):
20
+ def execute(cmd,encoding=None,errors=None,env=None,verbose=False):
21
21
  """
22
22
  Run [cmd] (a single string) in a shell, yielding each line of output to the caller.
23
+
24
+ The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
25
+
26
+ "verbose" only impacts output about process management, it is not related to printing
27
+ output from the child process.
23
28
  """
24
-
29
+
30
+ if verbose:
31
+ if encoding is not None:
32
+ print('Launching child process with non-default encoding {}'.format(encoding))
33
+ if errors is not None:
34
+ print('Launching child process with non-default text error handling {}'.format(errors))
35
+ if env is not None:
36
+ print('Launching child process with non-default environment {}'.format(str(env)))
37
+
25
38
  # https://stackoverflow.com/questions/4417546/constantly-print-subprocess-output-while-process-is-running
26
39
  popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
27
- shell=True, universal_newlines=True)
40
+ shell=True, universal_newlines=True, encoding=encoding,
41
+ errors=errors, env=env)
28
42
  for stdout_line in iter(popen.stdout.readline, ""):
29
43
  yield stdout_line
30
44
  popen.stdout.close()
@@ -33,22 +47,27 @@ def execute(cmd):
33
47
  raise subprocess.CalledProcessError(return_code, cmd)
34
48
 
35
49
 
36
- def execute_and_print(cmd,print_output=True):
50
+ def execute_and_print(cmd,print_output=True,encoding=None,errors=None,env=None,verbose=False):
37
51
  """
38
52
  Run [cmd] (a single string) in a shell, capturing and printing output. Returns
39
53
  a dictionary with fields "status" and "output".
54
+
55
+ The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
56
+
57
+ "verbose" only impacts output about process management, it is not related to printing
58
+ output from the child process.
40
59
  """
41
60
 
42
61
  to_return = {'status':'unknown','output':''}
43
- output=[]
62
+ output = []
44
63
  try:
45
- for s in execute(cmd):
64
+ for s in execute(cmd,encoding=encoding,errors=errors,env=env,verbose=verbose):
46
65
  output.append(s)
47
66
  if print_output:
48
67
  print(s,end='',flush=True)
49
68
  to_return['status'] = 0
50
69
  except subprocess.CalledProcessError as cpe:
51
- print('execute_and_print caught error: {}'.format(cpe.output))
70
+ print('execute_and_print caught error: {} ({})'.format(cpe.output,str(cpe)))
52
71
  to_return['status'] = cpe.returncode
53
72
  to_return['output'] = output
54
73
 
@@ -0,0 +1,215 @@
1
+ ########
2
+ #
3
+ # split_locations_into_train_val.py
4
+ #
5
+ # Split a list of location IDs into training and validation, targeting a specific
6
+ # train/val split for each category, but allowing some categories to be tighter or looser
7
+ # than others. Does nothing particularly clever, just randomly splits locations into
8
+ # train/val lots of times using the target val fraction, and picks the one that meets the
9
+ # specified constraints and minimizes weighted error, where "error" is defined as the
10
+ # sum of each class's absolute divergence from the target val fraction.
11
+ #
12
+ ########
13
+
14
+ #%% Imports/constants
15
+
16
+ import random
17
+ import numpy as np
18
+
19
+ from collections import defaultdict
20
+ from md_utils.ct_utils import sort_dictionary_by_value
21
+ from tqdm import tqdm
22
+
23
+
24
+ #%% Main function
25
+
26
+ def split_locations_into_train_val(location_to_category_counts,
27
+ n_random_seeds=10000,
28
+ target_val_fraction=0.15,
29
+ category_to_max_allowable_error=None,
30
+ category_to_error_weight=None,
31
+ default_max_allowable_error=0.1):
32
+ """
33
+ Split a list of location IDs into training and validation, targeting a specific
34
+ train/val split for each category, but allowing some categories to be tighter or looser
35
+ than others. Does nothing particularly clever, just randomly splits locations into
36
+ train/val lots of times using the target val fraction, and picks the one that meets the
37
+ specified constraints and minimizes weighted error, where "error" is defined as the
38
+ sum of each class's absolute divergence from the target val fraction.
39
+
40
+ location_to_category_counts should be a dict mapping location IDs to dicts,
41
+ with each dict mapping a category name to a count. Any categories not present in a
42
+ particular dict are assumed to have a count of zero for that location.
43
+
44
+ If not None, category_to_max_allowable_error should be a dict mapping category names
45
+ to maximum allowable errors. These are hard constraints, but you can specify a subset
46
+ of categories. Categories not included here have a maximum error of Inf.
47
+
48
+ If not None, category_to_error_weight should be a dict mapping category names to
49
+ error weights. You can specify a subset of categories. Categories not included here
50
+ have a weight of 1.0.
51
+
52
+ default_max_allowable_error is the maximum allowable error for categories not present in
53
+ category_to_max_allowable_error. Set to None (or >= 1.0) to disable hard constraints for
54
+ categories not present in category_to_max_allowable_error
55
+
56
+ returns val_locations,category_to_val_fraction
57
+
58
+ """
59
+
60
+ location_ids = list(location_to_category_counts.keys())
61
+
62
+ n_val_locations = int(target_val_fraction*len(location_ids))
63
+
64
+ if category_to_max_allowable_error is None:
65
+ category_to_max_allowable_error = {}
66
+
67
+ if category_to_error_weight is None:
68
+ category_to_error_weight = {}
69
+
70
+ # category ID to total count; the total count is used only for printouts
71
+ category_id_to_count = {}
72
+ for location_id in location_to_category_counts:
73
+ for category_id in location_to_category_counts[location_id].keys():
74
+ if category_id not in category_id_to_count:
75
+ category_id_to_count[category_id] = 0
76
+ category_id_to_count[category_id] += \
77
+ location_to_category_counts[location_id][category_id]
78
+
79
+ category_ids = set(category_id_to_count.keys())
80
+
81
+ print('Splitting {} categories over {} locations'.format(
82
+ len(category_ids),len(location_ids)))
83
+
84
+ # random_seed = 0
85
+ def compute_seed_errors(random_seed):
86
+ """
87
+ Compute the per-category error for a specific random seed.
88
+
89
+ returns weighted_average_error,category_to_val_fraction
90
+ """
91
+
92
+ # Randomly split into train/val
93
+ random.seed(random_seed)
94
+ val_locations = random.sample(location_ids,k=n_val_locations)
95
+ val_locations_set = set(val_locations)
96
+
97
+ # For each category, measure the % of images that went into the val set
98
+ category_to_val_fraction = defaultdict(float)
99
+
100
+ for category_id in category_ids:
101
+ category_val_count = 0
102
+ category_train_count = 0
103
+ for location_id in location_to_category_counts:
104
+ if category_id not in location_to_category_counts[location_id]:
105
+ location_category_count = 0
106
+ else:
107
+ location_category_count = location_to_category_counts[location_id][category_id]
108
+ if location_id in val_locations_set:
109
+ category_val_count += location_category_count
110
+ else:
111
+ category_train_count += location_category_count
112
+ category_val_fraction = category_val_count / (category_val_count + category_train_count)
113
+ category_to_val_fraction[category_id] = category_val_fraction
114
+
115
+ # Absolute deviation from the target val fraction for each categorys
116
+ category_errors = {}
117
+ weighted_category_errors = {}
118
+
119
+ # category = next(iter(category_to_val_fraction))
120
+ for category in category_to_val_fraction:
121
+
122
+ category_val_fraction = category_to_val_fraction[category]
123
+
124
+ category_error = abs(category_val_fraction-target_val_fraction)
125
+ category_errors[category] = category_error
126
+
127
+ category_weight = 1.0
128
+ if category in category_to_error_weight:
129
+ category_weight = category_to_error_weight[category]
130
+ weighted_category_error = category_error * category_weight
131
+ weighted_category_errors[category] = weighted_category_error
132
+
133
+ weighted_average_error = np.mean(list(weighted_category_errors.values()))
134
+
135
+ return weighted_average_error,weighted_category_errors,category_to_val_fraction
136
+
137
+ # ... def compute_seed_errors(...)
138
+
139
+ # This will only include random seeds that satisfy the hard constraints
140
+ random_seed_to_weighted_average_error = {}
141
+
142
+ # random_seed = 0
143
+ for random_seed in tqdm(range(0,n_random_seeds)):
144
+
145
+ weighted_average_error,weighted_category_errors,category_to_val_fraction = \
146
+ compute_seed_errors(random_seed)
147
+
148
+ seed_satisfies_hard_constraints = True
149
+
150
+ for category in category_to_val_fraction:
151
+ if category in category_to_max_allowable_error:
152
+ max_allowable_error = category_to_max_allowable_error[category]
153
+ else:
154
+ if default_max_allowable_error is None:
155
+ continue
156
+ max_allowable_error = default_max_allowable_error
157
+ val_fraction = category_to_val_fraction[category]
158
+ category_error = abs(val_fraction - target_val_fraction)
159
+ if category_error > max_allowable_error:
160
+ seed_satisfies_hard_constraints = False
161
+ break
162
+
163
+ if seed_satisfies_hard_constraints:
164
+ random_seed_to_weighted_average_error[random_seed] = weighted_average_error
165
+
166
+ # ...for each random seed
167
+
168
+ assert len(random_seed_to_weighted_average_error) > 0, \
169
+ 'No random seed met all the hard constraints'
170
+
171
+ print('\n{} of {} random seeds satisfied hard constraints'.format(
172
+ len(random_seed_to_weighted_average_error),n_random_seeds))
173
+
174
+ min_error = None
175
+ min_error_seed = None
176
+
177
+ for random_seed in random_seed_to_weighted_average_error.keys():
178
+ error_metric = random_seed_to_weighted_average_error[random_seed]
179
+ if min_error is None or error_metric < min_error:
180
+ min_error = error_metric
181
+ min_error_seed = random_seed
182
+
183
+ random.seed(min_error_seed)
184
+ val_locations = random.sample(location_ids,k=n_val_locations)
185
+ train_locations = []
186
+ for location_id in location_ids:
187
+ if location_id not in val_locations:
188
+ train_locations.append(location_id)
189
+
190
+ print('\nVal locations:\n')
191
+ for loc in val_locations:
192
+ print('{}'.format(loc))
193
+ print('')
194
+
195
+ weighted_average_error,weighted_category_errors,category_to_val_fraction = \
196
+ compute_seed_errors(min_error_seed)
197
+
198
+ random_seed = min_error_seed
199
+
200
+ category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,reverse=True)
201
+ category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,
202
+ sort_values=category_id_to_count,
203
+ reverse=True)
204
+
205
+
206
+ print('Val fractions by category:\n')
207
+
208
+ for category in category_to_val_fraction:
209
+ print('{} ({}) {:.2f}'.format(
210
+ category,category_id_to_count[category],
211
+ category_to_val_fraction[category]))
212
+
213
+ return val_locations,category_to_val_fraction
214
+
215
+ # ...def split_locations_into_train_val(...)
md_utils/string_utils.py CHANGED
@@ -57,3 +57,13 @@ def human_readable_to_bytes(size):
57
57
  bytes = 0
58
58
 
59
59
  return bytes
60
+
61
+
62
+ def remove_ansi_codes(s):
63
+ """
64
+ Remove ANSI escape codes from a string.
65
+
66
+ https://stackoverflow.com/questions/14693701/how-can-i-remove-the-ansi-escape-sequences-from-a-string-in-python#14693789
67
+ """
68
+ ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
69
+ return ansi_escape.sub('', s)
md_utils/url_utils.py CHANGED
@@ -140,5 +140,3 @@ def test_urls(urls, error_on_failure=True):
140
140
 
141
141
  return status_codes
142
142
 
143
-
144
-