megadetector 5.0.11__py3-none-any.whl → 5.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (201) hide show
  1. megadetector/api/__init__.py +0 -0
  2. megadetector/api/batch_processing/__init__.py +0 -0
  3. megadetector/api/batch_processing/api_core/__init__.py +0 -0
  4. megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
  5. megadetector/api/batch_processing/api_core/batch_service/score.py +439 -0
  6. megadetector/api/batch_processing/api_core/server.py +294 -0
  7. megadetector/api/batch_processing/api_core/server_api_config.py +98 -0
  8. megadetector/api/batch_processing/api_core/server_app_config.py +55 -0
  9. megadetector/api/batch_processing/api_core/server_batch_job_manager.py +220 -0
  10. megadetector/api/batch_processing/api_core/server_job_status_table.py +152 -0
  11. megadetector/api/batch_processing/api_core/server_orchestration.py +360 -0
  12. megadetector/api/batch_processing/api_core/server_utils.py +92 -0
  13. megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
  14. megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +46 -0
  15. megadetector/api/batch_processing/api_support/__init__.py +0 -0
  16. megadetector/api/batch_processing/api_support/summarize_daily_activity.py +152 -0
  17. megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
  18. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  19. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  20. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  21. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +126 -0
  22. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  23. megadetector/api/synchronous/__init__.py +0 -0
  24. megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  25. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +152 -0
  26. megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +266 -0
  27. megadetector/api/synchronous/api_core/animal_detection_api/config.py +35 -0
  28. megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
  29. megadetector/api/synchronous/api_core/tests/load_test.py +110 -0
  30. megadetector/classification/__init__.py +0 -0
  31. megadetector/classification/aggregate_classifier_probs.py +108 -0
  32. megadetector/classification/analyze_failed_images.py +227 -0
  33. megadetector/classification/cache_batchapi_outputs.py +198 -0
  34. megadetector/classification/create_classification_dataset.py +627 -0
  35. megadetector/classification/crop_detections.py +516 -0
  36. megadetector/classification/csv_to_json.py +226 -0
  37. megadetector/classification/detect_and_crop.py +855 -0
  38. megadetector/classification/efficientnet/__init__.py +9 -0
  39. megadetector/classification/efficientnet/model.py +415 -0
  40. megadetector/classification/efficientnet/utils.py +610 -0
  41. megadetector/classification/evaluate_model.py +520 -0
  42. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  43. megadetector/classification/json_to_azcopy_list.py +63 -0
  44. megadetector/classification/json_validator.py +699 -0
  45. megadetector/classification/map_classification_categories.py +276 -0
  46. megadetector/classification/merge_classification_detection_output.py +506 -0
  47. megadetector/classification/prepare_classification_script.py +194 -0
  48. megadetector/classification/prepare_classification_script_mc.py +228 -0
  49. megadetector/classification/run_classifier.py +287 -0
  50. megadetector/classification/save_mislabeled.py +110 -0
  51. megadetector/classification/train_classifier.py +827 -0
  52. megadetector/classification/train_classifier_tf.py +725 -0
  53. megadetector/classification/train_utils.py +323 -0
  54. megadetector/data_management/__init__.py +0 -0
  55. megadetector/data_management/annotations/__init__.py +0 -0
  56. megadetector/data_management/annotations/annotation_constants.py +34 -0
  57. megadetector/data_management/camtrap_dp_to_coco.py +239 -0
  58. megadetector/data_management/cct_json_utils.py +395 -0
  59. megadetector/data_management/cct_to_md.py +176 -0
  60. megadetector/data_management/cct_to_wi.py +289 -0
  61. megadetector/data_management/coco_to_labelme.py +272 -0
  62. megadetector/data_management/coco_to_yolo.py +662 -0
  63. megadetector/data_management/databases/__init__.py +0 -0
  64. megadetector/data_management/databases/add_width_and_height_to_db.py +33 -0
  65. megadetector/data_management/databases/combine_coco_camera_traps_files.py +206 -0
  66. megadetector/data_management/databases/integrity_check_json_db.py +477 -0
  67. megadetector/data_management/databases/subset_json_db.py +115 -0
  68. megadetector/data_management/generate_crops_from_cct.py +149 -0
  69. megadetector/data_management/get_image_sizes.py +189 -0
  70. megadetector/data_management/importers/add_nacti_sizes.py +52 -0
  71. megadetector/data_management/importers/add_timestamps_to_icct.py +79 -0
  72. megadetector/data_management/importers/animl_results_to_md_results.py +158 -0
  73. megadetector/data_management/importers/auckland_doc_test_to_json.py +373 -0
  74. megadetector/data_management/importers/auckland_doc_to_json.py +201 -0
  75. megadetector/data_management/importers/awc_to_json.py +191 -0
  76. megadetector/data_management/importers/bellevue_to_json.py +273 -0
  77. megadetector/data_management/importers/cacophony-thermal-importer.py +796 -0
  78. megadetector/data_management/importers/carrizo_shrubfree_2018.py +269 -0
  79. megadetector/data_management/importers/carrizo_trail_cam_2017.py +289 -0
  80. megadetector/data_management/importers/cct_field_adjustments.py +58 -0
  81. megadetector/data_management/importers/channel_islands_to_cct.py +913 -0
  82. megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +180 -0
  83. megadetector/data_management/importers/eMammal/eMammal_helpers.py +249 -0
  84. megadetector/data_management/importers/eMammal/make_eMammal_json.py +223 -0
  85. megadetector/data_management/importers/ena24_to_json.py +276 -0
  86. megadetector/data_management/importers/filenames_to_json.py +386 -0
  87. megadetector/data_management/importers/helena_to_cct.py +283 -0
  88. megadetector/data_management/importers/idaho-camera-traps.py +1407 -0
  89. megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +294 -0
  90. megadetector/data_management/importers/jb_csv_to_json.py +150 -0
  91. megadetector/data_management/importers/mcgill_to_json.py +250 -0
  92. megadetector/data_management/importers/missouri_to_json.py +490 -0
  93. megadetector/data_management/importers/nacti_fieldname_adjustments.py +79 -0
  94. megadetector/data_management/importers/noaa_seals_2019.py +181 -0
  95. megadetector/data_management/importers/pc_to_json.py +365 -0
  96. megadetector/data_management/importers/plot_wni_giraffes.py +123 -0
  97. megadetector/data_management/importers/prepare-noaa-fish-data-for-lila.py +359 -0
  98. megadetector/data_management/importers/prepare_zsl_imerit.py +131 -0
  99. megadetector/data_management/importers/rspb_to_json.py +356 -0
  100. megadetector/data_management/importers/save_the_elephants_survey_A.py +320 -0
  101. megadetector/data_management/importers/save_the_elephants_survey_B.py +329 -0
  102. megadetector/data_management/importers/snapshot_safari_importer.py +758 -0
  103. megadetector/data_management/importers/snapshot_safari_importer_reprise.py +665 -0
  104. megadetector/data_management/importers/snapshot_serengeti_lila.py +1067 -0
  105. megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +150 -0
  106. megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +153 -0
  107. megadetector/data_management/importers/sulross_get_exif.py +65 -0
  108. megadetector/data_management/importers/timelapse_csv_set_to_json.py +490 -0
  109. megadetector/data_management/importers/ubc_to_json.py +399 -0
  110. megadetector/data_management/importers/umn_to_json.py +507 -0
  111. megadetector/data_management/importers/wellington_to_json.py +263 -0
  112. megadetector/data_management/importers/wi_to_json.py +442 -0
  113. megadetector/data_management/importers/zamba_results_to_md_results.py +181 -0
  114. megadetector/data_management/labelme_to_coco.py +547 -0
  115. megadetector/data_management/labelme_to_yolo.py +272 -0
  116. megadetector/data_management/lila/__init__.py +0 -0
  117. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +97 -0
  118. megadetector/data_management/lila/add_locations_to_nacti.py +147 -0
  119. megadetector/data_management/lila/create_lila_blank_set.py +558 -0
  120. megadetector/data_management/lila/create_lila_test_set.py +152 -0
  121. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  122. megadetector/data_management/lila/download_lila_subset.py +178 -0
  123. megadetector/data_management/lila/generate_lila_per_image_labels.py +516 -0
  124. megadetector/data_management/lila/get_lila_annotation_counts.py +170 -0
  125. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  126. megadetector/data_management/lila/lila_common.py +300 -0
  127. megadetector/data_management/lila/test_lila_metadata_urls.py +132 -0
  128. megadetector/data_management/ocr_tools.py +874 -0
  129. megadetector/data_management/read_exif.py +681 -0
  130. megadetector/data_management/remap_coco_categories.py +84 -0
  131. megadetector/data_management/remove_exif.py +66 -0
  132. megadetector/data_management/resize_coco_dataset.py +189 -0
  133. megadetector/data_management/wi_download_csv_to_coco.py +246 -0
  134. megadetector/data_management/yolo_output_to_md_output.py +441 -0
  135. megadetector/data_management/yolo_to_coco.py +676 -0
  136. megadetector/detection/__init__.py +0 -0
  137. megadetector/detection/detector_training/__init__.py +0 -0
  138. megadetector/detection/detector_training/model_main_tf2.py +114 -0
  139. megadetector/detection/process_video.py +702 -0
  140. megadetector/detection/pytorch_detector.py +341 -0
  141. megadetector/detection/run_detector.py +779 -0
  142. megadetector/detection/run_detector_batch.py +1219 -0
  143. megadetector/detection/run_inference_with_yolov5_val.py +917 -0
  144. megadetector/detection/run_tiled_inference.py +934 -0
  145. megadetector/detection/tf_detector.py +189 -0
  146. megadetector/detection/video_utils.py +606 -0
  147. megadetector/postprocessing/__init__.py +0 -0
  148. megadetector/postprocessing/add_max_conf.py +64 -0
  149. megadetector/postprocessing/categorize_detections_by_size.py +163 -0
  150. megadetector/postprocessing/combine_api_outputs.py +249 -0
  151. megadetector/postprocessing/compare_batch_results.py +958 -0
  152. megadetector/postprocessing/convert_output_format.py +396 -0
  153. megadetector/postprocessing/load_api_results.py +195 -0
  154. megadetector/postprocessing/md_to_coco.py +310 -0
  155. megadetector/postprocessing/md_to_labelme.py +330 -0
  156. megadetector/postprocessing/merge_detections.py +401 -0
  157. megadetector/postprocessing/postprocess_batch_results.py +1902 -0
  158. megadetector/postprocessing/remap_detection_categories.py +170 -0
  159. megadetector/postprocessing/render_detection_confusion_matrix.py +660 -0
  160. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +211 -0
  161. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +83 -0
  162. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1631 -0
  163. megadetector/postprocessing/separate_detections_into_folders.py +730 -0
  164. megadetector/postprocessing/subset_json_detector_output.py +696 -0
  165. megadetector/postprocessing/top_folders_to_bottom.py +223 -0
  166. megadetector/taxonomy_mapping/__init__.py +0 -0
  167. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  168. megadetector/taxonomy_mapping/map_new_lila_datasets.py +150 -0
  169. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +142 -0
  170. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +590 -0
  171. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  172. megadetector/taxonomy_mapping/simple_image_download.py +219 -0
  173. megadetector/taxonomy_mapping/species_lookup.py +834 -0
  174. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  175. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  176. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  177. megadetector/utils/__init__.py +0 -0
  178. megadetector/utils/azure_utils.py +178 -0
  179. megadetector/utils/ct_utils.py +612 -0
  180. megadetector/utils/directory_listing.py +246 -0
  181. megadetector/utils/md_tests.py +968 -0
  182. megadetector/utils/path_utils.py +1044 -0
  183. megadetector/utils/process_utils.py +157 -0
  184. megadetector/utils/sas_blob_utils.py +509 -0
  185. megadetector/utils/split_locations_into_train_val.py +228 -0
  186. megadetector/utils/string_utils.py +92 -0
  187. megadetector/utils/url_utils.py +323 -0
  188. megadetector/utils/write_html_image_list.py +225 -0
  189. megadetector/visualization/__init__.py +0 -0
  190. megadetector/visualization/plot_utils.py +293 -0
  191. megadetector/visualization/render_images_with_thumbnails.py +275 -0
  192. megadetector/visualization/visualization_utils.py +1536 -0
  193. megadetector/visualization/visualize_db.py +550 -0
  194. megadetector/visualization/visualize_detector_output.py +405 -0
  195. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/METADATA +1 -1
  196. megadetector-5.0.12.dist-info/RECORD +199 -0
  197. megadetector-5.0.12.dist-info/top_level.txt +1 -0
  198. megadetector-5.0.11.dist-info/RECORD +0 -5
  199. megadetector-5.0.11.dist-info/top_level.txt +0 -1
  200. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/LICENSE +0 -0
  201. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/WHEEL +0 -0
