megadetector 5.0.11__py3-none-any.whl → 5.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (201) hide show
  1. megadetector/api/__init__.py +0 -0
  2. megadetector/api/batch_processing/__init__.py +0 -0
  3. megadetector/api/batch_processing/api_core/__init__.py +0 -0
  4. megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
  5. megadetector/api/batch_processing/api_core/batch_service/score.py +439 -0
  6. megadetector/api/batch_processing/api_core/server.py +294 -0
  7. megadetector/api/batch_processing/api_core/server_api_config.py +98 -0
  8. megadetector/api/batch_processing/api_core/server_app_config.py +55 -0
  9. megadetector/api/batch_processing/api_core/server_batch_job_manager.py +220 -0
  10. megadetector/api/batch_processing/api_core/server_job_status_table.py +152 -0
  11. megadetector/api/batch_processing/api_core/server_orchestration.py +360 -0
  12. megadetector/api/batch_processing/api_core/server_utils.py +92 -0
  13. megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
  14. megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +46 -0
  15. megadetector/api/batch_processing/api_support/__init__.py +0 -0
  16. megadetector/api/batch_processing/api_support/summarize_daily_activity.py +152 -0
  17. megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
  18. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  19. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  20. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  21. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +126 -0
  22. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  23. megadetector/api/synchronous/__init__.py +0 -0
  24. megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  25. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +152 -0
  26. megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +266 -0
  27. megadetector/api/synchronous/api_core/animal_detection_api/config.py +35 -0
  28. megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
  29. megadetector/api/synchronous/api_core/tests/load_test.py +110 -0
  30. megadetector/classification/__init__.py +0 -0
  31. megadetector/classification/aggregate_classifier_probs.py +108 -0
  32. megadetector/classification/analyze_failed_images.py +227 -0
  33. megadetector/classification/cache_batchapi_outputs.py +198 -0
  34. megadetector/classification/create_classification_dataset.py +627 -0
  35. megadetector/classification/crop_detections.py +516 -0
  36. megadetector/classification/csv_to_json.py +226 -0
  37. megadetector/classification/detect_and_crop.py +855 -0
  38. megadetector/classification/efficientnet/__init__.py +9 -0
  39. megadetector/classification/efficientnet/model.py +415 -0
  40. megadetector/classification/efficientnet/utils.py +610 -0
  41. megadetector/classification/evaluate_model.py +520 -0
  42. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  43. megadetector/classification/json_to_azcopy_list.py +63 -0
  44. megadetector/classification/json_validator.py +699 -0
  45. megadetector/classification/map_classification_categories.py +276 -0
  46. megadetector/classification/merge_classification_detection_output.py +506 -0
  47. megadetector/classification/prepare_classification_script.py +194 -0
  48. megadetector/classification/prepare_classification_script_mc.py +228 -0
  49. megadetector/classification/run_classifier.py +287 -0
  50. megadetector/classification/save_mislabeled.py +110 -0
  51. megadetector/classification/train_classifier.py +827 -0
  52. megadetector/classification/train_classifier_tf.py +725 -0
  53. megadetector/classification/train_utils.py +323 -0
  54. megadetector/data_management/__init__.py +0 -0
  55. megadetector/data_management/annotations/__init__.py +0 -0
  56. megadetector/data_management/annotations/annotation_constants.py +34 -0
  57. megadetector/data_management/camtrap_dp_to_coco.py +239 -0
  58. megadetector/data_management/cct_json_utils.py +395 -0
  59. megadetector/data_management/cct_to_md.py +176 -0
  60. megadetector/data_management/cct_to_wi.py +289 -0
  61. megadetector/data_management/coco_to_labelme.py +272 -0
  62. megadetector/data_management/coco_to_yolo.py +662 -0
  63. megadetector/data_management/databases/__init__.py +0 -0
  64. megadetector/data_management/databases/add_width_and_height_to_db.py +33 -0
  65. megadetector/data_management/databases/combine_coco_camera_traps_files.py +206 -0
  66. megadetector/data_management/databases/integrity_check_json_db.py +477 -0
  67. megadetector/data_management/databases/subset_json_db.py +115 -0
  68. megadetector/data_management/generate_crops_from_cct.py +149 -0
  69. megadetector/data_management/get_image_sizes.py +189 -0
  70. megadetector/data_management/importers/add_nacti_sizes.py +52 -0
  71. megadetector/data_management/importers/add_timestamps_to_icct.py +79 -0
  72. megadetector/data_management/importers/animl_results_to_md_results.py +158 -0
  73. megadetector/data_management/importers/auckland_doc_test_to_json.py +373 -0
  74. megadetector/data_management/importers/auckland_doc_to_json.py +201 -0
  75. megadetector/data_management/importers/awc_to_json.py +191 -0
  76. megadetector/data_management/importers/bellevue_to_json.py +273 -0
  77. megadetector/data_management/importers/cacophony-thermal-importer.py +796 -0
  78. megadetector/data_management/importers/carrizo_shrubfree_2018.py +269 -0
  79. megadetector/data_management/importers/carrizo_trail_cam_2017.py +289 -0
  80. megadetector/data_management/importers/cct_field_adjustments.py +58 -0
  81. megadetector/data_management/importers/channel_islands_to_cct.py +913 -0
  82. megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +180 -0
  83. megadetector/data_management/importers/eMammal/eMammal_helpers.py +249 -0
  84. megadetector/data_management/importers/eMammal/make_eMammal_json.py +223 -0
  85. megadetector/data_management/importers/ena24_to_json.py +276 -0
  86. megadetector/data_management/importers/filenames_to_json.py +386 -0
  87. megadetector/data_management/importers/helena_to_cct.py +283 -0
  88. megadetector/data_management/importers/idaho-camera-traps.py +1407 -0
  89. megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +294 -0
  90. megadetector/data_management/importers/jb_csv_to_json.py +150 -0
  91. megadetector/data_management/importers/mcgill_to_json.py +250 -0
  92. megadetector/data_management/importers/missouri_to_json.py +490 -0
  93. megadetector/data_management/importers/nacti_fieldname_adjustments.py +79 -0
  94. megadetector/data_management/importers/noaa_seals_2019.py +181 -0
  95. megadetector/data_management/importers/pc_to_json.py +365 -0
  96. megadetector/data_management/importers/plot_wni_giraffes.py +123 -0
  97. megadetector/data_management/importers/prepare-noaa-fish-data-for-lila.py +359 -0
  98. megadetector/data_management/importers/prepare_zsl_imerit.py +131 -0
  99. megadetector/data_management/importers/rspb_to_json.py +356 -0
  100. megadetector/data_management/importers/save_the_elephants_survey_A.py +320 -0
  101. megadetector/data_management/importers/save_the_elephants_survey_B.py +329 -0
  102. megadetector/data_management/importers/snapshot_safari_importer.py +758 -0
  103. megadetector/data_management/importers/snapshot_safari_importer_reprise.py +665 -0
  104. megadetector/data_management/importers/snapshot_serengeti_lila.py +1067 -0
  105. megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +150 -0
  106. megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +153 -0
  107. megadetector/data_management/importers/sulross_get_exif.py +65 -0
  108. megadetector/data_management/importers/timelapse_csv_set_to_json.py +490 -0
  109. megadetector/data_management/importers/ubc_to_json.py +399 -0
  110. megadetector/data_management/importers/umn_to_json.py +507 -0
  111. megadetector/data_management/importers/wellington_to_json.py +263 -0
  112. megadetector/data_management/importers/wi_to_json.py +442 -0
  113. megadetector/data_management/importers/zamba_results_to_md_results.py +181 -0
  114. megadetector/data_management/labelme_to_coco.py +547 -0
  115. megadetector/data_management/labelme_to_yolo.py +272 -0
  116. megadetector/data_management/lila/__init__.py +0 -0
  117. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +97 -0
  118. megadetector/data_management/lila/add_locations_to_nacti.py +147 -0
  119. megadetector/data_management/lila/create_lila_blank_set.py +558 -0
  120. megadetector/data_management/lila/create_lila_test_set.py +152 -0
  121. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  122. megadetector/data_management/lila/download_lila_subset.py +178 -0
  123. megadetector/data_management/lila/generate_lila_per_image_labels.py +516 -0
  124. megadetector/data_management/lila/get_lila_annotation_counts.py +170 -0
  125. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  126. megadetector/data_management/lila/lila_common.py +300 -0
  127. megadetector/data_management/lila/test_lila_metadata_urls.py +132 -0
  128. megadetector/data_management/ocr_tools.py +874 -0
  129. megadetector/data_management/read_exif.py +681 -0
  130. megadetector/data_management/remap_coco_categories.py +84 -0
  131. megadetector/data_management/remove_exif.py +66 -0
  132. megadetector/data_management/resize_coco_dataset.py +189 -0
  133. megadetector/data_management/wi_download_csv_to_coco.py +246 -0
  134. megadetector/data_management/yolo_output_to_md_output.py +441 -0
  135. megadetector/data_management/yolo_to_coco.py +676 -0
  136. megadetector/detection/__init__.py +0 -0
  137. megadetector/detection/detector_training/__init__.py +0 -0
  138. megadetector/detection/detector_training/model_main_tf2.py +114 -0
  139. megadetector/detection/process_video.py +702 -0
  140. megadetector/detection/pytorch_detector.py +341 -0
  141. megadetector/detection/run_detector.py +779 -0
  142. megadetector/detection/run_detector_batch.py +1219 -0
  143. megadetector/detection/run_inference_with_yolov5_val.py +917 -0
  144. megadetector/detection/run_tiled_inference.py +934 -0
  145. megadetector/detection/tf_detector.py +189 -0
  146. megadetector/detection/video_utils.py +606 -0
  147. megadetector/postprocessing/__init__.py +0 -0
  148. megadetector/postprocessing/add_max_conf.py +64 -0
  149. megadetector/postprocessing/categorize_detections_by_size.py +163 -0
  150. megadetector/postprocessing/combine_api_outputs.py +249 -0
  151. megadetector/postprocessing/compare_batch_results.py +958 -0
  152. megadetector/postprocessing/convert_output_format.py +396 -0
  153. megadetector/postprocessing/load_api_results.py +195 -0
  154. megadetector/postprocessing/md_to_coco.py +310 -0
  155. megadetector/postprocessing/md_to_labelme.py +330 -0
  156. megadetector/postprocessing/merge_detections.py +401 -0
  157. megadetector/postprocessing/postprocess_batch_results.py +1902 -0
  158. megadetector/postprocessing/remap_detection_categories.py +170 -0
  159. megadetector/postprocessing/render_detection_confusion_matrix.py +660 -0
  160. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +211 -0
  161. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +83 -0
  162. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1631 -0
  163. megadetector/postprocessing/separate_detections_into_folders.py +730 -0
  164. megadetector/postprocessing/subset_json_detector_output.py +696 -0
  165. megadetector/postprocessing/top_folders_to_bottom.py +223 -0
  166. megadetector/taxonomy_mapping/__init__.py +0 -0
  167. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  168. megadetector/taxonomy_mapping/map_new_lila_datasets.py +150 -0
  169. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +142 -0
  170. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +590 -0
  171. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  172. megadetector/taxonomy_mapping/simple_image_download.py +219 -0
  173. megadetector/taxonomy_mapping/species_lookup.py +834 -0
  174. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  175. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  176. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  177. megadetector/utils/__init__.py +0 -0
  178. megadetector/utils/azure_utils.py +178 -0
  179. megadetector/utils/ct_utils.py +612 -0
  180. megadetector/utils/directory_listing.py +246 -0
  181. megadetector/utils/md_tests.py +968 -0
  182. megadetector/utils/path_utils.py +1044 -0
  183. megadetector/utils/process_utils.py +157 -0
  184. megadetector/utils/sas_blob_utils.py +509 -0
  185. megadetector/utils/split_locations_into_train_val.py +228 -0
  186. megadetector/utils/string_utils.py +92 -0
  187. megadetector/utils/url_utils.py +323 -0
  188. megadetector/utils/write_html_image_list.py +225 -0
  189. megadetector/visualization/__init__.py +0 -0
  190. megadetector/visualization/plot_utils.py +293 -0
  191. megadetector/visualization/render_images_with_thumbnails.py +275 -0
  192. megadetector/visualization/visualization_utils.py +1536 -0
  193. megadetector/visualization/visualize_db.py +550 -0
  194. megadetector/visualization/visualize_detector_output.py +405 -0
  195. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/METADATA +1 -1
  196. megadetector-5.0.12.dist-info/RECORD +199 -0
  197. megadetector-5.0.12.dist-info/top_level.txt +1 -0
  198. megadetector-5.0.11.dist-info/RECORD +0 -5
  199. megadetector-5.0.11.dist-info/top_level.txt +0 -1
  200. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/LICENSE +0 -0
  201. {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/WHEEL +0 -0
@@ -0,0 +1,627 @@
1
+ """
2
+
3
+ create_classification_dataset.py
4
+
5
+ Creates a classification dataset CSV with a corresponding JSON file determining
6
+ the train/val/test split.
7
+
8
+ This script takes as input a "queried images" JSON file whose keys are paths to
9
+ images and values are dictionaries containing information relevant for training
10
+ a classifier, including labels and (optionally) ground-truth bounding boxes.
11
+ The image paths are in the format `<dataset-name>/<blob-name>` where we assume
12
+ that the dataset name does not contain '/'.
13
+
14
+ {
15
+ "caltech/cct_images/59f79901-23d2-11e8-a6a3-ec086b02610b.jpg": {
16
+ "dataset": "caltech",
17
+ "location": 13,
18
+ "class": "mountain_lion", # class from dataset
19
+ "bbox": [{"category": "animal",
20
+ "bbox": [0, 0.347, 0.237, 0.257]}], # ground-truth bbox
21
+ "label": ["monutain_lion"] # labels to use in classifier
22
+ },
23
+ "caltech/cct_images/59f5fe2b-23d2-11e8-a6a3-ec086b02610b.jpg": {
24
+ "dataset": "caltech",
25
+ "location": 13,
26
+ "class": "mountain_lion", # class from dataset
27
+ "label": ["monutain_lion"] # labels to use in classifier
28
+ },
29
+ ...
30
+ }
31
+
32
+ We assume that the tuple (dataset, location) identifies a unique location. In
33
+ other words, we assume that no two datasets have overlapping locations. This
34
+ probably isn't 100% true, but it's pretty much the best we can do in terms of
35
+ avoiding overlapping locations between the train/val/test splits.
36
+
37
+ This script outputs 3 files to <output_dir>:
38
+
39
+ 1) classification_ds.csv, contains columns:
40
+
41
+ - 'path': str, path to cropped images
42
+ - 'dataset': str, name of dataset
43
+ - 'location': str, location that image was taken, as saved in MegaDB
44
+ - 'dataset_class': str, original class assigned to image, as saved in MegaDB
45
+ - 'confidence': float, confidence that this crop is of an actual animal,
46
+ 1.0 if the crop is a "ground truth bounding box" (i.e., from MegaDB),
47
+ <= 1.0 if the bounding box was detected by MegaDetector
48
+ - 'label': str, comma-separated list of label(s) assigned to this crop for
49
+ the sake of classification
50
+
51
+ 2) label_index.json: maps integer to label name
52
+
53
+ - keys are string representations of Python integers (JSON requires keys to
54
+ be strings), numbered from 0 to num_labels-1
55
+ - values are strings, label names
56
+
57
+ 3) splits.json: serialization of a Python dict that maps each split
58
+ ['train', 'val', 'test'] to a list of length-2 lists, where each inner list
59
+ is [<dataset>, <location>]
60
+
61
+ """
62
+
63
+ #%% Imports and constants
64
+
65
+ from __future__ import annotations
66
+
67
+ import argparse
68
+ from collections.abc import Container, MutableMapping
69
+ import json
70
+ import os
71
+ from typing import Optional
72
+
73
+ import numpy as np
74
+ import pandas as pd
75
+ from tqdm import tqdm
76
+
77
+ from megadetector.classification import detect_and_crop
78
+
79
+
80
+ #%% Example usage
81
+
82
+ """
83
+ python create_classification_dataset.py \
84
+ run_idfg2 \
85
+ --queried-images-json run_idfg2/queried_images.json \
86
+ --cropped-images-dir /ssd/crops_sq \
87
+ -d $HOME/classifier-training/mdcache -v "4.1" -t 0.8
88
+ """
89
+
90
+
91
+ DATASET_FILENAME = 'classification_ds.csv'
92
+ LABEL_INDEX_FILENAME = 'label_index.json'
93
+ SPLITS_FILENAME = 'splits.json'
94
+
95
+
96
+ #%% Main function
97
+
98
+ def main(output_dir: str,
99
+ mode: list[str],
100
+ match_test: Optional[list[str]],
101
+ queried_images_json_path: Optional[str],
102
+ cropped_images_dir: Optional[str],
103
+ detector_version: Optional[str],
104
+ detector_output_cache_base_dir: Optional[str],
105
+ confidence_threshold: Optional[float],
106
+ min_locs: Optional[int],
107
+ val_frac: Optional[float],
108
+ test_frac: Optional[float],
109
+ splits_method: Optional[str],
110
+ label_spec_json_path: Optional[str]) -> None:
111
+
112
+ # input validation
113
+ assert set(mode) <= {'csv', 'splits'}
114
+ if label_spec_json_path is not None:
115
+ assert splits_method == 'smallest_first'
116
+
117
+ test_set_locs = None # set of (dataset, location) tuples
118
+ test_set_df = None
119
+ if match_test is not None:
120
+ match_test_csv_path, match_test_splits_path = match_test
121
+ match_df = pd.read_csv(match_test_csv_path, index_col=False,
122
+ float_precision='high')
123
+ with open(match_test_splits_path, 'r') as f:
124
+ match_splits = json.load(f)
125
+ test_set_locs = set((loc[0], loc[1]) for loc in match_splits['test'])
126
+ ds_locs = pd.Series(zip(match_df['dataset'], match_df['location']))
127
+ test_set_df = match_df[ds_locs.isin(test_set_locs)]
128
+
129
+ dataset_path = os.path.join(output_dir, DATASET_FILENAME)
130
+
131
+ if 'csv' in mode:
132
+ assert queried_images_json_path is not None
133
+ assert cropped_images_dir is not None
134
+ assert detector_version is not None
135
+ assert detector_output_cache_base_dir is not None
136
+ assert confidence_threshold is not None
137
+
138
+ if not os.path.exists(output_dir):
139
+ os.makedirs(output_dir)
140
+ print(f'Created {output_dir}')
141
+
142
+ df, log = create_classification_csv(
143
+ queried_images_json_path=queried_images_json_path,
144
+ detector_output_cache_base_dir=detector_output_cache_base_dir,
145
+ detector_version=detector_version,
146
+ cropped_images_dir=cropped_images_dir,
147
+ confidence_threshold=confidence_threshold,
148
+ min_locs=min_locs,
149
+ append_df=test_set_df,
150
+ exclude_locs=test_set_locs)
151
+ print('Saving classification dataset CSV')
152
+ df.to_csv(dataset_path, index=False)
153
+ for msg, img_list in log.items():
154
+ print(f'{msg}:', len(img_list))
155
+
156
+ # create label index JSON
157
+ labels = df['label']
158
+ if any(labels.str.contains(',')):
159
+ print('multi-label!')
160
+ labels = labels.map(lambda x: x.split(',')).explode()
161
+ # look into sklearn.preprocessing.MultiLabelBinarizer
162
+ label_names = sorted(labels.unique())
163
+ with open(os.path.join(output_dir, LABEL_INDEX_FILENAME), 'w') as f:
164
+ # Note: JSON always saves keys as strings!
165
+ json.dump(dict(enumerate(label_names)), f, indent=1)
166
+
167
+ if 'splits' in mode:
168
+ assert splits_method is not None
169
+ assert val_frac is not None
170
+ assert (match_test is None) != (test_frac is None)
171
+ if test_frac is None:
172
+ test_frac = 0
173
+
174
+ print(f'Creating splits using "{splits_method}" method...')
175
+ df = pd.read_csv(dataset_path, index_col=False, float_precision='high')
176
+
177
+ if splits_method == 'random':
178
+ split_to_locs = create_splits_random(
179
+ df, val_frac, test_frac, test_split=test_set_locs)
180
+ else:
181
+ split_to_locs = create_splits_smallest_label_first(
182
+ df, val_frac, test_frac, test_split=test_set_locs,
183
+ label_spec_json_path=label_spec_json_path)
184
+ with open(os.path.join(output_dir, SPLITS_FILENAME), 'w') as f:
185
+ json.dump(split_to_locs, f, indent=1)
186
+
187
+
188
+ #%% Support functions
189
+
190
+ def create_classification_csv(
191
+ queried_images_json_path: str,
192
+ detector_output_cache_base_dir: str,
193
+ detector_version: str,
194
+ cropped_images_dir: str,
195
+ confidence_threshold: float,
196
+ min_locs: Optional[int] = None,
197
+ append_df: Optional[pd.DataFrame] = None,
198
+ exclude_locs: Optional[Container[tuple[str, str]]] = None
199
+ ) -> tuple[pd.DataFrame, dict[str, list]]:
200
+ """
201
+ Creates a classification dataset.
202
+
203
+ The classification dataset is a pd.DataFrame with columns:
204
+ - path: str, <dataset>/<crop-filename>
205
+ - dataset: str, name of camera trap dataset
206
+ - location: str, location of image, provided by the camera trap dataset
207
+ - dataset_class: image class, as provided by the camera trap dataset
208
+ - confidence: float, confidence of bounding box, 1 if ground truth
209
+ - label: str, comma-separated list of classification labels
210
+
211
+ Args:
212
+ queried_images_json_path: str, path to output of json_validator.py
213
+ detector_version: str, detector version string, e.g., '4.1',
214
+ see {batch_detection_api_url}/supported_model_versions,
215
+ determines the subfolder of detector_output_cache_base_dir in
216
+ which to find and save detector outputs
217
+ detector_output_cache_base_dir: str, path to local directory
218
+ where detector outputs are cached, 1 JSON file per dataset
219
+ cropped_images_dir: str, path to local directory for saving crops of
220
+ bounding boxes
221
+ confidence_threshold: float, only crop bounding boxes above this value
222
+ min_locs: optional int, minimum # of locations that each label must
223
+ have in order to be included
224
+ append_df: optional pd.DataFrame, existing DataFrame that is appended to
225
+ the classification CSV
226
+ exclude_locs: optional set of (dataset, location) tuples, crops from
227
+ these locations are excluded (does not affect append_df)
228
+
229
+ Returns:
230
+ df: pd.DataFrame, the classification dataset
231
+ log: dict, with the following keys
232
+ 'images missing detections': list of str, images without ground
233
+ truth bboxes and not in detection cache
234
+ 'images without confident detections': list of str, images in
235
+ detection cache with no bboxes above the confidence threshold
236
+ 'missing crops': list of tuple (img_path, i), where i is the
237
+ i-th crop index
238
+ """
239
+
240
+ assert 0 <= confidence_threshold <= 1
241
+
242
+ columns = [
243
+ 'path', 'dataset', 'location', 'dataset_class', 'confidence', 'label']
244
+ if append_df is not None:
245
+ assert (append_df.columns == columns).all()
246
+
247
+ with open(queried_images_json_path, 'r') as f:
248
+ js = json.load(f)
249
+
250
+ print('loading detection cache...', end='')
251
+ detector_output_cache_dir = os.path.join(
252
+ detector_output_cache_base_dir, f'v{detector_version}')
253
+ datasets = set(img_path[:img_path.find('/')] for img_path in js)
254
+ detection_cache, cat_id_to_name = detect_and_crop.load_detection_cache(
255
+ detector_output_cache_dir=detector_output_cache_dir, datasets=datasets)
256
+ print('done!')
257
+
258
+ missing_detections = [] # no cached detections or ground truth bboxes
259
+ images_no_confident_detections = [] # cached detections contain 0 bboxes
260
+ images_missing_crop = [] # tuples: (img_path, crop_index)
261
+ all_rows = []
262
+
263
+ # True for ground truth, False for MegaDetector
264
+ # always save as .jpg for consistency
265
+ crop_path_template = {
266
+ True: '{img_path}___crop{n:>02d}.jpg',
267
+ False: '{img_path}___crop{n:>02d}_' + f'mdv{detector_version}.jpg'
268
+ }
269
+
270
+ for img_path, img_info in tqdm(js.items()):
271
+ ds, img_file = img_path.split('/', maxsplit=1)
272
+
273
+ # get bounding boxes
274
+ if 'bbox' in img_info: # ground-truth bounding boxes
275
+ bbox_dicts = img_info['bbox']
276
+ is_ground_truth = True
277
+ else: # get bounding boxes from detector cache
278
+ if img_file in detection_cache[ds]:
279
+ bbox_dicts = detection_cache[ds][img_file]['detections']
280
+ # convert from category ID to category name
281
+ for d in bbox_dicts:
282
+ d['category'] = cat_id_to_name[d['category']]
283
+ else:
284
+ missing_detections.append(img_path)
285
+ continue
286
+ is_ground_truth = False
287
+
288
+ # check if crops are already downloaded, and ignore bboxes below the
289
+ # confidence threshold
290
+ rows = []
291
+ for i, bbox_dict in enumerate(bbox_dicts):
292
+ conf = 1.0 if is_ground_truth else bbox_dict['conf']
293
+ if conf < confidence_threshold:
294
+ continue
295
+ if bbox_dict['category'] != 'animal':
296
+ tqdm.write(f'Bbox {i} of {img_path} is non-animal. Skipping.')
297
+ continue
298
+ crop_path = crop_path_template[is_ground_truth].format(
299
+ img_path=img_path, n=i)
300
+ full_crop_path = os.path.join(cropped_images_dir, crop_path)
301
+ if not os.path.exists(full_crop_path):
302
+ images_missing_crop.append((img_path, i))
303
+ else:
304
+ # assign all images without location info to 'unknown_location'
305
+ img_loc = img_info.get('location', 'unknown_location')
306
+ row = [crop_path, ds, img_loc, img_info['class'],
307
+ conf, ','.join(img_info['label'])]
308
+ rows.append(row)
309
+ if len(rows) == 0:
310
+ images_no_confident_detections.append(img_path)
311
+ continue
312
+ all_rows += rows
313
+
314
+ df = pd.DataFrame(data=all_rows, columns=columns)
315
+
316
+ # remove images from labels that have fewer than min_locs locations
317
+ if min_locs is not None:
318
+ nlocs_per_label = df.groupby('label').apply(
319
+ lambda xdf: len(xdf[['dataset', 'location']].drop_duplicates()))
320
+ valid_labels_mask = (nlocs_per_label >= min_locs)
321
+ valid_labels = nlocs_per_label.index[valid_labels_mask]
322
+ invalid_labels = nlocs_per_label.index[~valid_labels_mask]
323
+ orig = len(df)
324
+ df = df[df['label'].isin(valid_labels)]
325
+ print(f'Excluding {orig - len(df)} crops from {len(invalid_labels)} '
326
+ 'labels:', invalid_labels.tolist())
327
+
328
+ if exclude_locs is not None:
329
+ mask = ~pd.Series(zip(df['dataset'], df['location'])).isin(exclude_locs)
330
+ print(f'Excluding {(~mask).sum()} crops from CSV')
331
+ df = df[mask]
332
+ if append_df is not None:
333
+ print(f'Appending {len(append_df)} rows to CSV')
334
+ df = df.append(append_df)
335
+
336
+ log = {
337
+ 'images missing detections': missing_detections,
338
+ 'images without confident detections': images_no_confident_detections,
339
+ 'missing crops': images_missing_crop
340
+ }
341
+ return df, log
342
+
343
+
344
+ def create_splits_random(df: pd.DataFrame, val_frac: float,
345
+ test_frac: float = 0.,
346
+ test_split: Optional[set[tuple[str, str]]] = None,
347
+ ) -> dict[str, list[tuple[str, str]]]:
348
+ """
349
+ Args:
350
+ df: pd.DataFrame, contains columns ['dataset', 'location', 'label']
351
+ each row is a single image
352
+ assumes each image is assigned exactly 1 label
353
+ val_frac: float, desired fraction of dataset to use for val set
354
+ test_frac: float, desired fraction of dataset to use for test set,
355
+ must be 0 if test_split is given
356
+ test_split: optional set of (dataset, location) tuples to use as test
357
+ split
358
+
359
+ Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
360
+ where each loc is a tuple (dataset, location)
361
+ """
362
+
363
+ if test_split is not None:
364
+ assert test_frac == 0
365
+ train_frac = 1. - val_frac - test_frac
366
+ targets = {'train': train_frac, 'val': val_frac, 'test': test_frac}
367
+
368
+ # merge dataset and location into a single string '<dataset>/<location>'
369
+ df['dataset_location'] = df['dataset'] + '/' + df['location']
370
+
371
+ # create DataFrame of counts. rows = locations, columns = labels
372
+ loc_label_counts = (df.groupby(['label', 'dataset_location']).size()
373
+ .unstack('label', fill_value=0))
374
+ num_locs = len(loc_label_counts)
375
+
376
+ # label_count: label => number of examples
377
+ # loc_count: label => number of locs containing that label
378
+ label_count = loc_label_counts.sum()
379
+ loc_count = (loc_label_counts > 0).sum()
380
+
381
+ best_score = np.inf # lower is better
382
+ best_splits = None
383
+ for _ in tqdm(range(10_000)):
384
+
385
+ # generate a new split
386
+ num_train = int(num_locs * (train_frac + np.random.uniform(-.03, .03)))
387
+ if test_frac > 0:
388
+ num_val = int(num_locs * (val_frac + np.random.uniform(-.03, .03)))
389
+ else:
390
+ num_val = num_locs - num_train
391
+ permuted_locs = loc_label_counts.index[np.random.permutation(num_locs)]
392
+ split_to_locs = {'train': permuted_locs[:num_train],
393
+ 'val': permuted_locs[num_train:num_train + num_val]}
394
+ if test_frac > 0:
395
+ split_to_locs['test'] = permuted_locs[num_train + num_val:]
396
+
397
+ # score the split
398
+ score = 0.
399
+ for split, locs in split_to_locs.items():
400
+ split_df = loc_label_counts.loc[locs]
401
+ target = targets[split]
402
+
403
+ # SSE for # of images per label (with 2x weight)
404
+ crop_frac = split_df.sum() / label_count
405
+ score += 2 * ((crop_frac - target) ** 2).sum()
406
+
407
+ # SSE for # of locs per label
408
+ loc_frac = (split_df > 0).sum() / loc_count
409
+ score += ((loc_frac - target) ** 2).sum()
410
+
411
+ if score < best_score:
412
+ tqdm.write(f'New lowest score: {score}')
413
+ best_score = score
414
+ best_splits = split_to_locs
415
+
416
+ assert best_splits is not None
417
+ split_to_locs = {
418
+ s: sorted(locs.map(lambda x: tuple(x.split('/', maxsplit=1))))
419
+ for s, locs in best_splits.items()
420
+ }
421
+ if test_split is not None:
422
+ split_to_locs['test'] = test_split
423
+ return split_to_locs
424
+
425
+
426
+ def create_splits_smallest_label_first(
427
+ df: pd.DataFrame,
428
+ val_frac: float,
429
+ test_frac: float = 0.,
430
+ label_spec_json_path: Optional[str] = None,
431
+ test_split: Optional[set[tuple[str, str]]] = None,
432
+ ) -> dict[str, list[tuple[str, str]]]:
433
+ """
434
+ Args:
435
+ df: pd.DataFrame, contains columns ['dataset', 'location', 'label']
436
+ each row is a single image
437
+ assumes each image is assigned exactly 1 label
438
+ val_frac: float, desired fraction of dataset to use for val set
439
+ test_frac: float, desired fraction of dataset to use for test set,
440
+ must be 0 if test_split is given
441
+ label_spec_json_path: optional str, path to label spec JSON
442
+ test_split: optional set of (dataset, location) tuples to use as test
443
+ split
444
+
445
+ Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
446
+ where each loc is a tuple (dataset, location)
447
+ """
448
+
449
+ # label => list of datasets to prioritize for test and validation sets
450
+ prioritize = {}
451
+ if label_spec_json_path is not None:
452
+ with open(label_spec_json_path, 'r') as f:
453
+ label_spec_js = json.load(f)
454
+ for label, label_spec in label_spec_js.items():
455
+ if 'prioritize' in label_spec:
456
+ datasets = []
457
+ for level in label_spec['prioritize']:
458
+ datasets += level
459
+ prioritize[label] = datasets
460
+
461
+ # merge dataset and location into a tuple (dataset, location)
462
+ df['dataset_location'] = list(zip(df['dataset'], df['location']))
463
+ loc_to_label_sizes = df.groupby(['dataset_location', 'label']).size()
464
+
465
+ seen_locs = set()
466
+ split_to_locs: dict[str, list[tuple[str, str]]] = dict(
467
+ train=[], val=[], test=[])
468
+ label_sizes_by_split = {
469
+ label: dict(train=0, val=0, test=0)
470
+ for label in df['label'].unique()
471
+ }
472
+ if test_split is not None:
473
+ assert test_frac == 0
474
+ split_to_locs['test'] = list(test_split)
475
+ seen_locs.update(test_split)
476
+
477
+ def add_loc_to_split(loc: tuple[str, str], split: str) -> None:
478
+ split_to_locs[split].append(loc)
479
+ for label, label_size in loc_to_label_sizes[loc].items():
480
+ label_sizes_by_split[label][split] += label_size
481
+
482
+ # sorted smallest to largest
483
+ ordered_labels = df.groupby('label').size().sort_values()
484
+ for label, label_size in tqdm(ordered_labels.items()):
485
+
486
+ split_sizes = label_sizes_by_split[label]
487
+ test_thresh = test_frac * label_size
488
+ val_thresh = val_frac * label_size
489
+
490
+ mask = df['label'] == label
491
+ ordered_locs = sort_locs_by_size(
492
+ loc_to_size=df[mask].groupby('dataset_location').size().to_dict(),
493
+ prioritize=prioritize.get(label, None))
494
+ ordered_locs = [loc for loc in ordered_labels if loc not in seen_locs]
495
+
496
+ for loc in ordered_locs:
497
+ seen_locs.add(loc)
498
+ # greedily add to test set until it has >= 15% of images
499
+ if split_sizes['test'] < test_thresh:
500
+ split = 'test'
501
+ elif split_sizes['val'] < val_thresh:
502
+ split = 'val'
503
+ else:
504
+ split = 'train'
505
+ add_loc_to_split(loc, split)
506
+ seen_locs.update(ordered_locs)
507
+
508
+ # sort the resulting locs
509
+ split_to_locs = {s: sorted(locs) for s, locs in split_to_locs.items()}
510
+ return split_to_locs
511
+
512
+
513
+ def sort_locs_by_size(loc_to_size: MutableMapping[tuple[str, str], int],
514
+ prioritize: Optional[Container[str]] = None
515
+ ) -> list[tuple[str, str]]:
516
+ """
517
+ Sorts locations by size, optionally prioritizing locations from certain
518
+ datasets first.
519
+
520
+ Args:
521
+ loc_to_size: dict, maps each (dataset, location) tuple to its size,
522
+ modified in-place
523
+ prioritize: optional list of str, datasets to prioritize
524
+
525
+ Returns: list of (dataset, location) tuples, ordered from smallest size to
526
+ largest. Locations from prioritized datasets come first.
527
+ """
528
+
529
+ result = []
530
+ if prioritize is not None:
531
+ # modify loc_to_size in place, so copy its keys before iterating
532
+ prioritized_loc_to_size = {
533
+ loc: loc_to_size.pop(loc) for loc in list(loc_to_size.keys())
534
+ if loc[0] in prioritize
535
+ }
536
+ result = sort_locs_by_size(prioritized_loc_to_size)
537
+
538
+ result += sorted(loc_to_size, key=loc_to_size.__getitem__)
539
+ return result
540
+
541
+
542
+ #%% Command-line driver
543
+
544
+ def _parse_args() -> argparse.Namespace:
545
+
546
+ parser = argparse.ArgumentParser(
547
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
548
+ description='Creates classification dataset.')
549
+
550
+ # arguments relevant to both creating the dataset CSV and splits.json
551
+ parser.add_argument(
552
+ 'output_dir',
553
+ help='path to directory where the 3 output files should be '
554
+ 'saved: 1) dataset CSV, 2) label index JSON, 3) splits JSON')
555
+ parser.add_argument(
556
+ '--mode', nargs='+', choices=['csv', 'splits'],
557
+ default=['csv', 'splits'],
558
+ help='whether to generate only a CSV, only a splits.json file (based '
559
+ 'on an existing classification_ds.csv), or both')
560
+ parser.add_argument(
561
+ '--match-test', nargs=2, metavar=('CLASSIFICATION_CSV', 'SPLITS_JSON'),
562
+ help='path to an existing classification CSV and path to an existing '
563
+ 'splits JSON file from which to match test set')
564
+
565
+ # arguments only relevant for creating the dataset CSV
566
+ csv_group = parser.add_argument_group(
567
+ 'arguments for creating classification CSV')
568
+ csv_group.add_argument(
569
+ '-q', '--queried-images-json',
570
+ help='path to JSON file containing image paths and classification info')
571
+ csv_group.add_argument(
572
+ '-c', '--cropped-images-dir',
573
+ help='path to local directory for saving crops of bounding boxes')
574
+ csv_group.add_argument(
575
+ '-d', '--detector-output-cache-dir',
576
+ help='(required) path to directory where detector outputs are cached')
577
+ csv_group.add_argument(
578
+ '-v', '--detector-version',
579
+ help='(required) detector version string, e.g., "4.1"')
580
+ csv_group.add_argument(
581
+ '-t', '--threshold', type=float, default=0.8,
582
+ help='confidence threshold above which to crop bounding boxes')
583
+ csv_group.add_argument(
584
+ '--min-locs', type=int,
585
+ help='minimum number of locations that each label must have in order '
586
+ 'to be included (does not apply to match-test-splits)')
587
+
588
+ # arguments only relevant for creating the splits JSON
589
+ splits_group = parser.add_argument_group(
590
+ 'arguments for creating train/val/test splits')
591
+ splits_group.add_argument(
592
+ '--val-frac', type=float,
593
+ help='(required) fraction of data to use for validation split')
594
+ splits_group.add_argument(
595
+ '--test-frac', type=float,
596
+ help='fraction of data to use for test split, must be provided if '
597
+ '--match-test is not given')
598
+ splits_group.add_argument(
599
+ '--method', choices=['random', 'smallest_first'], default='random',
600
+ help='"random": randomly tries up to 10,000 different train/val/test '
601
+ 'splits and chooses the one that best meets the scoring criteria, '
602
+ 'does not support --label-spec. '
603
+ '"smallest_first": greedily divides locations into splits '
604
+ 'starting with the smallest class first. Supports --label-spec.')
605
+ splits_group.add_argument(
606
+ '--label-spec',
607
+ help='optional path to label specification JSON file, if specifying '
608
+ 'dataset priority. Requires --method=smallest_first.')
609
+ return parser.parse_args()
610
+
611
+
612
+ if __name__ == '__main__':
613
+
614
+ args = _parse_args()
615
+ main(output_dir=args.output_dir,
616
+ mode=args.mode,
617
+ match_test=args.match_test,
618
+ queried_images_json_path=args.queried_images_json,
619
+ cropped_images_dir=args.cropped_images_dir,
620
+ detector_version=args.detector_version,
621
+ detector_output_cache_base_dir=args.detector_output_cache_dir,
622
+ confidence_threshold=args.threshold,
623
+ min_locs=args.min_locs,
624
+ val_frac=args.val_frac,
625
+ test_frac=args.test_frac,
626
+ splits_method=args.method,
627
+ label_spec_json_path=args.label_spec)