megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
merge_classification_detection_output.py
|
|
4
|
+
|
|
5
|
+
Merges classification results with Batch Detection API outputs.
|
|
6
|
+
|
|
7
|
+
This script takes 2 main files as input:
|
|
8
|
+
|
|
9
|
+
1) Either a "dataset CSV" (output of create_classification_dataset.py) or a
|
|
10
|
+
"classification results CSV" (output of evaluate_model.py). The CSV is
|
|
11
|
+
expected to have columns listed below. The 'label' and [label names] columns
|
|
12
|
+
are optional, but at least one of them must be provided.
|
|
13
|
+
* 'path': str, path to cropped image
|
|
14
|
+
* if passing in a detections JSON, must match
|
|
15
|
+
<img_file>___cropXX_mdvY.Y.jpg
|
|
16
|
+
* if passing in a queried images JSON, must match
|
|
17
|
+
<dataset>/<img_file>___cropXX_mdvY.Y.jpg or
|
|
18
|
+
<dataset>/<img_file>___cropXX.jpg
|
|
19
|
+
* 'label': str, label assigned to this crop
|
|
20
|
+
* [label names]: float, confidence in each label
|
|
21
|
+
|
|
22
|
+
2) Either a "detections JSON" (output of MegaDetector) or a "queried images
|
|
23
|
+
JSON" (output of json_validatory.py).
|
|
24
|
+
|
|
25
|
+
If the CSV contains [label names] columns (e.g., output of evaluate_model.py),
|
|
26
|
+
then each crop's "classifications" output will have one value per category.
|
|
27
|
+
Categories are sorted decreasing by confidence.
|
|
28
|
+
"classifications": [
|
|
29
|
+
["3", 0.901],
|
|
30
|
+
["1", 0.071],
|
|
31
|
+
["4", 0.025],
|
|
32
|
+
["2", 0.003],
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
If the CSV only contains the 'label' column (e.g., output of
|
|
36
|
+
create_classification_dataset.py), then each crop's "classifications" output
|
|
37
|
+
will have only one value, with a confidence of 1.0. The label's classification
|
|
38
|
+
category ID is always greater than 1,000,000, to distinguish it from a predicted
|
|
39
|
+
category ID.
|
|
40
|
+
"classifications": [
|
|
41
|
+
["1000004", 1.0]
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
If the CSV contains both [label names] and 'label' columns, then both the
|
|
45
|
+
predicted categories and label category will be included. By default, the
|
|
46
|
+
label-category is included last; if the --label-first flag is given, then the
|
|
47
|
+
label category is placed first in the results.
|
|
48
|
+
"classifications": [
|
|
49
|
+
["1000004", 1.0], # label put first if --label-first flag is given
|
|
50
|
+
["3", 0.901], # all other results are sorted by confidence
|
|
51
|
+
["1", 0.071],
|
|
52
|
+
["4", 0.025],
|
|
53
|
+
["2", 0.003]
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
#%% Imports
|
|
59
|
+
|
|
60
|
+
from __future__ import annotations
|
|
61
|
+
|
|
62
|
+
import argparse
|
|
63
|
+
import datetime
|
|
64
|
+
import json
|
|
65
|
+
import os
|
|
66
|
+
|
|
67
|
+
from collections.abc import Mapping, Sequence
|
|
68
|
+
from typing import Any
|
|
69
|
+
|
|
70
|
+
import pandas as pd
|
|
71
|
+
from tqdm import tqdm
|
|
72
|
+
|
|
73
|
+
from megadetector.utils.ct_utils import round_float
|
|
74
|
+
from megadetector.utils import ct_utils
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
#%% Example usage
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
python merge_classification_detection_output.py \
|
|
81
|
+
BASE_LOGDIR/LOGDIR/outputs_test.csv.gz \
|
|
82
|
+
BASE_LOGDIR/label_index.json \
|
|
83
|
+
BASE_LOGDIR/queried_images.json \
|
|
84
|
+
--classifier-name "efficientnet-b3-idfg-moredata" \
|
|
85
|
+
--detector-output-cache-dir $HOME/classifier-training/mdcache \
|
|
86
|
+
--detector-version "4.1" \
|
|
87
|
+
--output-json BASE_LOGDIR/LOGDIR/classifier_results.json \
|
|
88
|
+
--datasets idfg idfg_swwlf_2019
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
#%% Support functions
|
|
93
|
+
|
|
94
|
+
def row_to_classification_list(row: Mapping[str, Any],
|
|
95
|
+
label_names: Sequence[str],
|
|
96
|
+
contains_preds: bool,
|
|
97
|
+
label_pos: str | None,
|
|
98
|
+
threshold: float,
|
|
99
|
+
relative_conf: bool = False
|
|
100
|
+
) -> list[tuple[str, float]]:
|
|
101
|
+
"""
|
|
102
|
+
Given a mapping from label name to output probability, returns a list of
|
|
103
|
+
tuples, (str(label_id), prob), which can be serialized into the Batch API
|
|
104
|
+
output format.
|
|
105
|
+
|
|
106
|
+
The list of tuples is returned in sorted order by the predicted probability
|
|
107
|
+
for each label.
|
|
108
|
+
|
|
109
|
+
If 'label' is in row and label_pos is not None, then we add
|
|
110
|
+
(label_id + 1_000_000, 1.) to the list. If label_pos='first', we put this at
|
|
111
|
+
the front of the list. Otherwise, we put it at the end.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
contains_label = ('label' in row)
|
|
115
|
+
assert contains_label or contains_preds
|
|
116
|
+
if relative_conf:
|
|
117
|
+
assert contains_label and contains_preds
|
|
118
|
+
|
|
119
|
+
result = []
|
|
120
|
+
if contains_preds:
|
|
121
|
+
result = [(str(i), row[label]) for i, label in enumerate(label_names)]
|
|
122
|
+
if relative_conf:
|
|
123
|
+
label_conf = row[row['label']]
|
|
124
|
+
result = [(k, max(v - label_conf, 0)) for k, v in result]
|
|
125
|
+
|
|
126
|
+
# filter out confidences below the threshold, and set precision to 4
|
|
127
|
+
result = [
|
|
128
|
+
(k, round_float(conf, precision=4))
|
|
129
|
+
for k, conf in result if conf >= threshold
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
# sort from highest to lowest probability
|
|
133
|
+
result = sorted(result, key=lambda x: x[1], reverse=True)
|
|
134
|
+
|
|
135
|
+
if contains_label and label_pos is not None:
|
|
136
|
+
label = row['label']
|
|
137
|
+
label_id = label_names.index(label)
|
|
138
|
+
item = (str(label_id + 1_000_000), 1.)
|
|
139
|
+
if label_pos == 'first':
|
|
140
|
+
result = [item] + result
|
|
141
|
+
else:
|
|
142
|
+
result.append(item)
|
|
143
|
+
return result
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def process_queried_images(
|
|
147
|
+
df: pd.DataFrame,
|
|
148
|
+
queried_images_json_path: str,
|
|
149
|
+
detector_output_cache_base_dir: str,
|
|
150
|
+
detector_version: str,
|
|
151
|
+
datasets: Sequence[str] | None = None,
|
|
152
|
+
samples_per_label: int | None = None,
|
|
153
|
+
seed: int = 123
|
|
154
|
+
) -> dict[str, Any]:
|
|
155
|
+
"""
|
|
156
|
+
Creates a detection JSON object roughly in the Batch API detection
|
|
157
|
+
format.
|
|
158
|
+
|
|
159
|
+
Detections are either ground-truth (from the queried images JSON) or
|
|
160
|
+
retrieved from the detector output cache. Only images corresponding to crop
|
|
161
|
+
paths from the given pd.DataFrame are included in the detection JSON.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
df: pd.DataFrame, either a "classification dataset CSV" or a
|
|
165
|
+
"classification results CSV", column 'path' has format
|
|
166
|
+
<dataset>/<img_file>___cropXX[...].jpg
|
|
167
|
+
queried_images_json_path: str, path to queried images JSON
|
|
168
|
+
detector_output_cache_base_dir: str
|
|
169
|
+
detector_version: str
|
|
170
|
+
datasets: optional list of str, only crops from these datasets will be
|
|
171
|
+
be included in the output, set to None to include all datasets
|
|
172
|
+
samples_per_label: optional int, if not None, then randomly sample this
|
|
173
|
+
many bounding boxes per label (each label must have at least this
|
|
174
|
+
many examples)
|
|
175
|
+
seed: int, used for random sampling if samples_per_label is not None
|
|
176
|
+
|
|
177
|
+
Returns: dict, detections JSON file, except that the 'images' field is a
|
|
178
|
+
dict (img_path => dict) instead of a list
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
# input validation
|
|
182
|
+
assert os.path.exists(queried_images_json_path)
|
|
183
|
+
detection_cache_dir = os.path.join(
|
|
184
|
+
detector_output_cache_base_dir, f'v{detector_version}')
|
|
185
|
+
assert os.path.isdir(detection_cache_dir)
|
|
186
|
+
|
|
187
|
+
# extract dataset name from crop path so we can process 1 dataset at a time
|
|
188
|
+
df['dataset'] = df.index.map(lambda x: x[:x.find('/')])
|
|
189
|
+
unique_datasets = df['dataset'].unique()
|
|
190
|
+
|
|
191
|
+
if datasets is not None:
|
|
192
|
+
for ds in datasets:
|
|
193
|
+
assert ds in unique_datasets
|
|
194
|
+
df = df[df['dataset'].isin(datasets)] # filter by dataset
|
|
195
|
+
else:
|
|
196
|
+
datasets = unique_datasets
|
|
197
|
+
|
|
198
|
+
# randomly sample images for each class
|
|
199
|
+
if samples_per_label is not None:
|
|
200
|
+
print(f'Sampling {samples_per_label} bounding boxes per label')
|
|
201
|
+
df = df.groupby('label').sample(samples_per_label, random_state=seed)
|
|
202
|
+
|
|
203
|
+
# load queried images JSON, needed for ground-truth bbox info
|
|
204
|
+
with open(queried_images_json_path, 'r') as f:
|
|
205
|
+
queried_images_js = json.load(f)
|
|
206
|
+
|
|
207
|
+
merged_js: dict[str, Any] = {
|
|
208
|
+
'images': {}, # start as dict, will convert to list later
|
|
209
|
+
'info': {}
|
|
210
|
+
}
|
|
211
|
+
images = merged_js['images']
|
|
212
|
+
|
|
213
|
+
for ds in datasets:
|
|
214
|
+
print('processing dataset:', ds)
|
|
215
|
+
ds_df = df[df['dataset'] == ds]
|
|
216
|
+
|
|
217
|
+
with open(os.path.join(detection_cache_dir, f'{ds}.json'), 'r') as f:
|
|
218
|
+
detection_js = json.load(f)
|
|
219
|
+
img_file_to_index = {
|
|
220
|
+
im['file']: idx
|
|
221
|
+
for idx, im in enumerate(detection_js['images'])
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
# compare info dicts
|
|
225
|
+
class_info = merged_js['info']
|
|
226
|
+
detection_info = detection_js['info']
|
|
227
|
+
key = 'detector'
|
|
228
|
+
if key not in class_info:
|
|
229
|
+
class_info[key] = detection_info[key]
|
|
230
|
+
assert class_info[key] == detection_info[key]
|
|
231
|
+
|
|
232
|
+
# compare detection categories
|
|
233
|
+
key = 'detection_categories'
|
|
234
|
+
if key not in merged_js:
|
|
235
|
+
merged_js[key] = detection_js[key]
|
|
236
|
+
assert merged_js[key] == detection_js[key]
|
|
237
|
+
cat_to_catid = {v: k for k, v in detection_js[key].items()}
|
|
238
|
+
|
|
239
|
+
for crop_path in tqdm(ds_df.index):
|
|
240
|
+
# crop_path: <dataset>/<img_file>___cropXX_mdvY.Y.jpg
|
|
241
|
+
# [----<img_path>----] [-<suffix>--]
|
|
242
|
+
img_path, suffix = crop_path.split('___crop')
|
|
243
|
+
img_file = img_path[img_path.find('/') + 1:]
|
|
244
|
+
|
|
245
|
+
# file has detection entry
|
|
246
|
+
if '_mdv' in suffix and img_path not in images:
|
|
247
|
+
img_idx = img_file_to_index[img_file]
|
|
248
|
+
images[img_path] = detection_js['images'][img_idx]
|
|
249
|
+
images[img_path]['file'] = img_path
|
|
250
|
+
|
|
251
|
+
# bounding box is from ground truth
|
|
252
|
+
elif img_path not in images:
|
|
253
|
+
images[img_path] = {
|
|
254
|
+
'file': img_path,
|
|
255
|
+
'detections': [
|
|
256
|
+
{
|
|
257
|
+
'category': cat_to_catid[bbox_dict['category']],
|
|
258
|
+
'conf': 1.0,
|
|
259
|
+
'bbox': bbox_dict['bbox']
|
|
260
|
+
}
|
|
261
|
+
for bbox_dict in queried_images_js[img_path]['bbox']
|
|
262
|
+
]
|
|
263
|
+
}
|
|
264
|
+
return merged_js
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def combine_classification_with_detection(
|
|
268
|
+
detection_js: dict[str, Any],
|
|
269
|
+
df: pd.DataFrame,
|
|
270
|
+
idx_to_label: Mapping[str, str],
|
|
271
|
+
label_names: Sequence[str],
|
|
272
|
+
classifier_name: str,
|
|
273
|
+
classifier_timestamp: str,
|
|
274
|
+
threshold: float,
|
|
275
|
+
label_pos: str | None = None,
|
|
276
|
+
relative_conf: bool = False,
|
|
277
|
+
typical_confidence_threshold: float = None
|
|
278
|
+
) -> dict[str, Any]:
|
|
279
|
+
"""
|
|
280
|
+
Adds classification information to a detection JSON. Classification
|
|
281
|
+
information may include the true label and/or the predicted confidences
|
|
282
|
+
of each label.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
detection_js: dict, detections JSON file, except that the 'images'
|
|
286
|
+
field is a dict (img_path => dict) instead of a list
|
|
287
|
+
df: pd.DataFrame, classification results, indexed by crop path
|
|
288
|
+
idx_to_label: dict, str(label_id) => label name, may also include
|
|
289
|
+
str(label_id + 1e6) => 'label: {label_name}'
|
|
290
|
+
label_names: list of str, label names
|
|
291
|
+
classifier_name: str, name of classifier to include in output JSON
|
|
292
|
+
classifier_timestamp: str, timestamp to include in output JSON
|
|
293
|
+
threshold: float, for each crop, omit classification results for
|
|
294
|
+
categories whose confidence is below this threshold
|
|
295
|
+
label_pos: one of [None, 'first', 'last']
|
|
296
|
+
None: do not include labels in the output JSON
|
|
297
|
+
'first' / 'last': position in classification list to put the label
|
|
298
|
+
relative_conf: bool, if True then for each class, outputs its relative
|
|
299
|
+
confidence over the confidence of the true label, requires 'label'
|
|
300
|
+
to be in CSV
|
|
301
|
+
typical_confidence_threshold: float, useful default confidence
|
|
302
|
+
threshold; not used directly, just passed along to the output file
|
|
303
|
+
|
|
304
|
+
Returns: dict, detections JSON file updated with classification results
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
classification_metadata = {
|
|
308
|
+
'classifier': classifier_name,
|
|
309
|
+
'classification_completion_time': classifier_timestamp
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
if typical_confidence_threshold is not None:
|
|
313
|
+
classification_metadata['classifier_metadata'] = \
|
|
314
|
+
{'typical_classification_threshold':typical_confidence_threshold}
|
|
315
|
+
|
|
316
|
+
detection_js['info'].update(classification_metadata)
|
|
317
|
+
detection_js['classification_categories'] = idx_to_label
|
|
318
|
+
|
|
319
|
+
contains_preds = (set(label_names) <= set(df.columns))
|
|
320
|
+
if not contains_preds:
|
|
321
|
+
print('CSV does not contain predictions. Outputting labels only.')
|
|
322
|
+
|
|
323
|
+
images = detection_js['images']
|
|
324
|
+
|
|
325
|
+
for crop_path in tqdm(df.index):
|
|
326
|
+
# crop_path: <dataset>/<img_file>___cropXX_mdvY.Y.jpg
|
|
327
|
+
# [----<img_path>----] [-<suffix>--]
|
|
328
|
+
img_path, suffix = crop_path.split('___crop')
|
|
329
|
+
crop_index = int(suffix[:2])
|
|
330
|
+
|
|
331
|
+
detection_dict = images[img_path]['detections'][crop_index]
|
|
332
|
+
detection_dict['classifications'] = row_to_classification_list(
|
|
333
|
+
row=df.loc[crop_path], label_names=label_names,
|
|
334
|
+
contains_preds=contains_preds, label_pos=label_pos,
|
|
335
|
+
threshold=threshold, relative_conf=relative_conf)
|
|
336
|
+
|
|
337
|
+
detection_js['images'] = list(images.values())
|
|
338
|
+
return detection_js
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
#%% Main function
|
|
342
|
+
|
|
343
|
+
def main(classification_csv_path: str,
|
|
344
|
+
label_names_json_path: str,
|
|
345
|
+
output_json_path: str,
|
|
346
|
+
classifier_name: str,
|
|
347
|
+
threshold: float,
|
|
348
|
+
datasets: Sequence[str] | None,
|
|
349
|
+
detection_json_path: str | None,
|
|
350
|
+
queried_images_json_path: str | None,
|
|
351
|
+
detector_output_cache_base_dir: str | None,
|
|
352
|
+
detector_version: str | None,
|
|
353
|
+
samples_per_label: int | None,
|
|
354
|
+
seed: int,
|
|
355
|
+
label_pos: str | None,
|
|
356
|
+
relative_conf: bool,
|
|
357
|
+
typical_confidence_threshold: float) -> None:
|
|
358
|
+
|
|
359
|
+
# input validation
|
|
360
|
+
assert os.path.exists(classification_csv_path)
|
|
361
|
+
assert os.path.exists(label_names_json_path)
|
|
362
|
+
assert 0 <= threshold <= 1
|
|
363
|
+
for x in [detection_json_path, queried_images_json_path]:
|
|
364
|
+
if x is not None:
|
|
365
|
+
assert os.path.exists(x)
|
|
366
|
+
assert label_pos in [None, 'first', 'last']
|
|
367
|
+
|
|
368
|
+
# load classification CSV
|
|
369
|
+
print('Loading classification CSV...')
|
|
370
|
+
df = pd.read_csv(classification_csv_path, float_precision='high',
|
|
371
|
+
index_col='path')
|
|
372
|
+
if relative_conf or label_pos is not None:
|
|
373
|
+
assert 'label' in df.columns
|
|
374
|
+
|
|
375
|
+
# load label names
|
|
376
|
+
with open(label_names_json_path, 'r') as f:
|
|
377
|
+
idx_to_label = json.load(f)
|
|
378
|
+
label_names = [idx_to_label[str(i)] for i in range(len(idx_to_label))]
|
|
379
|
+
if 'label' in df.columns:
|
|
380
|
+
for i, label in enumerate(label_names):
|
|
381
|
+
idx_to_label[str(i + 1_000_000)] = f'label: {label}'
|
|
382
|
+
|
|
383
|
+
if queried_images_json_path is not None:
|
|
384
|
+
assert detector_output_cache_base_dir is not None
|
|
385
|
+
assert detector_version is not None
|
|
386
|
+
detection_js = process_queried_images(
|
|
387
|
+
df=df, queried_images_json_path=queried_images_json_path,
|
|
388
|
+
detector_output_cache_base_dir=detector_output_cache_base_dir,
|
|
389
|
+
detector_version=detector_version, datasets=datasets,
|
|
390
|
+
samples_per_label=samples_per_label, seed=seed)
|
|
391
|
+
elif detection_json_path is not None:
|
|
392
|
+
with open(detection_json_path, 'r') as f:
|
|
393
|
+
detection_js = json.load(f)
|
|
394
|
+
images = {}
|
|
395
|
+
for img in detection_js['images']:
|
|
396
|
+
path = img['file']
|
|
397
|
+
if datasets is None or path[:path.find('/')] in datasets:
|
|
398
|
+
images[path] = img
|
|
399
|
+
detection_js['images'] = images
|
|
400
|
+
|
|
401
|
+
classification_time = datetime.date.fromtimestamp(
|
|
402
|
+
os.path.getmtime(classification_csv_path))
|
|
403
|
+
classifier_timestamp = classification_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
404
|
+
|
|
405
|
+
classification_js = combine_classification_with_detection(
|
|
406
|
+
detection_js=detection_js, df=df, idx_to_label=idx_to_label,
|
|
407
|
+
label_names=label_names, classifier_name=classifier_name,
|
|
408
|
+
classifier_timestamp=classifier_timestamp, threshold=threshold,
|
|
409
|
+
label_pos=label_pos, relative_conf=relative_conf,
|
|
410
|
+
typical_confidence_threshold=typical_confidence_threshold)
|
|
411
|
+
|
|
412
|
+
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
|
|
413
|
+
# The following line was removed as per the previous refactoring:
|
|
414
|
+
# with open(output_json_path, 'w') as f:
|
|
415
|
+
# json.dump(classification_js, f, indent=1)
|
|
416
|
+
ct_utils.write_json(output_json_path, classification_js)
|
|
417
|
+
|
|
418
|
+
print('Wrote merged classification/detection results to {}'.format(output_json_path))
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
#%% Command-line driver
|
|
422
|
+
|
|
423
|
+
def _parse_args() -> argparse.Namespace:
|
|
424
|
+
|
|
425
|
+
parser = argparse.ArgumentParser(
|
|
426
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
427
|
+
description='Merges classification results with Batch Detection API '
|
|
428
|
+
'outputs.')
|
|
429
|
+
parser.add_argument(
|
|
430
|
+
'classification_csv',
|
|
431
|
+
help='path to classification CSV')
|
|
432
|
+
parser.add_argument(
|
|
433
|
+
'label_names_json',
|
|
434
|
+
help='path to JSON file mapping label index to label name')
|
|
435
|
+
parser.add_argument(
|
|
436
|
+
'-o', '--output-json', required=True,
|
|
437
|
+
help='(required) path to save output JSON with both detection and '
|
|
438
|
+
'classification results')
|
|
439
|
+
parser.add_argument(
|
|
440
|
+
'-n', '--classifier-name', required=True,
|
|
441
|
+
help='(required) name of classifier')
|
|
442
|
+
parser.add_argument(
|
|
443
|
+
'-t', '--threshold', type=float, default=0.1,
|
|
444
|
+
help='Confidence threshold between 0 and 1. In the output file, omit '
|
|
445
|
+
'classifier results on classes whose confidence is below this '
|
|
446
|
+
'threshold.')
|
|
447
|
+
parser.add_argument(
|
|
448
|
+
'-d', '--datasets', nargs='*',
|
|
449
|
+
help='optionally limit output to images from certain datasets. Assumes '
|
|
450
|
+
'that image paths are given as <dataset>/<img_file>.')
|
|
451
|
+
parser.add_argument(
|
|
452
|
+
'--typical-confidence-threshold', type=float, default=None,
|
|
453
|
+
help='useful default confidence threshold; not used directly, just '
|
|
454
|
+
'passed along to the output file')
|
|
455
|
+
|
|
456
|
+
detection_json_group = parser.add_argument_group(
|
|
457
|
+
'arguments for passing in a detections JSON file')
|
|
458
|
+
detection_json_group.add_argument(
|
|
459
|
+
'-j', '--detection-json',
|
|
460
|
+
help='path to detections JSON file')
|
|
461
|
+
|
|
462
|
+
queried_images_group = parser.add_argument_group(
|
|
463
|
+
'arguments for passing in a queried images JSON file')
|
|
464
|
+
queried_images_group.add_argument(
|
|
465
|
+
'-q', '--queried-images-json',
|
|
466
|
+
help='path to queried images JSON file')
|
|
467
|
+
queried_images_group.add_argument(
|
|
468
|
+
'-c', '--detector-output-cache-dir',
|
|
469
|
+
help='(required) path to directory where detector outputs are cached')
|
|
470
|
+
queried_images_group.add_argument(
|
|
471
|
+
'-v', '--detector-version',
|
|
472
|
+
help='(required) detector version string, e.g., "4.1"')
|
|
473
|
+
queried_images_group.add_argument(
|
|
474
|
+
'-s', '--samples-per-label', type=int,
|
|
475
|
+
help='randomly sample this many bounding boxes per label (each label '
|
|
476
|
+
'must have at least this many examples)')
|
|
477
|
+
queried_images_group.add_argument(
|
|
478
|
+
'--seed', type=int, default=123,
|
|
479
|
+
help='random seed, only used if --samples-per-label is given')
|
|
480
|
+
queried_images_group.add_argument(
|
|
481
|
+
'--label', choices=['first', 'last'], default=None,
|
|
482
|
+
help='Whether to put the label first or last in the list of '
|
|
483
|
+
'classifications. If this argument is omitted, then no labels are '
|
|
484
|
+
'included in the output.')
|
|
485
|
+
queried_images_group.add_argument(
|
|
486
|
+
'--relative-conf', action='store_true',
|
|
487
|
+
help='for each class, outputs its relative confidence over the '
|
|
488
|
+
'confidence of the true label, requires "label" to be in CSV')
|
|
489
|
+
return parser.parse_args()
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
if __name__ == '__main__':
|
|
493
|
+
|
|
494
|
+
args = _parse_args()
|
|
495
|
+
main(classification_csv_path=args.classification_csv,
|
|
496
|
+
label_names_json_path=args.label_names_json,
|
|
497
|
+
output_json_path=args.output_json,
|
|
498
|
+
classifier_name=args.classifier_name,
|
|
499
|
+
threshold=args.threshold,
|
|
500
|
+
datasets=args.datasets,
|
|
501
|
+
detection_json_path=args.detection_json,
|
|
502
|
+
queried_images_json_path=args.queried_images_json,
|
|
503
|
+
detector_output_cache_base_dir=args.detector_output_cache_dir,
|
|
504
|
+
detector_version=args.detector_version,
|
|
505
|
+
samples_per_label=args.samples_per_label,
|
|
506
|
+
seed=args.seed,
|
|
507
|
+
label_pos=args.label,
|
|
508
|
+
relative_conf=args.relative_conf,
|
|
509
|
+
typical_confidence_threshold=args.typical_confidence_threshold)
|