megadetector 5.0.7__py3-none-any.whl → 5.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (48) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +28 -14
  2. api/batch_processing/postprocessing/combine_api_outputs.py +2 -2
  3. api/batch_processing/postprocessing/compare_batch_results.py +1 -1
  4. api/batch_processing/postprocessing/convert_output_format.py +24 -6
  5. api/batch_processing/postprocessing/load_api_results.py +1 -3
  6. api/batch_processing/postprocessing/md_to_labelme.py +118 -51
  7. api/batch_processing/postprocessing/merge_detections.py +30 -5
  8. api/batch_processing/postprocessing/postprocess_batch_results.py +24 -12
  9. api/batch_processing/postprocessing/remap_detection_categories.py +163 -0
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +15 -12
  11. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +2 -2
  12. data_management/cct_json_utils.py +7 -2
  13. data_management/coco_to_labelme.py +263 -0
  14. data_management/coco_to_yolo.py +7 -4
  15. data_management/databases/integrity_check_json_db.py +68 -59
  16. data_management/databases/subset_json_db.py +1 -1
  17. data_management/get_image_sizes.py +44 -26
  18. data_management/importers/animl_results_to_md_results.py +1 -3
  19. data_management/importers/noaa_seals_2019.py +1 -1
  20. data_management/labelme_to_coco.py +252 -143
  21. data_management/labelme_to_yolo.py +95 -52
  22. data_management/lila/create_lila_blank_set.py +106 -23
  23. data_management/lila/download_lila_subset.py +133 -65
  24. data_management/lila/generate_lila_per_image_labels.py +1 -1
  25. data_management/lila/lila_common.py +8 -38
  26. data_management/read_exif.py +65 -16
  27. data_management/remap_coco_categories.py +84 -0
  28. data_management/resize_coco_dataset.py +3 -22
  29. data_management/wi_download_csv_to_coco.py +239 -0
  30. data_management/yolo_to_coco.py +283 -83
  31. detection/run_detector_batch.py +12 -3
  32. detection/run_inference_with_yolov5_val.py +10 -3
  33. detection/run_tiled_inference.py +2 -2
  34. detection/tf_detector.py +2 -1
  35. detection/video_utils.py +1 -1
  36. md_utils/ct_utils.py +22 -3
  37. md_utils/md_tests.py +11 -2
  38. md_utils/path_utils.py +206 -32
  39. md_utils/url_utils.py +66 -1
  40. md_utils/write_html_image_list.py +12 -3
  41. md_visualization/visualization_utils.py +363 -72
  42. md_visualization/visualize_db.py +33 -10
  43. {megadetector-5.0.7.dist-info → megadetector-5.0.8.dist-info}/METADATA +10 -12
  44. {megadetector-5.0.7.dist-info → megadetector-5.0.8.dist-info}/RECORD +47 -44
  45. {megadetector-5.0.7.dist-info → megadetector-5.0.8.dist-info}/WHEEL +1 -1
  46. md_visualization/visualize_megadb.py +0 -183
  47. {megadetector-5.0.7.dist-info → megadetector-5.0.8.dist-info}/LICENSE +0 -0
  48. {megadetector-5.0.7.dist-info → megadetector-5.0.8.dist-info}/top_level.txt +0 -0
@@ -245,7 +245,8 @@ def process_images(im_files, detector, confidence_threshold, use_image_queue=Fal
245
245
  quiet=False, image_size=None, checkpoint_queue=None, include_image_size=False,
246
246
  include_image_timestamp=False, include_exif_data=False):
247
247
  """
248
- Runs MegaDetector over a list of image files.
248
+ Runs MegaDetector over a list of image files. As of 3/2024, this entry point is used when the
249
+ image queue is enabled, but not in the standard inference path (which loops over process_image()).
249
250
 
250
251
  Args
251
252
  - im_files: list of str, paths to image files
@@ -269,7 +270,7 @@ def process_images(im_files, detector, confidence_threshold, use_image_queue=Fal
269
270
  include_image_size=include_image_size,
270
271
  include_image_timestamp=include_image_timestamp,
271
272
  include_exif_data=include_exif_data)
272
- else:
273
+ else:
273
274
  results = []
274
275
  for im_file in im_files:
275
276
  result = process_image(im_file, detector, confidence_threshold,
@@ -662,7 +663,7 @@ def get_image_datetime(image):
662
663
 
663
664
  def write_results_to_file(results, output_file, relative_path_base=None,
664
665
  detector_file=None, info=None, include_max_conf=False,
665
- custom_metadata=None):
666
+ custom_metadata=None, force_forward_slashes=True):
666
667
  """
