megadetector 5.0.9__py3-none-any.whl → 5.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (226) hide show
  1. {megadetector-5.0.9.dist-info → megadetector-5.0.11.dist-info}/LICENSE +0 -0
  2. {megadetector-5.0.9.dist-info → megadetector-5.0.11.dist-info}/METADATA +12 -11
  3. megadetector-5.0.11.dist-info/RECORD +5 -0
  4. megadetector-5.0.11.dist-info/top_level.txt +1 -0
  5. api/__init__.py +0 -0
  6. api/batch_processing/__init__.py +0 -0
  7. api/batch_processing/api_core/__init__.py +0 -0
  8. api/batch_processing/api_core/batch_service/__init__.py +0 -0
  9. api/batch_processing/api_core/batch_service/score.py +0 -439
  10. api/batch_processing/api_core/server.py +0 -294
  11. api/batch_processing/api_core/server_api_config.py +0 -98
  12. api/batch_processing/api_core/server_app_config.py +0 -55
  13. api/batch_processing/api_core/server_batch_job_manager.py +0 -220
  14. api/batch_processing/api_core/server_job_status_table.py +0 -152
  15. api/batch_processing/api_core/server_orchestration.py +0 -360
  16. api/batch_processing/api_core/server_utils.py +0 -92
  17. api/batch_processing/api_core_support/__init__.py +0 -0
  18. api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
  19. api/batch_processing/api_support/__init__.py +0 -0
  20. api/batch_processing/api_support/summarize_daily_activity.py +0 -152
  21. api/batch_processing/data_preparation/__init__.py +0 -0
  22. api/batch_processing/data_preparation/manage_local_batch.py +0 -2391
  23. api/batch_processing/data_preparation/manage_video_batch.py +0 -327
  24. api/batch_processing/integration/digiKam/setup.py +0 -6
  25. api/batch_processing/integration/digiKam/xmp_integration.py +0 -465
  26. api/batch_processing/integration/eMammal/test_scripts/config_template.py +0 -5
  27. api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +0 -126
  28. api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +0 -55
  29. api/batch_processing/postprocessing/__init__.py +0 -0
  30. api/batch_processing/postprocessing/add_max_conf.py +0 -64
  31. api/batch_processing/postprocessing/categorize_detections_by_size.py +0 -163
  32. api/batch_processing/postprocessing/combine_api_outputs.py +0 -249
  33. api/batch_processing/postprocessing/compare_batch_results.py +0 -958
  34. api/batch_processing/postprocessing/convert_output_format.py +0 -397
  35. api/batch_processing/postprocessing/load_api_results.py +0 -195
  36. api/batch_processing/postprocessing/md_to_coco.py +0 -310
  37. api/batch_processing/postprocessing/md_to_labelme.py +0 -330
  38. api/batch_processing/postprocessing/merge_detections.py +0 -401
  39. api/batch_processing/postprocessing/postprocess_batch_results.py +0 -1904
  40. api/batch_processing/postprocessing/remap_detection_categories.py +0 -170
  41. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +0 -661
  42. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +0 -211
  43. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +0 -82
  44. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +0 -1631
  45. api/batch_processing/postprocessing/separate_detections_into_folders.py +0 -731
  46. api/batch_processing/postprocessing/subset_json_detector_output.py +0 -696
  47. api/batch_processing/postprocessing/top_folders_to_bottom.py +0 -223
  48. api/synchronous/__init__.py +0 -0
  49. api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  50. api/synchronous/api_core/animal_detection_api/api_backend.py +0 -152
  51. api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -266
  52. api/synchronous/api_core/animal_detection_api/config.py +0 -35
  53. api/synchronous/api_core/animal_detection_api/data_management/annotations/annotation_constants.py +0 -47
  54. api/synchronous/api_core/animal_detection_api/detection/detector_training/copy_checkpoints.py +0 -43
  55. api/synchronous/api_core/animal_detection_api/detection/detector_training/model_main_tf2.py +0 -114
  56. api/synchronous/api_core/animal_detection_api/detection/process_video.py +0 -543
  57. api/synchronous/api_core/animal_detection_api/detection/pytorch_detector.py +0 -304
  58. api/synchronous/api_core/animal_detection_api/detection/run_detector.py +0 -627
  59. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +0 -1029
  60. api/synchronous/api_core/animal_detection_api/detection/run_inference_with_yolov5_val.py +0 -581
  61. api/synchronous/api_core/animal_detection_api/detection/run_tiled_inference.py +0 -754
  62. api/synchronous/api_core/animal_detection_api/detection/tf_detector.py +0 -165
  63. api/synchronous/api_core/animal_detection_api/detection/video_utils.py +0 -495
  64. api/synchronous/api_core/animal_detection_api/md_utils/azure_utils.py +0 -174
  65. api/synchronous/api_core/animal_detection_api/md_utils/ct_utils.py +0 -262
  66. api/synchronous/api_core/animal_detection_api/md_utils/directory_listing.py +0 -251
  67. api/synchronous/api_core/animal_detection_api/md_utils/matlab_porting_tools.py +0 -97
  68. api/synchronous/api_core/animal_detection_api/md_utils/path_utils.py +0 -416
  69. api/synchronous/api_core/animal_detection_api/md_utils/process_utils.py +0 -110
  70. api/synchronous/api_core/animal_detection_api/md_utils/sas_blob_utils.py +0 -509
  71. api/synchronous/api_core/animal_detection_api/md_utils/string_utils.py +0 -59
  72. api/synchronous/api_core/animal_detection_api/md_utils/url_utils.py +0 -144
  73. api/synchronous/api_core/animal_detection_api/md_utils/write_html_image_list.py +0 -226
  74. api/synchronous/api_core/animal_detection_api/md_visualization/visualization_utils.py +0 -841
  75. api/synchronous/api_core/tests/__init__.py +0 -0
  76. api/synchronous/api_core/tests/load_test.py +0 -110
  77. classification/__init__.py +0 -0
  78. classification/aggregate_classifier_probs.py +0 -108
  79. classification/analyze_failed_images.py +0 -227
  80. classification/cache_batchapi_outputs.py +0 -198
  81. classification/create_classification_dataset.py +0 -627
  82. classification/crop_detections.py +0 -516
  83. classification/csv_to_json.py +0 -226
  84. classification/detect_and_crop.py +0 -855
  85. classification/efficientnet/__init__.py +0 -9
  86. classification/efficientnet/model.py +0 -415
  87. classification/efficientnet/utils.py +0 -610
  88. classification/evaluate_model.py +0 -520
  89. classification/identify_mislabeled_candidates.py +0 -152
  90. classification/json_to_azcopy_list.py +0 -63
  91. classification/json_validator.py +0 -695
  92. classification/map_classification_categories.py +0 -276
  93. classification/merge_classification_detection_output.py +0 -506
  94. classification/prepare_classification_script.py +0 -194
  95. classification/prepare_classification_script_mc.py +0 -228
  96. classification/run_classifier.py +0 -286
  97. classification/save_mislabeled.py +0 -110
  98. classification/train_classifier.py +0 -825
  99. classification/train_classifier_tf.py +0 -724
  100. classification/train_utils.py +0 -322
  101. data_management/__init__.py +0 -0
  102. data_management/annotations/__init__.py +0 -0
  103. data_management/annotations/annotation_constants.py +0 -34
  104. data_management/camtrap_dp_to_coco.py +0 -238
  105. data_management/cct_json_utils.py +0 -395
  106. data_management/cct_to_md.py +0 -176
  107. data_management/cct_to_wi.py +0 -289
  108. data_management/coco_to_labelme.py +0 -272
  109. data_management/coco_to_yolo.py +0 -662
  110. data_management/databases/__init__.py +0 -0
  111. data_management/databases/add_width_and_height_to_db.py +0 -33
  112. data_management/databases/combine_coco_camera_traps_files.py +0 -206
  113. data_management/databases/integrity_check_json_db.py +0 -477
  114. data_management/databases/subset_json_db.py +0 -115
  115. data_management/generate_crops_from_cct.py +0 -149
  116. data_management/get_image_sizes.py +0 -188
  117. data_management/importers/add_nacti_sizes.py +0 -52
  118. data_management/importers/add_timestamps_to_icct.py +0 -79
  119. data_management/importers/animl_results_to_md_results.py +0 -158
  120. data_management/importers/auckland_doc_test_to_json.py +0 -372
  121. data_management/importers/auckland_doc_to_json.py +0 -200
  122. data_management/importers/awc_to_json.py +0 -189
  123. data_management/importers/bellevue_to_json.py +0 -273
  124. data_management/importers/cacophony-thermal-importer.py +0 -796
  125. data_management/importers/carrizo_shrubfree_2018.py +0 -268
  126. data_management/importers/carrizo_trail_cam_2017.py +0 -287
  127. data_management/importers/cct_field_adjustments.py +0 -57
  128. data_management/importers/channel_islands_to_cct.py +0 -913
  129. data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
  130. data_management/importers/eMammal/eMammal_helpers.py +0 -249
  131. data_management/importers/eMammal/make_eMammal_json.py +0 -223
  132. data_management/importers/ena24_to_json.py +0 -275
  133. data_management/importers/filenames_to_json.py +0 -385
  134. data_management/importers/helena_to_cct.py +0 -282
  135. data_management/importers/idaho-camera-traps.py +0 -1407
  136. data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
  137. data_management/importers/jb_csv_to_json.py +0 -150
  138. data_management/importers/mcgill_to_json.py +0 -250
  139. data_management/importers/missouri_to_json.py +0 -489
  140. data_management/importers/nacti_fieldname_adjustments.py +0 -79
  141. data_management/importers/noaa_seals_2019.py +0 -181
  142. data_management/importers/pc_to_json.py +0 -365
  143. data_management/importers/plot_wni_giraffes.py +0 -123
  144. data_management/importers/prepare-noaa-fish-data-for-lila.py +0 -359
  145. data_management/importers/prepare_zsl_imerit.py +0 -131
  146. data_management/importers/rspb_to_json.py +0 -356
  147. data_management/importers/save_the_elephants_survey_A.py +0 -320
  148. data_management/importers/save_the_elephants_survey_B.py +0 -332
  149. data_management/importers/snapshot_safari_importer.py +0 -758
  150. data_management/importers/snapshot_safari_importer_reprise.py +0 -665
  151. data_management/importers/snapshot_serengeti_lila.py +0 -1067
  152. data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
  153. data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
  154. data_management/importers/sulross_get_exif.py +0 -65
  155. data_management/importers/timelapse_csv_set_to_json.py +0 -490
  156. data_management/importers/ubc_to_json.py +0 -399
  157. data_management/importers/umn_to_json.py +0 -507
  158. data_management/importers/wellington_to_json.py +0 -263
  159. data_management/importers/wi_to_json.py +0 -441
  160. data_management/importers/zamba_results_to_md_results.py +0 -181
  161. data_management/labelme_to_coco.py +0 -548
  162. data_management/labelme_to_yolo.py +0 -272
  163. data_management/lila/__init__.py +0 -0
  164. data_management/lila/add_locations_to_island_camera_traps.py +0 -97
  165. data_management/lila/add_locations_to_nacti.py +0 -147
  166. data_management/lila/create_lila_blank_set.py +0 -557
  167. data_management/lila/create_lila_test_set.py +0 -151
  168. data_management/lila/create_links_to_md_results_files.py +0 -106
  169. data_management/lila/download_lila_subset.py +0 -177
  170. data_management/lila/generate_lila_per_image_labels.py +0 -515
  171. data_management/lila/get_lila_annotation_counts.py +0 -170
  172. data_management/lila/get_lila_image_counts.py +0 -111
  173. data_management/lila/lila_common.py +0 -300
  174. data_management/lila/test_lila_metadata_urls.py +0 -132
  175. data_management/ocr_tools.py +0 -874
  176. data_management/read_exif.py +0 -681
  177. data_management/remap_coco_categories.py +0 -84
  178. data_management/remove_exif.py +0 -66
  179. data_management/resize_coco_dataset.py +0 -189
  180. data_management/wi_download_csv_to_coco.py +0 -246
  181. data_management/yolo_output_to_md_output.py +0 -441
  182. data_management/yolo_to_coco.py +0 -676
  183. detection/__init__.py +0 -0
  184. detection/detector_training/__init__.py +0 -0
  185. detection/detector_training/model_main_tf2.py +0 -114
  186. detection/process_video.py +0 -703
  187. detection/pytorch_detector.py +0 -337
  188. detection/run_detector.py +0 -779
  189. detection/run_detector_batch.py +0 -1219
  190. detection/run_inference_with_yolov5_val.py +0 -917
  191. detection/run_tiled_inference.py +0 -935
  192. detection/tf_detector.py +0 -188
  193. detection/video_utils.py +0 -606
  194. docs/source/conf.py +0 -43
  195. md_utils/__init__.py +0 -0
  196. md_utils/azure_utils.py +0 -174
  197. md_utils/ct_utils.py +0 -612
  198. md_utils/directory_listing.py +0 -246
  199. md_utils/md_tests.py +0 -968
  200. md_utils/path_utils.py +0 -1044
  201. md_utils/process_utils.py +0 -157
  202. md_utils/sas_blob_utils.py +0 -509
  203. md_utils/split_locations_into_train_val.py +0 -228
  204. md_utils/string_utils.py +0 -92
  205. md_utils/url_utils.py +0 -323
  206. md_utils/write_html_image_list.py +0 -225
  207. md_visualization/__init__.py +0 -0
  208. md_visualization/plot_utils.py +0 -293
  209. md_visualization/render_images_with_thumbnails.py +0 -275
  210. md_visualization/visualization_utils.py +0 -1537
  211. md_visualization/visualize_db.py +0 -551
  212. md_visualization/visualize_detector_output.py +0 -406
  213. megadetector-5.0.9.dist-info/RECORD +0 -224
  214. megadetector-5.0.9.dist-info/top_level.txt +0 -8
  215. taxonomy_mapping/__init__.py +0 -0
  216. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +0 -491
  217. taxonomy_mapping/map_new_lila_datasets.py +0 -154
  218. taxonomy_mapping/prepare_lila_taxonomy_release.py +0 -142
  219. taxonomy_mapping/preview_lila_taxonomy.py +0 -591
  220. taxonomy_mapping/retrieve_sample_image.py +0 -71
  221. taxonomy_mapping/simple_image_download.py +0 -218
  222. taxonomy_mapping/species_lookup.py +0 -834
  223. taxonomy_mapping/taxonomy_csv_checker.py +0 -159
  224. taxonomy_mapping/taxonomy_graph.py +0 -346
  225. taxonomy_mapping/validate_lila_category_mappings.py +0 -83
  226. {megadetector-5.0.9.dist-info → megadetector-5.0.11.dist-info}/WHEEL +0 -0
