megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- api/batch_processing/data_preparation/manage_local_batch.py +302 -263
- api/batch_processing/data_preparation/manage_video_batch.py +81 -2
- api/batch_processing/postprocessing/add_max_conf.py +1 -0
- api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
- api/batch_processing/postprocessing/compare_batch_results.py +110 -60
- api/batch_processing/postprocessing/load_api_results.py +56 -70
- api/batch_processing/postprocessing/md_to_coco.py +1 -1
- api/batch_processing/postprocessing/md_to_labelme.py +2 -1
- api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
- api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
- api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
- api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
- api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
- api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
- api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
- api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
- classification/prepare_classification_script.py +191 -191
- data_management/coco_to_yolo.py +68 -45
- data_management/databases/integrity_check_json_db.py +7 -5
- data_management/generate_crops_from_cct.py +3 -3
- data_management/get_image_sizes.py +8 -6
- data_management/importers/add_timestamps_to_icct.py +79 -0
- data_management/importers/animl_results_to_md_results.py +160 -0
- data_management/importers/auckland_doc_test_to_json.py +4 -4
- data_management/importers/auckland_doc_to_json.py +1 -1
- data_management/importers/awc_to_json.py +5 -5
- data_management/importers/bellevue_to_json.py +5 -5
- data_management/importers/carrizo_shrubfree_2018.py +5 -5
- data_management/importers/carrizo_trail_cam_2017.py +5 -5
- data_management/importers/cct_field_adjustments.py +2 -3
- data_management/importers/channel_islands_to_cct.py +4 -4
- data_management/importers/ena24_to_json.py +5 -5
- data_management/importers/helena_to_cct.py +10 -10
- data_management/importers/idaho-camera-traps.py +12 -12
- data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
- data_management/importers/jb_csv_to_json.py +4 -4
- data_management/importers/missouri_to_json.py +1 -1
- data_management/importers/noaa_seals_2019.py +1 -1
- data_management/importers/pc_to_json.py +5 -5
- data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
- data_management/importers/prepare_zsl_imerit.py +5 -5
- data_management/importers/rspb_to_json.py +4 -4
- data_management/importers/save_the_elephants_survey_A.py +5 -5
- data_management/importers/save_the_elephants_survey_B.py +6 -6
- data_management/importers/snapshot_safari_importer.py +9 -9
- data_management/importers/snapshot_serengeti_lila.py +9 -9
- data_management/importers/timelapse_csv_set_to_json.py +5 -7
- data_management/importers/ubc_to_json.py +4 -4
- data_management/importers/umn_to_json.py +4 -4
- data_management/importers/wellington_to_json.py +1 -1
- data_management/importers/wi_to_json.py +2 -2
- data_management/importers/zamba_results_to_md_results.py +181 -0
- data_management/labelme_to_coco.py +35 -7
- data_management/labelme_to_yolo.py +229 -0
- data_management/lila/add_locations_to_island_camera_traps.py +1 -1
- data_management/lila/add_locations_to_nacti.py +147 -0
- data_management/lila/create_lila_blank_set.py +474 -0
- data_management/lila/create_lila_test_set.py +2 -1
- data_management/lila/create_links_to_md_results_files.py +106 -0
- data_management/lila/download_lila_subset.py +46 -21
- data_management/lila/generate_lila_per_image_labels.py +23 -14
- data_management/lila/get_lila_annotation_counts.py +17 -11
- data_management/lila/lila_common.py +14 -11
- data_management/lila/test_lila_metadata_urls.py +116 -0
- data_management/ocr_tools.py +829 -0
- data_management/resize_coco_dataset.py +13 -11
- data_management/yolo_output_to_md_output.py +84 -12
- data_management/yolo_to_coco.py +38 -20
- detection/process_video.py +36 -14
- detection/pytorch_detector.py +23 -8
- detection/run_detector.py +76 -19
- detection/run_detector_batch.py +178 -63
- detection/run_inference_with_yolov5_val.py +326 -57
- detection/run_tiled_inference.py +153 -43
- detection/video_utils.py +34 -8
- md_utils/ct_utils.py +172 -1
- md_utils/md_tests.py +372 -51
- md_utils/path_utils.py +167 -39
- md_utils/process_utils.py +26 -7
- md_utils/split_locations_into_train_val.py +215 -0
- md_utils/string_utils.py +10 -0
- md_utils/url_utils.py +0 -2
- md_utils/write_html_image_list.py +9 -26
- md_visualization/plot_utils.py +12 -8
- md_visualization/visualization_utils.py +106 -7
- md_visualization/visualize_db.py +16 -8
- md_visualization/visualize_detector_output.py +208 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
- taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
- taxonomy_mapping/map_new_lila_datasets.py +43 -39
- taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
- taxonomy_mapping/preview_lila_taxonomy.py +27 -27
- taxonomy_mapping/species_lookup.py +33 -13
- taxonomy_mapping/taxonomy_csv_checker.py +7 -5
- api/synchronous/api_core/yolov5/detect.py +0 -252
- api/synchronous/api_core/yolov5/export.py +0 -607
- api/synchronous/api_core/yolov5/hubconf.py +0 -146
- api/synchronous/api_core/yolov5/models/__init__.py +0 -0
- api/synchronous/api_core/yolov5/models/common.py +0 -738
- api/synchronous/api_core/yolov5/models/experimental.py +0 -104
- api/synchronous/api_core/yolov5/models/tf.py +0 -574
- api/synchronous/api_core/yolov5/models/yolo.py +0 -338
- api/synchronous/api_core/yolov5/train.py +0 -670
- api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
- api/synchronous/api_core/yolov5/utils/activations.py +0 -103
- api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
- api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
- api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
- api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
- api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
- api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
- api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
- api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
- api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
- api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
- api/synchronous/api_core/yolov5/utils/general.py +0 -1018
- api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
- api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
- api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
- api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
- api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
- api/synchronous/api_core/yolov5/utils/loss.py +0 -234
- api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
- api/synchronous/api_core/yolov5/utils/plots.py +0 -489
- api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
- api/synchronous/api_core/yolov5/val.py +0 -394
- md_utils/matlab_porting_tools.py +0 -97
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
- {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
md_utils/path_utils.py
CHANGED
|
@@ -21,7 +21,8 @@ import zipfile
|
|
|
21
21
|
from zipfile import ZipFile
|
|
22
22
|
from datetime import datetime
|
|
23
23
|
from typing import Container, Iterable, List, Optional, Tuple, Sequence
|
|
24
|
-
from multiprocessing.pool import ThreadPool
|
|
24
|
+
from multiprocessing.pool import Pool, ThreadPool
|
|
25
|
+
from functools import partial
|
|
25
26
|
from tqdm import tqdm
|
|
26
27
|
|
|
27
28
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.gif', '.png', '.tif', '.tiff', '.bmp')
|
|
@@ -34,30 +35,51 @@ CHAR_LIMIT = 255
|
|
|
34
35
|
|
|
35
36
|
#%% General path functions
|
|
36
37
|
|
|
37
|
-
def recursive_file_list(base_dir, convert_slashes=True,
|
|
38
|
-
|
|
38
|
+
def recursive_file_list(base_dir, convert_slashes=True,
|
|
39
|
+
return_relative_paths=False, sort_files=True,
|
|
40
|
+
recursive=True):
|
|
41
|
+
r"""
|
|
39
42
|
Enumerate files (not directories) in [base_dir], optionally converting
|
|
40
43
|
\ to /
|
|
41
44
|
"""
|
|
42
45
|
|
|
43
46
|
all_files = []
|
|
44
47
|
|
|
45
|
-
|
|
46
|
-
for
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
48
|
+
if recursive:
|
|
49
|
+
for root, _, filenames in os.walk(base_dir):
|
|
50
|
+
for filename in filenames:
|
|
51
|
+
full_path = os.path.join(root, filename)
|
|
52
|
+
all_files.append(full_path)
|
|
53
|
+
else:
|
|
54
|
+
all_files_relative = os.listdir(base_dir)
|
|
55
|
+
all_files = [os.path.join(base_dir,fn) for fn in all_files_relative]
|
|
56
|
+
all_files = [fn for fn in all_files if os.path.isfile(fn)]
|
|
57
|
+
|
|
52
58
|
if return_relative_paths:
|
|
53
59
|
all_files = [os.path.relpath(fn,base_dir) for fn in all_files]
|
|
60
|
+
|
|
61
|
+
if convert_slashes:
|
|
62
|
+
all_files = [fn.replace('\\', '/') for fn in all_files]
|
|
63
|
+
|
|
64
|
+
if sort_files:
|
|
65
|
+
all_files = sorted(all_files)
|
|
54
66
|
|
|
55
|
-
all_files = sorted(all_files)
|
|
56
67
|
return all_files
|
|
57
68
|
|
|
58
69
|
|
|
59
|
-
def
|
|
70
|
+
def file_list(base_dir, convert_slashes=True, return_relative_paths=False, sort_files=True,
|
|
71
|
+
recursive=False):
|
|
72
|
+
"""
|
|
73
|
+
Trivial wrapper for recursive_file_list, which was a poor function name choice at the time,
|
|
74
|
+
it doesn't really make sense to have a "recursive" option in a function called "recursive_file_list".
|
|
60
75
|
"""
|
|
76
|
+
|
|
77
|
+
return recursive_file_list(base_dir,convert_slashes,return_relative_paths,sort_files,
|
|
78
|
+
recursive=recursive)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def split_path(path: str) -> List[str]:
|
|
82
|
+
r"""
|
|
61
83
|
Splits [path] into all its constituent tokens.
|
|
62
84
|
|
|
63
85
|
Non-recursive version of:
|
|
@@ -87,7 +109,7 @@ def split_path(path: str) -> List[str]:
|
|
|
87
109
|
|
|
88
110
|
|
|
89
111
|
def fileparts(path: str) -> Tuple[str, str, str]:
|
|
90
|
-
"""
|
|
112
|
+
r"""
|
|
91
113
|
Breaks down a path into the directory path, filename, and extension.
|
|
92
114
|
|
|
93
115
|
Note that the '.' lives with the extension, and separators are removed.
|
|
@@ -185,8 +207,9 @@ def safe_create_link(link_exists,link_new):
|
|
|
185
207
|
and if it has a different target than link_exists, remove and re-create
|
|
186
208
|
it.
|
|
187
209
|
|
|
188
|
-
Errors
|
|
189
|
-
"""
|
|
210
|
+
Errors if link_new already exists but it's not a link.
|
|
211
|
+
"""
|
|
212
|
+
|
|
190
213
|
if os.path.exists(link_new) or os.path.islink(link_new):
|
|
191
214
|
assert os.path.islink(link_new)
|
|
192
215
|
if not os.readlink(link_new) == link_exists:
|
|
@@ -196,6 +219,23 @@ def safe_create_link(link_exists,link_new):
|
|
|
196
219
|
os.symlink(link_exists,link_new)
|
|
197
220
|
|
|
198
221
|
|
|
222
|
+
def get_file_sizes(base_dir, convert_slashes=True):
|
|
223
|
+
"""
|
|
224
|
+
Get sizes recursively for all files in base_dir, returning a dict mapping
|
|
225
|
+
relative filenames to size.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
relative_filenames = recursive_file_list(base_dir, convert_slashes=convert_slashes,
|
|
229
|
+
return_relative_paths=True)
|
|
230
|
+
|
|
231
|
+
fn_to_size = {}
|
|
232
|
+
for fn_relative in tqdm(relative_filenames):
|
|
233
|
+
fn_abs = os.path.join(base_dir,fn_relative)
|
|
234
|
+
fn_to_size[fn_relative] = os.path.getsize(fn_abs)
|
|
235
|
+
|
|
236
|
+
return fn_to_size
|
|
237
|
+
|
|
238
|
+
|
|
199
239
|
#%% Image-related path functions
|
|
200
240
|
|
|
201
241
|
def is_image_file(s: str, img_extensions: Container[str] = IMG_EXTENSIONS
|
|
@@ -221,10 +261,14 @@ def find_image_strings(strings: Iterable[str]) -> List[str]:
|
|
|
221
261
|
return [s for s in strings if is_image_file(s)]
|
|
222
262
|
|
|
223
263
|
|
|
224
|
-
def find_images(dirname: str, recursive: bool = False,
|
|
264
|
+
def find_images(dirname: str, recursive: bool = False,
|
|
265
|
+
return_relative_paths: bool = False,
|
|
266
|
+
convert_slashes: bool = False) -> List[str]:
|
|
225
267
|
"""
|
|
226
268
|
Finds all files in a directory that look like image file names. Returns
|
|
227
|
-
absolute paths unless return_relative_paths is set.
|
|
269
|
+
absolute paths unless return_relative_paths is set. Uses the OS-native
|
|
270
|
+
path separator unless convert_slahes is set, in which case will always
|
|
271
|
+
use '/'.
|
|
228
272
|
"""
|
|
229
273
|
|
|
230
274
|
if recursive:
|
|
@@ -238,6 +282,10 @@ def find_images(dirname: str, recursive: bool = False, return_relative_paths: bo
|
|
|
238
282
|
image_files = [os.path.relpath(fn,dirname) for fn in image_files]
|
|
239
283
|
|
|
240
284
|
image_files = sorted(image_files)
|
|
285
|
+
|
|
286
|
+
if convert_slashes:
|
|
287
|
+
image_files = [fn.replace('\\', '/') for fn in image_files]
|
|
288
|
+
|
|
241
289
|
return image_files
|
|
242
290
|
|
|
243
291
|
|
|
@@ -245,11 +293,11 @@ def find_images(dirname: str, recursive: bool = False, return_relative_paths: bo
|
|
|
245
293
|
|
|
246
294
|
def clean_filename(filename: str, allow_list: str = VALID_FILENAME_CHARS,
|
|
247
295
|
char_limit: int = CHAR_LIMIT, force_lower: bool = False) -> str:
|
|
248
|
-
"""
|
|
296
|
+
r"""
|
|
249
297
|
Removes non-ASCII and other invalid filename characters (on any
|
|
250
298
|
reasonable OS) from a filename, then trims to a maximum length.
|
|
251
299
|
|
|
252
|
-
Does not allow
|
|
300
|
+
Does not allow :\/ by default, use clean_path if you want to preserve those.
|
|
253
301
|
|
|
254
302
|
Adapted from
|
|
255
303
|
https://gist.github.com/wassname/1393c4a57cfcbf03641dbc31886123b8
|
|
@@ -294,30 +342,71 @@ def flatten_path(pathname: str, separator_chars: str = SEPARATOR_CHARS) -> str:
|
|
|
294
342
|
|
|
295
343
|
#%% Platform-independent way to open files in their associated application
|
|
296
344
|
|
|
297
|
-
import sys,subprocess
|
|
298
|
-
|
|
299
|
-
def open_file(filename):
|
|
300
|
-
if sys.platform == "win32":
|
|
301
|
-
os.startfile(filename)
|
|
302
|
-
else:
|
|
303
|
-
opener = "open" if sys.platform == "darwin" else "xdg-open"
|
|
304
|
-
subprocess.call([opener, filename])
|
|
345
|
+
import sys,subprocess,platform,re
|
|
305
346
|
|
|
347
|
+
def environment_is_wsl():
|
|
348
|
+
"""
|
|
349
|
+
Returns True if we're running in WSL
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
if sys.platform not in ('linux','posix'):
|
|
353
|
+
return False
|
|
354
|
+
platform_string = ' '.join(platform.uname()).lower()
|
|
355
|
+
return 'microsoft' in platform_string and 'wsl' in platform_string
|
|
356
|
+
|
|
306
357
|
|
|
307
|
-
|
|
358
|
+
def wsl_path_to_windows_path(filename):
|
|
359
|
+
"""
|
|
360
|
+
Converts a WSL path to a Windows path, or returns None if that's not possible. E.g.
|
|
361
|
+
converts:
|
|
362
|
+
|
|
363
|
+
/mnt/e/a/b/c
|
|
364
|
+
|
|
365
|
+
...to:
|
|
366
|
+
|
|
367
|
+
e:\a\b\c
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
result = subprocess.run(['wslpath', '-w', filename], text=True, capture_output=True)
|
|
371
|
+
if result.returncode != 0:
|
|
372
|
+
print('Could not convert path {} from WSL to Windows'.format(filename))
|
|
373
|
+
return None
|
|
374
|
+
return result.stdout.strip()
|
|
375
|
+
|
|
308
376
|
|
|
309
|
-
def
|
|
377
|
+
def open_file(filename,attempt_to_open_in_wsl_host=False):
|
|
310
378
|
"""
|
|
311
|
-
|
|
312
|
-
|
|
379
|
+
Opens [filename] in the native OS file handler. If attempt_to_open_in_wsl_host
|
|
380
|
+
is True, and we're in WSL, attempts to open [filename] in Windows.
|
|
313
381
|
"""
|
|
314
382
|
|
|
315
|
-
if
|
|
316
|
-
output_folder = os.path.dirname(input_file)
|
|
383
|
+
if sys.platform == 'win32':
|
|
317
384
|
|
|
318
|
-
|
|
319
|
-
zf.extractall(output_folder)
|
|
385
|
+
os.startfile(filename)
|
|
320
386
|
|
|
387
|
+
elif sys.platform == 'darwin':
|
|
388
|
+
|
|
389
|
+
opener = 'open'
|
|
390
|
+
subprocess.call([opener, filename])
|
|
391
|
+
|
|
392
|
+
elif attempt_to_open_in_wsl_host and environment_is_wsl():
|
|
393
|
+
|
|
394
|
+
windows_path = wsl_path_to_windows_path(filename)
|
|
395
|
+
|
|
396
|
+
# Fall back to xdg-open
|
|
397
|
+
if windows_path is None:
|
|
398
|
+
subprocess.call(['xdg-open', filename])
|
|
399
|
+
|
|
400
|
+
if os.path.isdir(filename):
|
|
401
|
+
subprocess.run(["explorer.exe", windows_path])
|
|
402
|
+
else:
|
|
403
|
+
os.system("cmd.exe /C start %s" % (re.escape(windows_path)))
|
|
404
|
+
|
|
405
|
+
else:
|
|
406
|
+
|
|
407
|
+
opener = 'xdg-open'
|
|
408
|
+
subprocess.call([opener, filename])
|
|
409
|
+
|
|
321
410
|
|
|
322
411
|
#%% File list functions
|
|
323
412
|
|
|
@@ -393,7 +482,7 @@ def zip_folder(input_folder, output_fn=None, overwrite=False, verbose=False, com
|
|
|
393
482
|
relative_filenames = recursive_file_list(input_folder,return_relative_paths=True)
|
|
394
483
|
|
|
395
484
|
with ZipFile(output_fn,'w',zipfile.ZIP_DEFLATED) as zipf:
|
|
396
|
-
for input_fn_relative in relative_filenames:
|
|
485
|
+
for input_fn_relative in tqdm(relative_filenames,disable=(not verbose)):
|
|
397
486
|
input_fn_abs = os.path.join(input_folder,input_fn_relative)
|
|
398
487
|
zipf.write(input_fn_abs,
|
|
399
488
|
arcname=input_fn_relative,
|
|
@@ -403,14 +492,53 @@ def zip_folder(input_folder, output_fn=None, overwrite=False, verbose=False, com
|
|
|
403
492
|
return output_fn
|
|
404
493
|
|
|
405
494
|
|
|
406
|
-
def parallel_zip_files(input_files,max_workers=16):
|
|
495
|
+
def parallel_zip_files(input_files, max_workers=16, use_threads=True):
|
|
407
496
|
"""
|
|
408
497
|
Zip one or more files to separate output files in parallel, leaving the
|
|
409
|
-
original files in place.
|
|
498
|
+
original files in place. Each file is zipped to [filename].zip.
|
|
410
499
|
"""
|
|
411
500
|
|
|
412
501
|
n_workers = min(max_workers,len(input_files))
|
|
413
|
-
|
|
502
|
+
|
|
503
|
+
if use_threads:
|
|
504
|
+
pool = ThreadPool(n_workers)
|
|
505
|
+
else:
|
|
506
|
+
pool = Pool(n_workers)
|
|
507
|
+
|
|
414
508
|
with tqdm(total=len(input_files)) as pbar:
|
|
415
509
|
for i,_ in enumerate(pool.imap_unordered(zip_file,input_files)):
|
|
416
510
|
pbar.update()
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def parallel_zip_folders(input_folders, max_workers=16, use_threads=True,
|
|
514
|
+
compresslevel=9, overwrite=False):
|
|
515
|
+
"""
|
|
516
|
+
Zip one or more folders to separate output files in parallel, leaving the
|
|
517
|
+
original folders in place. Each folder is zipped to [folder_name].zip.
|
|
518
|
+
"""
|
|
519
|
+
|
|
520
|
+
n_workers = min(max_workers,len(input_folders))
|
|
521
|
+
|
|
522
|
+
if use_threads:
|
|
523
|
+
pool = ThreadPool(n_workers)
|
|
524
|
+
else:
|
|
525
|
+
pool = Pool(n_workers)
|
|
526
|
+
|
|
527
|
+
with tqdm(total=len(input_folders)) as pbar:
|
|
528
|
+
for i,_ in enumerate(pool.imap_unordered(
|
|
529
|
+
partial(zip_folder,overwrite=overwrite,compresslevel=compresslevel),
|
|
530
|
+
input_folders)):
|
|
531
|
+
pbar.update()
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def unzip_file(input_file, output_folder=None):
|
|
535
|
+
"""
|
|
536
|
+
Unzip a zipfile to the specified output folder, defaulting to the same location as
|
|
537
|
+
the input file
|
|
538
|
+
"""
|
|
539
|
+
|
|
540
|
+
if output_folder is None:
|
|
541
|
+
output_folder = os.path.dirname(input_file)
|
|
542
|
+
|
|
543
|
+
with zipfile.ZipFile(input_file, 'r') as zf:
|
|
544
|
+
zf.extractall(output_folder)
|
md_utils/process_utils.py
CHANGED
|
@@ -17,14 +17,28 @@ import subprocess
|
|
|
17
17
|
|
|
18
18
|
os.environ["PYTHONUNBUFFERED"] = "1"
|
|
19
19
|
|
|
20
|
-
def execute(cmd):
|
|
20
|
+
def execute(cmd,encoding=None,errors=None,env=None,verbose=False):
|
|
21
21
|
"""
|
|
22
22
|
Run [cmd] (a single string) in a shell, yielding each line of output to the caller.
|
|
23
|
+
|
|
24
|
+
The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
|
|
25
|
+
|
|
26
|
+
"verbose" only impacts output about process management, it is not related to printing
|
|
27
|
+
output from the child process.
|
|
23
28
|
"""
|
|
24
|
-
|
|
29
|
+
|
|
30
|
+
if verbose:
|
|
31
|
+
if encoding is not None:
|
|
32
|
+
print('Launching child process with non-default encoding {}'.format(encoding))
|
|
33
|
+
if errors is not None:
|
|
34
|
+
print('Launching child process with non-default text error handling {}'.format(errors))
|
|
35
|
+
if env is not None:
|
|
36
|
+
print('Launching child process with non-default environment {}'.format(str(env)))
|
|
37
|
+
|
|
25
38
|
# https://stackoverflow.com/questions/4417546/constantly-print-subprocess-output-while-process-is-running
|
|
26
39
|
popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
27
|
-
shell=True, universal_newlines=True
|
|
40
|
+
shell=True, universal_newlines=True, encoding=encoding,
|
|
41
|
+
errors=errors, env=env)
|
|
28
42
|
for stdout_line in iter(popen.stdout.readline, ""):
|
|
29
43
|
yield stdout_line
|
|
30
44
|
popen.stdout.close()
|
|
@@ -33,22 +47,27 @@ def execute(cmd):
|
|
|
33
47
|
raise subprocess.CalledProcessError(return_code, cmd)
|
|
34
48
|
|
|
35
49
|
|
|
36
|
-
def execute_and_print(cmd,print_output=True):
|
|
50
|
+
def execute_and_print(cmd,print_output=True,encoding=None,errors=None,env=None,verbose=False):
|
|
37
51
|
"""
|
|
38
52
|
Run [cmd] (a single string) in a shell, capturing and printing output. Returns
|
|
39
53
|
a dictionary with fields "status" and "output".
|
|
54
|
+
|
|
55
|
+
The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
|
|
56
|
+
|
|
57
|
+
"verbose" only impacts output about process management, it is not related to printing
|
|
58
|
+
output from the child process.
|
|
40
59
|
"""
|
|
41
60
|
|
|
42
61
|
to_return = {'status':'unknown','output':''}
|
|
43
|
-
output=[]
|
|
62
|
+
output = []
|
|
44
63
|
try:
|
|
45
|
-
for s in execute(cmd):
|
|
64
|
+
for s in execute(cmd,encoding=encoding,errors=errors,env=env,verbose=verbose):
|
|
46
65
|
output.append(s)
|
|
47
66
|
if print_output:
|
|
48
67
|
print(s,end='',flush=True)
|
|
49
68
|
to_return['status'] = 0
|
|
50
69
|
except subprocess.CalledProcessError as cpe:
|
|
51
|
-
print('execute_and_print caught error: {}'.format(cpe.output))
|
|
70
|
+
print('execute_and_print caught error: {} ({})'.format(cpe.output,str(cpe)))
|
|
52
71
|
to_return['status'] = cpe.returncode
|
|
53
72
|
to_return['output'] = output
|
|
54
73
|
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
########
|
|
2
|
+
#
|
|
3
|
+
# split_locations_into_train_val.py
|
|
4
|
+
#
|
|
5
|
+
# Split a list of location IDs into training and validation, targeting a specific
|
|
6
|
+
# train/val split for each category, but allowing some categories to be tighter or looser
|
|
7
|
+
# than others. Does nothing particularly clever, just randomly splits locations into
|
|
8
|
+
# train/val lots of times using the target val fraction, and picks the one that meets the
|
|
9
|
+
# specified constraints and minimizes weighted error, where "error" is defined as the
|
|
10
|
+
# sum of each class's absolute divergence from the target val fraction.
|
|
11
|
+
#
|
|
12
|
+
########
|
|
13
|
+
|
|
14
|
+
#%% Imports/constants
|
|
15
|
+
|
|
16
|
+
import random
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
from md_utils.ct_utils import sort_dictionary_by_value
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
#%% Main function
|
|
25
|
+
|
|
26
|
+
def split_locations_into_train_val(location_to_category_counts,
|
|
27
|
+
n_random_seeds=10000,
|
|
28
|
+
target_val_fraction=0.15,
|
|
29
|
+
category_to_max_allowable_error=None,
|
|
30
|
+
category_to_error_weight=None,
|
|
31
|
+
default_max_allowable_error=0.1):
|
|
32
|
+
"""
|
|
33
|
+
Split a list of location IDs into training and validation, targeting a specific
|
|
34
|
+
train/val split for each category, but allowing some categories to be tighter or looser
|
|
35
|
+
than others. Does nothing particularly clever, just randomly splits locations into
|
|
36
|
+
train/val lots of times using the target val fraction, and picks the one that meets the
|
|
37
|
+
specified constraints and minimizes weighted error, where "error" is defined as the
|
|
38
|
+
sum of each class's absolute divergence from the target val fraction.
|
|
39
|
+
|
|
40
|
+
location_to_category_counts should be a dict mapping location IDs to dicts,
|
|
41
|
+
with each dict mapping a category name to a count. Any categories not present in a
|
|
42
|
+
particular dict are assumed to have a count of zero for that location.
|
|
43
|
+
|
|
44
|
+
If not None, category_to_max_allowable_error should be a dict mapping category names
|
|
45
|
+
to maximum allowable errors. These are hard constraints, but you can specify a subset
|
|
46
|
+
of categories. Categories not included here have a maximum error of Inf.
|
|
47
|
+
|
|
48
|
+
If not None, category_to_error_weight should be a dict mapping category names to
|
|
49
|
+
error weights. You can specify a subset of categories. Categories not included here
|
|
50
|
+
have a weight of 1.0.
|
|
51
|
+
|
|
52
|
+
default_max_allowable_error is the maximum allowable error for categories not present in
|
|
53
|
+
category_to_max_allowable_error. Set to None (or >= 1.0) to disable hard constraints for
|
|
54
|
+
categories not present in category_to_max_allowable_error
|
|
55
|
+
|
|
56
|
+
returns val_locations,category_to_val_fraction
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
location_ids = list(location_to_category_counts.keys())
|
|
61
|
+
|
|
62
|
+
n_val_locations = int(target_val_fraction*len(location_ids))
|
|
63
|
+
|
|
64
|
+
if category_to_max_allowable_error is None:
|
|
65
|
+
category_to_max_allowable_error = {}
|
|
66
|
+
|
|
67
|
+
if category_to_error_weight is None:
|
|
68
|
+
category_to_error_weight = {}
|
|
69
|
+
|
|
70
|
+
# category ID to total count; the total count is used only for printouts
|
|
71
|
+
category_id_to_count = {}
|
|
72
|
+
for location_id in location_to_category_counts:
|
|
73
|
+
for category_id in location_to_category_counts[location_id].keys():
|
|
74
|
+
if category_id not in category_id_to_count:
|
|
75
|
+
category_id_to_count[category_id] = 0
|
|
76
|
+
category_id_to_count[category_id] += \
|
|
77
|
+
location_to_category_counts[location_id][category_id]
|
|
78
|
+
|
|
79
|
+
category_ids = set(category_id_to_count.keys())
|
|
80
|
+
|
|
81
|
+
print('Splitting {} categories over {} locations'.format(
|
|
82
|
+
len(category_ids),len(location_ids)))
|
|
83
|
+
|
|
84
|
+
# random_seed = 0
|
|
85
|
+
def compute_seed_errors(random_seed):
|
|
86
|
+
"""
|
|
87
|
+
Compute the per-category error for a specific random seed.
|
|
88
|
+
|
|
89
|
+
returns weighted_average_error,category_to_val_fraction
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
# Randomly split into train/val
|
|
93
|
+
random.seed(random_seed)
|
|
94
|
+
val_locations = random.sample(location_ids,k=n_val_locations)
|
|
95
|
+
val_locations_set = set(val_locations)
|
|
96
|
+
|
|
97
|
+
# For each category, measure the % of images that went into the val set
|
|
98
|
+
category_to_val_fraction = defaultdict(float)
|
|
99
|
+
|
|
100
|
+
for category_id in category_ids:
|
|
101
|
+
category_val_count = 0
|
|
102
|
+
category_train_count = 0
|
|
103
|
+
for location_id in location_to_category_counts:
|
|
104
|
+
if category_id not in location_to_category_counts[location_id]:
|
|
105
|
+
location_category_count = 0
|
|
106
|
+
else:
|
|
107
|
+
location_category_count = location_to_category_counts[location_id][category_id]
|
|
108
|
+
if location_id in val_locations_set:
|
|
109
|
+
category_val_count += location_category_count
|
|
110
|
+
else:
|
|
111
|
+
category_train_count += location_category_count
|
|
112
|
+
category_val_fraction = category_val_count / (category_val_count + category_train_count)
|
|
113
|
+
category_to_val_fraction[category_id] = category_val_fraction
|
|
114
|
+
|
|
115
|
+
# Absolute deviation from the target val fraction for each categorys
|
|
116
|
+
category_errors = {}
|
|
117
|
+
weighted_category_errors = {}
|
|
118
|
+
|
|
119
|
+
# category = next(iter(category_to_val_fraction))
|
|
120
|
+
for category in category_to_val_fraction:
|
|
121
|
+
|
|
122
|
+
category_val_fraction = category_to_val_fraction[category]
|
|
123
|
+
|
|
124
|
+
category_error = abs(category_val_fraction-target_val_fraction)
|
|
125
|
+
category_errors[category] = category_error
|
|
126
|
+
|
|
127
|
+
category_weight = 1.0
|
|
128
|
+
if category in category_to_error_weight:
|
|
129
|
+
category_weight = category_to_error_weight[category]
|
|
130
|
+
weighted_category_error = category_error * category_weight
|
|
131
|
+
weighted_category_errors[category] = weighted_category_error
|
|
132
|
+
|
|
133
|
+
weighted_average_error = np.mean(list(weighted_category_errors.values()))
|
|
134
|
+
|
|
135
|
+
return weighted_average_error,weighted_category_errors,category_to_val_fraction
|
|
136
|
+
|
|
137
|
+
# ... def compute_seed_errors(...)
|
|
138
|
+
|
|
139
|
+
# This will only include random seeds that satisfy the hard constraints
|
|
140
|
+
random_seed_to_weighted_average_error = {}
|
|
141
|
+
|
|
142
|
+
# random_seed = 0
|
|
143
|
+
for random_seed in tqdm(range(0,n_random_seeds)):
|
|
144
|
+
|
|
145
|
+
weighted_average_error,weighted_category_errors,category_to_val_fraction = \
|
|
146
|
+
compute_seed_errors(random_seed)
|
|
147
|
+
|
|
148
|
+
seed_satisfies_hard_constraints = True
|
|
149
|
+
|
|
150
|
+
for category in category_to_val_fraction:
|
|
151
|
+
if category in category_to_max_allowable_error:
|
|
152
|
+
max_allowable_error = category_to_max_allowable_error[category]
|
|
153
|
+
else:
|
|
154
|
+
if default_max_allowable_error is None:
|
|
155
|
+
continue
|
|
156
|
+
max_allowable_error = default_max_allowable_error
|
|
157
|
+
val_fraction = category_to_val_fraction[category]
|
|
158
|
+
category_error = abs(val_fraction - target_val_fraction)
|
|
159
|
+
if category_error > max_allowable_error:
|
|
160
|
+
seed_satisfies_hard_constraints = False
|
|
161
|
+
break
|
|
162
|
+
|
|
163
|
+
if seed_satisfies_hard_constraints:
|
|
164
|
+
random_seed_to_weighted_average_error[random_seed] = weighted_average_error
|
|
165
|
+
|
|
166
|
+
# ...for each random seed
|
|
167
|
+
|
|
168
|
+
assert len(random_seed_to_weighted_average_error) > 0, \
|
|
169
|
+
'No random seed met all the hard constraints'
|
|
170
|
+
|
|
171
|
+
print('\n{} of {} random seeds satisfied hard constraints'.format(
|
|
172
|
+
len(random_seed_to_weighted_average_error),n_random_seeds))
|
|
173
|
+
|
|
174
|
+
min_error = None
|
|
175
|
+
min_error_seed = None
|
|
176
|
+
|
|
177
|
+
for random_seed in random_seed_to_weighted_average_error.keys():
|
|
178
|
+
error_metric = random_seed_to_weighted_average_error[random_seed]
|
|
179
|
+
if min_error is None or error_metric < min_error:
|
|
180
|
+
min_error = error_metric
|
|
181
|
+
min_error_seed = random_seed
|
|
182
|
+
|
|
183
|
+
random.seed(min_error_seed)
|
|
184
|
+
val_locations = random.sample(location_ids,k=n_val_locations)
|
|
185
|
+
train_locations = []
|
|
186
|
+
for location_id in location_ids:
|
|
187
|
+
if location_id not in val_locations:
|
|
188
|
+
train_locations.append(location_id)
|
|
189
|
+
|
|
190
|
+
print('\nVal locations:\n')
|
|
191
|
+
for loc in val_locations:
|
|
192
|
+
print('{}'.format(loc))
|
|
193
|
+
print('')
|
|
194
|
+
|
|
195
|
+
weighted_average_error,weighted_category_errors,category_to_val_fraction = \
|
|
196
|
+
compute_seed_errors(min_error_seed)
|
|
197
|
+
|
|
198
|
+
random_seed = min_error_seed
|
|
199
|
+
|
|
200
|
+
category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,reverse=True)
|
|
201
|
+
category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,
|
|
202
|
+
sort_values=category_id_to_count,
|
|
203
|
+
reverse=True)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
print('Val fractions by category:\n')
|
|
207
|
+
|
|
208
|
+
for category in category_to_val_fraction:
|
|
209
|
+
print('{} ({}) {:.2f}'.format(
|
|
210
|
+
category,category_id_to_count[category],
|
|
211
|
+
category_to_val_fraction[category]))
|
|
212
|
+
|
|
213
|
+
return val_locations,category_to_val_fraction
|
|
214
|
+
|
|
215
|
+
# ...def split_locations_into_train_val(...)
|
md_utils/string_utils.py
CHANGED
|
@@ -57,3 +57,13 @@ def human_readable_to_bytes(size):
|
|
|
57
57
|
bytes = 0
|
|
58
58
|
|
|
59
59
|
return bytes
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def remove_ansi_codes(s):
|
|
63
|
+
"""
|
|
64
|
+
Remove ANSI escape codes from a string.
|
|
65
|
+
|
|
66
|
+
https://stackoverflow.com/questions/14693701/how-can-i-remove-the-ansi-escape-sequences-from-a-string-in-python#14693789
|
|
67
|
+
"""
|
|
68
|
+
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
|
69
|
+
return ansi_escape.sub('', s)
|