megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,172 @@
1
+ """
2
+
3
+ process_utils.py
4
+
5
+ Run something at the command line and capture the output, based on:
6
+
7
+ https://stackoverflow.com/questions/4417546/constantly-print-subprocess-output-while-process-is-running
8
+
9
+ Includes handy example code for doing this on multiple processes/threads.
10
+
11
+ """
12
+
13
+ #%% Constants, imports, and environment
14
+
15
+ import os
16
+ import subprocess
17
+
18
+ def execute(cmd,encoding=None,errors=None,env=None,verbose=False):
19
+ """
20
+ Run [cmd] (a single string) in a shell, yielding each line of output to the caller.
21
+
22
+ The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
23
+
24
+ "verbose" only impacts output about process management, it is not related to printing
25
+ output from the child process.
26
+
27
+ Args:
28
+ cmd (str): command to run
29
+ encoding (str, optional): stdout encoding, see Popen() documentation
30
+ errors (str, optional): error handling, see Popen() documentation
31
+ env (dict, optional): environment variables, see Popen() documentation
32
+ verbose (bool, optional): enable additional debug console output
33
+
34
+ Returns:
35
+ int: the command's return code, always zero, otherwise a CalledProcessError is raised
36
+ """
37
+
38
+ os.environ["PYTHONUNBUFFERED"] = "1"
39
+
40
+ if verbose:
41
+ if encoding is not None:
42
+ print('Launching child process with non-default encoding {}'.format(encoding))
43
+ if errors is not None:
44
+ print('Launching child process with non-default text error handling {}'.format(errors))
45
+ if env is not None:
46
+ print('Launching child process with non-default environment {}'.format(str(env)))
47
+
48
+ # https://stackoverflow.com/questions/4417546/constantly-print-subprocess-output-while-process-is-running
49
+ popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
50
+ shell=True, universal_newlines=True, encoding=encoding,
51
+ errors=errors, env=env)
52
+ for stdout_line in iter(popen.stdout.readline, ""):
53
+ yield stdout_line
54
+ popen.stdout.close()
55
+ return_code = popen.wait()
56
+ if return_code:
57
+ raise subprocess.CalledProcessError(return_code, cmd)
58
+
59
+ return return_code
60
+
61
+
62
+ def execute_and_print(cmd,
63
+ print_output=True,
64
+ encoding=None,
65
+ errors=None,
66
+ env=None,
67
+ verbose=False,
68
+ catch_exceptions=True,
69
+ echo_command=False):
70
+ """
71
+ Run [cmd] (a single string) in a shell, capturing and printing output. Returns
72
+ a dictionary with fields "status" and "output".
73
+
74
+ The "encoding", "errors", and "env" parameters are passed directly to subprocess.Popen().
75
+
76
+ "verbose" only impacts output about process management, it is not related to printing
77
+ output from the child process.
78
+
79
+ Args:
80
+ cmd (str): command to run
81
+ print_output (bool, optional): whether to print output from [cmd] (stdout is
82
+ captured regardless of the value of print_output)
83
+ encoding (str, optional): stdout encoding, see Popen() documentation
84
+ errors (str, optional): error handling, see Popen() documentation
85
+ env (dict, optional): environment variables, see Popen() documentation
86
+ verbose (bool, optional): enable additional debug console output
87
+ catch_exceptions (bool, optional): catch exceptions and include in the output, otherwise raise
88
+ echo_command (bool, optional): print the command before executing
89
+
90
+ Returns:
91
+ dict: a dictionary with fields "status" (the process return code) and "output"
92
+ (the content of stdout)
93
+ """
94
+
95
+ if echo_command:
96
+ print('Running command:\n{}\n'.format(cmd))
97
+
98
+ to_return = {'status':'unknown','output':''}
99
+ output = []
100
+ try:
101
+ for s in execute(cmd,encoding=encoding,errors=errors,env=env,verbose=verbose):
102
+ output.append(s)
103
+ if print_output:
104
+ print(s,end='',flush=True)
105
+ to_return['status'] = 0
106
+ except subprocess.CalledProcessError as cpe:
107
+ if not catch_exceptions:
108
+ raise
109
+ print('execute_and_print caught error: {} ({})'.format(cpe.output,str(cpe)))
110
+ to_return['status'] = cpe.returncode
111
+ to_return['output'] = output
112
+
113
+ return to_return
114
+
115
+
116
+ #%% Single-threaded test driver for execute_and_print
117
+
118
+ if False:
119
+
120
+ pass
121
+
122
+ #%%
123
+
124
+ if os.name == 'nt':
125
+ execute_and_print('echo hello && ping -n 5 127.0.0.1 && echo goodbye')
126
+ else:
127
+ execute_and_print('echo hello && sleep 1 && echo goodbye')
128
+
129
+
130
+ #%% Parallel test driver for execute_and_print
131
+
132
+ if False:
133
+
134
+ pass
135
+
136
+ #%%
137
+
138
+ from functools import partial
139
+ from multiprocessing.pool import ThreadPool as ThreadPool
140
+ from multiprocessing.pool import Pool as Pool
141
+
142
+ n_workers = 10
143
+
144
+ # Should we use threads (vs. processes) for parallelization?
145
+ use_threads = True
146
+
147
+ test_data = ['a','b','c','d']
148
+
149
+ def _process_sample(s):
150
+ return execute_and_print('echo ' + s,True)
151
+
152
+ if n_workers == 1:
153
+
154
+ results = []
155
+ for i_sample,sample in enumerate(test_data):
156
+ results.append(_process_sample(sample))
157
+
158
+ else:
159
+
160
+ n_threads = min(n_workers,len(test_data))
161
+
162
+ if use_threads:
163
+ print('Starting parallel thread pool with {} workers'.format(n_threads))
164
+ pool = ThreadPool(n_threads)
165
+ else:
166
+ print('Starting parallel process pool with {} workers'.format(n_threads))
167
+ pool = Pool(n_threads)
168
+
169
+ results = list(pool.map(partial(_process_sample),test_data))
170
+
171
+ for r in results:
172
+ print(r)
@@ -0,0 +1,237 @@
1
+ """
2
+
3
+ split_locations_into_train_val.py
4
+
5
+ Splits a list of location IDs into training and validation, targeting a specific
6
+ train/val split for each category, but allowing some categories to be tighter or looser
7
+ than others. Does nothing particularly clever, just randomly splits locations into
8
+ train/val lots of times using the target val fraction, and picks the one that meets the
9
+ specified constraints and minimizes weighted error, where "error" is defined as the
10
+ sum of each class's absolute divergence from the target val fraction.
11
+
12
+ """
13
+
14
+ #%% Imports/constants
15
+
16
+ import random
17
+ import numpy as np
18
+
19
+ from collections import defaultdict
20
+ from megadetector.utils.ct_utils import sort_dictionary_by_value
21
+ from tqdm import tqdm
22
+
23
+
24
+ #%% Main function
25
+
26
+ def split_locations_into_train_val(location_to_category_counts,
27
+ n_random_seeds=10000,
28
+ target_val_fraction=0.15,
29
+ category_to_max_allowable_error=None,
30
+ category_to_error_weight=None,
31
+ default_max_allowable_error=0.1,
32
+ require_complete_coverage=True):
33
+ """
34
+ Splits a list of location IDs into training and validation, targeting a specific
35
+ train/val split for each category, but allowing some categories to be tighter or looser
36
+ than others. Does nothing particularly clever, just randomly splits locations into
37
+ train/val lots of times using the target val fraction, and picks the one that meets the
38
+ specified constraints and minimizes weighted error, where "error" is defined as the
39
+ sum of each class's absolute divergence from the target val fraction.
40
+
41
+ Args:
42
+ location_to_category_counts (dict): a dict mapping location IDs to dicts,
43
+ with each dict mapping a category name to a count. Any categories not present
44
+ in a particular dict are assumed to have a count of zero for that location.
45
+
46
+ For example:
47
+
48
+ .. code-block:: none
49
+
50
+ {'location-000': {'bear':4,'wolf':10},
51
+ 'location-001': {'bear':12,'elk':20}}
52
+
53
+ n_random_seeds (int, optional): number of random seeds to try, always starting from zero
54
+ target_val_fraction (float, optional): fraction of images containing each species we'd
55
+ like to put in the val split
56
+ category_to_max_allowable_error (dict, optional): a dict mapping category names
57
+ to maximum allowable errors. These are hard constraints (i.e., we will error
58
+ if we can't meet them). Does not need to include all categories; categories not
59
+ included will be assigned a maximum error according to [default_max_allowable_error].
60
+ If this is None, no hard constraints are applied.
61
+ category_to_error_weight (dict, optional): a dict mapping category names to
62
+ error weights. You can specify a subset of categories; categories not included here
63
+ have a weight of 1.0. If None, all categories have the same weight.
64
+ default_max_allowable_error (float, optional): the maximum allowable error for categories not
65
+ present in [category_to_max_allowable_error]. Set to None (or >= 1.0) to disable hard
66
+ constraints for categories not present in [category_to_max_allowable_error]
67
+ require_complete_coverage (bool, optional): require that every category appear in both train
68
+ and val
69
+
70
+ Returns:
71
+ tuple: A two-element tuple:
72
+ - list of location IDs in the val split
73
+ - a dict mapping category names to the fraction of images in the val split
74
+ """
75
+
76
+ location_ids = list(location_to_category_counts.keys())
77
+
78
+ n_val_locations = int(target_val_fraction*len(location_ids))
79
+
80
+ if category_to_max_allowable_error is None:
81
+ category_to_max_allowable_error = {}
82
+
83
+ if category_to_error_weight is None:
84
+ category_to_error_weight = {}
85
+
86
+ # category ID to total count; the total count is used only for printouts
87
+ category_id_to_count = {}
88
+ for location_id in location_to_category_counts:
89
+ for category_id in location_to_category_counts[location_id].keys():
90
+ if category_id not in category_id_to_count:
91
+ category_id_to_count[category_id] = 0
92
+ category_id_to_count[category_id] += \
93
+ location_to_category_counts[location_id][category_id]
94
+
95
+ category_ids = set(category_id_to_count.keys())
96
+
97
+ print('Splitting {} categories over {} locations'.format(
98
+ len(category_ids),len(location_ids)))
99
+
100
+ # random_seed = 0
101
+ def compute_seed_errors(random_seed):
102
+ """
103
+ Computes the per-category error for a specific random seed.
104
+
105
+ returns weighted_average_error,category_to_val_fraction
106
+ """
107
+
108
+ # Randomly split into train/val
109
+ random.seed(random_seed)
110
+ val_locations = random.sample(location_ids,k=n_val_locations)
111
+ val_locations_set = set(val_locations)
112
+
113
+ # For each category, measure the % of images that went into the val set
114
+ category_to_val_fraction = defaultdict(float)
115
+
116
+ for category_id in category_ids:
117
+ category_val_count = 0
118
+ category_train_count = 0
119
+ for location_id in location_to_category_counts:
120
+ if category_id not in location_to_category_counts[location_id]:
121
+ location_category_count = 0
122
+ else:
123
+ location_category_count = location_to_category_counts[location_id][category_id]
124
+ if location_id in val_locations_set:
125
+ category_val_count += location_category_count
126
+ else:
127
+ category_train_count += location_category_count
128
+ category_val_fraction = category_val_count / (category_val_count + category_train_count)
129
+ category_to_val_fraction[category_id] = category_val_fraction
130
+
131
+ # Absolute deviation from the target val fraction for each category
132
+ category_errors = {}
133
+ weighted_category_errors = {}
134
+
135
+ # category = next(iter(category_to_val_fraction))
136
+ for category in category_to_val_fraction:
137
+
138
+ category_val_fraction = category_to_val_fraction[category]
139
+
140
+ category_error = abs(category_val_fraction-target_val_fraction)
141
+ category_errors[category] = category_error
142
+
143
+ category_weight = 1.0
144
+ if category in category_to_error_weight:
145
+ category_weight = category_to_error_weight[category]
146
+ weighted_category_error = category_error * category_weight
147
+ weighted_category_errors[category] = weighted_category_error
148
+
149
+ weighted_average_error = np.mean(list(weighted_category_errors.values()))
150
+
151
+ return weighted_average_error,weighted_category_errors,category_to_val_fraction
152
+
153
+ # ... def compute_seed_errors(...)
154
+
155
+ # This will only include random seeds that satisfy the hard constraints
156
+ random_seed_to_weighted_average_error = {}
157
+
158
+ # random_seed = 0
159
+ for random_seed in tqdm(range(0,n_random_seeds)):
160
+
161
+ weighted_average_error,weighted_category_errors,category_to_val_fraction = \
162
+ compute_seed_errors(random_seed)
163
+
164
+ seed_satisfies_hard_constraints = True
165
+
166
+ for category in category_to_val_fraction:
167
+ if category in category_to_max_allowable_error:
168
+ max_allowable_error = category_to_max_allowable_error[category]
169
+ else:
170
+ if default_max_allowable_error is None:
171
+ continue
172
+ max_allowable_error = default_max_allowable_error
173
+ val_fraction = category_to_val_fraction[category]
174
+
175
+ # If necessary, verify that this category doesn't *only* appear in train or val
176
+ if require_complete_coverage:
177
+ if (val_fraction == 0.0) or (val_fraction == 1.0):
178
+ seed_satisfies_hard_constraints = False
179
+ break
180
+
181
+ # Check whether this category exceeds the hard maximum deviation
182
+ category_error = abs(val_fraction - target_val_fraction)
183
+ if category_error > max_allowable_error:
184
+ seed_satisfies_hard_constraints = False
185
+ break
186
+
187
+ # ...for each category
188
+
189
+ if seed_satisfies_hard_constraints:
190
+ random_seed_to_weighted_average_error[random_seed] = weighted_average_error
191
+
192
+ # ...for each random seed
193
+
194
+ assert len(random_seed_to_weighted_average_error) > 0, \
195
+ 'No random seed met all the hard constraints'
196
+
197
+ print('\n{} of {} random seeds satisfied hard constraints'.format(
198
+ len(random_seed_to_weighted_average_error),n_random_seeds))
199
+
200
+ min_error = None
201
+ min_error_seed = None
202
+
203
+ for random_seed in random_seed_to_weighted_average_error.keys():
204
+ error_metric = random_seed_to_weighted_average_error[random_seed]
205
+ if min_error is None or error_metric < min_error:
206
+ min_error = error_metric
207
+ min_error_seed = random_seed
208
+
209
+ random.seed(min_error_seed)
210
+ val_locations = random.sample(location_ids,k=n_val_locations)
211
+ train_locations = []
212
+ for location_id in location_ids:
213
+ if location_id not in val_locations:
214
+ train_locations.append(location_id)
215
+
216
+ print('\nVal locations:\n')
217
+ for loc in val_locations:
218
+ print('{}'.format(loc))
219
+ print('')
220
+
221
+ weighted_average_error,weighted_category_errors,category_to_val_fraction = \
222
+ compute_seed_errors(min_error_seed)
223
+
224
+ category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction,
225
+ sort_values=category_id_to_count,
226
+ reverse=True)
227
+
228
+ print('Val fractions by category:\n')
229
+
230
+ for category in category_to_val_fraction:
231
+ print('{} ({}) {:.2f}'.format(
232
+ category,category_id_to_count[category],
233
+ category_to_val_fraction[category]))
234
+
235
+ return val_locations,category_to_val_fraction
236
+
237
+ # ...def split_locations_into_train_val(...)
@@ -0,0 +1,234 @@
1
+ """
2
+
3
+ string_utils.py
4
+
5
+ Miscellaneous string utilities.
6
+
7
+ """
8
+
9
+ #%% Imports
10
+
11
+ import re
12
+
13
+
14
+ #%% Functions
15
+
16
+ def is_float(s):
17
+ """
18
+ Checks whether [s] is an object (typically a string) that can be cast to a float
19
+
20
+ Args:
21
+ s (object): object to evaluate
22
+
23
+ Returns:
24
+ bool: True if s successfully casts to a float, otherwise False
25
+ """
26
+
27
+ if s is None:
28
+ return False
29
+
30
+ try:
31
+ _ = float(s)
32
+ except ValueError:
33
+ return False
34
+ return True
35
+
36
+
37
+ def is_int(s):
38
+ """
39
+ Checks whether [s] is an object (typically a string) that can be cast to a int
40
+
41
+ Args:
42
+ s (object): object to evaluate
43
+
44
+ Returns:
45
+ bool: True if s successfully casts to a int, otherwise False
46
+ """
47
+
48
+ if s is None:
49
+ return False
50
+
51
+ try:
52
+ _ = int(s)
53
+ except ValueError:
54
+ return False
55
+ return True
56
+
57
+
58
+ def human_readable_to_bytes(size):
59
+ """
60
+ Given a human-readable byte string (e.g. 2G, 10GB, 30MB, 20KB),
61
+ returns the number of bytes. Will return 0 if the argument has
62
+ unexpected form.
63
+
64
+ https://gist.github.com/beugley/ccd69945346759eb6142272a6d69b4e0
65
+
66
+ Args:
67
+ size (str): string representing a size
68
+
69
+ Returns:
70
+ int: the corresponding size in bytes
71
+ """
72
+
73
+ size = re.sub(r'\s+', '', size)
74
+
75
+ if not size: # Handle empty string case after stripping spaces
76
+ return 0
77
+
78
+ if (size[-1] == 'B'):
79
+ size = size[:-1]
80
+
81
+ if not size: # Handle case where size was just "B"
82
+ return 0
83
+
84
+ if (size.isdigit()):
85
+ bytes_val = int(size) # Renamed to avoid conflict with built-in 'bytes'
86
+ elif (is_float(size)):
87
+ bytes_val = float(size) # Renamed
88
+ else:
89
+ # Handle cases like "1KB" where size[:-1] might be "1K" before this block
90
+ # The original code would try to float("1K") which fails.
91
+ # Need to separate numeric part from unit more carefully.
92
+ numeric_part = ''
93
+ unit_part = ''
94
+
95
+ # Iterate from the end to find the unit (K, M, G, T)
96
+ # This handles cases like "10KB" or "2.5GB"
97
+ for i in range(len(size) -1, -1, -1):
98
+ if size[i].isalpha():
99
+ unit_part = size[i] + unit_part
100
+ else:
101
+ numeric_part = size[:i+1]
102
+ break
103
+
104
+ # If no unit found, or numeric part is empty after stripping unit
105
+ if not unit_part or not numeric_part:
106
+ return 0
107
+
108
+ try:
109
+ bytes_val = float(numeric_part)
110
+ unit = unit_part
111
+ if (unit == 'T'):
112
+ bytes_val *= 1024*1024*1024*1024
113
+ elif (unit == 'G'):
114
+ bytes_val *= 1024*1024*1024
115
+ elif (unit == 'M'):
116
+ bytes_val *= 1024*1024
117
+ elif (unit == 'K'):
118
+ bytes_val *= 1024
119
+ else:
120
+ # If it's a known unit (like 'B' already stripped) but not T/G/M/K,
121
+ # and it was floatable, it's just bytes. If it's an unknown unit, it's
122
+ # an error.
123
+ if unit not in ['B', '']: # 'B' was stripped, '' means just a number
124
+ bytes_val = 0
125
+ except ValueError:
126
+ bytes_val = 0
127
+
128
+ return bytes_val
129
+
130
+
131
+ def remove_ansi_codes(s):
132
+ """
133
+ Removes ANSI escape codes from a string.
134
+
135
+ https://stackoverflow.com/questions/14693701/how-can-i-remove-the-ansi-escape-sequences-from-a-string-in-python#14693789
136
+
137
+ Args:
138
+ s (str): the string to de-ANSI-i-fy
139
+
140
+ Returns:
141
+ str: A copy of [s] without ANSI codes
142
+ """
143
+
144
+ ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
145
+ return ansi_escape.sub('', s)
146
+
147
+
148
+ #%% Tests
149
+
150
+
151
+ class TestStringUtils:
152
+ """
153
+ Tests for string_utils.py
154
+ """
155
+
156
+
157
+ def test_is_float(self):
158
+ """
159
+ Test the is_float function.
160
+ """
161
+
162
+ assert is_float("1.23")
163
+ assert is_float("-0.5")
164
+ assert is_float("0")
165
+ assert is_float(1.23)
166
+ assert is_float(0)
167
+ assert not is_float("abc")
168
+ assert not is_float("1.2.3")
169
+ assert not is_float("")
170
+ assert not is_float(None)
171
+ assert not is_float("1,23")
172
+
173
+
174
+ def test_human_readable_to_bytes(self):
175
+ """
176
+ Test the human_readable_to_bytes function.
177
+ """
178
+
179
+ assert human_readable_to_bytes("10B") == 10
180
+ assert human_readable_to_bytes("10") == 10
181
+ assert human_readable_to_bytes("1K") == 1024
182
+ assert human_readable_to_bytes("1KB") == 1024
183
+ assert human_readable_to_bytes("1M") == 1024*1024
184
+ assert human_readable_to_bytes("1MB") == 1024*1024
185
+ assert human_readable_to_bytes("1G") == 1024*1024*1024
186
+ assert human_readable_to_bytes("1GB") == 1024*1024*1024
187
+ assert human_readable_to_bytes("1T") == 1024*1024*1024*1024
188
+ assert human_readable_to_bytes("1TB") == 1024*1024*1024*1024
189
+
190
+ assert human_readable_to_bytes("2.5K") == 2.5 * 1024
191
+ assert human_readable_to_bytes("0.5MB") == 0.5 * 1024 * 1024
192
+
193
+ # Test with spaces
194
+ assert human_readable_to_bytes(" 2 G ") == 2 * 1024*1024*1024
195
+ assert human_readable_to_bytes("500 KB") == 500 * 1024
196
+
197
+ # Invalid inputs
198
+ assert human_readable_to_bytes("abc") == 0
199
+ assert human_readable_to_bytes("1X") == 0
200
+ assert human_readable_to_bytes("1KBB") == 0
201
+ assert human_readable_to_bytes("K1") == 0
202
+ assert human_readable_to_bytes("") == 0
203
+ assert human_readable_to_bytes("1.2.3K") == 0
204
+ assert human_readable_to_bytes("B") == 0
205
+
206
+
207
+ def test_remove_ansi_codes(self):
208
+ """
209
+ Test the remove_ansi_codes function.
210
+ """
211
+
212
+ assert remove_ansi_codes("text without codes") == "text without codes"
213
+ assert remove_ansi_codes("\x1b[31mRed text\x1b[0m") == "Red text"
214
+ assert remove_ansi_codes("\x1b[1m\x1b[4mBold and Underline\x1b[0m") == "Bold and Underline"
215
+ assert remove_ansi_codes("Mixed \x1b[32mgreen\x1b[0m and normal") == "Mixed green and normal"
216
+ assert remove_ansi_codes("") == ""
217
+
218
+ # More complex/varied ANSI codes
219
+ assert remove_ansi_codes("text\x1b[1Aup") == "textup"
220
+ assert remove_ansi_codes("\x1b[2Jclearscreen") == "clearscreen"
221
+
222
+
223
+ def test_string_utils():
224
+ """
225
+ Runs all tests in the TestStringUtils class.
226
+ """
227
+
228
+ test_instance = TestStringUtils()
229
+ test_instance.test_is_float()
230
+ test_instance.test_human_readable_to_bytes()
231
+ test_instance.test_remove_ansi_codes()
232
+
233
+ # from IPython import embed; embed()
234
+ # test_string_utils()