@@ -1,627 +0,0 @@
1
- """
2
-
3
- create_classification_dataset.py
4
-
5
- Creates a classification dataset CSV with a corresponding JSON file determining
6
- the train/val/test split.
7
-
8
- This script takes as input a "queried images" JSON file whose keys are paths to
9
- images and values are dictionaries containing information relevant for training
10
- a classifier, including labels and (optionally) ground-truth bounding boxes.
11
- The image paths are in the format `<dataset-name>/<blob-name>` where we assume
12
- that the dataset name does not contain '/'.
13
-
14
- {
15
- "caltech/cct_images/59f79901-23d2-11e8-a6a3-ec086b02610b.jpg": {
16
- "dataset": "caltech",
17
- "location": 13,
18
- "class": "mountain_lion", # class from dataset
19
- "bbox": [{"category": "animal",
20
- "bbox": [0, 0.347, 0.237, 0.257]}], # ground-truth bbox
21
- "label": ["monutain_lion"] # labels to use in classifier
22
- },
23
- "caltech/cct_images/59f5fe2b-23d2-11e8-a6a3-ec086b02610b.jpg": {
24
- "dataset": "caltech",
25
- "location": 13,
26
- "class": "mountain_lion", # class from dataset
27
- "label": ["monutain_lion"] # labels to use in classifier
28
- },
29
- ...
30
- }
31
-
32
- We assume that the tuple (dataset, location) identifies a unique location. In
33
- other words, we assume that no two datasets have overlapping locations. This
34
- probably isn't 100% true, but it's pretty much the best we can do in terms of
35
- avoiding overlapping locations between the train/val/test splits.
36
-
37
- This script outputs 3 files to <output_dir>:
38
-
39
- 1) classification_ds.csv, contains columns:
40
-
41
- - 'path': str, path to cropped images
42
- - 'dataset': str, name of dataset
43
- - 'location': str, location that image was taken, as saved in MegaDB
44
- - 'dataset_class': str, original class assigned to image, as saved in MegaDB
45
- - 'confidence': float, confidence that this crop is of an actual animal,
46
- 1.0 if the crop is a "ground truth bounding box" (i.e., from MegaDB),
47
- <= 1.0 if the bounding box was detected by MegaDetector
48
- - 'label': str, comma-separated list of label(s) assigned to this crop for
49
- the sake of classification
50
-
51
- 2) label_index.json: maps integer to label name
52
-
53
- - keys are string representations of Python integers (JSON requires keys to
54
- be strings), numbered from 0 to num_labels-1
55
- - values are strings, label names
56
-
57
- 3) splits.json: serialization of a Python dict that maps each split
58
- ['train', 'val', 'test'] to a list of length-2 lists, where each inner list
59
- is [<dataset>, <location>]
60
-
61
- """
62
-
63
- #%% Imports and constants
64
-
65
- from __future__ import annotations
66
-
67
- import argparse
68
- from collections.abc import Container, MutableMapping
69
- import json
70
- import os
71
- from typing import Optional
72
-
73
- import numpy as np
74
- import pandas as pd
75
- from tqdm import tqdm
76
-
77
- from classification import detect_and_crop
78
-
79
-
80
- #%% Example usage
81
-
82
- """
83
- python create_classification_dataset.py \
84
- run_idfg2 \
85
- --queried-images-json run_idfg2/queried_images.json \
86
- --cropped-images-dir /ssd/crops_sq \
87
- -d $HOME/classifier-training/mdcache -v "4.1" -t 0.8
88
- """
89
-
90
-
91
- DATASET_FILENAME = 'classification_ds.csv'
92
- LABEL_INDEX_FILENAME = 'label_index.json'
93
- SPLITS_FILENAME = 'splits.json'
94
-
95
-
96
- #%% Main function
97
-
98
- def main(output_dir: str,
99
- mode: list[str],
100
- match_test: Optional[list[str]],
101
- queried_images_json_path: Optional[str],
102
- cropped_images_dir: Optional[str],
103
- detector_version: Optional[str],
104
- detector_output_cache_base_dir: Optional[str],
105
- confidence_threshold: Optional[float],
106
- min_locs: Optional[int],
107
- val_frac: Optional[float],
108
- test_frac: Optional[float],
109
- splits_method: Optional[str],
110
- label_spec_json_path: Optional[str]) -> None:
111
-
112
- # input validation
113
- assert set(mode) <= {'csv', 'splits'}
114
- if label_spec_json_path is not None:
115
- assert splits_method == 'smallest_first'
116
-
117
- test_set_locs = None # set of (dataset, location) tuples
118
- test_set_df = None
119
- if match_test is not None:
120
- match_test_csv_path, match_test_splits_path = match_test
121
- match_df = pd.read_csv(match_test_csv_path, index_col=False,
122
- float_precision='high')
123
- with open(match_test_splits_path, 'r') as f:
124
- match_splits = json.load(f)
125
- test_set_locs = set((loc[0], loc[1]) for loc in match_splits['test'])
126
- ds_locs = pd.Series(zip(match_df['dataset'], match_df['location']))
127
- test_set_df = match_df[ds_locs.isin(test_set_locs)]
128
-
129
- dataset_path = os.path.join(output_dir, DATASET_FILENAME)
130
-
131
- if 'csv' in mode:
132
- assert queried_images_json_path is not None
133
- assert cropped_images_dir is not None
134
- assert detector_version is not None
135
- assert detector_output_cache_base_dir is not None
136
- assert confidence_threshold is not None
137
-
138
- if not os.path.exists(output_dir):
139
- os.makedirs(output_dir)
140
- print(f'Created {output_dir}')
141
-
142
- df, log = create_classification_csv(
143
- queried_images_json_path=queried_images_json_path,
144
- detector_output_cache_base_dir=detector_output_cache_base_dir,
145
- detector_version=detector_version,
146
- cropped_images_dir=cropped_images_dir,
147
- confidence_threshold=confidence_threshold,
148
- min_locs=min_locs,
149
- append_df=test_set_df,
150
- exclude_locs=test_set_locs)
151
- print('Saving classification dataset CSV')
152
- df.to_csv(dataset_path, index=False)
153
- for msg, img_list in log.items():
154
- print(f'{msg}:', len(img_list))
155
-
156
- # create label index JSON
157
- labels = df['label']
158
- if any(labels.str.contains(',')):
159
- print('multi-label!')
160
- labels = labels.map(lambda x: x.split(',')).explode()
161
- # look into sklearn.preprocessing.MultiLabelBinarizer
162
- label_names = sorted(labels.unique())
163
- with open(os.path.join(output_dir, LABEL_INDEX_FILENAME), 'w') as f:
164
- # Note: JSON always saves keys as strings!
165
- json.dump(dict(enumerate(label_names)), f, indent=1)
166
-
167
- if 'splits' in mode:
168
- assert splits_method is not None
169
- assert val_frac is not None
170
- assert (match_test is None) != (test_frac is None)
171
- if test_frac is None:
172
- test_frac = 0
173
-
174
- print(f'Creating splits using "{splits_method}" method...')
175
- df = pd.read_csv(dataset_path, index_col=False, float_precision='high')
176
-
177
- if splits_method == 'random':
178
- split_to_locs = create_splits_random(
179
- df, val_frac, test_frac, test_split=test_set_locs)
180
- else:
181
- split_to_locs = create_splits_smallest_label_first(
182
- df, val_frac, test_frac, test_split=test_set_locs,
183
- label_spec_json_path=label_spec_json_path)
184
- with open(os.path.join(output_dir, SPLITS_FILENAME), 'w') as f:
185
- json.dump(split_to_locs, f, indent=1)
186
-
187
-
188
- #%% Support functions
189
-
190
- def create_classification_csv(
191
- queried_images_json_path: str,
192
- detector_output_cache_base_dir: str,
193
- detector_version: str,
194
- cropped_images_dir: str,
195
- confidence_threshold: float,
196
- min_locs: Optional[int] = None,
197
- append_df: Optional[pd.DataFrame] = None,
198
- exclude_locs: Optional[Container[tuple[str, str]]] = None
199
- ) -> tuple[pd.DataFrame, dict[str, list]]:
200
- """
201
- Creates a classification dataset.
202
-
203
- The classification dataset is a pd.DataFrame with columns:
204
- - path: str, <dataset>/<crop-filename>
205
- - dataset: str, name of camera trap dataset
206
- - location: str, location of image, provided by the camera trap dataset
207
- - dataset_class: image class, as provided by the camera trap dataset
208
- - confidence: float, confidence of bounding box, 1 if ground truth
209
- - label: str, comma-separated list of classification labels
210
-
211
- Args:
212
- queried_images_json_path: str, path to output of json_validator.py
213
- detector_version: str, detector version string, e.g., '4.1',
214
- see {batch_detection_api_url}/supported_model_versions,
215
- determines the subfolder of detector_output_cache_base_dir in
216
- which to find and save detector outputs
217
- detector_output_cache_base_dir: str, path to local directory
218
- where detector outputs are cached, 1 JSON file per dataset
219
- cropped_images_dir: str, path to local directory for saving crops of
220
- bounding boxes
221
- confidence_threshold: float, only crop bounding boxes above this value
222
- min_locs: optional int, minimum # of locations that each label must
223
- have in order to be included
224
- append_df: optional pd.DataFrame, existing DataFrame that is appended to
225
- the classification CSV
226
- exclude_locs: optional set of (dataset, location) tuples, crops from
227
- these locations are excluded (does not affect append_df)
228
-
229
- Returns:
230
- df: pd.DataFrame, the classification dataset
231
- log: dict, with the following keys
232
- 'images missing detections': list of str, images without ground
233
- truth bboxes and not in detection cache
234
- 'images without confident detections': list of str, images in
235
- detection cache with no bboxes above the confidence threshold
236
- 'missing crops': list of tuple (img_path, i), where i is the
237
- i-th crop index
238
- """
239
-
240
- assert 0 <= confidence_threshold <= 1
241
-
242
- columns = [
243
- 'path', 'dataset', 'location', 'dataset_class', 'confidence', 'label']
244
- if append_df is not None:
245
- assert (append_df.columns == columns).all()
246
-
247
- with open(queried_images_json_path, 'r') as f:
248
- js = json.load(f)
249
-
250
- print('loading detection cache...', end='')
251
- detector_output_cache_dir = os.path.join(
252
- detector_output_cache_base_dir, f'v{detector_version}')
253
- datasets = set(img_path[:img_path.find('/')] for img_path in js)
254
- detection_cache, cat_id_to_name = detect_and_crop.load_detection_cache(
255
- detector_output_cache_dir=detector_output_cache_dir, datasets=datasets)
256
- print('done!')
257
-
258
- missing_detections = [] # no cached detections or ground truth bboxes
259
- images_no_confident_detections = [] # cached detections contain 0 bboxes
260
- images_missing_crop = [] # tuples: (img_path, crop_index)
261
- all_rows = []
262
-
263
- # True for ground truth, False for MegaDetector
264
- # always save as .jpg for consistency
265
- crop_path_template = {
266
- True: '{img_path}___crop{n:>02d}.jpg',
267
- False: '{img_path}___crop{n:>02d}_' + f'mdv{detector_version}.jpg'
268
- }
269
-
270
- for img_path, img_info in tqdm(js.items()):
271
- ds, img_file = img_path.split('/', maxsplit=1)
272
-
273
- # get bounding boxes
274
- if 'bbox' in img_info: # ground-truth bounding boxes
275
- bbox_dicts = img_info['bbox']
276
- is_ground_truth = True
277
- else: # get bounding boxes from detector cache
278
- if img_file in detection_cache[ds]:
279
- bbox_dicts = detection_cache[ds][img_file]['detections']
280
- # convert from category ID to category name
281
- for d in bbox_dicts:
282
- d['category'] = cat_id_to_name[d['category']]
283
- else:
284
- missing_detections.append(img_path)
285
- continue
286
- is_ground_truth = False
287
-
288
- # check if crops are already downloaded, and ignore bboxes below the
289
- # confidence threshold
290
- rows = []
291
- for i, bbox_dict in enumerate(bbox_dicts):
292
- conf = 1.0 if is_ground_truth else bbox_dict['conf']
293
- if conf < confidence_threshold:
294
- continue
295
- if bbox_dict['category'] != 'animal':
296
- tqdm.write(f'Bbox {i} of {img_path} is non-animal. Skipping.')
297
- continue
298
- crop_path = crop_path_template[is_ground_truth].format(
299
- img_path=img_path, n=i)
300
- full_crop_path = os.path.join(cropped_images_dir, crop_path)
301
- if not os.path.exists(full_crop_path):
302
- images_missing_crop.append((img_path, i))
303
- else:
304
- # assign all images without location info to 'unknown_location'
305
- img_loc = img_info.get('location', 'unknown_location')
306
- row = [crop_path, ds, img_loc, img_info['class'],
307
- conf, ','.join(img_info['label'])]
308
- rows.append(row)
309
- if len(rows) == 0:
310
- images_no_confident_detections.append(img_path)
311
- continue
312
- all_rows += rows
313
-
314
- df = pd.DataFrame(data=all_rows, columns=columns)
315
-
316
- # remove images from labels that have fewer than min_locs locations
317
- if min_locs is not None:
318
- nlocs_per_label = df.groupby('label').apply(
319
- lambda xdf: len(xdf[['dataset', 'location']].drop_duplicates()))
320
- valid_labels_mask = (nlocs_per_label >= min_locs)
321
- valid_labels = nlocs_per_label.index[valid_labels_mask]
322
- invalid_labels = nlocs_per_label.index[~valid_labels_mask]
323
- orig = len(df)
324
- df = df[df['label'].isin(valid_labels)]
325
- print(f'Excluding {orig - len(df)} crops from {len(invalid_labels)} '
326
- 'labels:', invalid_labels.tolist())
327
-
328
- if exclude_locs is not None:
329
- mask = ~pd.Series(zip(df['dataset'], df['location'])).isin(exclude_locs)
330
- print(f'Excluding {(~mask).sum()} crops from CSV')
331
- df = df[mask]
332
- if append_df is not None:
333
- print(f'Appending {len(append_df)} rows to CSV')
334
- df = df.append(append_df)
335
-
336
- log = {
337
- 'images missing detections': missing_detections,
338
- 'images without confident detections': images_no_confident_detections,
339
- 'missing crops': images_missing_crop
340
- }
341
- return df, log
342
-
343
-
344
- def create_splits_random(df: pd.DataFrame, val_frac: float,
345
- test_frac: float = 0.,
346
- test_split: Optional[set[tuple[str, str]]] = None,
347
- ) -> dict[str, list[tuple[str, str]]]:
348
- """
349
- Args:
350
- df: pd.DataFrame, contains columns ['dataset', 'location', 'label']
351
- each row is a single image
352
- assumes each image is assigned exactly 1 label
353
- val_frac: float, desired fraction of dataset to use for val set
354
- test_frac: float, desired fraction of dataset to use for test set,
355
- must be 0 if test_split is given
356
- test_split: optional set of (dataset, location) tuples to use as test
357
- split
358
-
359
- Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
360
- where each loc is a tuple (dataset, location)
361
- """
362
-
363
- if test_split is not None:
364
- assert test_frac == 0
365
- train_frac = 1. - val_frac - test_frac
366
- targets = {'train': train_frac, 'val': val_frac, 'test': test_frac}
367
-
368
- # merge dataset and location into a single string '<dataset>/<location>'
369
- df['dataset_location'] = df['dataset'] + '/' + df['location']
370
-
371
- # create DataFrame of counts. rows = locations, columns = labels
372
- loc_label_counts = (df.groupby(['label', 'dataset_location']).size()
373
- .unstack('label', fill_value=0))
374
- num_locs = len(loc_label_counts)
375
-
376
- # label_count: label => number of examples
377
- # loc_count: label => number of locs containing that label
378
- label_count = loc_label_counts.sum()
379
- loc_count = (loc_label_counts > 0).sum()
380
-
381
- best_score = np.inf # lower is better
382
- best_splits = None
383
- for _ in tqdm(range(10_000)):
384
-
385
- # generate a new split
386
- num_train = int(num_locs * (train_frac + np.random.uniform(-.03, .03)))
387
- if test_frac > 0:
388
- num_val = int(num_locs * (val_frac + np.random.uniform(-.03, .03)))
389
- else:
390
- num_val = num_locs - num_train
391
- permuted_locs = loc_label_counts.index[np.random.permutation(num_locs)]
392
- split_to_locs = {'train': permuted_locs[:num_train],
393
- 'val': permuted_locs[num_train:num_train + num_val]}
394
- if test_frac > 0:
395
- split_to_locs['test'] = permuted_locs[num_train + num_val:]
396
-
397
- # score the split
398
- score = 0.
399
- for split, locs in split_to_locs.items():
400
- split_df = loc_label_counts.loc[locs]
401
- target = targets[split]
402
-
403
- # SSE for # of images per label (with 2x weight)
404
- crop_frac = split_df.sum() / label_count
405
- score += 2 * ((crop_frac - target) ** 2).sum()
406
-
407
- # SSE for # of locs per label
408
- loc_frac = (split_df > 0).sum() / loc_count
409
- score += ((loc_frac - target) ** 2).sum()
410
-
411
- if score < best_score:
412
- tqdm.write(f'New lowest score: {score}')
413
- best_score = score
414
- best_splits = split_to_locs
415
-
416
- assert best_splits is not None
417
- split_to_locs = {
418
- s: sorted(locs.map(lambda x: tuple(x.split('/', maxsplit=1))))
419
- for s, locs in best_splits.items()
420
- }
421
- if test_split is not None:
422
- split_to_locs['test'] = test_split
423
- return split_to_locs
424
-
425
-
426
- def create_splits_smallest_label_first(
427
- df: pd.DataFrame,
428
- val_frac: float,
429
- test_frac: float = 0.,
430
- label_spec_json_path: Optional[str] = None,
431
- test_split: Optional[set[tuple[str, str]]] = None,
432
- ) -> dict[str, list[tuple[str, str]]]:
433
- """
434
- Args:
435
- df: pd.DataFrame, contains columns ['dataset', 'location', 'label']
436
- each row is a single image
437
- assumes each image is assigned exactly 1 label
438
- val_frac: float, desired fraction of dataset to use for val set
439
- test_frac: float, desired fraction of dataset to use for test set,
440
- must be 0 if test_split is given
441
- label_spec_json_path: optional str, path to label spec JSON
442
- test_split: optional set of (dataset, location) tuples to use as test
443
- split
444
-
445
- Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
446
- where each loc is a tuple (dataset, location)
447
- """
448
-
449
- # label => list of datasets to prioritize for test and validation sets
450
- prioritize = {}
451
- if label_spec_json_path is not None:
452
- with open(label_spec_json_path, 'r') as f:
453
- label_spec_js = json.load(f)
454
- for label, label_spec in label_spec_js.items():
455
- if 'prioritize' in label_spec:
456
- datasets = []
457
- for level in label_spec['prioritize']:
458
- datasets += level
459
- prioritize[label] = datasets
460
-
461
- # merge dataset and location into a tuple (dataset, location)
462
- df['dataset_location'] = list(zip(df['dataset'], df['location']))
463
- loc_to_label_sizes = df.groupby(['dataset_location', 'label']).size()
464
-
465
- seen_locs = set()
466
- split_to_locs: dict[str, list[tuple[str, str]]] = dict(
467
- train=[], val=[], test=[])
468
- label_sizes_by_split = {
469
- label: dict(train=0, val=0, test=0)
470
- for label in df['label'].unique()
471
- }
472
- if test_split is not None:
473
- assert test_frac == 0
474
- split_to_locs['test'] = list(test_split)
475
- seen_locs.update(test_split)
476
-
477
- def add_loc_to_split(loc: tuple[str, str], split: str) -> None:
478
- split_to_locs[split].append(loc)
479
- for label, label_size in loc_to_label_sizes[loc].items():
480
- label_sizes_by_split[label][split] += label_size
481
-
482
- # sorted smallest to largest
483
- ordered_labels = df.groupby('label').size().sort_values()
484
- for label, label_size in tqdm(ordered_labels.items()):
485
-
486
- split_sizes = label_sizes_by_split[label]
487
- test_thresh = test_frac * label_size
488
- val_thresh = val_frac * label_size
489
-
490
- mask = df['label'] == label
491
- ordered_locs = sort_locs_by_size(
492
- loc_to_size=df[mask].groupby('dataset_location').size().to_dict(),
493
- prioritize=prioritize.get(label, None))
494
- ordered_locs = [loc for loc in ordered_labels if loc not in seen_locs]
495
-
496
- for loc in ordered_locs:
497
- seen_locs.add(loc)
498
- # greedily add to test set until it has >= 15% of images
499
- if split_sizes['test'] < test_thresh:
500
- split = 'test'
501
- elif split_sizes['val'] < val_thresh:
502
- split = 'val'
503
- else:
504
- split = 'train'
505
- add_loc_to_split(loc, split)
506
- seen_locs.update(ordered_locs)
507
-
508
- # sort the resulting locs
509
- split_to_locs = {s: sorted(locs) for s, locs in split_to_locs.items()}
510
- return split_to_locs
511
-
512
-
513
- def sort_locs_by_size(loc_to_size: MutableMapping[tuple[str, str], int],
514
- prioritize: Optional[Container[str]] = None
515
- ) -> list[tuple[str, str]]:
516
- """
517
- Sorts locations by size, optionally prioritizing locations from certain
518
- datasets first.
519
-
520
- Args:
521
- loc_to_size: dict, maps each (dataset, location) tuple to its size,
522
- modified in-place
523
- prioritize: optional list of str, datasets to prioritize
524
-
525
- Returns: list of (dataset, location) tuples, ordered from smallest size to
526
- largest. Locations from prioritized datasets come first.
527
- """
528
-
529
- result = []
530
- if prioritize is not None:
531
- # modify loc_to_size in place, so copy its keys before iterating
532
- prioritized_loc_to_size = {
533
- loc: loc_to_size.pop(loc) for loc in list(loc_to_size.keys())
534
- if loc[0] in prioritize
535
- }
536
- result = sort_locs_by_size(prioritized_loc_to_size)
537
-
538
- result += sorted(loc_to_size, key=loc_to_size.__getitem__)
539
- return result
540
-
541
-
542
- #%% Command-line driver
543
-
544
- def _parse_args() -> argparse.Namespace:
545
-
546
- parser = argparse.ArgumentParser(
547
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
548
- description='Creates classification dataset.')
549
-
550
- # arguments relevant to both creating the dataset CSV and splits.json
551
- parser.add_argument(
552
- 'output_dir',
553
- help='path to directory where the 3 output files should be '
554
- 'saved: 1) dataset CSV, 2) label index JSON, 3) splits JSON')
555
- parser.add_argument(
556
- '--mode', nargs='+', choices=['csv', 'splits'],
557
- default=['csv', 'splits'],
558
- help='whether to generate only a CSV, only a splits.json file (based '
559
- 'on an existing classification_ds.csv), or both')
560
- parser.add_argument(
561
- '--match-test', nargs=2, metavar=('CLASSIFICATION_CSV', 'SPLITS_JSON'),
562
- help='path to an existing classification CSV and path to an existing '
563
- 'splits JSON file from which to match test set')
564
-
565
- # arguments only relevant for creating the dataset CSV
566
- csv_group = parser.add_argument_group(
567
- 'arguments for creating classification CSV')
568
- csv_group.add_argument(
569
- '-q', '--queried-images-json',
570
- help='path to JSON file containing image paths and classification info')
571
- csv_group.add_argument(
572
- '-c', '--cropped-images-dir',
573
- help='path to local directory for saving crops of bounding boxes')
574
- csv_group.add_argument(
575
- '-d', '--detector-output-cache-dir',
576
- help='(required) path to directory where detector outputs are cached')
577
- csv_group.add_argument(
578
- '-v', '--detector-version',
579
- help='(required) detector version string, e.g., "4.1"')
580
- csv_group.add_argument(
581
- '-t', '--threshold', type=float, default=0.8,
582
- help='confidence threshold above which to crop bounding boxes')
583
- csv_group.add_argument(
584
- '--min-locs', type=int,
585
- help='minimum number of locations that each label must have in order '
586
- 'to be included (does not apply to match-test-splits)')
587
-
588
- # arguments only relevant for creating the splits JSON
589
- splits_group = parser.add_argument_group(
590
- 'arguments for creating train/val/test splits')
591
- splits_group.add_argument(
592
- '--val-frac', type=float,
593
- help='(required) fraction of data to use for validation split')
594
- splits_group.add_argument(
595
- '--test-frac', type=float,
596
- help='fraction of data to use for test split, must be provided if '
597
- '--match-test is not given')
598
- splits_group.add_argument(
599
- '--method', choices=['random', 'smallest_first'], default='random',
600
- help='"random": randomly tries up to 10,000 different train/val/test '
601
- 'splits and chooses the one that best meets the scoring criteria, '
602
- 'does not support --label-spec. '
603
- '"smallest_first": greedily divides locations into splits '
604
- 'starting with the smallest class first. Supports --label-spec.')
605
- splits_group.add_argument(
606
- '--label-spec',
607
- help='optional path to label specification JSON file, if specifying '
608
- 'dataset priority. Requires --method=smallest_first.')
609
- return parser.parse_args()
610
-
611
-
612
- if __name__ == '__main__':
613
-
614
- args = _parse_args()
615
- main(output_dir=args.output_dir,
616
- mode=args.mode,
617
- match_test=args.match_test,
618
- queried_images_json_path=args.queried_images_json,
619
- cropped_images_dir=args.cropped_images_dir,
620
- detector_version=args.detector_version,
621
- detector_output_cache_base_dir=args.detector_output_cache_dir,
622
- confidence_threshold=args.threshold,
623
- min_locs=args.min_locs,
624
- val_frac=args.val_frac,
625
- test_frac=args.test_frac,
626
- splits_method=args.method,
627
- label_spec_json_path=args.label_spec)