megadetector 5.0.11__py3-none-any.whl → 5.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/score.py +439 -0
- megadetector/api/batch_processing/api_core/server.py +294 -0
- megadetector/api/batch_processing/api_core/server_api_config.py +97 -0
- megadetector/api/batch_processing/api_core/server_app_config.py +55 -0
- megadetector/api/batch_processing/api_core/server_batch_job_manager.py +220 -0
- megadetector/api/batch_processing/api_core/server_job_status_table.py +149 -0
- megadetector/api/batch_processing/api_core/server_orchestration.py +360 -0
- megadetector/api/batch_processing/api_core/server_utils.py +88 -0
- megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +46 -0
- megadetector/api/batch_processing/api_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_support/summarize_daily_activity.py +152 -0
- megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/api/synchronous/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +152 -0
- megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +263 -0
- megadetector/api/synchronous/api_core/animal_detection_api/config.py +35 -0
- megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
- megadetector/api/synchronous/api_core/tests/load_test.py +110 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +627 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +855 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +607 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +699 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +506 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +34 -0
- megadetector/data_management/camtrap_dp_to_coco.py +237 -0
- megadetector/data_management/cct_json_utils.py +404 -0
- megadetector/data_management/cct_to_md.py +176 -0
- megadetector/data_management/cct_to_wi.py +289 -0
- megadetector/data_management/coco_to_labelme.py +283 -0
- megadetector/data_management/coco_to_yolo.py +662 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +33 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +206 -0
- megadetector/data_management/databases/integrity_check_json_db.py +493 -0
- megadetector/data_management/databases/subset_json_db.py +115 -0
- megadetector/data_management/generate_crops_from_cct.py +149 -0
- megadetector/data_management/get_image_sizes.py +189 -0
- megadetector/data_management/importers/add_nacti_sizes.py +52 -0
- megadetector/data_management/importers/add_timestamps_to_icct.py +79 -0
- megadetector/data_management/importers/animl_results_to_md_results.py +158 -0
- megadetector/data_management/importers/auckland_doc_test_to_json.py +373 -0
- megadetector/data_management/importers/auckland_doc_to_json.py +201 -0
- megadetector/data_management/importers/awc_to_json.py +191 -0
- megadetector/data_management/importers/bellevue_to_json.py +273 -0
- megadetector/data_management/importers/cacophony-thermal-importer.py +793 -0
- megadetector/data_management/importers/carrizo_shrubfree_2018.py +269 -0
- megadetector/data_management/importers/carrizo_trail_cam_2017.py +289 -0
- megadetector/data_management/importers/cct_field_adjustments.py +58 -0
- megadetector/data_management/importers/channel_islands_to_cct.py +913 -0
- megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +180 -0
- megadetector/data_management/importers/eMammal/eMammal_helpers.py +249 -0
- megadetector/data_management/importers/eMammal/make_eMammal_json.py +223 -0
- megadetector/data_management/importers/ena24_to_json.py +276 -0
- megadetector/data_management/importers/filenames_to_json.py +386 -0
- megadetector/data_management/importers/helena_to_cct.py +283 -0
- megadetector/data_management/importers/idaho-camera-traps.py +1407 -0
- megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +294 -0
- megadetector/data_management/importers/jb_csv_to_json.py +150 -0
- megadetector/data_management/importers/mcgill_to_json.py +250 -0
- megadetector/data_management/importers/missouri_to_json.py +490 -0
- megadetector/data_management/importers/nacti_fieldname_adjustments.py +79 -0
- megadetector/data_management/importers/noaa_seals_2019.py +181 -0
- megadetector/data_management/importers/pc_to_json.py +365 -0
- megadetector/data_management/importers/plot_wni_giraffes.py +123 -0
- megadetector/data_management/importers/prepare-noaa-fish-data-for-lila.py +359 -0
- megadetector/data_management/importers/prepare_zsl_imerit.py +131 -0
- megadetector/data_management/importers/rspb_to_json.py +356 -0
- megadetector/data_management/importers/save_the_elephants_survey_A.py +320 -0
- megadetector/data_management/importers/save_the_elephants_survey_B.py +329 -0
- megadetector/data_management/importers/snapshot_safari_importer.py +758 -0
- megadetector/data_management/importers/snapshot_safari_importer_reprise.py +665 -0
- megadetector/data_management/importers/snapshot_serengeti_lila.py +1067 -0
- megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +150 -0
- megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +153 -0
- megadetector/data_management/importers/sulross_get_exif.py +65 -0
- megadetector/data_management/importers/timelapse_csv_set_to_json.py +490 -0
- megadetector/data_management/importers/ubc_to_json.py +399 -0
- megadetector/data_management/importers/umn_to_json.py +507 -0
- megadetector/data_management/importers/wellington_to_json.py +263 -0
- megadetector/data_management/importers/wi_to_json.py +442 -0
- megadetector/data_management/importers/zamba_results_to_md_results.py +181 -0
- megadetector/data_management/labelme_to_coco.py +547 -0
- megadetector/data_management/labelme_to_yolo.py +272 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/add_locations_to_island_camera_traps.py +97 -0
- megadetector/data_management/lila/add_locations_to_nacti.py +147 -0
- megadetector/data_management/lila/create_lila_blank_set.py +558 -0
- megadetector/data_management/lila/create_lila_test_set.py +152 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +178 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +516 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +170 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +300 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +132 -0
- megadetector/data_management/ocr_tools.py +870 -0
- megadetector/data_management/read_exif.py +809 -0
- megadetector/data_management/remap_coco_categories.py +84 -0
- megadetector/data_management/remove_exif.py +66 -0
- megadetector/data_management/rename_images.py +187 -0
- megadetector/data_management/resize_coco_dataset.py +189 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +446 -0
- megadetector/data_management/yolo_to_coco.py +676 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/detector_training/__init__.py +0 -0
- megadetector/detection/detector_training/model_main_tf2.py +114 -0
- megadetector/detection/process_video.py +846 -0
- megadetector/detection/pytorch_detector.py +355 -0
- megadetector/detection/run_detector.py +779 -0
- megadetector/detection/run_detector_batch.py +1219 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1087 -0
- megadetector/detection/run_tiled_inference.py +934 -0
- megadetector/detection/tf_detector.py +192 -0
- megadetector/detection/video_utils.py +698 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +64 -0
- megadetector/postprocessing/categorize_detections_by_size.py +165 -0
- megadetector/postprocessing/classification_postprocessing.py +716 -0
- megadetector/postprocessing/combine_api_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +966 -0
- megadetector/postprocessing/convert_output_format.py +396 -0
- megadetector/postprocessing/load_api_results.py +195 -0
- megadetector/postprocessing/md_to_coco.py +310 -0
- megadetector/postprocessing/md_to_labelme.py +330 -0
- megadetector/postprocessing/merge_detections.py +412 -0
- megadetector/postprocessing/postprocess_batch_results.py +1908 -0
- megadetector/postprocessing/remap_detection_categories.py +170 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +660 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +211 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +83 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1635 -0
- megadetector/postprocessing/separate_detections_into_folders.py +730 -0
- megadetector/postprocessing/subset_json_detector_output.py +700 -0
- megadetector/postprocessing/top_folders_to_bottom.py +223 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +150 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +142 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +588 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +219 -0
- megadetector/taxonomy_mapping/species_lookup.py +834 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/azure_utils.py +178 -0
- megadetector/utils/ct_utils.py +613 -0
- megadetector/utils/directory_listing.py +246 -0
- megadetector/utils/md_tests.py +1164 -0
- megadetector/utils/path_utils.py +1045 -0
- megadetector/utils/process_utils.py +160 -0
- megadetector/utils/sas_blob_utils.py +509 -0
- megadetector/utils/split_locations_into_train_val.py +228 -0
- megadetector/utils/string_utils.py +92 -0
- megadetector/utils/url_utils.py +323 -0
- megadetector/utils/write_html_image_list.py +225 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +293 -0
- megadetector/visualization/render_images_with_thumbnails.py +275 -0
- megadetector/visualization/visualization_utils.py +1536 -0
- megadetector/visualization/visualize_db.py +552 -0
- megadetector/visualization/visualize_detector_output.py +405 -0
- {megadetector-5.0.11.dist-info → megadetector-5.0.13.dist-info}/LICENSE +0 -0
- {megadetector-5.0.11.dist-info → megadetector-5.0.13.dist-info}/METADATA +2 -2
- megadetector-5.0.13.dist-info/RECORD +201 -0
- megadetector-5.0.13.dist-info/top_level.txt +1 -0
- megadetector-5.0.11.dist-info/RECORD +0 -5
- megadetector-5.0.11.dist-info/top_level.txt +0 -1
- {megadetector-5.0.11.dist-info → megadetector-5.0.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
split_locations_into_train_val.py
|
|
4
|
+
|
|
5
|
+
Splits a list of location IDs into training and validation, targeting a specific
|
|
6
|
+
train/val split for each category, but allowing some categories to be tighter or looser
|
|
7
|
+
than others. Does nothing particularly clever, just randomly splits locations into
|
|
8
|
+
train/val lots of times using the target val fraction, and picks the one that meets the
|
|
9
|
+
specified constraints and minimizes weighted error, where "error" is defined as the
|
|
10
|
+
sum of each class's absolute divergence from the target val fraction.
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
#%% Imports/constants
|
|
15
|
+
|
|
16
|
+
import random
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
from megadetector.utils.ct_utils import sort_dictionary_by_value
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
#%% Main function
|
|
25
|
+
|
|
26
|
+
def split_locations_into_train_val(location_to_category_counts,
|
|
27
|
+
n_random_seeds=10000,
|
|
28
|
+
target_val_fraction=0.15,
|
|
29
|
+
category_to_max_allowable_error=None,
|
|
30
|
+
category_to_error_weight=None,
|
|
31
|
+
default_max_allowable_error=0.1):
|
|
32
|
+
"""
|
|
33
|
+
Splits a list of location IDs into training and validation, targeting a specific
|
|
34
|
+
train/val split for each category, but allowing some categories to be tighter or looser
|
|
35
|
+
than others. Does nothing particularly clever, just randomly splits locations into
|
|
36
|
+
train/val lots of times using the target val fraction, and picks the one that meets the
|
|
37
|
+
specified constraints and minimizes weighted error, where "error" is defined as the
|
|
38
|
+
sum of each class's absolute divergence from the target val fraction.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
location_to_category_counts (dict): a dict mapping location IDs to dicts,
|
|
42
|
+
with each dict mapping a category name to a count. Any categories not present
|
|
43
|
+
in a particular dict are assumed to have a count of zero for that location.
|
|
44
|
+
|
|
45
|
+
For example:
|
|
46
|
+
|
|
47
|
+
.. code-block:: none
|
|
48
|
+
|
|
49
|
+
{'location-000': {'bear':4,'wolf':10},
|
|
50
|
+
'location-001': {'bear':12,'elk':20}}
|
|
51
|
+
|
|
52
|
+
n_random_seeds (int, optional): number of random seeds to try, always starting from zero
|
|
53
|
+
target_val_fraction (float, optional): fraction of images containing each species we'd
|
|
54
|
+
like to put in the val split
|
|
55
|
+
category_to_max_allowable_error (dict, optional): a dict mapping category names
|
|
56
|
+
to maximum allowable errors. These are hard constraints (i.e., we will error
|
|
57
|
+
if we can't meet them). Does not need to include all categories; categories not
|
|
58
|
+
included will be assigned a maximum error according to [default_max_allowable_error].
|
|
59
|
+
If this is None, no hard constraints are applied.
|
|
60
|
+
category_to_error_weight (dict, optional): a dict mapping category names to
|
|
61
|
+
error weights. You can specify a subset of categories; categories not included here
|
|
62
|
+
have a weight of 1.0. If None, all categories have the same weight.
|
|
63
|
+
default_max_allowable_error (float, optional): the maximum allowable error for categories not
|
|
64
|
+
present in [category_to_max_allowable_error]. Set to None (or >= 1.0) to disable hard
|
|
65
|
+
constraints for categories not present in [category_to_max_allowable_error]
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
tuple: A two-element tuple:
|
|
69
|
+
- list of location IDs in the val split
|
|
70
|
+
- a dict mapping category names to the fraction of images in the val split
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
location_ids = list(location_to_category_counts.keys())
|
|
74
|
+
|
|
75
|
+
n_val_locations = int(target_val_fraction*len(location_ids))
|
|
76
|
+
|
|
77
|
+
if category_to_max_allowable_error is None:
|
|
78
|
+
category_to_max_allowable_error = {}
|
|
79
|
+
|
|
80
|
+
if category_to_error_weight is None:
|
|
81
|
+
category_to_error_weight = {}
|
|
82
|
+
|
|
83
|
+
# category ID to total count; the total count is used only for printouts
|
|
84
|
+
category_id_to_count = {}
|
|
85
|
+
for location_id in location_to_category_counts:
|
|
86
|
+
for category_id in location_to_category_counts[location_id].keys():
|
|
87
|
+
if category_id not in category_id_to_count:
|
|
88
|
+
category_id_to_count[category_id] = 0
|
|
89
|
+
category_id_to_count[category_id] += \
|
|
90
|
+
location_to_category_counts[location_id][category_id]
|
|
91
|
+
|
|
92
|
+
category_ids = set(category_id_to_count.keys())
|
|
93
|
+
|
|
94
|
+
print('Splitting {} categories over {} locations'.format(
|
|
95
|
+
len(category_ids),len(location_ids)))
|
|
96
|
+
|
|
97
|
+
# random_seed = 0
|
|
98
|
+
def compute_seed_errors(random_seed):
|
|
99
|
+
"""
|
|
100
|
+
Computes the per-category error for a specific random seed.
|
|
101
|
+
|
|
102
|
+
returns weighted_average_error,category_to_val_fraction
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
# Randomly split into train/val
|
|
106
|
+
random.seed(random_seed)
|
|
107
|
+
val_locations = random.sample(location_ids,k=n_val_locations)
|
|
108
|
+
val_locations_set = set(val_locations)
|
|
109
|
+
|
|
110
|
+
# For each category, measure the % of images that went into the val set
|
|
111
|
+
category_to_val_fraction = defaultdict(float)
|
|
112
|
+
|
|
113
|
+
for category_id in category_ids:
|
|
114
|
+
category_val_count = 0
|
|
115
|
+
category_train_count = 0
|
|
116
|
+
for location_id in location_to_category_counts:
|
|
117
|
+
if category_id not in location_to_category_counts[location_id]:
|
|
118
|
+
location_category_count = 0
|
|
119
|
+
else:
|
|
120
|
+
location_category_count = location_to_category_counts[location_id][category_id]
|
|
121
|
+
if location_id in val_locations_set:
|
|
122
|
+
category_val_count += location_category_count
|
|
123
|
+
else:
|
|
124
|
+
category_train_count += location_category_count
|
|
125
|
+
category_val_fraction = category_val_count / (category_val_count + category_train_count)
|
|
126
|
+
category_to_val_fraction[category_id] = category_val_fraction
|
|
127
|
+
|
|
128
|
+
# Absolute deviation from the target val fraction for each categorys
|
|
129
|
+
category_errors = {}
|
|
130
|
+
weighted_category_errors = {}
|
|
131
|
+
|
|
132
|
+
# category = next(iter(category_to_val_fraction))
|
|
133
|
+
for category in category_to_val_fraction:
|
|
134
|
+
|
|
135
|
+
category_val_fraction = category_to_val_fraction[category]
|
|
136
|
+
|
|
137
|
+
category_error = abs(category_val_fraction-target_val_fraction)
|
|
138
|
+
category_errors[category] = category_error
|
|
139
|
+
|
|
140
|
+
category_weight = 1.0
|
|
141
|
+
if category in category_to_error_weight:
|
|
142
|
+
category_weight = category_to_error_weight[category]
|
|
143
|
+
weighted_category_error = category_error * category_weight
|
|
144
|
+
weighted_category_errors[category] = weighted_category_error
|
|
145
|
+
|
|
146
|
+
weighted_average_error = np.mean(list(weighted_category_errors.values()))
|
|
147
|
+
|
|
148
|
+
return weighted_average_error,weighted_category_errors,category_to_val_fraction
|
|
149
|
+
|
|
150
|
+
# ... def compute_seed_errors(...)
|
|
151
|
+
|
|
152
|
+
# This will only include random seeds that satisfy the hard constraints
|
|
153
|
+
random_seed_to_weighted_average_error = {}
|
|
154
|
+
|
|
155
|
+
# random_seed = 0
|
|
156
|
+
for random_seed in tqdm(range(0,n_random_seeds)):
|
|
157
|
+
|
|
158
|
+
weighted_average_error,weighted_category_errors,category_to_val_fraction = \
|
|
159
|
+
compute_seed_errors(random_seed)
|
|
160
|
+
|
|
161
|
+
seed_satisfies_hard_constraints = True
|
|
162
|
+
|
|
163
|
+
for category in category_to_val_fraction:
|
|
164
|
+
if category in category_to_max_allowable_error:
|
|
165
|
+
max_allowable_error = category_to_max_allowable_error[category]
|
|
166
|
+
else:
|
|
167
|
+
if default_max_allowable_error is None:
|
|
168
|
+
continue
|
|
169
|
+
max_allowable_error = default_max_allowable_error
|
|
170
|
+
val_fraction = category_to_val_fraction[category]
|
|
171
|
+
category_error = abs(val_fraction - target_val_fraction)
|
|
172
|
+
if category_error > max_allowable_error:
|
|
173
|
+
seed_satisfies_hard_constraints = False
|
|
174
|
+
break
|
|
175
|
+
|
|
176
|
+
if seed_satisfies_hard_constraints:
|
|
177
|
+
random_seed_to_weighted_average_error[random_seed] = weighted_average_error
|
|
178
|
+
|
|
179
|
+
# ...for each random seed
|
|
180
|
+
|
|
181
|
+
assert len(random_seed_to_weighted_average_error) > 0, \
|
|
182
|
+
'No random seed met all the hard constraints'
|
|
183
|
+
|
|
184
|
+
print('\n{} of {} random seeds satisfied hard constraints'.format(
|
|
185
|
+
len(random_seed_to_weighted_average_error),n_random_seeds))
|
|
186
|
+
|
|
187
|
+
min_error = None
|
|
188
|
+
min_error_seed = None
|
|
189
|
+
|
|
190
|
+
for random_seed in random_seed_to_weighted_average_error.keys():
|
|
191
|
+
error_metric = random_seed_to_weighted_average_error[random_seed]
|
|
192
|
+
if min_error is None or error_metric < min_error:
|
|
193
|
+
min_error = error_metric
|
|
194
|
+
min_error_seed = random_seed
|
|
195
|
+
|
|
196
|
+
random.seed(min_error_seed)
|
|
197
|
+
val_locations = random.sample(location_ids,k=n_val_locations)
|
|
198
|
+
train_locations = []
|
|
199
|
+
for location_id in location_ids:
|
|
200
|
+
if location_id not in val_locations:
|
|
201
|
+
train_locations.append(location_id)
|
|
202
|
+
|
|
203
|
+
print('\nVal locations:\n')
|
|
204
|
+
for loc in val_locations:
|
|
205
|
+
print('{}'.format(loc))
|
|
206
|
+
print('')
|
|
207
|
+
|
|
208
|
+
weighted_average_error,weighted_category_errors,category_to_val_fraction = \
|
|
209
|
+
compute_seed_errors(min_error_seed)
|
|
210
|
+
|
|
211
|
+
random_seed = min_error_seed
|
|
212
|
+
|
|
213
|
+
category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,reverse=True)
|
|
214
|
+
category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,
|
|
215
|
+
sort_values=category_id_to_count,
|
|
216
|
+
reverse=True)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
print('Val fractions by category:\n')
|
|
220
|
+
|
|
221
|
+
for category in category_to_val_fraction:
|
|
222
|
+
print('{} ({}) {:.2f}'.format(
|
|
223
|
+
category,category_id_to_count[category],
|
|
224
|
+
category_to_val_fraction[category]))
|
|
225
|
+
|
|
226
|
+
return val_locations,category_to_val_fraction
|
|
227
|
+
|
|
228
|
+
# ...def split_locations_into_train_val(...)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
string_utils.py
|
|
4
|
+
|
|
5
|
+
Miscellaneous string utilities.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
#%% Imports
|
|
10
|
+
|
|
11
|
+
import re
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#%% Functions
|
|
15
|
+
|
|
16
|
+
def is_float(s):
|
|
17
|
+
"""
|
|
18
|
+
Checks whether [s] is an object (typically a string) that can be cast to a float
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
s (object): object to evaluate
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
bool: True if s successfully casts to a float, otherwise False
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
_ = float(s)
|
|
29
|
+
except ValueError:
|
|
30
|
+
return False
|
|
31
|
+
return True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def human_readable_to_bytes(size):
|
|
35
|
+
"""
|
|
36
|
+
Given a human-readable byte string (e.g. 2G, 10GB, 30MB, 20KB),
|
|
37
|
+
returns the number of bytes. Will return 0 if the argument has
|
|
38
|
+
unexpected form.
|
|
39
|
+
|
|
40
|
+
https://gist.github.com/beugley/ccd69945346759eb6142272a6d69b4e0
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
size (str): string representing a size
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
int: the corresponding size in bytes
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
size = re.sub(r'\s+', '', size)
|
|
50
|
+
|
|
51
|
+
if (size[-1] == 'B'):
|
|
52
|
+
size = size[:-1]
|
|
53
|
+
|
|
54
|
+
if (size.isdigit()):
|
|
55
|
+
bytes = int(size)
|
|
56
|
+
elif (is_float(size)):
|
|
57
|
+
bytes = float(size)
|
|
58
|
+
else:
|
|
59
|
+
bytes = size[:-1]
|
|
60
|
+
unit = size[-1]
|
|
61
|
+
try:
|
|
62
|
+
bytes = float(bytes)
|
|
63
|
+
if (unit == 'T'):
|
|
64
|
+
bytes *= 1024*1024*1024*1024
|
|
65
|
+
elif (unit == 'G'):
|
|
66
|
+
bytes *= 1024*1024*1024
|
|
67
|
+
elif (unit == 'M'):
|
|
68
|
+
bytes *= 1024*1024
|
|
69
|
+
elif (unit == 'K'):
|
|
70
|
+
bytes *= 1024
|
|
71
|
+
else:
|
|
72
|
+
bytes = 0
|
|
73
|
+
except ValueError:
|
|
74
|
+
bytes = 0
|
|
75
|
+
|
|
76
|
+
return bytes
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def remove_ansi_codes(s):
|
|
80
|
+
"""
|
|
81
|
+
Removes ANSI escape codes from a string.
|
|
82
|
+
|
|
83
|
+
https://stackoverflow.com/questions/14693701/how-can-i-remove-the-ansi-escape-sequences-from-a-string-in-python#14693789
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
s (str): the string to de-ANSI-i-fy
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
str: A copy of [s] without ANSI codes
|
|
90
|
+
"""
|
|
91
|
+
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
|
92
|
+
return ansi_escape.sub('', s)
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
url_utils.py
|
|
4
|
+
|
|
5
|
+
Frequently-used functions for downloading or manipulating URLs
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
#%% Imports and constants
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import re
|
|
13
|
+
import urllib
|
|
14
|
+
import tempfile
|
|
15
|
+
import requests
|
|
16
|
+
|
|
17
|
+
from functools import partial
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
from urllib.parse import urlparse
|
|
20
|
+
from multiprocessing.pool import ThreadPool
|
|
21
|
+
from multiprocessing.pool import Pool
|
|
22
|
+
|
|
23
|
+
url_utils_temp_dir = None
|
|
24
|
+
max_path_len = 255
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
#%% Download functions
|
|
28
|
+
|
|
29
|
+
class DownloadProgressBar():
|
|
30
|
+
"""
|
|
31
|
+
Progress updater based on the progressbar2 package.
|
|
32
|
+
|
|
33
|
+
https://stackoverflow.com/questions/37748105/how-to-use-progressbar-module-with-urlretrieve
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.pbar = None
|
|
38
|
+
|
|
39
|
+
def __call__(self, block_num, block_size, total_size):
|
|
40
|
+
if not self.pbar:
|
|
41
|
+
# This is a pretty random import I'd rather not depend on outside of the
|
|
42
|
+
# rare case where it's used, so importing locally
|
|
43
|
+
# pip install progressbar2
|
|
44
|
+
import progressbar
|
|
45
|
+
self.pbar = progressbar.ProgressBar(max_value=total_size)
|
|
46
|
+
self.pbar.start()
|
|
47
|
+
|
|
48
|
+
downloaded = block_num * block_size
|
|
49
|
+
if downloaded < total_size:
|
|
50
|
+
self.pbar.update(downloaded)
|
|
51
|
+
else:
|
|
52
|
+
self.pbar.finish()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_temp_folder(preferred_name='url_utils'):
|
|
56
|
+
"""
|
|
57
|
+
Gets a temporary folder for use within this module.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
preferred_name (str, optional): subfolder to use within the system temp folder
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
str: the full path to the temporary subfolder
|
|
64
|
+
"""
|
|
65
|
+
global url_utils_temp_dir
|
|
66
|
+
|
|
67
|
+
if url_utils_temp_dir is None:
|
|
68
|
+
url_utils_temp_dir = os.path.join(tempfile.gettempdir(),preferred_name)
|
|
69
|
+
os.makedirs(url_utils_temp_dir,exist_ok=True)
|
|
70
|
+
|
|
71
|
+
return url_utils_temp_dir
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def download_url(url,
|
|
75
|
+
destination_filename=None,
|
|
76
|
+
progress_updater=None,
|
|
77
|
+
force_download=False,
|
|
78
|
+
verbose=True):
|
|
79
|
+
"""
|
|
80
|
+
Downloads a URL to a file. If no file is specified, creates a temporary file,
|
|
81
|
+
making a best effort to avoid filename collisions.
|
|
82
|
+
|
|
83
|
+
Prints some diagnostic information and makes sure to omit SAS tokens from printouts.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
url (str): the URL to download
|
|
87
|
+
destination_filename (str, optional): the target filename; if None, will create
|
|
88
|
+
a file in system temp space
|
|
89
|
+
progress_updater (object or bool, optional): can be "None", "False", "True", or a
|
|
90
|
+
specific callable object. If None or False, no progress updated will be
|
|
91
|
+
displayed. If True, a default progress bar will be created.
|
|
92
|
+
force_download (bool, optional): download this file even if [destination_filename]
|
|
93
|
+
exists.
|
|
94
|
+
verbose (bool, optional): enable additional debug console output
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
str: the filename to which [url] was downloaded, the same as [destination_filename]
|
|
98
|
+
if [destination_filename] was not None
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
if progress_updater is not None and isinstance(progress_updater,bool):
|
|
102
|
+
if not progress_updater:
|
|
103
|
+
progress_updater = None
|
|
104
|
+
else:
|
|
105
|
+
progress_updater = DownloadProgressBar()
|
|
106
|
+
|
|
107
|
+
url_no_sas = url.split('?')[0]
|
|
108
|
+
|
|
109
|
+
if destination_filename is None:
|
|
110
|
+
target_folder = get_temp_folder()
|
|
111
|
+
url_without_sas = url.split('?', 1)[0]
|
|
112
|
+
|
|
113
|
+
# This does not guarantee uniqueness, hence "semi-best-effort"
|
|
114
|
+
url_as_filename = re.sub(r'\W+', '', url_without_sas)
|
|
115
|
+
n_folder_chars = len(url_utils_temp_dir)
|
|
116
|
+
if len(url_as_filename) + n_folder_chars > max_path_len:
|
|
117
|
+
print('Warning: truncating filename target to {} characters'.format(max_path_len))
|
|
118
|
+
url_as_filename = url_as_filename[-1*(max_path_len-n_folder_chars):]
|
|
119
|
+
destination_filename = \
|
|
120
|
+
os.path.join(target_folder,url_as_filename)
|
|
121
|
+
|
|
122
|
+
if (not force_download) and (os.path.isfile(destination_filename)):
|
|
123
|
+
if verbose:
|
|
124
|
+
print('Bypassing download of already-downloaded file {}'.format(os.path.basename(url_no_sas)))
|
|
125
|
+
else:
|
|
126
|
+
if verbose:
|
|
127
|
+
print('Downloading file {} to {}'.format(os.path.basename(url_no_sas),destination_filename),end='')
|
|
128
|
+
target_dir = os.path.dirname(destination_filename)
|
|
129
|
+
os.makedirs(target_dir,exist_ok=True)
|
|
130
|
+
urllib.request.urlretrieve(url, destination_filename, progress_updater)
|
|
131
|
+
assert(os.path.isfile(destination_filename))
|
|
132
|
+
nBytes = os.path.getsize(destination_filename)
|
|
133
|
+
if verbose:
|
|
134
|
+
print('...done, {} bytes.'.format(nBytes))
|
|
135
|
+
|
|
136
|
+
return destination_filename
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def download_relative_filename(url, output_base, verbose=False):
|
|
140
|
+
"""
|
|
141
|
+
Download a URL to output_base, preserving relative path. Path is relative to
|
|
142
|
+
the site, so:
|
|
143
|
+
|
|
144
|
+
https://abc.com/xyz/123.txt
|
|
145
|
+
|
|
146
|
+
...will get downloaded to:
|
|
147
|
+
|
|
148
|
+
output_base/xyz/123.txt
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
url (str): the URL to download
|
|
152
|
+
output_base (str): the base folder to which we should download this file
|
|
153
|
+
verbose (bool, optional): enable additional debug console output
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
str: the local destination filename
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
p = urlparse(url)
|
|
160
|
+
# remove the leading '/'
|
|
161
|
+
assert p.path.startswith('/'); relative_filename = p.path[1:]
|
|
162
|
+
destination_filename = os.path.join(output_base,relative_filename)
|
|
163
|
+
return download_url(url, destination_filename, verbose=verbose)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _do_parallelized_download(download_info,overwrite=False,verbose=False):
|
|
167
|
+
"""
|
|
168
|
+
Internal function for download parallelization.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
url = download_info['url']
|
|
172
|
+
target_file = download_info['target_file']
|
|
173
|
+
result = {'status':'unknown','url':url,'target_file':target_file}
|
|
174
|
+
|
|
175
|
+
if ((os.path.isfile(target_file)) and (not overwrite)):
|
|
176
|
+
if verbose:
|
|
177
|
+
print('Skipping existing file {}'.format(target_file))
|
|
178
|
+
result['status'] = 'skipped'
|
|
179
|
+
return result
|
|
180
|
+
try:
|
|
181
|
+
download_url(url=url,
|
|
182
|
+
destination_filename=target_file,
|
|
183
|
+
verbose=verbose,
|
|
184
|
+
force_download=overwrite)
|
|
185
|
+
except Exception as e:
|
|
186
|
+
print('Warning: error downloading URL {}: {}'.format(
|
|
187
|
+
url,str(e)))
|
|
188
|
+
result['status'] = 'error: {}'.format(str(e))
|
|
189
|
+
return result
|
|
190
|
+
|
|
191
|
+
result['status'] = 'success'
|
|
192
|
+
return result
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def parallel_download_urls(url_to_target_file,verbose=False,overwrite=False,
|
|
196
|
+
n_workers=20,pool_type='thread'):
|
|
197
|
+
"""
|
|
198
|
+
Downloads a list of URLs to local files.
|
|
199
|
+
|
|
200
|
+
Catches exceptions and reports them in the returned "results" array.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
url_to_target_file: a dict mapping URLs to local filenames.
|
|
204
|
+
verbose (bool, optional): enable additional debug console output
|
|
205
|
+
overwrite (bool, optional): whether to overwrite existing local files
|
|
206
|
+
n_workers (int, optional): number of concurrent workers, set to <=1 to disable
|
|
207
|
+
parallelization
|
|
208
|
+
pool_type (str, optional): worker type to use; should be 'thread' or 'process'
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
list: list of dicts with keys:
|
|
212
|
+
- 'url': the url this item refers to
|
|
213
|
+
- 'status': 'skipped', 'success', or a string starting with 'error'
|
|
214
|
+
- 'target_file': the local filename to which we downloaded (or tried to
|
|
215
|
+
download) this URL
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
all_download_info = []
|
|
219
|
+
|
|
220
|
+
print('Preparing download list')
|
|
221
|
+
for url in tqdm(url_to_target_file):
|
|
222
|
+
download_info = {}
|
|
223
|
+
download_info['url'] = url
|
|
224
|
+
download_info['target_file'] = url_to_target_file[url]
|
|
225
|
+
all_download_info.append(download_info)
|
|
226
|
+
|
|
227
|
+
print('Downloading {} images on {} workers'.format(
|
|
228
|
+
len(all_download_info),n_workers))
|
|
229
|
+
|
|
230
|
+
if n_workers <= 1:
|
|
231
|
+
|
|
232
|
+
results = []
|
|
233
|
+
|
|
234
|
+
for download_info in tqdm(all_download_info):
|
|
235
|
+
result = _do_parallelized_download(download_info,overwrite=overwrite,verbose=verbose)
|
|
236
|
+
results.append(result)
|
|
237
|
+
|
|
238
|
+
else:
|
|
239
|
+
|
|
240
|
+
if pool_type == 'thread':
|
|
241
|
+
pool = ThreadPool(n_workers)
|
|
242
|
+
else:
|
|
243
|
+
assert pool_type == 'process', 'Unsupported pool type {}'.format(pool_type)
|
|
244
|
+
pool = Pool(n_workers)
|
|
245
|
+
|
|
246
|
+
print('Starting a {} pool with {} workers'.format(pool_type,n_workers))
|
|
247
|
+
|
|
248
|
+
results = list(tqdm(pool.imap(
|
|
249
|
+
partial(_do_parallelized_download,overwrite=overwrite,verbose=verbose),
|
|
250
|
+
all_download_info), total=len(all_download_info)))
|
|
251
|
+
|
|
252
|
+
return results
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def test_url(url, error_on_failure=True, timeout=None):
|
|
256
|
+
"""
|
|
257
|
+
Tests the availability of [url], returning an http status code.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
url (str): URL to test
|
|
261
|
+
error_on_failure (bool, optional): whether to error (vs. just returning an
|
|
262
|
+
error code) if accessing this URL fails
|
|
263
|
+
timeout (int, optional): timeout in seconds to wait before considering this
|
|
264
|
+
access attempt to be a failure; see requests.head() for precise documentation
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
int: http status code (200 for success)
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
# r = requests.get(url, stream=True, verify=True, timeout=timeout)
|
|
271
|
+
r = requests.head(url, stream=True, verify=True, timeout=timeout)
|
|
272
|
+
|
|
273
|
+
if error_on_failure and r.status_code != 200:
|
|
274
|
+
raise ValueError('Could not access {}: error {}'.format(url,r.status_code))
|
|
275
|
+
return r.status_code
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def test_urls(urls, error_on_failure=True, n_workers=1, pool_type='thread', timeout=None):
|
|
279
|
+
"""
|
|
280
|
+
Verify that URLs are available (i.e., returns status 200). By default,
|
|
281
|
+
errors if any URL is unavailable.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
urls (list): list of URLs to test
|
|
285
|
+
error_on_failure (bool, optional): whether to error (vs. just returning an
|
|
286
|
+
error code) if accessing this URL fails
|
|
287
|
+
n_workers (int, optional): number of concurrent workers, set to <=1 to disable
|
|
288
|
+
parallelization
|
|
289
|
+
pool_type (str, optional): worker type to use; should be 'thread' or 'process'
|
|
290
|
+
timeout (int, optional): timeout in seconds to wait before considering this
|
|
291
|
+
access attempt to be a failure; see requests.head() for precise documentation
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
list: a list of http status codes, the same length and order as [urls]
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
if n_workers <= 1:
|
|
298
|
+
|
|
299
|
+
status_codes = []
|
|
300
|
+
|
|
301
|
+
for url in tqdm(urls):
|
|
302
|
+
|
|
303
|
+
r = requests.get(url, timeout=timeout)
|
|
304
|
+
|
|
305
|
+
if error_on_failure and r.status_code != 200:
|
|
306
|
+
raise ValueError('Could not access {}: error {}'.format(url,r.status_code))
|
|
307
|
+
status_codes.append(r.status_code)
|
|
308
|
+
|
|
309
|
+
else:
|
|
310
|
+
|
|
311
|
+
if pool_type == 'thread':
|
|
312
|
+
pool = ThreadPool(n_workers)
|
|
313
|
+
else:
|
|
314
|
+
assert pool_type == 'process', 'Unsupported pool type {}'.format(pool_type)
|
|
315
|
+
pool = Pool(n_workers)
|
|
316
|
+
|
|
317
|
+
print('Starting a {} pool with {} workers'.format(pool_type,n_workers))
|
|
318
|
+
|
|
319
|
+
status_codes = list(tqdm(pool.imap(
|
|
320
|
+
partial(test_url,error_on_failure=error_on_failure,timeout=timeout),
|
|
321
|
+
urls), total=len(urls)))
|
|
322
|
+
|
|
323
|
+
return status_codes
|