@@ -0,0 +1,228 @@
1
+ """
2
+
3
+ split_locations_into_train_val.py
4
+
5
+ Splits a list of location IDs into training and validation, targeting a specific
6
+ train/val split for each category, but allowing some categories to be tighter or looser
7
+ than others. Does nothing particularly clever, just randomly splits locations into
8
+ train/val lots of times using the target val fraction, and picks the one that meets the
9
+ specified constraints and minimizes weighted error, where "error" is defined as the
10
+ sum of each class's absolute divergence from the target val fraction.
11
+
12
+ """
13
+
14
+ #%% Imports/constants
15
+
16
+ import random
17
+ import numpy as np
18
+
19
+ from collections import defaultdict
20
+ from megadetector.utils.ct_utils import sort_dictionary_by_value
21
+ from tqdm import tqdm
22
+
23
+
24
+ #%% Main function
25
+
26
+ def split_locations_into_train_val(location_to_category_counts,
27
+ n_random_seeds=10000,
28
+ target_val_fraction=0.15,
29
+ category_to_max_allowable_error=None,
30
+ category_to_error_weight=None,
31
+ default_max_allowable_error=0.1):
32
+ """
33
+ Splits a list of location IDs into training and validation, targeting a specific
34
+ train/val split for each category, but allowing some categories to be tighter or looser
35
+ than others. Does nothing particularly clever, just randomly splits locations into
36
+ train/val lots of times using the target val fraction, and picks the one that meets the
37
+ specified constraints and minimizes weighted error, where "error" is defined as the
38
+ sum of each class's absolute divergence from the target val fraction.
39
+
40
+ Args:
41
+ location_to_category_counts (dict): a dict mapping location IDs to dicts,
42
+ with each dict mapping a category name to a count. Any categories not present
43
+ in a particular dict are assumed to have a count of zero for that location.
44
+
45
+ For example:
46
+
47
+ .. code-block:: none
48
+
49
+ {'location-000': {'bear':4,'wolf':10},
50
+ 'location-001': {'bear':12,'elk':20}}
51
+
52
+ n_random_seeds (int, optional): number of random seeds to try, always starting from zero
53
+ target_val_fraction (float, optional): fraction of images containing each species we'd
54
+ like to put in the val split
55
+ category_to_max_allowable_error (dict, optional): a dict mapping category names
56
+ to maximum allowable errors. These are hard constraints (i.e., we will error
57
+ if we can't meet them). Does not need to include all categories; categories not
58
+ included will be assigned a maximum error according to [default_max_allowable_error].
59
+ If this is None, no hard constraints are applied.
60
+ category_to_error_weight (dict, optional): a dict mapping category names to
61
+ error weights. You can specify a subset of categories; categories not included here
62
+ have a weight of 1.0. If None, all categories have the same weight.
63
+ default_max_allowable_error (float, optional): the maximum allowable error for categories not
64
+ present in [category_to_max_allowable_error]. Set to None (or >= 1.0) to disable hard
65
+ constraints for categories not present in [category_to_max_allowable_error]
66
+
67
+ Returns:
68
+ tuple: A two-element tuple:
69
+ - list of location IDs in the val split
70
+ - a dict mapping category names to the fraction of images in the val split
71
+ """
72
+
73
+ location_ids = list(location_to_category_counts.keys())
74
+
75
+ n_val_locations = int(target_val_fraction*len(location_ids))
76
+
77
+ if category_to_max_allowable_error is None:
78
+ category_to_max_allowable_error = {}
79
+
80
+ if category_to_error_weight is None:
81
+ category_to_error_weight = {}
82
+
83
+ # category ID to total count; the total count is used only for printouts
84
+ category_id_to_count = {}
85
+ for location_id in location_to_category_counts:
86
+ for category_id in location_to_category_counts[location_id].keys():
87
+ if category_id not in category_id_to_count:
88
+ category_id_to_count[category_id] = 0
89
+ category_id_to_count[category_id] += \
90
+ location_to_category_counts[location_id][category_id]
91
+
92
+ category_ids = set(category_id_to_count.keys())
93
+
94
+ print('Splitting {} categories over {} locations'.format(
95
+ len(category_ids),len(location_ids)))
96
+
97
+ # random_seed = 0
98
+ def compute_seed_errors(random_seed):
99
+ """
100
+ Computes the per-category error for a specific random seed.
101
+
102
+ returns weighted_average_error,category_to_val_fraction
103
+ """
104
+
105
+ # Randomly split into train/val
106
+ random.seed(random_seed)
107
+ val_locations = random.sample(location_ids,k=n_val_locations)
108
+ val_locations_set = set(val_locations)
109
+
110
+ # For each category, measure the % of images that went into the val set
111
+ category_to_val_fraction = defaultdict(float)
112
+
113
+ for category_id in category_ids:
114
+ category_val_count = 0
115
+ category_train_count = 0
116
+ for location_id in location_to_category_counts:
117
+ if category_id not in location_to_category_counts[location_id]:
118
+ location_category_count = 0
119
+ else:
120
+ location_category_count = location_to_category_counts[location_id][category_id]
121
+ if location_id in val_locations_set:
122
+ category_val_count += location_category_count
123
+ else:
124
+ category_train_count += location_category_count
125
+ category_val_fraction = category_val_count / (category_val_count + category_train_count)
126
+ category_to_val_fraction[category_id] = category_val_fraction
127
+
128
+ # Absolute deviation from the target val fraction for each categorys
129
+ category_errors = {}
130
+ weighted_category_errors = {}
131
+
132
+ # category = next(iter(category_to_val_fraction))
133
+ for category in category_to_val_fraction:
134
+
135
+ category_val_fraction = category_to_val_fraction[category]
136
+
137
+ category_error = abs(category_val_fraction-target_val_fraction)
138
+ category_errors[category] = category_error
139
+
140
+ category_weight = 1.0
141
+ if category in category_to_error_weight:
142
+ category_weight = category_to_error_weight[category]
143
+ weighted_category_error = category_error * category_weight
144
+ weighted_category_errors[category] = weighted_category_error
145
+
146
+ weighted_average_error = np.mean(list(weighted_category_errors.values()))
147
+
148
+ return weighted_average_error,weighted_category_errors,category_to_val_fraction
149
+
150
+ # ... def compute_seed_errors(...)
151
+
152
+ # This will only include random seeds that satisfy the hard constraints
153
+ random_seed_to_weighted_average_error = {}
154
+
155
+ # random_seed = 0
156
+ for random_seed in tqdm(range(0,n_random_seeds)):
157
+
158
+ weighted_average_error,weighted_category_errors,category_to_val_fraction = \
159
+ compute_seed_errors(random_seed)
160
+
161
+ seed_satisfies_hard_constraints = True
162
+
163
+ for category in category_to_val_fraction:
164
+ if category in category_to_max_allowable_error:
165
+ max_allowable_error = category_to_max_allowable_error[category]
166
+ else:
167
+ if default_max_allowable_error is None:
168
+ continue
169
+ max_allowable_error = default_max_allowable_error
170
+ val_fraction = category_to_val_fraction[category]
171
+ category_error = abs(val_fraction - target_val_fraction)
172
+ if category_error > max_allowable_error:
173
+ seed_satisfies_hard_constraints = False
174
+ break
175
+
176
+ if seed_satisfies_hard_constraints:
177
+ random_seed_to_weighted_average_error[random_seed] = weighted_average_error
178
+
179
+ # ...for each random seed
180
+
181
+ assert len(random_seed_to_weighted_average_error) > 0, \
182
+ 'No random seed met all the hard constraints'
183
+
184
+ print('\n{} of {} random seeds satisfied hard constraints'.format(
185
+ len(random_seed_to_weighted_average_error),n_random_seeds))
186
+
187
+ min_error = None
188
+ min_error_seed = None
189
+
190
+ for random_seed in random_seed_to_weighted_average_error.keys():
191
+ error_metric = random_seed_to_weighted_average_error[random_seed]
192
+ if min_error is None or error_metric < min_error:
193
+ min_error = error_metric
194
+ min_error_seed = random_seed
195
+
196
+ random.seed(min_error_seed)
197
+ val_locations = random.sample(location_ids,k=n_val_locations)
198
+ train_locations = []
199
+ for location_id in location_ids:
200
+ if location_id not in val_locations:
201
+ train_locations.append(location_id)
202
+
203
+ print('\nVal locations:\n')
204
+ for loc in val_locations:
205
+ print('{}'.format(loc))
206
+ print('')
207
+
208
+ weighted_average_error,weighted_category_errors,category_to_val_fraction = \
209
+ compute_seed_errors(min_error_seed)
210
+
211
+ random_seed = min_error_seed
212
+
213
+ category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,reverse=True)
214
+ category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,
215
+ sort_values=category_id_to_count,
216
+ reverse=True)
217
+
218
+
219
+ print('Val fractions by category:\n')
220
+
221
+ for category in category_to_val_fraction:
222
+ print('{} ({}) {:.2f}'.format(
223
+ category,category_id_to_count[category],
224
+ category_to_val_fraction[category]))
225
+
226
+ return val_locations,category_to_val_fraction
227
+
228
+ # ...def split_locations_into_train_val(...)
@@ -0,0 +1,92 @@
1
+ """
2
+
3
+ string_utils.py
4
+
5
+ Miscellaneous string utilities.
6
+
7
+ """
8
+
9
+ #%% Imports
10
+
11
+ import re
12
+
13
+
14
+ #%% Functions
15
+
16
+ def is_float(s):
17
+ """
18
+ Checks whether [s] is an object (typically a string) that can be cast to a float
19
+
20
+ Args:
21
+ s (object): object to evaluate
22
+
23
+ Returns:
24
+ bool: True if s successfully casts to a float, otherwise False
25
+ """
26
+
27
+ try:
28
+ _ = float(s)
29
+ except ValueError:
30
+ return False
31
+ return True
32
+
33
+
34
+ def human_readable_to_bytes(size):
35
+ """
36
+ Given a human-readable byte string (e.g. 2G, 10GB, 30MB, 20KB),
37
+ returns the number of bytes. Will return 0 if the argument has
38
+ unexpected form.
39
+
40
+ https://gist.github.com/beugley/ccd69945346759eb6142272a6d69b4e0
41
+
42
+ Args:
43
+ size (str): string representing a size
44
+
45
+ Returns:
46
+ int: the corresponding size in bytes
47
+ """
48
+
49
+ size = re.sub(r'\s+', '', size)
50
+
51
+ if (size[-1] == 'B'):
52
+ size = size[:-1]
53
+
54
+ if (size.isdigit()):
55
+ bytes = int(size)
56
+ elif (is_float(size)):
57
+ bytes = float(size)
58
+ else:
59
+ bytes = size[:-1]
60
+ unit = size[-1]
61
+ try:
62
+ bytes = float(bytes)
63
+ if (unit == 'T'):
64
+ bytes *= 1024*1024*1024*1024
65
+ elif (unit == 'G'):
66
+ bytes *= 1024*1024*1024
67
+ elif (unit == 'M'):
68
+ bytes *= 1024*1024
69
+ elif (unit == 'K'):
70
+ bytes *= 1024
71
+ else:
72
+ bytes = 0
73
+ except ValueError:
74
+ bytes = 0
75
+
76
+ return bytes
77
+
78
+
79
+ def remove_ansi_codes(s):
80
+ """
81
+ Removes ANSI escape codes from a string.
82
+
83
+ https://stackoverflow.com/questions/14693701/how-can-i-remove-the-ansi-escape-sequences-from-a-string-in-python#14693789
84
+
85
+ Args:
86
+ s (str): the string to de-ANSI-i-fy
87
+
88
+ Returns:
89
+ str: A copy of [s] without ANSI codes
90
+ """
91
+ ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
92
+ return ansi_escape.sub('', s)
@@ -0,0 +1,323 @@
1
+ """
2
+
3
+ url_utils.py
4
+
5
+ Frequently-used functions for downloading or manipulating URLs
6
+
7
+ """
8
+
9
+ #%% Imports and constants
10
+
11
+ import os
12
+ import re
13
+ import urllib
14
+ import tempfile
15
+ import requests
16
+
17
+ from functools import partial
18
+ from tqdm import tqdm
19
+ from urllib.parse import urlparse
20
+ from multiprocessing.pool import ThreadPool
21
+ from multiprocessing.pool import Pool
22
+
23
+ url_utils_temp_dir = None
24
+ max_path_len = 255
25
+
26
+
27
+ #%% Download functions
28
+
29
+ class DownloadProgressBar():
30
+ """
31
+ Progress updater based on the progressbar2 package.
32
+
33
+ https://stackoverflow.com/questions/37748105/how-to-use-progressbar-module-with-urlretrieve
34
+ """
35
+
36
+ def __init__(self):
37
+ self.pbar = None
38
+
39
+ def __call__(self, block_num, block_size, total_size):
40
+ if not self.pbar:
41
+ # This is a pretty random import I'd rather not depend on outside of the
42
+ # rare case where it's used, so importing locally
43
+ # pip install progressbar2
44
+ import progressbar
45
+ self.pbar = progressbar.ProgressBar(max_value=total_size)
46
+ self.pbar.start()
47
+
48
+ downloaded = block_num * block_size
49
+ if downloaded < total_size:
50
+ self.pbar.update(downloaded)
51
+ else:
52
+ self.pbar.finish()
53
+
54
+
55
+ def get_temp_folder(preferred_name='url_utils'):
56
+ """
57
+ Gets a temporary folder for use within this module.
58
+
59
+ Args:
60
+ preferred_name (str, optional): subfolder to use within the system temp folder
61
+
62
+ Returns:
63
+ str: the full path to the temporary subfolder
64
+ """
65
+ global url_utils_temp_dir
66
+
67
+ if url_utils_temp_dir is None:
68
+ url_utils_temp_dir = os.path.join(tempfile.gettempdir(),preferred_name)
69
+ os.makedirs(url_utils_temp_dir,exist_ok=True)
70
+
71
+ return url_utils_temp_dir
72
+
73
+
74
+ def download_url(url,
75
+ destination_filename=None,
76
+ progress_updater=None,
77
+ force_download=False,
78
+ verbose=True):
79
+ """
80
+ Downloads a URL to a file. If no file is specified, creates a temporary file,
81
+ making a best effort to avoid filename collisions.
82
+
83
+ Prints some diagnostic information and makes sure to omit SAS tokens from printouts.
84
+
85
+ Args:
86
+ url (str): the URL to download
87
+ destination_filename (str, optional): the target filename; if None, will create
88
+ a file in system temp space
89
+ progress_updater (object or bool, optional): can be "None", "False", "True", or a
90
+ specific callable object. If None or False, no progress updated will be
91
+ displayed. If True, a default progress bar will be created.
92
+ force_download (bool, optional): download this file even if [destination_filename]
93
+ exists.
94
+ verbose (bool, optional): enable additional debug console output
95
+
96
+ Returns:
97
+ str: the filename to which [url] was downloaded, the same as [destination_filename]
98
+ if [destination_filename] was not None
99
+ """
100
+
101
+ if progress_updater is not None and isinstance(progress_updater,bool):
102
+ if not progress_updater:
103
+ progress_updater = None
104
+ else:
105
+ progress_updater = DownloadProgressBar()
106
+
107
+ url_no_sas = url.split('?')[0]
108
+
109
+ if destination_filename is None:
110
+ target_folder = get_temp_folder()
111
+ url_without_sas = url.split('?', 1)[0]
112
+
113
+ # This does not guarantee uniqueness, hence "semi-best-effort"
114
+ url_as_filename = re.sub(r'\W+', '', url_without_sas)
115
+ n_folder_chars = len(url_utils_temp_dir)
116
+ if len(url_as_filename) + n_folder_chars > max_path_len:
117
+ print('Warning: truncating filename target to {} characters'.format(max_path_len))
118
+ url_as_filename = url_as_filename[-1*(max_path_len-n_folder_chars):]
119
+ destination_filename = \
120
+ os.path.join(target_folder,url_as_filename)
121
+
122
+ if (not force_download) and (os.path.isfile(destination_filename)):
123
+ if verbose:
124
+ print('Bypassing download of already-downloaded file {}'.format(os.path.basename(url_no_sas)))
125
+ else:
126
+ if verbose:
127
+ print('Downloading file {} to {}'.format(os.path.basename(url_no_sas),destination_filename),end='')
128
+ target_dir = os.path.dirname(destination_filename)
129
+ os.makedirs(target_dir,exist_ok=True)
130
+ urllib.request.urlretrieve(url, destination_filename, progress_updater)
131
+ assert(os.path.isfile(destination_filename))
132
+ nBytes = os.path.getsize(destination_filename)
133
+ if verbose:
134
+ print('...done, {} bytes.'.format(nBytes))
135
+
136
+ return destination_filename
137
+
138
+
139
+ def download_relative_filename(url, output_base, verbose=False):
140
+ """
141
+ Download a URL to output_base, preserving relative path. Path is relative to
142
+ the site, so:
143
+
144
+ https://abc.com/xyz/123.txt
145
+
146
+ ...will get downloaded to:
147
+
148
+ output_base/xyz/123.txt
149
+
150
+ Args:
151
+ url (str): the URL to download
152
+ output_base (str): the base folder to which we should download this file
153
+ verbose (bool, optional): enable additional debug console output
154
+
155
+ Returns:
156
+ str: the local destination filename
157
+ """
158
+
159
+ p = urlparse(url)
160
+ # remove the leading '/'
161
+ assert p.path.startswith('/'); relative_filename = p.path[1:]
162
+ destination_filename = os.path.join(output_base,relative_filename)
163
+ return download_url(url, destination_filename, verbose=verbose)
164
+
165
+
166
+ def _do_parallelized_download(download_info,overwrite=False,verbose=False):
167
+ """
168
+ Internal function for download parallelization.
169
+ """
170
+
171
+ url = download_info['url']
172
+ target_file = download_info['target_file']
173
+ result = {'status':'unknown','url':url,'target_file':target_file}
174
+
175
+ if ((os.path.isfile(target_file)) and (not overwrite)):
176
+ if verbose:
177
+ print('Skipping existing file {}'.format(target_file))
178
+ result['status'] = 'skipped'
179
+ return result
180
+ try:
181
+ download_url(url=url,
182
+ destination_filename=target_file,
183
+ verbose=verbose,
184
+ force_download=overwrite)
185
+ except Exception as e:
186
+ print('Warning: error downloading URL {}: {}'.format(
187
+ url,str(e)))
188
+ result['status'] = 'error: {}'.format(str(e))
189
+ return result
190
+
191
+ result['status'] = 'success'
192
+ return result
193
+
194
+
195
+ def parallel_download_urls(url_to_target_file,verbose=False,overwrite=False,
196
+ n_workers=20,pool_type='thread'):
197
+ """
198
+ Downloads a list of URLs to local files.
199
+
200
+ Catches exceptions and reports them in the returned "results" array.
201
+
202
+ Args:
203
+ url_to_target_file: a dict mapping URLs to local filenames.
204
+ verbose (bool, optional): enable additional debug console output
205
+ overwrite (bool, optional): whether to overwrite existing local files
206
+ n_workers (int, optional): number of concurrent workers, set to <=1 to disable
207
+ parallelization
208
+ pool_type (str, optional): worker type to use; should be 'thread' or 'process'
209
+
210
+ Returns:
211
+ list: list of dicts with keys:
212
+ - 'url': the url this item refers to
213
+ - 'status': 'skipped', 'success', or a string starting with 'error'
214
+ - 'target_file': the local filename to which we downloaded (or tried to
215
+ download) this URL
216
+ """
217
+
218
+ all_download_info = []
219
+
220
+ print('Preparing download list')
221
+ for url in tqdm(url_to_target_file):
222
+ download_info = {}
223
+ download_info['url'] = url
224
+ download_info['target_file'] = url_to_target_file[url]
225
+ all_download_info.append(download_info)
226
+
227
+ print('Downloading {} images on {} workers'.format(
228
+ len(all_download_info),n_workers))
229
+
230
+ if n_workers <= 1:
231
+
232
+ results = []
233
+
234
+ for download_info in tqdm(all_download_info):
235
+ result = _do_parallelized_download(download_info,overwrite=overwrite,verbose=verbose)
236
+ results.append(result)
237
+
238
+ else:
239
+
240
+ if pool_type == 'thread':
241
+ pool = ThreadPool(n_workers)
242
+ else:
243
+ assert pool_type == 'process', 'Unsupported pool type {}'.format(pool_type)
244
+ pool = Pool(n_workers)
245
+
246
+ print('Starting a {} pool with {} workers'.format(pool_type,n_workers))
247
+
248
+ results = list(tqdm(pool.imap(
249
+ partial(_do_parallelized_download,overwrite=overwrite,verbose=verbose),
250
+ all_download_info), total=len(all_download_info)))
251
+
252
+ return results
253
+
254
+
255
+ def test_url(url, error_on_failure=True, timeout=None):
256
+ """
257
+ Tests the availability of [url], returning an http status code.
258
+
259
+ Args:
260
+ url (str): URL to test
261
+ error_on_failure (bool, optional): whether to error (vs. just returning an
262
+ error code) if accessing this URL fails
263
+ timeout (int, optional): timeout in seconds to wait before considering this
264
+ access attempt to be a failure; see requests.head() for precise documentation
265
+
266
+ Returns:
267
+ int: http status code (200 for success)
268
+ """
269
+
270
+ # r = requests.get(url, stream=True, verify=True, timeout=timeout)
271
+ r = requests.head(url, stream=True, verify=True, timeout=timeout)
272
+
273
+ if error_on_failure and r.status_code != 200:
274
+ raise ValueError('Could not access {}: error {}'.format(url,r.status_code))
275
+ return r.status_code
276
+
277
+
278
+ def test_urls(urls, error_on_failure=True, n_workers=1, pool_type='thread', timeout=None):
279
+ """
280
+ Verify that URLs are available (i.e., returns status 200). By default,
281
+ errors if any URL is unavailable.
282
+
283
+ Args:
284
+ urls (list): list of URLs to test
285
+ error_on_failure (bool, optional): whether to error (vs. just returning an
286
+ error code) if accessing this URL fails
287
+ n_workers (int, optional): number of concurrent workers, set to <=1 to disable
288
+ parallelization
289
+ pool_type (str, optional): worker type to use; should be 'thread' or 'process'
290
+ timeout (int, optional): timeout in seconds to wait before considering this
291
+ access attempt to be a failure; see requests.head() for precise documentation
292
+
293
+ Returns:
294
+ list: a list of http status codes, the same length and order as [urls]
295
+ """
296
+
297
+ if n_workers <= 1:
298
+
299
+ status_codes = []
300
+
301
+ for url in tqdm(urls):
302
+
303
+ r = requests.get(url, timeout=timeout)
304
+
305
+ if error_on_failure and r.status_code != 200:
306
+ raise ValueError('Could not access {}: error {}'.format(url,r.status_code))
307
+ status_codes.append(r.status_code)
308
+
309
+ else:
310
+
311
+ if pool_type == 'thread':
312
+ pool = ThreadPool(n_workers)
313
+ else:
314
+ assert pool_type == 'process', 'Unsupported pool type {}'.format(pool_type)
315
+ pool = Pool(n_workers)
316
+
317
+ print('Starting a {} pool with {} workers'.format(pool_type,n_workers))
318
+
319
+ status_codes = list(tqdm(pool.imap(
320
+ partial(test_url,error_on_failure=error_on_failure,timeout=timeout),
321
+ urls), total=len(urls)))
322
+
323
+ return status_codes