megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,509 @@
1
+ """
2
+
3
+ merge_classification_detection_output.py
4
+
5
+ Merges classification results with Batch Detection API outputs.
6
+
7
+ This script takes 2 main files as input:
8
+
9
+ 1) Either a "dataset CSV" (output of create_classification_dataset.py) or a
10
+ "classification results CSV" (output of evaluate_model.py). The CSV is
11
+ expected to have columns listed below. The 'label' and [label names] columns
12
+ are optional, but at least one of them must be provided.
13
+ * 'path': str, path to cropped image
14
+ * if passing in a detections JSON, must match
15
+ <img_file>___cropXX_mdvY.Y.jpg
16
+ * if passing in a queried images JSON, must match
17
+ <dataset>/<img_file>___cropXX_mdvY.Y.jpg or
18
+ <dataset>/<img_file>___cropXX.jpg
19
+ * 'label': str, label assigned to this crop
20
+ * [label names]: float, confidence in each label
21
+
22
+ 2) Either a "detections JSON" (output of MegaDetector) or a "queried images
23
+ JSON" (output of json_validatory.py).
24
+
25
+ If the CSV contains [label names] columns (e.g., output of evaluate_model.py),
26
+ then each crop's "classifications" output will have one value per category.
27
+ Categories are sorted decreasing by confidence.
28
+ "classifications": [
29
+ ["3", 0.901],
30
+ ["1", 0.071],
31
+ ["4", 0.025],
32
+ ["2", 0.003],
33
+ ]
34
+
35
+ If the CSV only contains the 'label' column (e.g., output of
36
+ create_classification_dataset.py), then each crop's "classifications" output
37
+ will have only one value, with a confidence of 1.0. The label's classification
38
+ category ID is always greater than 1,000,000, to distinguish it from a predicted
39
+ category ID.
40
+ "classifications": [
41
+ ["1000004", 1.0]
42
+ ]
43
+
44
+ If the CSV contains both [label names] and 'label' columns, then both the
45
+ predicted categories and label category will be included. By default, the
46
+ label-category is included last; if the --label-first flag is given, then the
47
+ label category is placed first in the results.
48
+ "classifications": [
49
+ ["1000004", 1.0], # label put first if --label-first flag is given
50
+ ["3", 0.901], # all other results are sorted by confidence
51
+ ["1", 0.071],
52
+ ["4", 0.025],
53
+ ["2", 0.003]
54
+ ]
55
+
56
+ """
57
+
58
+ #%% Imports
59
+
60
+ from __future__ import annotations
61
+
62
+ import argparse
63
+ import datetime
64
+ import json
65
+ import os
66
+
67
+ from collections.abc import Mapping, Sequence
68
+ from typing import Any
69
+
70
+ import pandas as pd
71
+ from tqdm import tqdm
72
+
73
+ from megadetector.utils.ct_utils import round_float
74
+ from megadetector.utils import ct_utils
75
+
76
+
77
+ #%% Example usage
78
+
79
+ """
80
+ python merge_classification_detection_output.py \
81
+ BASE_LOGDIR/LOGDIR/outputs_test.csv.gz \
82
+ BASE_LOGDIR/label_index.json \
83
+ BASE_LOGDIR/queried_images.json \
84
+ --classifier-name "efficientnet-b3-idfg-moredata" \
85
+ --detector-output-cache-dir $HOME/classifier-training/mdcache \
86
+ --detector-version "4.1" \
87
+ --output-json BASE_LOGDIR/LOGDIR/classifier_results.json \
88
+ --datasets idfg idfg_swwlf_2019
89
+ """
90
+
91
+
92
+ #%% Support functions
93
+
94
+ def row_to_classification_list(row: Mapping[str, Any],
95
+ label_names: Sequence[str],
96
+ contains_preds: bool,
97
+ label_pos: str | None,
98
+ threshold: float,
99
+ relative_conf: bool = False
100
+ ) -> list[tuple[str, float]]:
101
+ """
102
+ Given a mapping from label name to output probability, returns a list of
103
+ tuples, (str(label_id), prob), which can be serialized into the Batch API
104
+ output format.
105
+
106
+ The list of tuples is returned in sorted order by the predicted probability
107
+ for each label.
108
+
109
+ If 'label' is in row and label_pos is not None, then we add
110
+ (label_id + 1_000_000, 1.) to the list. If label_pos='first', we put this at
111
+ the front of the list. Otherwise, we put it at the end.
112
+ """
113
+
114
+ contains_label = ('label' in row)
115
+ assert contains_label or contains_preds
116
+ if relative_conf:
117
+ assert contains_label and contains_preds
118
+
119
+ result = []
120
+ if contains_preds:
121
+ result = [(str(i), row[label]) for i, label in enumerate(label_names)]
122
+ if relative_conf:
123
+ label_conf = row[row['label']]
124
+ result = [(k, max(v - label_conf, 0)) for k, v in result]
125
+
126
+ # filter out confidences below the threshold, and set precision to 4
127
+ result = [
128
+ (k, round_float(conf, precision=4))
129
+ for k, conf in result if conf >= threshold
130
+ ]
131
+
132
+ # sort from highest to lowest probability
133
+ result = sorted(result, key=lambda x: x[1], reverse=True)
134
+
135
+ if contains_label and label_pos is not None:
136
+ label = row['label']
137
+ label_id = label_names.index(label)
138
+ item = (str(label_id + 1_000_000), 1.)
139
+ if label_pos == 'first':
140
+ result = [item] + result
141
+ else:
142
+ result.append(item)
143
+ return result
144
+
145
+
146
+ def process_queried_images(
147
+ df: pd.DataFrame,
148
+ queried_images_json_path: str,
149
+ detector_output_cache_base_dir: str,
150
+ detector_version: str,
151
+ datasets: Sequence[str] | None = None,
152
+ samples_per_label: int | None = None,
153
+ seed: int = 123
154
+ ) -> dict[str, Any]:
155
+ """
156
+ Creates a detection JSON object roughly in the Batch API detection
157
+ format.
158
+
159
+ Detections are either ground-truth (from the queried images JSON) or
160
+ retrieved from the detector output cache. Only images corresponding to crop
161
+ paths from the given pd.DataFrame are included in the detection JSON.
162
+
163
+ Args:
164
+ df: pd.DataFrame, either a "classification dataset CSV" or a
165
+ "classification results CSV", column 'path' has format
166
+ <dataset>/<img_file>___cropXX[...].jpg
167
+ queried_images_json_path: str, path to queried images JSON
168
+ detector_output_cache_base_dir: str
169
+ detector_version: str
170
+ datasets: optional list of str, only crops from these datasets will be
171
+ be included in the output, set to None to include all datasets
172
+ samples_per_label: optional int, if not None, then randomly sample this
173
+ many bounding boxes per label (each label must have at least this
174
+ many examples)
175
+ seed: int, used for random sampling if samples_per_label is not None
176
+
177
+ Returns: dict, detections JSON file, except that the 'images' field is a
178
+ dict (img_path => dict) instead of a list
179
+ """
180
+
181
+ # input validation
182
+ assert os.path.exists(queried_images_json_path)
183
+ detection_cache_dir = os.path.join(
184
+ detector_output_cache_base_dir, f'v{detector_version}')
185
+ assert os.path.isdir(detection_cache_dir)
186
+
187
+ # extract dataset name from crop path so we can process 1 dataset at a time
188
+ df['dataset'] = df.index.map(lambda x: x[:x.find('/')])
189
+ unique_datasets = df['dataset'].unique()
190
+
191
+ if datasets is not None:
192
+ for ds in datasets:
193
+ assert ds in unique_datasets
194
+ df = df[df['dataset'].isin(datasets)] # filter by dataset
195
+ else:
196
+ datasets = unique_datasets
197
+
198
+ # randomly sample images for each class
199
+ if samples_per_label is not None:
200
+ print(f'Sampling {samples_per_label} bounding boxes per label')
201
+ df = df.groupby('label').sample(samples_per_label, random_state=seed)
202
+
203
+ # load queried images JSON, needed for ground-truth bbox info
204
+ with open(queried_images_json_path, 'r') as f:
205
+ queried_images_js = json.load(f)
206
+
207
+ merged_js: dict[str, Any] = {
208
+ 'images': {}, # start as dict, will convert to list later
209
+ 'info': {}
210
+ }
211
+ images = merged_js['images']
212
+
213
+ for ds in datasets:
214
+ print('processing dataset:', ds)
215
+ ds_df = df[df['dataset'] == ds]
216
+
217
+ with open(os.path.join(detection_cache_dir, f'{ds}.json'), 'r') as f:
218
+ detection_js = json.load(f)
219
+ img_file_to_index = {
220
+ im['file']: idx
221
+ for idx, im in enumerate(detection_js['images'])
222
+ }
223
+
224
+ # compare info dicts
225
+ class_info = merged_js['info']
226
+ detection_info = detection_js['info']
227
+ key = 'detector'
228
+ if key not in class_info:
229
+ class_info[key] = detection_info[key]
230
+ assert class_info[key] == detection_info[key]
231
+
232
+ # compare detection categories
233
+ key = 'detection_categories'
234
+ if key not in merged_js:
235
+ merged_js[key] = detection_js[key]
236
+ assert merged_js[key] == detection_js[key]
237
+ cat_to_catid = {v: k for k, v in detection_js[key].items()}
238
+
239
+ for crop_path in tqdm(ds_df.index):
240
+ # crop_path: <dataset>/<img_file>___cropXX_mdvY.Y.jpg
241
+ # [----<img_path>----] [-<suffix>--]
242
+ img_path, suffix = crop_path.split('___crop')
243
+ img_file = img_path[img_path.find('/') + 1:]
244
+
245
+ # file has detection entry
246
+ if '_mdv' in suffix and img_path not in images:
247
+ img_idx = img_file_to_index[img_file]
248
+ images[img_path] = detection_js['images'][img_idx]
249
+ images[img_path]['file'] = img_path
250
+
251
+ # bounding box is from ground truth
252
+ elif img_path not in images:
253
+ images[img_path] = {
254
+ 'file': img_path,
255
+ 'detections': [
256
+ {
257
+ 'category': cat_to_catid[bbox_dict['category']],
258
+ 'conf': 1.0,
259
+ 'bbox': bbox_dict['bbox']
260
+ }
261
+ for bbox_dict in queried_images_js[img_path]['bbox']
262
+ ]
263
+ }
264
+ return merged_js
265
+
266
+
267
+ def combine_classification_with_detection(
268
+ detection_js: dict[str, Any],
269
+ df: pd.DataFrame,
270
+ idx_to_label: Mapping[str, str],
271
+ label_names: Sequence[str],
272
+ classifier_name: str,
273
+ classifier_timestamp: str,
274
+ threshold: float,
275
+ label_pos: str | None = None,
276
+ relative_conf: bool = False,
277
+ typical_confidence_threshold: float = None
278
+ ) -> dict[str, Any]:
279
+ """
280
+ Adds classification information to a detection JSON. Classification
281
+ information may include the true label and/or the predicted confidences
282
+ of each label.
283
+
284
+ Args:
285
+ detection_js: dict, detections JSON file, except that the 'images'
286
+ field is a dict (img_path => dict) instead of a list
287
+ df: pd.DataFrame, classification results, indexed by crop path
288
+ idx_to_label: dict, str(label_id) => label name, may also include
289
+ str(label_id + 1e6) => 'label: {label_name}'
290
+ label_names: list of str, label names
291
+ classifier_name: str, name of classifier to include in output JSON
292
+ classifier_timestamp: str, timestamp to include in output JSON
293
+ threshold: float, for each crop, omit classification results for
294
+ categories whose confidence is below this threshold
295
+ label_pos: one of [None, 'first', 'last']
296
+ None: do not include labels in the output JSON
297
+ 'first' / 'last': position in classification list to put the label
298
+ relative_conf: bool, if True then for each class, outputs its relative
299
+ confidence over the confidence of the true label, requires 'label'
300
+ to be in CSV
301
+ typical_confidence_threshold: float, useful default confidence
302
+ threshold; not used directly, just passed along to the output file
303
+
304
+ Returns: dict, detections JSON file updated with classification results
305
+ """
306
+
307
+ classification_metadata = {
308
+ 'classifier': classifier_name,
309
+ 'classification_completion_time': classifier_timestamp
310
+ }
311
+
312
+ if typical_confidence_threshold is not None:
313
+ classification_metadata['classifier_metadata'] = \
314
+ {'typical_classification_threshold':typical_confidence_threshold}
315
+
316
+ detection_js['info'].update(classification_metadata)
317
+ detection_js['classification_categories'] = idx_to_label
318
+
319
+ contains_preds = (set(label_names) <= set(df.columns))
320
+ if not contains_preds:
321
+ print('CSV does not contain predictions. Outputting labels only.')
322
+
323
+ images = detection_js['images']
324
+
325
+ for crop_path in tqdm(df.index):
326
+ # crop_path: <dataset>/<img_file>___cropXX_mdvY.Y.jpg
327
+ # [----<img_path>----] [-<suffix>--]
328
+ img_path, suffix = crop_path.split('___crop')
329
+ crop_index = int(suffix[:2])
330
+
331
+ detection_dict = images[img_path]['detections'][crop_index]
332
+ detection_dict['classifications'] = row_to_classification_list(
333
+ row=df.loc[crop_path], label_names=label_names,
334
+ contains_preds=contains_preds, label_pos=label_pos,
335
+ threshold=threshold, relative_conf=relative_conf)
336
+
337
+ detection_js['images'] = list(images.values())
338
+ return detection_js
339
+
340
+
341
+ #%% Main function
342
+
343
+ def main(classification_csv_path: str,
344
+ label_names_json_path: str,
345
+ output_json_path: str,
346
+ classifier_name: str,
347
+ threshold: float,
348
+ datasets: Sequence[str] | None,
349
+ detection_json_path: str | None,
350
+ queried_images_json_path: str | None,
351
+ detector_output_cache_base_dir: str | None,
352
+ detector_version: str | None,
353
+ samples_per_label: int | None,
354
+ seed: int,
355
+ label_pos: str | None,
356
+ relative_conf: bool,
357
+ typical_confidence_threshold: float) -> None:
358
+
359
+ # input validation
360
+ assert os.path.exists(classification_csv_path)
361
+ assert os.path.exists(label_names_json_path)
362
+ assert 0 <= threshold <= 1
363
+ for x in [detection_json_path, queried_images_json_path]:
364
+ if x is not None:
365
+ assert os.path.exists(x)
366
+ assert label_pos in [None, 'first', 'last']
367
+
368
+ # load classification CSV
369
+ print('Loading classification CSV...')
370
+ df = pd.read_csv(classification_csv_path, float_precision='high',
371
+ index_col='path')
372
+ if relative_conf or label_pos is not None:
373
+ assert 'label' in df.columns
374
+
375
+ # load label names
376
+ with open(label_names_json_path, 'r') as f:
377
+ idx_to_label = json.load(f)
378
+ label_names = [idx_to_label[str(i)] for i in range(len(idx_to_label))]
379
+ if 'label' in df.columns:
380
+ for i, label in enumerate(label_names):
381
+ idx_to_label[str(i + 1_000_000)] = f'label: {label}'
382
+
383
+ if queried_images_json_path is not None:
384
+ assert detector_output_cache_base_dir is not None
385
+ assert detector_version is not None
386
+ detection_js = process_queried_images(
387
+ df=df, queried_images_json_path=queried_images_json_path,
388
+ detector_output_cache_base_dir=detector_output_cache_base_dir,
389
+ detector_version=detector_version, datasets=datasets,
390
+ samples_per_label=samples_per_label, seed=seed)
391
+ elif detection_json_path is not None:
392
+ with open(detection_json_path, 'r') as f:
393
+ detection_js = json.load(f)
394
+ images = {}
395
+ for img in detection_js['images']:
396
+ path = img['file']
397
+ if datasets is None or path[:path.find('/')] in datasets:
398
+ images[path] = img
399
+ detection_js['images'] = images
400
+
401
+ classification_time = datetime.date.fromtimestamp(
402
+ os.path.getmtime(classification_csv_path))
403
+ classifier_timestamp = classification_time.strftime('%Y-%m-%d %H:%M:%S')
404
+
405
+ classification_js = combine_classification_with_detection(
406
+ detection_js=detection_js, df=df, idx_to_label=idx_to_label,
407
+ label_names=label_names, classifier_name=classifier_name,
408
+ classifier_timestamp=classifier_timestamp, threshold=threshold,
409
+ label_pos=label_pos, relative_conf=relative_conf,
410
+ typical_confidence_threshold=typical_confidence_threshold)
411
+
412
+ os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
413
+ # The following line was removed as per the previous refactoring:
414
+ # with open(output_json_path, 'w') as f:
415
+ # json.dump(classification_js, f, indent=1)
416
+ ct_utils.write_json(output_json_path, classification_js)
417
+
418
+ print('Wrote merged classification/detection results to {}'.format(output_json_path))
419
+
420
+
421
+ #%% Command-line driver
422
+
423
+ def _parse_args() -> argparse.Namespace:
424
+
425
+ parser = argparse.ArgumentParser(
426
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
427
+ description='Merges classification results with Batch Detection API '
428
+ 'outputs.')
429
+ parser.add_argument(
430
+ 'classification_csv',
431
+ help='path to classification CSV')
432
+ parser.add_argument(
433
+ 'label_names_json',
434
+ help='path to JSON file mapping label index to label name')
435
+ parser.add_argument(
436
+ '-o', '--output-json', required=True,
437
+ help='(required) path to save output JSON with both detection and '
438
+ 'classification results')
439
+ parser.add_argument(
440
+ '-n', '--classifier-name', required=True,
441
+ help='(required) name of classifier')
442
+ parser.add_argument(
443
+ '-t', '--threshold', type=float, default=0.1,
444
+ help='Confidence threshold between 0 and 1. In the output file, omit '
445
+ 'classifier results on classes whose confidence is below this '
446
+ 'threshold.')
447
+ parser.add_argument(
448
+ '-d', '--datasets', nargs='*',
449
+ help='optionally limit output to images from certain datasets. Assumes '
450
+ 'that image paths are given as <dataset>/<img_file>.')
451
+ parser.add_argument(
452
+ '--typical-confidence-threshold', type=float, default=None,
453
+ help='useful default confidence threshold; not used directly, just '
454
+ 'passed along to the output file')
455
+
456
+ detection_json_group = parser.add_argument_group(
457
+ 'arguments for passing in a detections JSON file')
458
+ detection_json_group.add_argument(
459
+ '-j', '--detection-json',
460
+ help='path to detections JSON file')
461
+
462
+ queried_images_group = parser.add_argument_group(
463
+ 'arguments for passing in a queried images JSON file')
464
+ queried_images_group.add_argument(
465
+ '-q', '--queried-images-json',
466
+ help='path to queried images JSON file')
467
+ queried_images_group.add_argument(
468
+ '-c', '--detector-output-cache-dir',
469
+ help='(required) path to directory where detector outputs are cached')
470
+ queried_images_group.add_argument(
471
+ '-v', '--detector-version',
472
+ help='(required) detector version string, e.g., "4.1"')
473
+ queried_images_group.add_argument(
474
+ '-s', '--samples-per-label', type=int,
475
+ help='randomly sample this many bounding boxes per label (each label '
476
+ 'must have at least this many examples)')
477
+ queried_images_group.add_argument(
478
+ '--seed', type=int, default=123,
479
+ help='random seed, only used if --samples-per-label is given')
480
+ queried_images_group.add_argument(
481
+ '--label', choices=['first', 'last'], default=None,
482
+ help='Whether to put the label first or last in the list of '
483
+ 'classifications. If this argument is omitted, then no labels are '
484
+ 'included in the output.')
485
+ queried_images_group.add_argument(
486
+ '--relative-conf', action='store_true',
487
+ help='for each class, outputs its relative confidence over the '
488
+ 'confidence of the true label, requires "label" to be in CSV')
489
+ return parser.parse_args()
490
+
491
+
492
+ if __name__ == '__main__':
493
+
494
+ args = _parse_args()
495
+ main(classification_csv_path=args.classification_csv,
496
+ label_names_json_path=args.label_names_json,
497
+ output_json_path=args.output_json,
498
+ classifier_name=args.classifier_name,
499
+ threshold=args.threshold,
500
+ datasets=args.datasets,
501
+ detection_json_path=args.detection_json,
502
+ queried_images_json_path=args.queried_images_json,
503
+ detector_output_cache_base_dir=args.detector_output_cache_dir,
504
+ detector_version=args.detector_version,
505
+ samples_per_label=args.samples_per_label,
506
+ seed=args.seed,
507
+ label_pos=args.label,
508
+ relative_conf=args.relative_conf,
509
+ typical_confidence_threshold=args.typical_confidence_threshold)