megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,626 @@
1
+ """
2
+
3
+ create_classification_dataset.py
4
+
5
+ Creates a classification dataset CSV with a corresponding JSON file determining
6
+ the train/val/test split.
7
+
8
+ This script takes as input a "queried images" JSON file whose keys are paths to
9
+ images and values are dictionaries containing information relevant for training
10
+ a classifier, including labels and (optionally) ground-truth bounding boxes.
11
+ The image paths are in the format `<dataset-name>/<blob-name>` where we assume
12
+ that the dataset name does not contain '/'.
13
+
14
+ {
15
+ "caltech/cct_images/59f79901-23d2-11e8-a6a3-ec086b02610b.jpg": {
16
+ "dataset": "caltech",
17
+ "location": 13,
18
+ "class": "mountain_lion", # class from dataset
19
+ "bbox": [{"category": "animal",
20
+ "bbox": [0, 0.347, 0.237, 0.257]}], # ground-truth bbox
21
+ "label": ["monutain_lion"] # labels to use in classifier
22
+ },
23
+ "caltech/cct_images/59f5fe2b-23d2-11e8-a6a3-ec086b02610b.jpg": {
24
+ "dataset": "caltech",
25
+ "location": 13,
26
+ "class": "mountain_lion", # class from dataset
27
+ "label": ["monutain_lion"] # labels to use in classifier
28
+ },
29
+ ...
30
+ }
31
+
32
+ We assume that the tuple (dataset, location) identifies a unique location. In
33
+ other words, we assume that no two datasets have overlapping locations. This
34
+ probably isn't 100% true, but it's pretty much the best we can do in terms of
35
+ avoiding overlapping locations between the train/val/test splits.
36
+
37
+ This script outputs 3 files to <output_dir>:
38
+
39
+ 1) classification_ds.csv, contains columns:
40
+
41
+ - 'path': str, path to cropped images
42
+ - 'dataset': str, name of dataset
43
+ - 'location': str, location that image was taken, as saved in MegaDB
44
+ - 'dataset_class': str, original class assigned to image, as saved in MegaDB
45
+ - 'confidence': float, confidence that this crop is of an actual animal,
46
+ 1.0 if the crop is a "ground truth bounding box" (i.e., from MegaDB),
47
+ <= 1.0 if the bounding box was detected by MegaDetector
48
+ - 'label': str, comma-separated list of label(s) assigned to this crop for
49
+ the sake of classification
50
+
51
+ 2) label_index.json: maps integer to label name
52
+
53
+ - keys are string representations of Python integers (JSON requires keys to
54
+ be strings), numbered from 0 to num_labels-1
55
+ - values are strings, label names
56
+
57
+ 3) splits.json: serialization of a Python dict that maps each split
58
+ ['train', 'val', 'test'] to a list of length-2 lists, where each inner list
59
+ is [<dataset>, <location>]
60
+
61
+ """
62
+
63
+ #%% Imports and constants
64
+
65
+ from __future__ import annotations
66
+
67
+ import argparse
68
+ from collections.abc import Container, MutableMapping
69
+ import json
70
+ import os
71
+ from typing import Optional
72
+
73
+ import numpy as np
74
+ import pandas as pd
75
+ from tqdm import tqdm
76
+
77
+ from megadetector.classification import detect_and_crop
78
+ from megadetector.utils import ct_utils
79
+
80
+
81
+ #%% Example usage
82
+
83
+ """
84
+ python create_classification_dataset.py \
85
+ run_idfg2 \
86
+ --queried-images-json run_idfg2/queried_images.json \
87
+ --cropped-images-dir /ssd/crops_sq \
88
+ -d $HOME/classifier-training/mdcache -v "4.1" -t 0.8
89
+ """
90
+
91
+
92
+ DATASET_FILENAME = 'classification_ds.csv'
93
+ LABEL_INDEX_FILENAME = 'label_index.json'
94
+ SPLITS_FILENAME = 'splits.json'
95
+
96
+
97
+ #%% Main function
98
+
99
+ def main(output_dir: str,
100
+ mode: list[str],
101
+ match_test: Optional[list[str]],
102
+ queried_images_json_path: Optional[str],
103
+ cropped_images_dir: Optional[str],
104
+ detector_version: Optional[str],
105
+ detector_output_cache_base_dir: Optional[str],
106
+ confidence_threshold: Optional[float],
107
+ min_locs: Optional[int],
108
+ val_frac: Optional[float],
109
+ test_frac: Optional[float],
110
+ splits_method: Optional[str],
111
+ label_spec_json_path: Optional[str]) -> None:
112
+
113
+ # input validation
114
+ assert set(mode) <= {'csv', 'splits'}
115
+ if label_spec_json_path is not None:
116
+ assert splits_method == 'smallest_first'
117
+
118
+ test_set_locs = None # set of (dataset, location) tuples
119
+ test_set_df = None
120
+ if match_test is not None:
121
+ match_test_csv_path, match_test_splits_path = match_test
122
+ match_df = pd.read_csv(match_test_csv_path, index_col=False,
123
+ float_precision='high')
124
+ with open(match_test_splits_path, 'r') as f:
125
+ match_splits = json.load(f)
126
+ test_set_locs = set((loc[0], loc[1]) for loc in match_splits['test'])
127
+ ds_locs = pd.Series(zip(match_df['dataset'], match_df['location']))
128
+ test_set_df = match_df[ds_locs.isin(test_set_locs)]
129
+
130
+ dataset_path = os.path.join(output_dir, DATASET_FILENAME)
131
+
132
+ if 'csv' in mode:
133
+ assert queried_images_json_path is not None
134
+ assert cropped_images_dir is not None
135
+ assert detector_version is not None
136
+ assert detector_output_cache_base_dir is not None
137
+ assert confidence_threshold is not None
138
+
139
+ if not os.path.exists(output_dir):
140
+ os.makedirs(output_dir)
141
+ print(f'Created {output_dir}')
142
+
143
+ df, log = create_classification_csv(
144
+ queried_images_json_path=queried_images_json_path,
145
+ detector_output_cache_base_dir=detector_output_cache_base_dir,
146
+ detector_version=detector_version,
147
+ cropped_images_dir=cropped_images_dir,
148
+ confidence_threshold=confidence_threshold,
149
+ min_locs=min_locs,
150
+ append_df=test_set_df,
151
+ exclude_locs=test_set_locs)
152
+ print('Saving classification dataset CSV')
153
+ df.to_csv(dataset_path, index=False)
154
+ for msg, img_list in log.items():
155
+ print(f'{msg}:', len(img_list))
156
+
157
+ # create label index JSON
158
+ labels = df['label']
159
+ if any(labels.str.contains(',')):
160
+ print('multi-label!')
161
+ labels = labels.map(lambda x: x.split(',')).explode()
162
+ # look into sklearn.preprocessing.MultiLabelBinarizer
163
+ label_names = sorted(labels.unique())
164
+ # Note: JSON always saves keys as strings!
165
+ ct_utils.write_json(os.path.join(output_dir, LABEL_INDEX_FILENAME), dict(enumerate(label_names)))
166
+
167
+ if 'splits' in mode:
168
+ assert splits_method is not None
169
+ assert val_frac is not None
170
+ assert (match_test is None) != (test_frac is None)
171
+ if test_frac is None:
172
+ test_frac = 0
173
+
174
+ print(f'Creating splits using "{splits_method}" method...')
175
+ df = pd.read_csv(dataset_path, index_col=False, float_precision='high')
176
+
177
+ if splits_method == 'random':
178
+ split_to_locs = create_splits_random(
179
+ df, val_frac, test_frac, test_split=test_set_locs)
180
+ else:
181
+ split_to_locs = create_splits_smallest_label_first(
182
+ df, val_frac, test_frac, test_split=test_set_locs,
183
+ label_spec_json_path=label_spec_json_path)
184
+ ct_utils.write_json(os.path.join(output_dir, SPLITS_FILENAME), split_to_locs)
185
+
186
+
187
+ #%% Support functions
188
+
189
+ def create_classification_csv(
190
+ queried_images_json_path: str,
191
+ detector_output_cache_base_dir: str,
192
+ detector_version: str,
193
+ cropped_images_dir: str,
194
+ confidence_threshold: float,
195
+ min_locs: Optional[int] = None,
196
+ append_df: Optional[pd.DataFrame] = None,
197
+ exclude_locs: Optional[Container[tuple[str, str]]] = None
198
+ ) -> tuple[pd.DataFrame, dict[str, list]]:
199
+ """
200
+ Creates a classification dataset.
201
+
202
+ The classification dataset is a pd.DataFrame with columns:
203
+ - path: str, <dataset>/<crop-filename>
204
+ - dataset: str, name of camera trap dataset
205
+ - location: str, location of image, provided by the camera trap dataset
206
+ - dataset_class: image class, as provided by the camera trap dataset
207
+ - confidence: float, confidence of bounding box, 1 if ground truth
208
+ - label: str, comma-separated list of classification labels
209
+
210
+ Args:
211
+ queried_images_json_path: str, path to output of json_validator.py
212
+ detector_version: str, detector version string, e.g., '4.1',
213
+ see {batch_detection_api_url}/supported_model_versions,
214
+ determines the subfolder of detector_output_cache_base_dir in
215
+ which to find and save detector outputs
216
+ detector_output_cache_base_dir: str, path to local directory
217
+ where detector outputs are cached, 1 JSON file per dataset
218
+ cropped_images_dir: str, path to local directory for saving crops of
219
+ bounding boxes
220
+ confidence_threshold: float, only crop bounding boxes above this value
221
+ min_locs: optional int, minimum # of locations that each label must
222
+ have in order to be included
223
+ append_df: optional pd.DataFrame, existing DataFrame that is appended to
224
+ the classification CSV
225
+ exclude_locs: optional set of (dataset, location) tuples, crops from
226
+ these locations are excluded (does not affect append_df)
227
+
228
+ Returns:
229
+ df: pd.DataFrame, the classification dataset
230
+ log: dict, with the following keys
231
+ 'images missing detections': list of str, images without ground
232
+ truth bboxes and not in detection cache
233
+ 'images without confident detections': list of str, images in
234
+ detection cache with no bboxes above the confidence threshold
235
+ 'missing crops': list of tuple (img_path, i), where i is the
236
+ i-th crop index
237
+ """
238
+
239
+ assert 0 <= confidence_threshold <= 1
240
+
241
+ columns = [
242
+ 'path', 'dataset', 'location', 'dataset_class', 'confidence', 'label']
243
+ if append_df is not None:
244
+ assert (append_df.columns == columns).all()
245
+
246
+ with open(queried_images_json_path, 'r') as f:
247
+ js = json.load(f)
248
+
249
+ print('loading detection cache...', end='')
250
+ detector_output_cache_dir = os.path.join(
251
+ detector_output_cache_base_dir, f'v{detector_version}')
252
+ datasets = set(img_path[:img_path.find('/')] for img_path in js)
253
+ detection_cache, cat_id_to_name = detect_and_crop.load_detection_cache(
254
+ detector_output_cache_dir=detector_output_cache_dir, datasets=datasets)
255
+ print('done!')
256
+
257
+ missing_detections = [] # no cached detections or ground truth bboxes
258
+ images_no_confident_detections = [] # cached detections contain 0 bboxes
259
+ images_missing_crop = [] # tuples: (img_path, crop_index)
260
+ all_rows = []
261
+
262
+ # True for ground truth, False for MegaDetector
263
+ # always save as .jpg for consistency
264
+ crop_path_template = {
265
+ True: '{img_path}___crop{n:>02d}.jpg',
266
+ False: '{img_path}___crop{n:>02d}_' + f'mdv{detector_version}.jpg'
267
+ }
268
+
269
+ for img_path, img_info in tqdm(js.items()):
270
+ ds, img_file = img_path.split('/', maxsplit=1)
271
+
272
+ # get bounding boxes
273
+ if 'bbox' in img_info: # ground-truth bounding boxes
274
+ bbox_dicts = img_info['bbox']
275
+ is_ground_truth = True
276
+ else: # get bounding boxes from detector cache
277
+ if img_file in detection_cache[ds]:
278
+ bbox_dicts = detection_cache[ds][img_file]['detections']
279
+ # convert from category ID to category name
280
+ for d in bbox_dicts:
281
+ d['category'] = cat_id_to_name[d['category']]
282
+ else:
283
+ missing_detections.append(img_path)
284
+ continue
285
+ is_ground_truth = False
286
+
287
+ # check if crops are already downloaded, and ignore bboxes below the
288
+ # confidence threshold
289
+ rows = []
290
+ for i, bbox_dict in enumerate(bbox_dicts):
291
+ conf = 1.0 if is_ground_truth else bbox_dict['conf']
292
+ if conf < confidence_threshold:
293
+ continue
294
+ if bbox_dict['category'] != 'animal':
295
+ tqdm.write(f'Bbox {i} of {img_path} is non-animal. Skipping.')
296
+ continue
297
+ crop_path = crop_path_template[is_ground_truth].format(
298
+ img_path=img_path, n=i)
299
+ full_crop_path = os.path.join(cropped_images_dir, crop_path)
300
+ if not os.path.exists(full_crop_path):
301
+ images_missing_crop.append((img_path, i))
302
+ else:
303
+ # assign all images without location info to 'unknown_location'
304
+ img_loc = img_info.get('location', 'unknown_location')
305
+ row = [crop_path, ds, img_loc, img_info['class'],
306
+ conf, ','.join(img_info['label'])]
307
+ rows.append(row)
308
+ if len(rows) == 0:
309
+ images_no_confident_detections.append(img_path)
310
+ continue
311
+ all_rows += rows
312
+
313
+ df = pd.DataFrame(data=all_rows, columns=columns)
314
+
315
+ # remove images from labels that have fewer than min_locs locations
316
+ if min_locs is not None:
317
+ nlocs_per_label = df.groupby('label').apply(
318
+ lambda xdf: len(xdf[['dataset', 'location']].drop_duplicates()))
319
+ valid_labels_mask = (nlocs_per_label >= min_locs)
320
+ valid_labels = nlocs_per_label.index[valid_labels_mask]
321
+ invalid_labels = nlocs_per_label.index[~valid_labels_mask]
322
+ orig = len(df)
323
+ df = df[df['label'].isin(valid_labels)]
324
+ print(f'Excluding {orig - len(df)} crops from {len(invalid_labels)} '
325
+ 'labels:', invalid_labels.tolist())
326
+
327
+ if exclude_locs is not None:
328
+ mask = ~pd.Series(zip(df['dataset'], df['location'])).isin(exclude_locs)
329
+ print(f'Excluding {(~mask).sum()} crops from CSV')
330
+ df = df[mask]
331
+ if append_df is not None:
332
+ print(f'Appending {len(append_df)} rows to CSV')
333
+ df = df.append(append_df)
334
+
335
+ log = {
336
+ 'images missing detections': missing_detections,
337
+ 'images without confident detections': images_no_confident_detections,
338
+ 'missing crops': images_missing_crop
339
+ }
340
+ return df, log
341
+
342
+
343
+ def create_splits_random(df: pd.DataFrame, val_frac: float,
344
+ test_frac: float = 0.,
345
+ test_split: Optional[set[tuple[str, str]]] = None,
346
+ ) -> dict[str, list[tuple[str, str]]]:
347
+ """
348
+ Args:
349
+ df: pd.DataFrame, contains columns ['dataset', 'location', 'label']
350
+ each row is a single image
351
+ assumes each image is assigned exactly 1 label
352
+ val_frac: float, desired fraction of dataset to use for val set
353
+ test_frac: float, desired fraction of dataset to use for test set,
354
+ must be 0 if test_split is given
355
+ test_split: optional set of (dataset, location) tuples to use as test
356
+ split
357
+
358
+ Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
359
+ where each loc is a tuple (dataset, location)
360
+ """
361
+
362
+ if test_split is not None:
363
+ assert test_frac == 0
364
+ train_frac = 1. - val_frac - test_frac
365
+ targets = {'train': train_frac, 'val': val_frac, 'test': test_frac}
366
+
367
+ # merge dataset and location into a single string '<dataset>/<location>'
368
+ df['dataset_location'] = df['dataset'] + '/' + df['location']
369
+
370
+ # create DataFrame of counts. rows = locations, columns = labels
371
+ loc_label_counts = (df.groupby(['label', 'dataset_location']).size()
372
+ .unstack('label', fill_value=0))
373
+ num_locs = len(loc_label_counts)
374
+
375
+ # label_count: label => number of examples
376
+ # loc_count: label => number of locs containing that label
377
+ label_count = loc_label_counts.sum()
378
+ loc_count = (loc_label_counts > 0).sum()
379
+
380
+ best_score = np.inf # lower is better
381
+ best_splits = None
382
+ for _ in tqdm(range(10_000)):
383
+
384
+ # generate a new split
385
+ num_train = int(num_locs * (train_frac + np.random.uniform(-.03, .03)))
386
+ if test_frac > 0:
387
+ num_val = int(num_locs * (val_frac + np.random.uniform(-.03, .03)))
388
+ else:
389
+ num_val = num_locs - num_train
390
+ permuted_locs = loc_label_counts.index[np.random.permutation(num_locs)]
391
+ split_to_locs = {'train': permuted_locs[:num_train],
392
+ 'val': permuted_locs[num_train:num_train + num_val]}
393
+ if test_frac > 0:
394
+ split_to_locs['test'] = permuted_locs[num_train + num_val:]
395
+
396
+ # score the split
397
+ score = 0.
398
+ for split, locs in split_to_locs.items():
399
+ split_df = loc_label_counts.loc[locs]
400
+ target = targets[split]
401
+
402
+ # SSE for # of images per label (with 2x weight)
403
+ crop_frac = split_df.sum() / label_count
404
+ score += 2 * ((crop_frac - target) ** 2).sum()
405
+
406
+ # SSE for # of locs per label
407
+ loc_frac = (split_df > 0).sum() / loc_count
408
+ score += ((loc_frac - target) ** 2).sum()
409
+
410
+ if score < best_score:
411
+ tqdm.write(f'New lowest score: {score}')
412
+ best_score = score
413
+ best_splits = split_to_locs
414
+
415
+ assert best_splits is not None
416
+ split_to_locs = {
417
+ s: sorted(locs.map(lambda x: tuple(x.split('/', maxsplit=1))))
418
+ for s, locs in best_splits.items()
419
+ }
420
+ if test_split is not None:
421
+ split_to_locs['test'] = test_split
422
+ return split_to_locs
423
+
424
+
425
+ def create_splits_smallest_label_first(
426
+ df: pd.DataFrame,
427
+ val_frac: float,
428
+ test_frac: float = 0.,
429
+ label_spec_json_path: Optional[str] = None,
430
+ test_split: Optional[set[tuple[str, str]]] = None,
431
+ ) -> dict[str, list[tuple[str, str]]]:
432
+ """
433
+ Args:
434
+ df: pd.DataFrame, contains columns ['dataset', 'location', 'label']
435
+ each row is a single image
436
+ assumes each image is assigned exactly 1 label
437
+ val_frac: float, desired fraction of dataset to use for val set
438
+ test_frac: float, desired fraction of dataset to use for test set,
439
+ must be 0 if test_split is given
440
+ label_spec_json_path: optional str, path to label spec JSON
441
+ test_split: optional set of (dataset, location) tuples to use as test
442
+ split
443
+
444
+ Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
445
+ where each loc is a tuple (dataset, location)
446
+ """
447
+
448
+ # label => list of datasets to prioritize for test and validation sets
449
+ prioritize = {}
450
+ if label_spec_json_path is not None:
451
+ with open(label_spec_json_path, 'r') as f:
452
+ label_spec_js = json.load(f)
453
+ for label, label_spec in label_spec_js.items():
454
+ if 'prioritize' in label_spec:
455
+ datasets = []
456
+ for level in label_spec['prioritize']:
457
+ datasets += level
458
+ prioritize[label] = datasets
459
+
460
+ # merge dataset and location into a tuple (dataset, location)
461
+ df['dataset_location'] = list(zip(df['dataset'], df['location']))
462
+ loc_to_label_sizes = df.groupby(['dataset_location', 'label']).size()
463
+
464
+ seen_locs = set()
465
+ split_to_locs: dict[str, list[tuple[str, str]]] = dict(
466
+ train=[], val=[], test=[])
467
+ label_sizes_by_split = {
468
+ label: dict(train=0, val=0, test=0)
469
+ for label in df['label'].unique()
470
+ }
471
+ if test_split is not None:
472
+ assert test_frac == 0
473
+ split_to_locs['test'] = list(test_split)
474
+ seen_locs.update(test_split)
475
+
476
+ def add_loc_to_split(loc: tuple[str, str], split: str) -> None:
477
+ split_to_locs[split].append(loc)
478
+ for label, label_size in loc_to_label_sizes[loc].items():
479
+ label_sizes_by_split[label][split] += label_size
480
+
481
+ # sorted smallest to largest
482
+ ordered_labels = df.groupby('label').size().sort_values()
483
+ for label, label_size in tqdm(ordered_labels.items()):
484
+
485
+ split_sizes = label_sizes_by_split[label]
486
+ test_thresh = test_frac * label_size
487
+ val_thresh = val_frac * label_size
488
+
489
+ mask = df['label'] == label
490
+ ordered_locs = sort_locs_by_size(
491
+ loc_to_size=df[mask].groupby('dataset_location').size().to_dict(),
492
+ prioritize=prioritize.get(label, None))
493
+ ordered_locs = [loc for loc in ordered_labels if loc not in seen_locs]
494
+
495
+ for loc in ordered_locs:
496
+ seen_locs.add(loc)
497
+ # greedily add to test set until it has >= 15% of images
498
+ if split_sizes['test'] < test_thresh:
499
+ split = 'test'
500
+ elif split_sizes['val'] < val_thresh:
501
+ split = 'val'
502
+ else:
503
+ split = 'train'
504
+ add_loc_to_split(loc, split)
505
+ seen_locs.update(ordered_locs)
506
+
507
+ # sort the resulting locs
508
+ split_to_locs = {s: sorted(locs) for s, locs in split_to_locs.items()}
509
+ return split_to_locs
510
+
511
+
512
+ def sort_locs_by_size(loc_to_size: MutableMapping[tuple[str, str], int],
513
+ prioritize: Optional[Container[str]] = None
514
+ ) -> list[tuple[str, str]]:
515
+ """
516
+ Sorts locations by size, optionally prioritizing locations from certain
517
+ datasets first.
518
+
519
+ Args:
520
+ loc_to_size: dict, maps each (dataset, location) tuple to its size,
521
+ modified in-place
522
+ prioritize: optional list of str, datasets to prioritize
523
+
524
+ Returns: list of (dataset, location) tuples, ordered from smallest size to
525
+ largest. Locations from prioritized datasets come first.
526
+ """
527
+
528
+ result = []
529
+ if prioritize is not None:
530
+ # modify loc_to_size in place, so copy its keys before iterating
531
+ prioritized_loc_to_size = {
532
+ loc: loc_to_size.pop(loc) for loc in list(loc_to_size.keys())
533
+ if loc[0] in prioritize
534
+ }
535
+ result = sort_locs_by_size(prioritized_loc_to_size)
536
+
537
+ result += sorted(loc_to_size, key=loc_to_size.__getitem__)
538
+ return result
539
+
540
+
541
+ #%% Command-line driver
542
+
543
+ def _parse_args() -> argparse.Namespace:
544
+
545
+ parser = argparse.ArgumentParser(
546
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
547
+ description='Creates classification dataset.')
548
+
549
+ # arguments relevant to both creating the dataset CSV and splits.json
550
+ parser.add_argument(
551
+ 'output_dir',
552
+ help='path to directory where the 3 output files should be '
553
+ 'saved: 1) dataset CSV, 2) label index JSON, 3) splits JSON')
554
+ parser.add_argument(
555
+ '--mode', nargs='+', choices=['csv', 'splits'],
556
+ default=['csv', 'splits'],
557
+ help='whether to generate only a CSV, only a splits.json file (based '
558
+ 'on an existing classification_ds.csv), or both')
559
+ parser.add_argument(
560
+ '--match-test', nargs=2, metavar=('CLASSIFICATION_CSV', 'SPLITS_JSON'),
561
+ help='path to an existing classification CSV and path to an existing '
562
+ 'splits JSON file from which to match test set')
563
+
564
+ # arguments only relevant for creating the dataset CSV
565
+ csv_group = parser.add_argument_group(
566
+ 'arguments for creating classification CSV')
567
+ csv_group.add_argument(
568
+ '-q', '--queried-images-json',
569
+ help='path to JSON file containing image paths and classification info')
570
+ csv_group.add_argument(
571
+ '-c', '--cropped-images-dir',
572
+ help='path to local directory for saving crops of bounding boxes')
573
+ csv_group.add_argument(
574
+ '-d', '--detector-output-cache-dir',
575
+ help='(required) path to directory where detector outputs are cached')
576
+ csv_group.add_argument(
577
+ '-v', '--detector-version',
578
+ help='(required) detector version string, e.g., "4.1"')
579
+ csv_group.add_argument(
580
+ '-t', '--threshold', type=float, default=0.8,
581
+ help='confidence threshold above which to crop bounding boxes')
582
+ csv_group.add_argument(
583
+ '--min-locs', type=int,
584
+ help='minimum number of locations that each label must have in order '
585
+ 'to be included (does not apply to match-test-splits)')
586
+
587
+ # arguments only relevant for creating the splits JSON
588
+ splits_group = parser.add_argument_group(
589
+ 'arguments for creating train/val/test splits')
590
+ splits_group.add_argument(
591
+ '--val-frac', type=float,
592
+ help='(required) fraction of data to use for validation split')
593
+ splits_group.add_argument(
594
+ '--test-frac', type=float,
595
+ help='fraction of data to use for test split, must be provided if '
596
+ '--match-test is not given')
597
+ splits_group.add_argument(
598
+ '--method', choices=['random', 'smallest_first'], default='random',
599
+ help='"random": randomly tries up to 10,000 different train/val/test '
600
+ 'splits and chooses the one that best meets the scoring criteria, '
601
+ 'does not support --label-spec. '
602
+ '"smallest_first": greedily divides locations into splits '
603
+ 'starting with the smallest class first. Supports --label-spec.')
604
+ splits_group.add_argument(
605
+ '--label-spec',
606
+ help='optional path to label specification JSON file, if specifying '
607
+ 'dataset priority. Requires --method=smallest_first.')
608
+ return parser.parse_args()
609
+
610
+
611
+ if __name__ == '__main__':
612
+
613
+ args = _parse_args()
614
+ main(output_dir=args.output_dir,
615
+ mode=args.mode,
616
+ match_test=args.match_test,
617
+ queried_images_json_path=args.queried_images_json,
618
+ cropped_images_dir=args.cropped_images_dir,
619
+ detector_version=args.detector_version,
620
+ detector_output_cache_base_dir=args.detector_output_cache_dir,
621
+ confidence_threshold=args.threshold,
622
+ min_locs=args.min_locs,
623
+ val_frac=args.val_frac,
624
+ test_frac=args.test_frac,
625
+ splits_method=args.method,
626
+ label_spec_json_path=args.label_spec)