667
668
  Writes list of detection results to JSON output file. Format matches:
668
669
 
@@ -692,6 +693,14 @@ def write_results_to_file(results, output_file, relative_path_base=None,
692
693
  results_relative.append(r_relative)
693
694
  results = results_relative
694
695
 
696
+ if force_forward_slashes:
697
+ results_converted = []
698
+ for r in results:
699
+ r_converted = copy.copy(r)
700
+ r_converted['file'] = r_converted['file'].replace('\\','/')
701
+ results_converted.append(r_converted)
702
+ results = results_converted
703
+
695
704
  # The typical case: we need to build the 'info' struct
696
705
  if info is None:
697
706
 
@@ -105,6 +105,8 @@ class YoloInferenceOptions:
105
105
  treat_copy_failures_as_warnings = False
106
106
 
107
107
  save_yolo_debug_output = False
108
+
109
+ recursive = True
108
110
 
109
111
 
110
112
  #%% Main function
@@ -203,7 +205,7 @@ def run_inference_with_yolo_val(options):
203
205
  ##%% Enumerate images
204
206
 
205
207
  if os.path.isdir(options.input_folder):
206
- image_files_absolute = path_utils.find_images(options.input_folder,recursive=True)
208
+ image_files_absolute = path_utils.find_images(options.input_folder,recursive=options.recursive)
207
209
  else:
208
210
  assert os.path.isfile(options.input_folder)
209
211
  with open(options.input_folder,'r') as f:
@@ -381,7 +383,7 @@ def run_inference_with_yolo_val(options):
381
383
  # YOLO console output contains lots of ANSI escape codes, remove them for easier parsing
382
384
  yolo_console_output = [string_utils.remove_ansi_codes(s) for s in yolo_console_output]
383
385
 
384
- # Find errors that occrred during the initial corruption check; these will not be included in the
386
+ # Find errors that occurred during the initial corruption check; these will not be included in the
385
387
  # output. Errors that occur during inference will be handled separately.
386
388
  yolo_read_failures = []
387
389
 
@@ -518,7 +520,7 @@ def main():
518
520
  help='inference batch size (default {})'.format(options.batch_size))
519
521
  parser.add_argument(
520
522
  '--half_precision_enabled', default=None, type=int,
521
- help='use half-precision-inference (1 or 0) (default is the underlying model\'s default, probably half for YOLOv8 and full for YOLOv8')
523
+ help='use half-precision-inference (1 or 0) (default is the underlying model\'s default, probably full for YOLOv8 and half for YOLOv5')
522
524
  parser.add_argument(
523
525
  '--device_string', default=options.device_string, type=str,
524
526
  help='CUDA device specifier, typically "0" or "1" for CUDA devices, "mps" for M1/M2 devices, or "cpu" (default {})'.format(options.device_string))
@@ -553,6 +555,10 @@ def main():
553
555
  '--save_yolo_debug_output', action='store_true',
554
556
  help='write yolo console output to a text file in the results folder, along with additional debug files')
555
557
 
558
+ parser.add_argument(
559
+ '--nonrecursive', action='store_true',
560
+ help='Disable recursive folder processing')
561
+
556
562
  parser.add_argument(
557
563
  '--preview_yolo_command_only', action='store_true',
558
564
  help='don\'t run inference, just preview the YOLO inference command (still creates symlinks)')
@@ -592,6 +598,7 @@ def main():
592
598
  if args.yolo_dataset_file is not None:
593
599
  options.yolo_category_id_to_name = args.yolo_dataset_file
594
600
 
601
+ options.recursive = (not options.nonrecursive)
595
602
  options.remove_symlink_folder = (not options.no_remove_symlink_folder)
