megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
process_utils.py
|
|
4
|
+
|
|
5
|
+
Run something at the command line and capture the output, based on:
|
|
6
|
+
|
|
7
|
+
https://stackoverflow.com/questions/4417546/constantly-print-subprocess-output-while-process-is-running
|
|
8
|
+
|
|
9
|
+
Includes handy example code for doing this on multiple processes/threads.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
#%% Constants, imports, and environment
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import subprocess
|
|
17
|
+
|
|
18
|
+
def execute(cmd,encoding=None,errors=None,env=None,verbose=False):
|
|
19
|
+
"""
|
|
20
|
+
Run [cmd] (a single string) in a shell, yielding each line of output to the caller.
|
|
21
|
+
|
|
22
|
+
The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
|
|
23
|
+
|
|
24
|
+
"verbose" only impacts output about process management, it is not related to printing
|
|
25
|
+
output from the child process.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
cmd (str): command to run
|
|
29
|
+
encoding (str, optional): stdout encoding, see Popen() documentation
|
|
30
|
+
errors (str, optional): error handling, see Popen() documentation
|
|
31
|
+
env (dict, optional): environment variables, see Popen() documentation
|
|
32
|
+
verbose (bool, optional): enable additional debug console output
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
int: the command's return code, always zero, otherwise a CalledProcessError is raised
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
os.environ["PYTHONUNBUFFERED"] = "1"
|
|
39
|
+
|
|
40
|
+
if verbose:
|
|
41
|
+
if encoding is not None:
|
|
42
|
+
print('Launching child process with non-default encoding {}'.format(encoding))
|
|
43
|
+
if errors is not None:
|
|
44
|
+
print('Launching child process with non-default text error handling {}'.format(errors))
|
|
45
|
+
if env is not None:
|
|
46
|
+
print('Launching child process with non-default environment {}'.format(str(env)))
|
|
47
|
+
|
|
48
|
+
# https://stackoverflow.com/questions/4417546/constantly-print-subprocess-output-while-process-is-running
|
|
49
|
+
popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
50
|
+
shell=True, universal_newlines=True, encoding=encoding,
|
|
51
|
+
errors=errors, env=env)
|
|
52
|
+
for stdout_line in iter(popen.stdout.readline, ""):
|
|
53
|
+
yield stdout_line
|
|
54
|
+
popen.stdout.close()
|
|
55
|
+
return_code = popen.wait()
|
|
56
|
+
if return_code:
|
|
57
|
+
raise subprocess.CalledProcessError(return_code, cmd)
|
|
58
|
+
|
|
59
|
+
return return_code
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def execute_and_print(cmd,
|
|
63
|
+
print_output=True,
|
|
64
|
+
encoding=None,
|
|
65
|
+
errors=None,
|
|
66
|
+
env=None,
|
|
67
|
+
verbose=False,
|
|
68
|
+
catch_exceptions=True,
|
|
69
|
+
echo_command=False):
|
|
70
|
+
"""
|
|
71
|
+
Run [cmd] (a single string) in a shell, capturing and printing output. Returns
|
|
72
|
+
a dictionary with fields "status" and "output".
|
|
73
|
+
|
|
74
|
+
The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
|
|
75
|
+
|
|
76
|
+
"verbose" only impacts output about process management, it is not related to printing
|
|
77
|
+
output from the child process.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
cmd (str): command to run
|
|
81
|
+
print_output (bool, optional): whether to print output from [cmd] (stdout is
|
|
82
|
+
captured regardless of the value of print_output)
|
|
83
|
+
encoding (str, optional): stdout encoding, see Popen() documentation
|
|
84
|
+
errors (str, optional): error handling, see Popen() documentation
|
|
85
|
+
env (dict, optional): environment variables, see Popen() documentation
|
|
86
|
+
verbose (bool, optional): enable additional debug console output
|
|
87
|
+
catch_exceptions (bool, optional): catch exceptions and include in the output, otherwise raise
|
|
88
|
+
echo_command (bool, optional): print the command before executing
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
dict: a dictionary with fields "status" (the process return code) and "output"
|
|
92
|
+
(the content of stdout)
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
if echo_command:
|
|
96
|
+
print('Running command:\n{}\n'.format(cmd))
|
|
97
|
+
|
|
98
|
+
to_return = {'status':'unknown','output':''}
|
|
99
|
+
output = []
|
|
100
|
+
try:
|
|
101
|
+
for s in execute(cmd,encoding=encoding,errors=errors,env=env,verbose=verbose):
|
|
102
|
+
output.append(s)
|
|
103
|
+
if print_output:
|
|
104
|
+
print(s,end='',flush=True)
|
|
105
|
+
to_return['status'] = 0
|
|
106
|
+
except subprocess.CalledProcessError as cpe:
|
|
107
|
+
if not catch_exceptions:
|
|
108
|
+
raise
|
|
109
|
+
print('execute_and_print caught error: {} ({})'.format(cpe.output,str(cpe)))
|
|
110
|
+
to_return['status'] = cpe.returncode
|
|
111
|
+
to_return['output'] = output
|
|
112
|
+
|
|
113
|
+
return to_return
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
#%% Single-threaded test driver for execute_and_print
|
|
117
|
+
|
|
118
|
+
if False:
|
|
119
|
+
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
#%%
|
|
123
|
+
|
|
124
|
+
if os.name == 'nt':
|
|
125
|
+
execute_and_print('echo hello && ping -n 5 127.0.0.1 && echo goodbye')
|
|
126
|
+
else:
|
|
127
|
+
execute_and_print('echo hello && sleep 1 && echo goodbye')
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
#%% Parallel test driver for execute_and_print
|
|
131
|
+
|
|
132
|
+
if False:
|
|
133
|
+
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
#%%
|
|
137
|
+
|
|
138
|
+
from functools import partial
|
|
139
|
+
from multiprocessing.pool import ThreadPool as ThreadPool
|
|
140
|
+
from multiprocessing.pool import Pool as Pool
|
|
141
|
+
|
|
142
|
+
n_workers = 10
|
|
143
|
+
|
|
144
|
+
# Should we use threads (vs. processes) for parallelization?
|
|
145
|
+
use_threads = True
|
|
146
|
+
|
|
147
|
+
test_data = ['a','b','c','d']
|
|
148
|
+
|
|
149
|
+
def _process_sample(s):
|
|
150
|
+
return execute_and_print('echo ' + s,True)
|
|
151
|
+
|
|
152
|
+
if n_workers == 1:
|
|
153
|
+
|
|
154
|
+
results = []
|
|
155
|
+
for i_sample,sample in enumerate(test_data):
|
|
156
|
+
results.append(_process_sample(sample))
|
|
157
|
+
|
|
158
|
+
else:
|
|
159
|
+
|
|
160
|
+
n_threads = min(n_workers,len(test_data))
|
|
161
|
+
|
|
162
|
+
if use_threads:
|
|
163
|
+
print('Starting parallel thread pool with {} workers'.format(n_threads))
|
|
164
|
+
pool = ThreadPool(n_threads)
|
|
165
|
+
else:
|
|
166
|
+
print('Starting parallel process pool with {} workers'.format(n_threads))
|
|
167
|
+
pool = Pool(n_threads)
|
|
168
|
+
|
|
169
|
+
results = list(pool.map(partial(_process_sample),test_data))
|
|
170
|
+
|
|
171
|
+
for r in results:
|
|
172
|
+
print(r)
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
split_locations_into_train_val.py
|
|
4
|
+
|
|
5
|
+
Splits a list of location IDs into training and validation, targeting a specific
|
|
6
|
+
train/val split for each category, but allowing some categories to be tighter or looser
|
|
7
|
+
than others. Does nothing particularly clever, just randomly splits locations into
|
|
8
|
+
train/val lots of times using the target val fraction, and picks the one that meets the
|
|
9
|
+
specified constraints and minimizes weighted error, where "error" is defined as the
|
|
10
|
+
sum of each class's absolute divergence from the target val fraction.
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
#%% Imports/constants
|
|
15
|
+
|
|
16
|
+
import random
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
from megadetector.utils.ct_utils import sort_dictionary_by_value
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
#%% Main function
|
|
25
|
+
|
|
26
|
+
def split_locations_into_train_val(location_to_category_counts,
|
|
27
|
+
n_random_seeds=10000,
|
|
28
|
+
target_val_fraction=0.15,
|
|
29
|
+
category_to_max_allowable_error=None,
|
|
30
|
+
category_to_error_weight=None,
|
|
31
|
+
default_max_allowable_error=0.1,
|
|
32
|
+
require_complete_coverage=True):
|
|
33
|
+
"""
|
|
34
|
+
Splits a list of location IDs into training and validation, targeting a specific
|
|
35
|
+
train/val split for each category, but allowing some categories to be tighter or looser
|
|
36
|
+
than others. Does nothing particularly clever, just randomly splits locations into
|
|
37
|
+
train/val lots of times using the target val fraction, and picks the one that meets the
|
|
38
|
+
specified constraints and minimizes weighted error, where "error" is defined as the
|
|
39
|
+
sum of each class's absolute divergence from the target val fraction.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
location_to_category_counts (dict): a dict mapping location IDs to dicts,
|
|
43
|
+
with each dict mapping a category name to a count. Any categories not present
|
|
44
|
+
in a particular dict are assumed to have a count of zero for that location.
|
|
45
|
+
|
|
46
|
+
For example:
|
|
47
|
+
|
|
48
|
+
.. code-block:: none
|
|
49
|
+
|
|
50
|
+
{'location-000': {'bear':4,'wolf':10},
|
|
51
|
+
'location-001': {'bear':12,'elk':20}}
|
|
52
|
+
|
|
53
|
+
n_random_seeds (int, optional): number of random seeds to try, always starting from zero
|
|
54
|
+
target_val_fraction (float, optional): fraction of images containing each species we'd
|
|
55
|
+
like to put in the val split
|
|
56
|
+
category_to_max_allowable_error (dict, optional): a dict mapping category names
|
|
57
|
+
to maximum allowable errors. These are hard constraints (i.e., we will error
|
|
58
|
+
if we can't meet them). Does not need to include all categories; categories not
|
|
59
|
+
included will be assigned a maximum error according to [default_max_allowable_error].
|
|
60
|
+
If this is None, no hard constraints are applied.
|
|
61
|
+
category_to_error_weight (dict, optional): a dict mapping category names to
|
|
62
|
+
error weights. You can specify a subset of categories; categories not included here
|
|
63
|
+
have a weight of 1.0. If None, all categories have the same weight.
|
|
64
|
+
default_max_allowable_error (float, optional): the maximum allowable error for categories not
|
|
65
|
+
present in [category_to_max_allowable_error]. Set to None (or >= 1.0) to disable hard
|
|
66
|
+
constraints for categories not present in [category_to_max_allowable_error]
|
|
67
|
+
require_complete_coverage (bool, optional): require that every category appear in both train
|
|
68
|
+
and val
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
tuple: A two-element tuple:
|
|
72
|
+
- list of location IDs in the val split
|
|
73
|
+
- a dict mapping category names to the fraction of images in the val split
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
location_ids = list(location_to_category_counts.keys())
|
|
77
|
+
|
|
78
|
+
n_val_locations = int(target_val_fraction*len(location_ids))
|
|
79
|
+
|
|
80
|
+
if category_to_max_allowable_error is None:
|
|
81
|
+
category_to_max_allowable_error = {}
|
|
82
|
+
|
|
83
|
+
if category_to_error_weight is None:
|
|
84
|
+
category_to_error_weight = {}
|
|
85
|
+
|
|
86
|
+
# category ID to total count; the total count is used only for printouts
|
|
87
|
+
category_id_to_count = {}
|
|
88
|
+
for location_id in location_to_category_counts:
|
|
89
|
+
for category_id in location_to_category_counts[location_id].keys():
|
|
90
|
+
if category_id not in category_id_to_count:
|
|
91
|
+
category_id_to_count[category_id] = 0
|
|
92
|
+
category_id_to_count[category_id] += \
|
|
93
|
+
location_to_category_counts[location_id][category_id]
|
|
94
|
+
|
|
95
|
+
category_ids = set(category_id_to_count.keys())
|
|
96
|
+
|
|
97
|
+
print('Splitting {} categories over {} locations'.format(
|
|
98
|
+
len(category_ids),len(location_ids)))
|
|
99
|
+
|
|
100
|
+
# random_seed = 0
|
|
101
|
+
def compute_seed_errors(random_seed):
|
|
102
|
+
"""
|
|
103
|
+
Computes the per-category error for a specific random seed.
|
|
104
|
+
|
|
105
|
+
returns weighted_average_error,category_to_val_fraction
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
# Randomly split into train/val
|
|
109
|
+
random.seed(random_seed)
|
|
110
|
+
val_locations = random.sample(location_ids,k=n_val_locations)
|
|
111
|
+
val_locations_set = set(val_locations)
|
|
112
|
+
|
|
113
|
+
# For each category, measure the % of images that went into the val set
|
|
114
|
+
category_to_val_fraction = defaultdict(float)
|
|
115
|
+
|
|
116
|
+
for category_id in category_ids:
|
|
117
|
+
category_val_count = 0
|
|
118
|
+
category_train_count = 0
|
|
119
|
+
for location_id in location_to_category_counts:
|
|
120
|
+
if category_id not in location_to_category_counts[location_id]:
|
|
121
|
+
location_category_count = 0
|
|
122
|
+
else:
|
|
123
|
+
location_category_count = location_to_category_counts[location_id][category_id]
|
|
124
|
+
if location_id in val_locations_set:
|
|
125
|
+
category_val_count += location_category_count
|
|
126
|
+
else:
|
|
127
|
+
category_train_count += location_category_count
|
|
128
|
+
category_val_fraction = category_val_count / (category_val_count + category_train_count)
|
|
129
|
+
category_to_val_fraction[category_id] = category_val_fraction
|
|
130
|
+
|
|
131
|
+
# Absolute deviation from the target val fraction for each category
|
|
132
|
+
category_errors = {}
|
|
133
|
+
weighted_category_errors = {}
|
|
134
|
+
|
|
135
|
+
# category = next(iter(category_to_val_fraction))
|
|
136
|
+
for category in category_to_val_fraction:
|
|
137
|
+
|
|
138
|
+
category_val_fraction = category_to_val_fraction[category]
|
|
139
|
+
|
|
140
|
+
category_error = abs(category_val_fraction-target_val_fraction)
|
|
141
|
+
category_errors[category] = category_error
|
|
142
|
+
|
|
143
|
+
category_weight = 1.0
|
|
144
|
+
if category in category_to_error_weight:
|
|
145
|
+
category_weight = category_to_error_weight[category]
|
|
146
|
+
weighted_category_error = category_error * category_weight
|
|
147
|
+
weighted_category_errors[category] = weighted_category_error
|
|
148
|
+
|
|
149
|
+
weighted_average_error = np.mean(list(weighted_category_errors.values()))
|
|
150
|
+
|
|
151
|
+
return weighted_average_error,weighted_category_errors,category_to_val_fraction
|
|
152
|
+
|
|
153
|
+
# ... def compute_seed_errors(...)
|
|
154
|
+
|
|
155
|
+
# This will only include random seeds that satisfy the hard constraints
|
|
156
|
+
random_seed_to_weighted_average_error = {}
|
|
157
|
+
|
|
158
|
+
# random_seed = 0
|
|
159
|
+
for random_seed in tqdm(range(0,n_random_seeds)):
|
|
160
|
+
|
|
161
|
+
weighted_average_error,weighted_category_errors,category_to_val_fraction = \
|
|
162
|
+
compute_seed_errors(random_seed)
|
|
163
|
+
|
|
164
|
+
seed_satisfies_hard_constraints = True
|
|
165
|
+
|
|
166
|
+
for category in category_to_val_fraction:
|
|
167
|
+
if category in category_to_max_allowable_error:
|
|
168
|
+
max_allowable_error = category_to_max_allowable_error[category]
|
|
169
|
+
else:
|
|
170
|
+
if default_max_allowable_error is None:
|
|
171
|
+
continue
|
|
172
|
+
max_allowable_error = default_max_allowable_error
|
|
173
|
+
val_fraction = category_to_val_fraction[category]
|
|
174
|
+
|
|
175
|
+
# If necessary, verify that this category doesn't *only* appear in train or val
|
|
176
|
+
if require_complete_coverage:
|
|
177
|
+
if (val_fraction == 0.0) or (val_fraction == 1.0):
|
|
178
|
+
seed_satisfies_hard_constraints = False
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
# Check whether this category exceeds the hard maximum deviation
|
|
182
|
+
category_error = abs(val_fraction - target_val_fraction)
|
|
183
|
+
if category_error > max_allowable_error:
|
|
184
|
+
seed_satisfies_hard_constraints = False
|
|
185
|
+
break
|
|
186
|
+
|
|
187
|
+
# ...for each category
|
|
188
|
+
|
|
189
|
+
if seed_satisfies_hard_constraints:
|
|
190
|
+
random_seed_to_weighted_average_error[random_seed] = weighted_average_error
|
|
191
|
+
|
|
192
|
+
# ...for each random seed
|
|
193
|
+
|
|
194
|
+
assert len(random_seed_to_weighted_average_error) > 0, \
|
|
195
|
+
'No random seed met all the hard constraints'
|
|
196
|
+
|
|
197
|
+
print('\n{} of {} random seeds satisfied hard constraints'.format(
|
|
198
|
+
len(random_seed_to_weighted_average_error),n_random_seeds))
|
|
199
|
+
|
|
200
|
+
min_error = None
|
|
201
|
+
min_error_seed = None
|
|
202
|
+
|
|
203
|
+
for random_seed in random_seed_to_weighted_average_error.keys():
|
|
204
|
+
error_metric = random_seed_to_weighted_average_error[random_seed]
|
|
205
|
+
if min_error is None or error_metric < min_error:
|
|
206
|
+
min_error = error_metric
|
|
207
|
+
min_error_seed = random_seed
|
|
208
|
+
|
|
209
|
+
random.seed(min_error_seed)
|
|
210
|
+
val_locations = random.sample(location_ids,k=n_val_locations)
|
|
211
|
+
train_locations = []
|
|
212
|
+
for location_id in location_ids:
|
|
213
|
+
if location_id not in val_locations:
|
|
214
|
+
train_locations.append(location_id)
|
|
215
|
+
|
|
216
|
+
print('\nVal locations:\n')
|
|
217
|
+
for loc in val_locations:
|
|
218
|
+
print('{}'.format(loc))
|
|
219
|
+
print('')
|
|
220
|
+
|
|
221
|
+
weighted_average_error,weighted_category_errors,category_to_val_fraction = \
|
|
222
|
+
compute_seed_errors(min_error_seed)
|
|
223
|
+
|
|
224
|
+
category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,
|
|
225
|
+
sort_values=category_id_to_count,
|
|
226
|
+
reverse=True)
|
|
227
|
+
|
|
228
|
+
print('Val fractions by category:\n')
|
|
229
|
+
|
|
230
|
+
for category in category_to_val_fraction:
|
|
231
|
+
print('{} ({}) {:.2f}'.format(
|
|
232
|
+
category,category_id_to_count[category],
|
|
233
|
+
category_to_val_fraction[category]))
|
|
234
|
+
|
|
235
|
+
return val_locations,category_to_val_fraction
|
|
236
|
+
|
|
237
|
+
# ...def split_locations_into_train_val(...)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
string_utils.py
|
|
4
|
+
|
|
5
|
+
Miscellaneous string utilities.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
#%% Imports
|
|
10
|
+
|
|
11
|
+
import re
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#%% Functions
|
|
15
|
+
|
|
16
|
+
def is_float(s):
|
|
17
|
+
"""
|
|
18
|
+
Checks whether [s] is an object (typically a string) that can be cast to a float
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
s (object): object to evaluate
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
bool: True if s successfully casts to a float, otherwise False
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
if s is None:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
_ = float(s)
|
|
32
|
+
except ValueError:
|
|
33
|
+
return False
|
|
34
|
+
return True
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def is_int(s):
|
|
38
|
+
"""
|
|
39
|
+
Checks whether [s] is an object (typically a string) that can be cast to a int
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
s (object): object to evaluate
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
bool: True if s successfully casts to a int, otherwise False
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
if s is None:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
_ = int(s)
|
|
53
|
+
except ValueError:
|
|
54
|
+
return False
|
|
55
|
+
return True
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def human_readable_to_bytes(size):
|
|
59
|
+
"""
|
|
60
|
+
Given a human-readable byte string (e.g. 2G, 10GB, 30MB, 20KB),
|
|
61
|
+
returns the number of bytes. Will return 0 if the argument has
|
|
62
|
+
unexpected form.
|
|
63
|
+
|
|
64
|
+
https://gist.github.com/beugley/ccd69945346759eb6142272a6d69b4e0
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
size (str): string representing a size
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
int: the corresponding size in bytes
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
size = re.sub(r'\s+', '', size)
|
|
74
|
+
|
|
75
|
+
if not size: # Handle empty string case after stripping spaces
|
|
76
|
+
return 0
|
|
77
|
+
|
|
78
|
+
if (size[-1] == 'B'):
|
|
79
|
+
size = size[:-1]
|
|
80
|
+
|
|
81
|
+
if not size: # Handle case where size was just "B"
|
|
82
|
+
return 0
|
|
83
|
+
|
|
84
|
+
if (size.isdigit()):
|
|
85
|
+
bytes_val = int(size) # Renamed to avoid conflict with built-in 'bytes'
|
|
86
|
+
elif (is_float(size)):
|
|
87
|
+
bytes_val = float(size) # Renamed
|
|
88
|
+
else:
|
|
89
|
+
# Handle cases like "1KB" where size[:-1] might be "1K" before this block
|
|
90
|
+
# The original code would try to float("1K") which fails.
|
|
91
|
+
# Need to separate numeric part from unit more carefully.
|
|
92
|
+
numeric_part = ''
|
|
93
|
+
unit_part = ''
|
|
94
|
+
|
|
95
|
+
# Iterate from the end to find the unit (K, M, G, T)
|
|
96
|
+
# This handles cases like "10KB" or "2.5GB"
|
|
97
|
+
for i in range(len(size) -1, -1, -1):
|
|
98
|
+
if size[i].isalpha():
|
|
99
|
+
unit_part = size[i] + unit_part
|
|
100
|
+
else:
|
|
101
|
+
numeric_part = size[:i+1]
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
# If no unit found, or numeric part is empty after stripping unit
|
|
105
|
+
if not unit_part or not numeric_part:
|
|
106
|
+
return 0
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
bytes_val = float(numeric_part)
|
|
110
|
+
unit = unit_part
|
|
111
|
+
if (unit == 'T'):
|
|
112
|
+
bytes_val *= 1024*1024*1024*1024
|
|
113
|
+
elif (unit == 'G'):
|
|
114
|
+
bytes_val *= 1024*1024*1024
|
|
115
|
+
elif (unit == 'M'):
|
|
116
|
+
bytes_val *= 1024*1024
|
|
117
|
+
elif (unit == 'K'):
|
|
118
|
+
bytes_val *= 1024
|
|
119
|
+
else:
|
|
120
|
+
# If it's a known unit (like 'B' already stripped) but not T/G/M/K,
|
|
121
|
+
# and it was floatable, it's just bytes. If it's an unknown unit, it's
|
|
122
|
+
# an error.
|
|
123
|
+
if unit not in ['B', '']: # 'B' was stripped, '' means just a number
|
|
124
|
+
bytes_val = 0
|
|
125
|
+
except ValueError:
|
|
126
|
+
bytes_val = 0
|
|
127
|
+
|
|
128
|
+
return bytes_val
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def remove_ansi_codes(s):
|
|
132
|
+
"""
|
|
133
|
+
Removes ANSI escape codes from a string.
|
|
134
|
+
|
|
135
|
+
https://stackoverflow.com/questions/14693701/how-can-i-remove-the-ansi-escape-sequences-from-a-string-in-python#14693789
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
s (str): the string to de-ANSI-i-fy
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
str: A copy of [s] without ANSI codes
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
|
145
|
+
return ansi_escape.sub('', s)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
#%% Tests
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class TestStringUtils:
|
|
152
|
+
"""
|
|
153
|
+
Tests for string_utils.py
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_is_float(self):
|
|
158
|
+
"""
|
|
159
|
+
Test the is_float function.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
assert is_float("1.23")
|
|
163
|
+
assert is_float("-0.5")
|
|
164
|
+
assert is_float("0")
|
|
165
|
+
assert is_float(1.23)
|
|
166
|
+
assert is_float(0)
|
|
167
|
+
assert not is_float("abc")
|
|
168
|
+
assert not is_float("1.2.3")
|
|
169
|
+
assert not is_float("")
|
|
170
|
+
assert not is_float(None)
|
|
171
|
+
assert not is_float("1,23")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def test_human_readable_to_bytes(self):
|
|
175
|
+
"""
|
|
176
|
+
Test the human_readable_to_bytes function.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
assert human_readable_to_bytes("10B") == 10
|
|
180
|
+
assert human_readable_to_bytes("10") == 10
|
|
181
|
+
assert human_readable_to_bytes("1K") == 1024
|
|
182
|
+
assert human_readable_to_bytes("1KB") == 1024
|
|
183
|
+
assert human_readable_to_bytes("1M") == 1024*1024
|
|
184
|
+
assert human_readable_to_bytes("1MB") == 1024*1024
|
|
185
|
+
assert human_readable_to_bytes("1G") == 1024*1024*1024
|
|
186
|
+
assert human_readable_to_bytes("1GB") == 1024*1024*1024
|
|
187
|
+
assert human_readable_to_bytes("1T") == 1024*1024*1024*1024
|
|
188
|
+
assert human_readable_to_bytes("1TB") == 1024*1024*1024*1024
|
|
189
|
+
|
|
190
|
+
assert human_readable_to_bytes("2.5K") == 2.5 * 1024
|
|
191
|
+
assert human_readable_to_bytes("0.5MB") == 0.5 * 1024 * 1024
|
|
192
|
+
|
|
193
|
+
# Test with spaces
|
|
194
|
+
assert human_readable_to_bytes(" 2 G ") == 2 * 1024*1024*1024
|
|
195
|
+
assert human_readable_to_bytes("500 KB") == 500 * 1024
|
|
196
|
+
|
|
197
|
+
# Invalid inputs
|
|
198
|
+
assert human_readable_to_bytes("abc") == 0
|
|
199
|
+
assert human_readable_to_bytes("1X") == 0
|
|
200
|
+
assert human_readable_to_bytes("1KBB") == 0
|
|
201
|
+
assert human_readable_to_bytes("K1") == 0
|
|
202
|
+
assert human_readable_to_bytes("") == 0
|
|
203
|
+
assert human_readable_to_bytes("1.2.3K") == 0
|
|
204
|
+
assert human_readable_to_bytes("B") == 0
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def test_remove_ansi_codes(self):
|
|
208
|
+
"""
|
|
209
|
+
Test the remove_ansi_codes function.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
assert remove_ansi_codes("text without codes") == "text without codes"
|
|
213
|
+
assert remove_ansi_codes("\x1b[31mRed text\x1b[0m") == "Red text"
|
|
214
|
+
assert remove_ansi_codes("\x1b[1m\x1b[4mBold and Underline\x1b[0m") == "Bold and Underline"
|
|
215
|
+
assert remove_ansi_codes("Mixed \x1b[32mgreen\x1b[0m and normal") == "Mixed green and normal"
|
|
216
|
+
assert remove_ansi_codes("") == ""
|
|
217
|
+
|
|
218
|
+
# More complex/varied ANSI codes
|
|
219
|
+
assert remove_ansi_codes("text\x1b[1Aup") == "textup"
|
|
220
|
+
assert remove_ansi_codes("\x1b[2Jclearscreen") == "clearscreen"
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def test_string_utils():
|
|
224
|
+
"""
|
|
225
|
+
Runs all tests in the TestStringUtils class.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
test_instance = TestStringUtils()
|
|
229
|
+
test_instance.test_is_float()
|
|
230
|
+
test_instance.test_human_readable_to_bytes()
|
|
231
|
+
test_instance.test_remove_ansi_codes()
|
|
232
|
+
|
|
233
|
+
# from IPython import embed; embed()
|
|
234
|
+
# test_string_utils()
|