596
603
  options.remove_yolo_results_folder = (not options.no_remove_yolo_results_folder)
597
604
  options.use_symlinks = (not options.no_use_symlinks)
@@ -823,12 +823,12 @@ def main():
823
823
  '--overwrite_handling',
824
824
  type=str,
825
825
  default='skip',
826
- help=('behavior when the targt file exists (skip/overwrite/error) (default skip)'))
826
+ help=('Behavior when the target file exists (skip/overwrite/error) (default skip)'))
827
827
  parser.add_argument(
828
828
  '--image_list',
829
829
  type=str,
830
830
  default=None,
831
- help=('a .json list of relative filenames (or absolute paths contained within image_folder) to include'))
831
+ help=('A .json list of relative filenames (or absolute paths contained within image_folder) to include'))
832
832
 
833
833
  if len(sys.argv[1:]) == 0:
834
834
  parser.print_help()
detection/tf_detector.py CHANGED
@@ -122,7 +122,8 @@ class TFDetector:
122
122
  detection_threshold: confidence above which to include the detection proposal
123
123
 
124
124
  Returns:
125
- A dict with the following fields, see the 'images' key in https://github.com/agentmorris/MegaDetector/tree/master/api/batch_processing#batch-processing-api-output-format
125
+ A dict with the following fields, see the 'images' key in:
126
+ https://github.com/agentmorris/MegaDetector/tree/master/api/batch_processing#batch-processing-api-output-format
126
127
  - 'file' (always present)
127
128
  - 'max_detection_conf'
128
129
  - 'detections', which is a list of detection objects containing keys 'category', 'conf' and 'bbox'
detection/video_utils.py CHANGED
@@ -310,7 +310,7 @@ def video_folder_to_frames(input_folder:str, output_folder_base:str,
310
310
 
311
311
  class FrameToVideoOptions:
312
312
 
313
- # zero-indexed
313
+ # One-indexed, i.e. "1" means "use the confidence value from the highest-confidence frame"
314
314
  nth_highest_confidence = 1
315
315
 
316
316
  # 'error' or 'skip_with_warning'
md_utils/ct_utils.py CHANGED
@@ -39,9 +39,13 @@ def truncate_float_array(xs, precision=3):
39
39
 
40
40
  def truncate_float(x, precision=3):
41
41
  """
42
- Truncates a floating-point value to a specific number of significant digits.
42
+ Truncates the fractional portion of a floating-point value to a specific number of
43
+ floating-point digits.
43
44
 
44
- For example: truncate_float(0.0003214884) --> 0.000321
45
+ For example:
46
+
47
+ truncate_float(0.0003214884) --> 0.000321
48
+ truncate_float(1.0003214884) --> 1.000321
45
49
 
46
50
  This function is primarily used to achieve a certain float representation
47
51
  before exporting to JSON.
@@ -58,13 +62,18 @@ def truncate_float(x, precision=3):
58
62
 
59
63
  return 0
60
64
 
65
+ elif (x > 1):
66
+
67
+ fractional_component = x - 1.0
68
+ return 1 + truncate_float(fractional_component)
69
+
61
70
  else:
62
71
 
63
72
  # Determine the factor, which shifts the decimal point of x
64
73
  # just behind the last significant digit.
65
74
  factor = math.pow(10, precision - 1 - math.floor(math.log10(abs(x))))
66
75
 
67
- # Shift decimal point by multiplicatipon with factor, flooring, and
76
+ # Shift decimal point by multiplication with factor, flooring, and
68
77
  # division by factor.
69
78
  return math.floor(x * factor)/factor
70
79
 
@@ -174,6 +183,7 @@ def convert_xywh_to_xyxy(api_bbox):
174
183
  Converts an xywh bounding box to an xyxy bounding box.
175
184
 
176
185
  Note that this is also different from the TensorFlow Object Detection API coords format.
186
+
177
187
  Args:
178
188
  api_bbox: bbox output by the batch processing API [x_min, y_min, width_of_box, height_of_box]
179
189
 
@@ -352,6 +362,15 @@ def split_list_into_n_chunks(L, n, chunk_strategy='greedy'):
352
362
  raise ValueError('Invalid chunk strategy: {}'.format(chunk_strategy))
353
363
 
354
364
 
365
+ def sort_dictionary_by_key(d,reverse=False):
366
+ """
367
+ Sorts the dictionary [d] by key.
368
+ """
369
+
370
+ d = dict(sorted(d.items(),reverse=reverse))
371
+ return d
372
+
373
+
355
374
  def sort_dictionary_by_value(d,sort_values=None,reverse=False):
356
375
  """
357
376
  Sorts the dictionary [d] by value. If sort_values is None, uses d.values(),
md_utils/md_tests.py CHANGED
@@ -86,11 +86,14 @@ def get_expected_results_filename(gpu_is_available):
86
86
  return 'md-test-results-{}-{}.json'.format(hw_string,pt_string)
87
87
 
88
88
 
89
- def download_test_data(options):
89
+ def download_test_data(options=None):
90
90
  """
91
91
  Download the test zipfile if necessary, unzip if necessary.
92
92
  """
93
-
93
+
94
+ if options is None:
95
+ options = MDTestOptions()
96
+
94
97
  if options.scratch_dir is None:
95
98
  tempdir_base = tempfile.gettempdir()
96
99
  scratch_dir = os.path.join(tempdir_base,'md-tests')
@@ -160,6 +163,8 @@ def download_test_data(options):
160
163
  options.test_videos = [fn for fn in test_files if os.path.splitext(fn.lower())[1] in ('.mp4','.avi')]
161
164
  options.test_videos = [fn for fn in options.test_videos if 'rendered' not in fn]
162
165
 
166
+ print('Finished unzipping and enumerating test data')
167
+
163
168
  # ...def download_test_data(...)
164
169
 
165
170
 
@@ -840,6 +845,10 @@ def main():
840
845
  type=str,
841
846
  default=None,
842
847
  help='Working directory for CLI tests')
848
+
849
+ # token used for linting
850
+ #
851
+ # no_arguments_required
843
852
 
844
853
  args = parser.parse_args()
845
854
 
md_utils/path_utils.py CHANGED
@@ -12,11 +12,17 @@
12
12
  import glob
13
13
  import ntpath
14
14
  import os
15
+ import sys
16
+ import platform
15
17
  import posixpath
16
18
  import string
17
19
  import json
20
+ import shutil
18
21
  import unicodedata
19
22
  import zipfile
23
+ import webbrowser
24
+ import subprocess
25
+ import re
20
26
 
21
27
  from zipfile import ZipFile
22
28
  from datetime import datetime
@@ -43,6 +49,8 @@ def recursive_file_list(base_dir, convert_slashes=True,
43
49
  \ to /
44
50
  """
45
51
 
52
+ assert os.path.isdir(base_dir), '{} is not a folder'.format(base_dir)
53
+
46
54
  all_files = []
47
55
 
48
56
  if recursive:
@@ -219,23 +227,6 @@ def safe_create_link(link_exists,link_new):
219
227
  os.symlink(link_exists,link_new)
220
228
 
221
229
 
222
- def get_file_sizes(base_dir, convert_slashes=True):
223
- """
224
- Get sizes recursively for all files in base_dir, returning a dict mapping
225
- relative filenames to size.
226
- """
227
-
228
- relative_filenames = recursive_file_list(base_dir, convert_slashes=convert_slashes,
229
- return_relative_paths=True)
230
-
231
- fn_to_size = {}
232
- for fn_relative in tqdm(relative_filenames):
233
- fn_abs = os.path.join(base_dir,fn_relative)
234
- fn_to_size[fn_relative] = os.path.getsize(fn_abs)
235
-
236
- return fn_to_size
237
-
238
-
239
230
  #%% Image-related path functions
240
231
 
241
232
  def is_image_file(s: str, img_extensions: Container[str] = IMG_EXTENSIONS
@@ -267,10 +258,12 @@ def find_images(dirname: str, recursive: bool = False,
267
258
  """
268
259
  Finds all files in a directory that look like image file names. Returns
269
260
  absolute paths unless return_relative_paths is set. Uses the OS-native
270
- path separator unless convert_slahes is set, in which case will always
261
+ path separator unless convert_slashes is set, in which case will always
271
262
  use '/'.
272
263
  """
273
264
 
265
+ assert os.path.isdir(dirname), '{} is not a folder'.format(dirname)
266
+
274
267
  if recursive:
275
268
  strings = glob.glob(os.path.join(dirname, '**', '*.*'), recursive=True)
276
269
  else:
@@ -342,8 +335,6 @@ def flatten_path(pathname: str, separator_chars: str = SEPARATOR_CHARS) -> str:
342
335
 
343
336
  #%% Platform-independent way to open files in their associated application
344
337
 
345
- import sys,subprocess,platform,re
346
-
347
338
  def environment_is_wsl():
348
339
  """
349
340
  Returns True if we're running in WSL
@@ -373,13 +364,35 @@ def wsl_path_to_windows_path(filename):
373
364
  return None
374
365
  return result.stdout.strip()
375
366
 
376
-
377
- def open_file(filename,attempt_to_open_in_wsl_host=False):
367
+
368
+ def open_file(filename, attempt_to_open_in_wsl_host=False, browser_name=None):
378
369
  """
379
- Opens [filename] in the native OS file handler. If attempt_to_open_in_wsl_host
380
- is True, and we're in WSL, attempts to open [filename] in Windows.
370
+ Opens [filename] in the default OS file handler for this file type.
371
+
372
+ If attempt_to_open_in_wsl_host is True, and we're in WSL, attempts to open
373
+ [filename] in the Windows host environment.
374
+
375
+ If browser_name is not None, uses the webbrowser module to open the filename
376
+ in the specified browser; see https://docs.python.org/3/library/webbrowser.html
377
+ for supported browsers. Falls back to the default file handler if webbrowser.open()
378
+ fails. In this case, attempt_to_open_in_wsl_host is ignored unless webbrowser.open() fails.
379
+
380
+ If browser_name is 'default', use the system default. This is different from the
381
+ parameter to webbrowser.get(), where None implies the system default.
381
382
  """
382
383
 
384
+ if browser_name is not None:
385
+ if browser_name == 'chrome':
386
+ browser_name = 'google-chrome'
387
+ elif browser_name == 'default':
388
+ browser_name = None
389
+ try:
390
+ result = webbrowser.get(using=browser_name).open(filename)
391
+ except Exception:
392
+ result = False
393
+ if result:
394
+ return
395
+
383
396
  if sys.platform == 'win32':
384
397
 
385
398
  os.startfile(filename)
@@ -437,6 +450,107 @@ def read_list_from_file(filename: str) -> List[str]:
437
450
  return file_list
438
451
 
439
452
 
453
+ def _copy_file(input_output_tuple,overwrite=True,verbose=False):
454
+ assert len(input_output_tuple) == 2
455
+ source_fn = input_output_tuple[0]
456
+ target_fn = input_output_tuple[1]
457
+ if (not overwrite) and (os.path.isfile(target_fn)):
458
+ if verbose:
459
+ print('Skipping existing file {}'.format(target_fn))
460
+ return
461
+ os.makedirs(os.path.dirname(target_fn),exist_ok=True)
462
+ shutil.copyfile(source_fn,target_fn)
463
+
464
+
465
+ def parallel_copy_files(input_file_to_output_file, max_workers=16,
466
+ use_threads=True, overwrite=False, verbose=False):
467
+ """
468
+ Copy files from source to target according to the dict input_file_to_output_file.
469
+ """
470
+
471
+ n_workers = min(max_workers,len(input_file_to_output_file))
472
+
473
+ # Package the dictionary as a set of 2-tuples
474
+ input_output_tuples = []
475
+ for input_fn in input_file_to_output_file:
476
+ input_output_tuples.append((input_fn,input_file_to_output_file[input_fn]))
477
+
478
+ if use_threads:
479
+ pool = ThreadPool(n_workers)
480
+ else:
481
+ pool = Pool(n_workers)
482
+
483
+ with tqdm(total=len(input_output_tuples)) as pbar:
484
+ for i,_ in enumerate(pool.imap_unordered(partial(_copy_file,overwrite=overwrite,verbose=verbose),
485
+ input_output_tuples)):
486
+ pbar.update()
487
+
488
+ # ...def parallel_copy_files(...)
489
+
490
+
491
+ def get_file_sizes(base_dir, convert_slashes=True):
492
+ """
493
+ Get sizes recursively for all files in base_dir, returning a dict mapping
494
+ relative filenames to size.
495
+
496
+ TODO: merge the functionality here with parallel_get_file_sizes, which uses slightly
497
+ different semantics.
498
+ """
499
+
500
+ relative_filenames = recursive_file_list(base_dir, convert_slashes=convert_slashes,
501
+ return_relative_paths=True)
502
+
503
+ fn_to_size = {}
504
+ for fn_relative in tqdm(relative_filenames):
505
+ fn_abs = os.path.join(base_dir,fn_relative)
506
+ fn_to_size[fn_relative] = os.path.getsize(fn_abs)
507
+
508
+ return fn_to_size
509
+
510
+
511
+ def _get_file_size(filename,verbose=False):
512
+ """
513
+ Internal function for safely getting the size of a file. Returns a (filename,size)
514
+ tuple, where size is None if there is an error.
515
+ """
516
+
517
+ try:
518
+ size = os.path.getsize(filename)
519
+ except Exception as e:
520
+ if verbose:
521
+ print('Error reading file size for {}: {}'.format(filename,str(e)))
522
+ size = None
523
+ return (filename,size)
524
+
525
+
526
+ def parallel_get_file_sizes(filenames, max_workers=16,
527
+ use_threads=True, verbose=False,
528
+ recursive=True):
529
+ """
530
+ Return a dictionary mapping every file in [filenames] to the corresponding file size,
531
+ or None for errors. If [filenames] is a folder, will enumerate the folder (optionally recursively).
532
+ """
533
+
534
+ n_workers = min(max_workers,len(filenames))
535
+
536
+ if isinstance(filenames,str) and os.path.isdir(filenames):
537
+ filenames = recursive_file_list(filenames,recursive=recursive,return_relative_paths=False)
538
+
539
+ if use_threads:
540
+ pool = ThreadPool(n_workers)
541
+ else:
542
+ pool = Pool(n_workers)
543
+
544
+ resize_results = list(tqdm(pool.imap(
545
+ partial(_get_file_size,verbose=verbose),filenames), total=len(filenames)))
546
+
547
+ to_return = {}
548
+ for r in resize_results:
549
+ to_return[r[0]] = r[1]
550
+
551
+ return to_return
552
+
553
+
440
554
  #%% Zip functions
441
555
 
442
556
  def zip_file(input_fn, output_fn=None, overwrite=False, verbose=False, compresslevel=9):
@@ -454,7 +568,7 @@ def zip_file(input_fn, output_fn=None, overwrite=False, verbose=False, compressl
454
568
  return
455
569
 
456
570
  if verbose:
457
- print('Zipping {} to {}'.format(input_fn,output_fn))
571
+ print('Zipping {} to {} with level {}'.format(input_fn,output_fn,compresslevel))
458
572
 
459
573
  with ZipFile(output_fn,'w',zipfile.ZIP_DEFLATED) as zipf:
460
574
  zipf.write(input_fn,arcname=basename,compresslevel=compresslevel,
@@ -463,9 +577,37 @@ def zip_file(input_fn, output_fn=None, overwrite=False, verbose=False, compressl
463
577
  return output_fn
464
578
 
465
579
 
580
+ def zip_files_into_single_zipfile(input_files, output_fn, arc_name_base,
581
+ overwrite=False, verbose=False, compresslevel=9):
582
+ """
583
+ Zip all the files in [input_files] into [output_fn]. Archive names are relative to
584
+ arc_name_base.
585
+ """
586
+
587
+ if not overwrite:
588
+ if os.path.isfile(output_fn):
589
+ print('Zip file {} exists, skipping'.format(output_fn))
590
+ return
591
+
592
+ if verbose:
593
+ print('Zipping {} files to {} (compression level {})'.format(
594
+ len(input_files),output_fn,compresslevel))
595
+
596
+ with ZipFile(output_fn,'w',zipfile.ZIP_DEFLATED) as zipf:
597
+ for input_fn_abs in tqdm(input_files,disable=(not verbose)):
598
+ input_fn_relative = os.path.relpath(input_fn_abs,arc_name_base)
599
+ zipf.write(input_fn_abs,
600
+ arcname=input_fn_relative,
601
+ compresslevel=compresslevel,
602
+ compress_type=zipfile.ZIP_DEFLATED)
603
+
604
+ return output_fn
605
+
606
+
466
607
  def zip_folder(input_folder, output_fn=None, overwrite=False, verbose=False, compresslevel=9):
467
608
  """
468
- Recursively zip everything in [input_folder], storing outputs as relative paths.
609
+ Recursively zip everything in [input_folder] into a single zipfile, storing outputs as relative
610
+ paths.
469
611
 
470
612
  Defaults to writing to [input_folder].zip
471
613
  """
@@ -474,10 +616,13 @@ def zip_folder(input_folder, output_fn=None, overwrite=False, verbose=False, com
474
616
  output_fn = input_folder + '.zip'
475
617
 
476
618
  if not overwrite:
477
- assert not os.path.isfile(output_fn), 'Zip file {} exists'.format(output_fn)
619
+ if os.path.isfile(output_fn):
620
+ print('Zip file {} exists, skipping'.format(output_fn))
621
+ return
478
622
 
479
623
  if verbose:
480
- print('Zipping {} to {}'.format(input_folder,output_fn))
624
+ print('Zipping {} to {} (compression level {})'.format(
625
+ input_folder,output_fn,compresslevel))
481
626
 
482
627
  relative_filenames = recursive_file_list(input_folder,return_relative_paths=True)
483
628
 
@@ -492,7 +637,8 @@ def zip_folder(input_folder, output_fn=None, overwrite=False, verbose=False, com
492
637
  return output_fn
493
638
 
494
639
 
495
- def parallel_zip_files(input_files, max_workers=16, use_threads=True):
640
+ def parallel_zip_files(input_files, max_workers=16, use_threads=True, compresslevel=9,
641
+ overwrite=False, verbose=False):
496
642
  """
497
643
  Zip one or more files to separate output files in parallel, leaving the
498
644
  original files in place. Each file is zipped to [filename].zip.
@@ -506,12 +652,14 @@ def parallel_zip_files(input_files, max_workers=16, use_threads=True):
506
652
  pool = Pool(n_workers)
507
653
 
508
654
  with tqdm(total=len(input_files)) as pbar:
509
- for i,_ in enumerate(pool.imap_unordered(zip_file,input_files)):
655
+ for i,_ in enumerate(pool.imap_unordered(partial(zip_file,
656
+ output_fn=None,overwrite=overwrite,verbose=verbose,compresslevel=compresslevel),
657
+ input_files)):
510
658
  pbar.update()
511
659
 
512
660
 
513
661
  def parallel_zip_folders(input_folders, max_workers=16, use_threads=True,
514
- compresslevel=9, overwrite=False):
662
+ compresslevel=9, overwrite=False, verbose=False):
515
663
  """
516
664
  Zip one or more folders to separate output files in parallel, leaving the
517
665
  original folders in place. Each folder is zipped to [folder_name].zip.
@@ -526,11 +674,37 @@ def parallel_zip_folders(input_folders, max_workers=16, use_threads=True,
526
674
 
527
675
  with tqdm(total=len(input_folders)) as pbar:
528
676
  for i,_ in enumerate(pool.imap_unordered(
529
- partial(zip_folder,overwrite=overwrite,compresslevel=compresslevel),
677
+ partial(zip_folder,overwrite=overwrite,
678
+ compresslevel=compresslevel,verbose=verbose),
530
679
  input_folders)):
531
680
  pbar.update()
532
681
 
533
682
 
683
+ def zip_each_file_in_folder(folder_name,recursive=False,max_workers=16,use_threads=True,
684
+ compresslevel=9,overwrite=False,required_token=None,verbose=False,
685
+ exclude_zip=True):
686
+ """
687
+ Zip each file in [folder_name] to its own zipfile (filename.zip), optionally recursing. To zip a whole
688
+ folder into a single zipfile, use zip_folder().
689
+
690
+ If required_token is not None, include only files that contain that token.
691
+ """
692
+
693
+ assert os.path.isdir(folder_name), '{} is not a folder'.format(folder_name)
694
+
695
+ input_files = recursive_file_list(folder_name,recursive=recursive,return_relative_paths=False)
696
+
697
+ if required_token is not None:
698
+ input_files = [fn for fn in input_files if required_token in fn]
699
+
700
+ if exclude_zip:
701
+ input_files = [fn for fn in input_files if (not fn.endswith('.zip'))]
702
+
703
+ parallel_zip_files(input_files=input_files,max_workers=max_workers,
704
+ use_threads=use_threads,compresslevel=compresslevel,
705
+ overwrite=overwrite,verbose=verbose)
706
+
707
+
534
708
  def unzip_file(input_file, output_folder=None):
535
709
  """
536
710
  Unzip a zipfile to the specified output folder, defaulting to the same location as
md_utils/url_utils.py CHANGED
@@ -16,6 +16,7 @@ import requests
16
16
 
17
17
  from tqdm import tqdm
18
18
  from urllib.parse import urlparse
19
+ from multiprocessing.pool import ThreadPool
19
20
 
20
21
  url_utils_temp_dir = None
21
22
  max_path_len = 255
@@ -109,7 +110,14 @@ def download_url(url, destination_filename=None, progress_updater=None,
109
110
 
110
111
  def download_relative_filename(url, output_base, verbose=False):
111
112
  """
112
- Download a URL to output_base, preserving relative path
113
+ Download a URL to output_base, preserving relative path. Path is relative to
114
+ the site, so:
115
+
116
+ https://abc.com/xyz/123.txt
117
+
118
+ ...will get downloaded to:
119
+
120
+ output_base/xyz/123.txt
113
121
  """
114
122
 
115
123
  p = urlparse(url)
@@ -119,6 +127,63 @@ def download_relative_filename(url, output_base, verbose=False):
119
127
  download_url(url, destination_filename, verbose=verbose)
120
128
 
121
129
 
130
+ def parallel_download_urls(url_to_target_file,verbose=False,overwrite=False,
131
+ n_workers=20):
132
+ """
133
+ Download a list of URLs to local files. url_to_target_file should
134
+ be a dict mapping URLs to output files. Catches exceptions and reports
135
+ them in the returned "results" array.
136
+ """
137
+
138
+ def _do_parallelized_download(download_info,overwrite=False):
139
+ url = download_info['url']
140
+ target_file = download_info['target_file']
141
+ result = {'status':'unknown','url':url,'target_file':target_file}
142
+
143
+ if ((os.path.isfile(target_file)) and (not overwrite)):
144
+ result['status'] = 'skipped'
145
+ return result
146
+ try:
147
+ download_url(url=url,
148
+ destination_filename=target_file,
149
+ verbose=verbose, force_download=overwrite)
150
+ except Exception as e:
151
+ print('Warning: error downloading URL {}: {}'.format(
152
+ url,str(e)))
153
+ result['status'] = 'error: {}'.format(str(e))
154
+ return result
155
+
156
+ result['status'] = 'success'
157
+ return result
158
+
159
+ all_download_info = []
160
+ for url in url_to_target_file:
161
+ download_info = {}
162
+ download_info['url'] = url
163
+ download_info['target_file'] = url_to_target_file[url]
164
+ all_download_info.append(download_info)
165
+
166
+ print('Downloading {} images on {} workers'.format(
167
+ len(all_download_info),n_workers))
168
+
169
+ if n_workers <= 1:
170
+
171
+ results = []
172
+
173
+ for download_info in tqdm(all_download_info):
174
+ result = _do_parallelized_download(download_info,overwrite=overwrite)
175
+ results.append(result)
176
+
177
+ else:
178
+
179
+ pool = ThreadPool(n_workers)
180
+ results = list(tqdm(pool.imap(lambda download_info: _do_parallelized_download(
181
+ download_info,overwrite=overwrite),all_download_info),
182
+ total=len(all_download_info)))
183
+
184
+ return results
185
+
186
+
122
187
  def test_urls(urls, error_on_failure=True):
123
188
  """
124
189
  Verify that a list of URLs is available (returns status 200). By default,