megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
evaluate_model.py
|
|
4
|
+
|
|
5
|
+
Evaluate a species classifier.
|
|
6
|
+
|
|
7
|
+
Currently the implementation of multi-label multi-class classification is
|
|
8
|
+
non-functional.
|
|
9
|
+
|
|
10
|
+
Outputs the following files:
|
|
11
|
+
|
|
12
|
+
1) outputs_{split}.csv, one file per split, contains columns:
|
|
13
|
+
- 'path': str, path to cropped image
|
|
14
|
+
- 'label': str
|
|
15
|
+
- 'weight': float
|
|
16
|
+
- [label names]: float, confidence in each label
|
|
17
|
+
|
|
18
|
+
2) overall_metrics.csv, contains columns:
|
|
19
|
+
- 'split': str
|
|
20
|
+
- 'loss': float, mean per-example loss over entire epoch
|
|
21
|
+
- 'acc_top{k}': float, accuracy@k over the entire epoch
|
|
22
|
+
- 'loss_weighted' and 'acc_weighted_top{k}': float, weighted versions
|
|
23
|
+
|
|
24
|
+
3) confusion_matrices.npz
|
|
25
|
+
- keys ['train', 'val', 'test']
|
|
26
|
+
- values are np.ndarray, confusion matrices
|
|
27
|
+
|
|
28
|
+
4) label_stats.csv, per-label statistics, columns
|
|
29
|
+
- 'split': str
|
|
30
|
+
- 'label': str
|
|
31
|
+
- 'precision': float
|
|
32
|
+
- 'recall': float
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
#%% Imports and constants
|
|
37
|
+
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
import argparse
|
|
41
|
+
from collections.abc import Mapping, Sequence
|
|
42
|
+
import json
|
|
43
|
+
import os
|
|
44
|
+
from pprint import pprint
|
|
45
|
+
from typing import Any, Optional
|
|
46
|
+
|
|
47
|
+
import numpy as np
|
|
48
|
+
import pandas as pd
|
|
49
|
+
import sklearn.metrics
|
|
50
|
+
import torch
|
|
51
|
+
import torchvision
|
|
52
|
+
import tqdm
|
|
53
|
+
|
|
54
|
+
from megadetector.classification import efficientnet, train_classifier
|
|
55
|
+
from megadetector.utils import ct_utils
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
#%% Example usage
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
python evaluate_model.py \
|
|
62
|
+
$BASE_LOGDIR/$LOGDIR/params.json \
|
|
63
|
+
$BASE_LOGDIR/$LOGDIR/ckpt_XX.pt \
|
|
64
|
+
--output-dir $BASE_LOGDIR/$LOGDIR \
|
|
65
|
+
--splits train val test \
|
|
66
|
+
--batch-size 256
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
SPLITS = ['train', 'val', 'test']
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
#%% Support functions
|
|
73
|
+
|
|
74
|
+
def check_override(params: Mapping[str, Any], key: str,
|
|
75
|
+
override: Optional[Any]) -> Any:
|
|
76
|
+
"""
|
|
77
|
+
Return desired value, with optional override.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
if override is None:
|
|
81
|
+
return params[key]
|
|
82
|
+
saved = params.get(key, None)
|
|
83
|
+
print(f'Overriding saved {key}. Saved: {saved}. Override with: {override}.')
|
|
84
|
+
return override
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def trace_model(model_name: str, ckpt_path: str, num_classes: int,
|
|
88
|
+
img_size: int) -> str:
|
|
89
|
+
"""
|
|
90
|
+
Use TorchScript tracing to compile trained model into standalone file.
|
|
91
|
+
|
|
92
|
+
For now, we have to use tracing instead of scripting. See
|
|
93
|
+
https://github.com/lukemelas/EfficientNet-PyTorch/issues/89
|
|
94
|
+
https://github.com/lukemelas/EfficientNet-PyTorch/issues/218
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
model_name: str
|
|
98
|
+
ckpt_path: str, path to checkpoint file
|
|
99
|
+
num_labels: int, number of classification classes
|
|
100
|
+
img_size: int, size of input image, used for tracing
|
|
101
|
+
|
|
102
|
+
Returns: str, name of file for compiled model. For example, if ckpt_path is
|
|
103
|
+
'/path/to/ckpt_16.pt', then the returned path is
|
|
104
|
+
'/path/to/ckpt_16_compiled.pt'.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
root, ext = os.path.splitext(ckpt_path)
|
|
108
|
+
compiled_path = root + '_compiled' + ext
|
|
109
|
+
if os.path.exists(compiled_path):
|
|
110
|
+
return compiled_path
|
|
111
|
+
|
|
112
|
+
model = train_classifier.build_model(model_name, num_classes=num_classes,
|
|
113
|
+
pretrained=ckpt_path, finetune=False)
|
|
114
|
+
if 'efficientnet' in model_name:
|
|
115
|
+
model.set_swish(memory_efficient=False)
|
|
116
|
+
model.eval()
|
|
117
|
+
|
|
118
|
+
ex_img = torch.rand(1, 3, img_size, img_size)
|
|
119
|
+
scripted_model = torch.jit.trace(model, (ex_img,))
|
|
120
|
+
|
|
121
|
+
scripted_model.save(compiled_path)
|
|
122
|
+
print('Saved TorchScript compiled model to', compiled_path)
|
|
123
|
+
return compiled_path
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def calc_per_label_stats(cm: np.ndarray, label_names: Sequence[str]
|
|
127
|
+
) -> pd.DataFrame:
|
|
128
|
+
"""
|
|
129
|
+
Args:
|
|
130
|
+
cm: np.ndarray, type int, confusion matrix C such that C[i,j] is the #
|
|
131
|
+
of observations from group i that are predicted to be in group j
|
|
132
|
+
label_names: list of str, label names in order of label id
|
|
133
|
+
|
|
134
|
+
Returns: pd.DataFrame, index 'label', columns ['precision', 'recall']
|
|
135
|
+
precision values are in [0, 1]
|
|
136
|
+
recall values are in [0, 1], or np.nan if that label had 0 ground-truth
|
|
137
|
+
observations
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
tp = np.diag(cm) # true positives
|
|
141
|
+
|
|
142
|
+
predicted_positives = cm.sum(axis=0, dtype=np.float64) # tp + fp
|
|
143
|
+
predicted_positives[predicted_positives == 0] += 1e-8
|
|
144
|
+
|
|
145
|
+
all_positives = cm.sum(axis=1, dtype=np.float64) # tp + fn
|
|
146
|
+
all_positives[all_positives == 0] = np.nan
|
|
147
|
+
|
|
148
|
+
df = pd.DataFrame()
|
|
149
|
+
df['label'] = label_names
|
|
150
|
+
df['precision'] = tp / predicted_positives
|
|
151
|
+
df['recall'] = tp / all_positives
|
|
152
|
+
df.set_index('label', inplace=True)
|
|
153
|
+
return df
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_epoch(model: torch.nn.Module,
|
|
157
|
+
loader: torch.utils.data.DataLoader,
|
|
158
|
+
weighted: bool,
|
|
159
|
+
device: torch.device,
|
|
160
|
+
label_names: Sequence[str],
|
|
161
|
+
top: Sequence[int] = (1, 3),
|
|
162
|
+
loss_fn: Optional[torch.nn.Module] = None,
|
|
163
|
+
target_mapping: Mapping[int, Sequence[int]] = None
|
|
164
|
+
) -> tuple[pd.DataFrame, pd.Series, np.ndarray]:
|
|
165
|
+
"""
|
|
166
|
+
Runs for 1 epoch.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
model: torch.nn.Module
|
|
170
|
+
loader: torch.utils.data.DataLoader
|
|
171
|
+
weighted: bool, whether to calculate weighted accuracy statistics
|
|
172
|
+
device: torch.device
|
|
173
|
+
label_names: list of str, label names in order of label id
|
|
174
|
+
top: tuple of int, list of values of k for calculating top-K accuracy
|
|
175
|
+
loss_fn: optional loss function, calculates per-example loss
|
|
176
|
+
target_mapping: optional dict, label_id => list of ids from classifier
|
|
177
|
+
that should map to the label_id
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
df: pd.DataFrame, columns ['img_file', 'label', 'weight', label_names]
|
|
181
|
+
metrics: pd.Series, type float, index includes:
|
|
182
|
+
'loss': mean per-example loss over entire epoch,
|
|
183
|
+
only included if loss_fn is not None
|
|
184
|
+
'acc_top{k}': accuracy@k over the entire epoch
|
|
185
|
+
'loss_weighted' and 'acc_weighted_top{k}': weighted versions, only
|
|
186
|
+
included if weighted=True
|
|
187
|
+
cm: np.ndarray, confusion matrix C such that C[i,j] is the # of
|
|
188
|
+
observations known to be in group i and predicted to be in group j
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
# set dropout and BN layers to eval mode
|
|
192
|
+
model.eval()
|
|
193
|
+
|
|
194
|
+
if loss_fn is not None:
|
|
195
|
+
losses = train_classifier.AverageMeter()
|
|
196
|
+
accuracies_topk = {k: train_classifier.AverageMeter() for k in top} # acc@k
|
|
197
|
+
if weighted:
|
|
198
|
+
accs_weighted = {k: train_classifier.AverageMeter() for k in top}
|
|
199
|
+
losses_weighted = train_classifier.AverageMeter()
|
|
200
|
+
|
|
201
|
+
num_examples = len(loader.dataset)
|
|
202
|
+
num_labels = len(label_names)
|
|
203
|
+
|
|
204
|
+
all_img_files = []
|
|
205
|
+
all_probs = np.zeros([num_examples, len(label_names)], dtype=np.float32)
|
|
206
|
+
all_labels = np.zeros(num_examples, dtype=np.int32)
|
|
207
|
+
if weighted:
|
|
208
|
+
all_weights = np.zeros(num_examples, dtype=np.float32)
|
|
209
|
+
|
|
210
|
+
batch_slice = slice(0, 0)
|
|
211
|
+
tqdm_loader = tqdm.tqdm(loader)
|
|
212
|
+
with torch.no_grad():
|
|
213
|
+
for batch in tqdm_loader:
|
|
214
|
+
if weighted:
|
|
215
|
+
inputs, labels, img_files, weights = batch
|
|
216
|
+
else:
|
|
217
|
+
# even if batch contains sample weights, don't use them
|
|
218
|
+
inputs, labels, img_files = batch[0:3]
|
|
219
|
+
weights = None
|
|
220
|
+
|
|
221
|
+
all_img_files.append(img_files)
|
|
222
|
+
|
|
223
|
+
batch_size = labels.size(0)
|
|
224
|
+
batch_slice = slice(batch_slice.stop, batch_slice.stop + batch_size)
|
|
225
|
+
all_labels[batch_slice] = labels
|
|
226
|
+
if weighted:
|
|
227
|
+
all_weights[batch_slice] = weights
|
|
228
|
+
weights = weights.to(device, non_blocking=True)
|
|
229
|
+
|
|
230
|
+
inputs = inputs.to(device, non_blocking=True)
|
|
231
|
+
labels = labels.to(device, non_blocking=True)
|
|
232
|
+
outputs = model(inputs)
|
|
233
|
+
|
|
234
|
+
# Do target mapping on the outputs (unnormalized logits) instead of
|
|
235
|
+
# the normalized (softmax) probabilities, because the loss function
|
|
236
|
+
# uses unnormalized logits. Summing probabilities is equivalent to
|
|
237
|
+
# log-sum-exp of unnormalized logits.
|
|
238
|
+
if target_mapping is not None:
|
|
239
|
+
outputs_mapped = torch.zeros(
|
|
240
|
+
[batch_size, num_labels], dtype=outputs.dtype,
|
|
241
|
+
device=outputs.device)
|
|
242
|
+
for target, cols in target_mapping.items():
|
|
243
|
+
outputs_mapped[:, target] = torch.logsumexp(
|
|
244
|
+
outputs[:, cols], dim=1)
|
|
245
|
+
outputs = outputs_mapped
|
|
246
|
+
|
|
247
|
+
probs = torch.nn.functional.softmax(outputs, dim=1).cpu()
|
|
248
|
+
all_probs[batch_slice] = probs
|
|
249
|
+
|
|
250
|
+
desc = []
|
|
251
|
+
if loss_fn is not None:
|
|
252
|
+
loss = loss_fn(outputs, labels)
|
|
253
|
+
losses.update(loss.mean().item(), n=batch_size)
|
|
254
|
+
desc.append(f'Loss {losses.val:.3f} ({losses.avg:.3f})')
|
|
255
|
+
if weights is not None:
|
|
256
|
+
loss_weighted = (loss * weights).mean()
|
|
257
|
+
losses_weighted.update(loss_weighted.item(), n=batch_size)
|
|
258
|
+
|
|
259
|
+
top_correct = train_classifier.correct(
|
|
260
|
+
outputs, labels, weights=None, top=top)
|
|
261
|
+
for k, acc in accuracies_topk.items():
|
|
262
|
+
acc.update(top_correct[k] * (100. / batch_size), n=batch_size)
|
|
263
|
+
desc.append(f'Acc@{k} {acc.val:.2f} ({acc.avg:.2f})')
|
|
264
|
+
|
|
265
|
+
if weighted:
|
|
266
|
+
top_correct = train_classifier.correct(
|
|
267
|
+
outputs, labels, weights=weights, top=top)
|
|
268
|
+
for k, acc in accs_weighted.items():
|
|
269
|
+
acc.update(top_correct[k] * (100. / batch_size),
|
|
270
|
+
n=batch_size)
|
|
271
|
+
desc.append(f'Acc_w@{k} {acc.val:.2f} ({acc.avg:.2f})')
|
|
272
|
+
|
|
273
|
+
tqdm_loader.set_description(' '.join(desc))
|
|
274
|
+
|
|
275
|
+
# a confusion matrix C is such that C[i,j] is the # of observations known to
|
|
276
|
+
# be in group i and predicted to be in group j.
|
|
277
|
+
all_preds = all_probs.argmax(axis=1)
|
|
278
|
+
cm = sklearn.metrics.confusion_matrix(
|
|
279
|
+
y_true=all_labels, y_pred=all_preds, labels=np.arange(num_labels))
|
|
280
|
+
|
|
281
|
+
df = pd.DataFrame()
|
|
282
|
+
df['path'] = np.concatenate(all_img_files)
|
|
283
|
+
df['label'] = list(map(label_names.__getitem__, all_labels))
|
|
284
|
+
df['weight'] = all_weights
|
|
285
|
+
df[label_names] = all_probs
|
|
286
|
+
|
|
287
|
+
metrics = {}
|
|
288
|
+
if loss_fn is not None:
|
|
289
|
+
metrics['loss'] = losses.avg
|
|
290
|
+
if weighted:
|
|
291
|
+
metrics['loss_weighted'] = losses_weighted.avg
|
|
292
|
+
for k, acc in accuracies_topk.items():
|
|
293
|
+
metrics[f'acc_top{k}'] = acc.avg
|
|
294
|
+
if weighted:
|
|
295
|
+
for k, acc in accs_weighted.items():
|
|
296
|
+
metrics[f'acc_weighted_top{k}'] = acc.avg
|
|
297
|
+
return df, pd.Series(metrics), cm
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
#%% Main function
|
|
301
|
+
|
|
302
|
+
def main(params_json_path: str, ckpt_path: str, output_dir: str,
|
|
303
|
+
splits: Sequence[str], target_mapping_json_path: Optional[str] = None,
|
|
304
|
+
label_index_json_path: Optional[str] = None,
|
|
305
|
+
**kwargs: Any) -> None:
|
|
306
|
+
|
|
307
|
+
# input validation
|
|
308
|
+
assert os.path.exists(params_json_path)
|
|
309
|
+
assert os.path.exists(ckpt_path)
|
|
310
|
+
assert (target_mapping_json_path is None) == (label_index_json_path is None)
|
|
311
|
+
if target_mapping_json_path is not None:
|
|
312
|
+
assert label_index_json_path is not None
|
|
313
|
+
assert os.path.exists(target_mapping_json_path)
|
|
314
|
+
assert os.path.exists(label_index_json_path)
|
|
315
|
+
|
|
316
|
+
# Evaluating with accimage is much faster than Pillow or Pillow-SIMD, but accimage
|
|
317
|
+
# is Linux-only.
|
|
318
|
+
try:
|
|
319
|
+
import accimage # noqa
|
|
320
|
+
torchvision.set_image_backend('accimage')
|
|
321
|
+
except:
|
|
322
|
+
print('Warning: could not start accimage backend (ignore this if you\'re not using Linux)')
|
|
323
|
+
|
|
324
|
+
# create output directory
|
|
325
|
+
if not os.path.exists(output_dir):
|
|
326
|
+
print('Creating output directory:', output_dir)
|
|
327
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
328
|
+
|
|
329
|
+
with open(params_json_path, 'r') as f:
|
|
330
|
+
params = json.load(f)
|
|
331
|
+
pprint(params)
|
|
332
|
+
|
|
333
|
+
# override saved params with kwargs
|
|
334
|
+
for key, new in kwargs.items():
|
|
335
|
+
if new is None:
|
|
336
|
+
continue
|
|
337
|
+
if key in params:
|
|
338
|
+
saved = params[key]
|
|
339
|
+
print(f'Overriding saved {key}. Saved: {saved}. '
|
|
340
|
+
f'Override with: {new}.')
|
|
341
|
+
else:
|
|
342
|
+
print(f'Did not find {key} in saved params. Using value {new}.')
|
|
343
|
+
params[key] = new
|
|
344
|
+
|
|
345
|
+
model_name: str = params['model_name']
|
|
346
|
+
dataset_dir: str = params['dataset_dir']
|
|
347
|
+
|
|
348
|
+
if 'efficientnet' in model_name:
|
|
349
|
+
img_size = efficientnet.EfficientNet.get_image_size(model_name)
|
|
350
|
+
else:
|
|
351
|
+
img_size = 224
|
|
352
|
+
|
|
353
|
+
# For now, we don't weight crops by detection confidence during
|
|
354
|
+
# evaluation. But consider changing this.
|
|
355
|
+
print('Creating dataloaders')
|
|
356
|
+
loaders, label_names = train_classifier.create_dataloaders(
|
|
357
|
+
dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
|
|
358
|
+
label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
|
|
359
|
+
splits_json_path=os.path.join(dataset_dir, 'splits.json'),
|
|
360
|
+
cropped_images_dir=params['cropped_images_dir'],
|
|
361
|
+
img_size=img_size,
|
|
362
|
+
multilabel=params['multilabel'],
|
|
363
|
+
label_weighted=params['label_weighted'],
|
|
364
|
+
weight_by_detection_conf=False,
|
|
365
|
+
batch_size=params['batch_size'],
|
|
366
|
+
num_workers=params['num_workers'],
|
|
367
|
+
augment_train=False)
|
|
368
|
+
num_labels = len(label_names)
|
|
369
|
+
|
|
370
|
+
# create model, compile with TorchScript if given checkpoint is not compiled
|
|
371
|
+
print('Loading model from checkpoint')
|
|
372
|
+
try:
|
|
373
|
+
model = torch.jit.load(ckpt_path, map_location='cpu')
|
|
374
|
+
except RuntimeError:
|
|
375
|
+
compiled_path = trace_model(model_name, ckpt_path, num_labels, img_size)
|
|
376
|
+
model = torch.jit.load(compiled_path, map_location='cpu')
|
|
377
|
+
model, device = train_classifier.prep_device(model)
|
|
378
|
+
|
|
379
|
+
if len(splits) == 0:
|
|
380
|
+
print('No splits given! Exiting.')
|
|
381
|
+
return
|
|
382
|
+
|
|
383
|
+
target_cols_map = None
|
|
384
|
+
if target_mapping_json_path is not None:
|
|
385
|
+
assert label_index_json_path is not None
|
|
386
|
+
|
|
387
|
+
# verify that target names matches original "label names" from dataset
|
|
388
|
+
with open(target_mapping_json_path, 'r') as f:
|
|
389
|
+
target_names_map = json.load(f)
|
|
390
|
+
target_names = set(target_names_map.keys())
|
|
391
|
+
|
|
392
|
+
# if the dataset does not already have a 'other' category, then the
|
|
393
|
+
# 'other' category must come last in label_names to avoid conflicting
|
|
394
|
+
# with an existing label_id
|
|
395
|
+
if target_names != set(label_names):
|
|
396
|
+
assert target_names == set(label_names) | {'other'}
|
|
397
|
+
label_names.append('other')
|
|
398
|
+
|
|
399
|
+
ct_utils.write_json(os.path.join(output_dir, 'label_index.json'), dict(enumerate(label_names)), indent=None)
|
|
400
|
+
|
|
401
|
+
with open(label_index_json_path, 'r') as f:
|
|
402
|
+
idx_to_label = json.load(f)
|
|
403
|
+
classifier_name_to_idx = {
|
|
404
|
+
idx_to_label[str(k)]: k for k in range(len(idx_to_label))
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
target_cols_map = {}
|
|
408
|
+
for i_target, label_name in enumerate(label_names):
|
|
409
|
+
classifier_names = target_names_map[label_name]
|
|
410
|
+
target_cols_map[i_target] = [
|
|
411
|
+
classifier_name_to_idx[classifier_name]
|
|
412
|
+
for classifier_name in classifier_names
|
|
413
|
+
]
|
|
414
|
+
|
|
415
|
+
# define loss function (criterion)
|
|
416
|
+
loss_fn: torch.nn.Module
|
|
417
|
+
if params['multilabel']:
|
|
418
|
+
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device)
|
|
419
|
+
else:
|
|
420
|
+
loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device)
|
|
421
|
+
|
|
422
|
+
split_metrics = {}
|
|
423
|
+
split_label_stats = {}
|
|
424
|
+
cms = {}
|
|
425
|
+
for split in splits:
|
|
426
|
+
print(f'Evaluating {split}...')
|
|
427
|
+
df, metrics, cm = test_epoch(
|
|
428
|
+
model, loaders[split], weighted=True, device=device,
|
|
429
|
+
label_names=label_names, loss_fn=loss_fn,
|
|
430
|
+
target_mapping=target_cols_map)
|
|
431
|
+
|
|
432
|
+
# this file ends up being huge, so we GZIP compress it
|
|
433
|
+
output_csv_path = os.path.join(output_dir, f'outputs_{split}.csv.gz')
|
|
434
|
+
df.to_csv(output_csv_path, index=False, compression='gzip')
|
|
435
|
+
|
|
436
|
+
split_metrics[split] = metrics
|
|
437
|
+
cms[split] = cm
|
|
438
|
+
split_label_stats[split] = calc_per_label_stats(cm, label_names)
|
|
439
|
+
|
|
440
|
+
# double check that the accuracy metrics are computed properly
|
|
441
|
+
preds = df[label_names].to_numpy().argmax(axis=1)
|
|
442
|
+
preds = np.asarray(label_names)[preds]
|
|
443
|
+
assert np.isclose(metrics['acc_top1'] / 100.,
|
|
444
|
+
sum(preds == df['label']) / len(df))
|
|
445
|
+
assert np.isclose(metrics['acc_weighted_top1'] / 100.,
|
|
446
|
+
sum((preds == df['label']) * df['weight']) / len(df))
|
|
447
|
+
|
|
448
|
+
metrics_df = pd.concat(split_metrics, names=['split']).unstack(level=1)
|
|
449
|
+
metrics_df.to_csv(os.path.join(output_dir, 'overall_metrics.csv'))
|
|
450
|
+
|
|
451
|
+
# save the confusion matrices to .npz
|
|
452
|
+
npz_path = os.path.join(output_dir, 'confusion_matrices.npz')
|
|
453
|
+
np.savez_compressed(npz_path, **cms)
|
|
454
|
+
|
|
455
|
+
# save per-label statistics
|
|
456
|
+
label_stats_df = pd.concat(
|
|
457
|
+
split_label_stats, names=['split', 'label']).reset_index()
|
|
458
|
+
label_stats_csv_path = os.path.join(output_dir, 'label_stats.csv')
|
|
459
|
+
label_stats_df.to_csv(label_stats_csv_path, index=False)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
#%% Command-line driver
|
|
463
|
+
|
|
464
|
+
def _parse_args() -> argparse.Namespace:
|
|
465
|
+
|
|
466
|
+
parser = argparse.ArgumentParser(
|
|
467
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
468
|
+
description='Evaluate trained model.')
|
|
469
|
+
parser.add_argument(
|
|
470
|
+
'params_json',
|
|
471
|
+
help='path to params.json')
|
|
472
|
+
parser.add_argument(
|
|
473
|
+
'ckpt_path',
|
|
474
|
+
help='path to checkpoint file (normal or TorchScript-compiled)')
|
|
475
|
+
parser.add_argument(
|
|
476
|
+
'-o', '--output-dir', required=True,
|
|
477
|
+
help='(required) path to output directory')
|
|
478
|
+
parser.add_argument(
|
|
479
|
+
'--splits', nargs='*', choices=SPLITS, default=[],
|
|
480
|
+
help='which splits to evaluate model on. If no splits are given, then '
|
|
481
|
+
'only compiles normal checkpoint file using TorchScript without '
|
|
482
|
+
'actually running the model.')
|
|
483
|
+
|
|
484
|
+
other_classifier = parser.add_argument_group(
|
|
485
|
+
'arguments for evaluating a model (e.g., MegaClassifier) on a '
|
|
486
|
+
'different set of labels than what the model was trained on')
|
|
487
|
+
other_classifier.add_argument(
|
|
488
|
+
'--target-mapping',
|
|
489
|
+
help='path to JSON file mapping target categories to classifier labels')
|
|
490
|
+
other_classifier.add_argument(
|
|
491
|
+
'--label-index',
|
|
492
|
+
help='path to label index JSON file for classifier')
|
|
493
|
+
|
|
494
|
+
override_group = parser.add_argument_group(
|
|
495
|
+
'optional arguments to override values in params_json')
|
|
496
|
+
override_group.add_argument(
|
|
497
|
+
'--model-name',
|
|
498
|
+
help='which EfficientNet or Resnet model')
|
|
499
|
+
override_group.add_argument(
|
|
500
|
+
'--batch-size', type=int,
|
|
501
|
+
help='batch size for evaluating model')
|
|
502
|
+
override_group.add_argument(
|
|
503
|
+
'--num-workers', type=int,
|
|
504
|
+
help='number of workers for data loading')
|
|
505
|
+
override_group.add_argument(
|
|
506
|
+
'--dataset-dir',
|
|
507
|
+
help='path to directory containing classification_ds.csv, '
|
|
508
|
+
'label_index.json, and splits.json')
|
|
509
|
+
return parser.parse_args()
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
if __name__ == '__main__':
|
|
513
|
+
|
|
514
|
+
args = _parse_args()
|
|
515
|
+
main(params_json_path=args.params_json, ckpt_path=args.ckpt_path,
|
|
516
|
+
output_dir=args.output_dir, splits=args.splits,
|
|
517
|
+
target_mapping_json_path=args.target_mapping,
|
|
518
|
+
label_index_json_path=args.label_index,
|
|
519
|
+
model_name=args.model_name, batch_size=args.batch_size,
|
|
520
|
+
num_workers=args.num_workers, dataset_dir=args.dataset_dir)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
identify_mislabeled_candidates.py
|
|
4
|
+
|
|
5
|
+
Identify images that may have been mislabeled.
|
|
6
|
+
|
|
7
|
+
A "mislabeled candidate" is defined as an image meeting both criteria:
|
|
8
|
+
|
|
9
|
+
* according to the ground-truth label, the model made an incorrect prediction
|
|
10
|
+
|
|
11
|
+
* the model's prediction confidence exceeds its confidence for the ground-truth
|
|
12
|
+
label by at least <margin>
|
|
13
|
+
|
|
14
|
+
This script outputs for each dataset a text file containing the filenames of
|
|
15
|
+
mislabeled candidates, one per line. The text files are saved to:
|
|
16
|
+
|
|
17
|
+
<logdir>/mislabeled_candidates_{split}_{dataset}.txt
|
|
18
|
+
|
|
19
|
+
To this list of files can then be passed to AzCopy to be downloaded:
|
|
20
|
+
|
|
21
|
+
""
|
|
22
|
+
azcopy cp "http://<url_of_container>?<sas_token>" "/save/files/here" \
|
|
23
|
+
--list-of-files "/path/to/mislabeled_candidates_{split}_{dataset}.txt"
|
|
24
|
+
""
|
|
25
|
+
|
|
26
|
+
To save the filename as <dataset_name>/<blob_name> (instead of just <blob_name>
|
|
27
|
+
by default), pass the --include-dataset-in-filename flag. Then, the images can
|
|
28
|
+
be downloaded with:
|
|
29
|
+
|
|
30
|
+
""
|
|
31
|
+
python data_management/megadb/download_images.py txt \
|
|
32
|
+
"/path/to/mislabeled_candidates_{split}_{dataset}.txt" \
|
|
33
|
+
/save/files/here \
|
|
34
|
+
--threads 50
|
|
35
|
+
""
|
|
36
|
+
|
|
37
|
+
Assumes the following directory layout:
|
|
38
|
+
<base_logdir>/
|
|
39
|
+
label_index.json
|
|
40
|
+
<logdir>/
|
|
41
|
+
outputs_{split}.csv.gz
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
#%% Imports
|
|
46
|
+
|
|
47
|
+
from __future__ import annotations
|
|
48
|
+
|
|
49
|
+
import argparse
|
|
50
|
+
from collections import defaultdict
|
|
51
|
+
from collections.abc import Iterable, Sequence
|
|
52
|
+
import json
|
|
53
|
+
import os
|
|
54
|
+
|
|
55
|
+
import pandas as pd
|
|
56
|
+
from tqdm import tqdm
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
#%% Example usage
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
python identify_mislabeled_candidates.py <base_logdir>/<logdir> \
|
|
63
|
+
--margin 0.5 --splits val test
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
#%% Main function
|
|
68
|
+
|
|
69
|
+
def main(logdir: str, splits: Iterable[str], margin: float,
|
|
70
|
+
include_dataset_in_filename: bool) -> None:
|
|
71
|
+
|
|
72
|
+
# load files
|
|
73
|
+
logdir = os.path.normpath(logdir) # removes any trailing slash
|
|
74
|
+
base_logdir = os.path.dirname(logdir)
|
|
75
|
+
idx_to_label_json_path = os.path.join(base_logdir, 'label_index.json')
|
|
76
|
+
with open(idx_to_label_json_path, 'r') as f:
|
|
77
|
+
idx_to_label = json.load(f)
|
|
78
|
+
label_names = [idx_to_label[str(i)] for i in range(len(idx_to_label))]
|
|
79
|
+
|
|
80
|
+
for split in splits:
|
|
81
|
+
outputs_csv_path = os.path.join(logdir, f'outputs_{split}.csv.gz')
|
|
82
|
+
candidates_df = get_candidates_df(outputs_csv_path, label_names, margin)
|
|
83
|
+
|
|
84
|
+
# dataset => set of img_file
|
|
85
|
+
candidate_image_files: defaultdict[str, set[str]] = defaultdict(set)
|
|
86
|
+
|
|
87
|
+
for crop_path in tqdm(candidates_df['path']):
|
|
88
|
+
# crop_path: <dataset>/<img_file>___cropXX_mdvY.Y.jpg
|
|
89
|
+
# [----<img_path>----]
|
|
90
|
+
img_path = crop_path.split('___crop')[0]
|
|
91
|
+
ds, img_file = img_path.split('/', maxsplit=1)
|
|
92
|
+
if include_dataset_in_filename:
|
|
93
|
+
candidate_image_files[ds].add(img_path)
|
|
94
|
+
else:
|
|
95
|
+
candidate_image_files[ds].add(img_file)
|
|
96
|
+
|
|
97
|
+
for ds in sorted(candidate_image_files.keys()):
|
|
98
|
+
img_files = candidate_image_files[ds]
|
|
99
|
+
print(f'{ds} contains {len(img_files)} mislabeled candidates.')
|
|
100
|
+
save_path = os.path.join(
|
|
101
|
+
logdir, f'mislabeled_candidates_{split}_{ds}.txt')
|
|
102
|
+
with open(save_path, 'w') as f:
|
|
103
|
+
for img_file in sorted(img_files):
|
|
104
|
+
f.write(img_file + '\n')
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
#%% Support functions
|
|
108
|
+
|
|
109
|
+
def get_candidates_df(outputs_csv_path: str, label_names: Sequence[str],
|
|
110
|
+
margin: float) -> pd.DataFrame:
|
|
111
|
+
"""
|
|
112
|
+
Returns a DataFrame containing crops only from mislabeled candidate
|
|
113
|
+
images.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
df = pd.read_csv(outputs_csv_path, float_precision='high')
|
|
117
|
+
all_rows = range(len(df))
|
|
118
|
+
df['pred'] = df[label_names].idxmax(axis=1)
|
|
119
|
+
df['pred_conf'] = df.lookup(row_labels=all_rows, col_labels=df['pred'])
|
|
120
|
+
df['label_conf'] = df.lookup(row_labels=all_rows, col_labels=df['label'])
|
|
121
|
+
candidate_mask = df['pred_conf'] >= df['label_conf'] + margin
|
|
122
|
+
candidates_df = df[candidate_mask].copy()
|
|
123
|
+
return candidates_df
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
#%% Command-line driver
|
|
127
|
+
|
|
128
|
+
def _parse_args() -> argparse.Namespace:
|
|
129
|
+
|
|
130
|
+
parser = argparse.ArgumentParser(
|
|
131
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
132
|
+
description='Identify mislabeled candidate images.')
|
|
133
|
+
parser.add_argument(
|
|
134
|
+
'logdir',
|
|
135
|
+
help='folder inside <base_logdir> containing `outputs_<split>.csv.gz`')
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
'--margin', type=float, default=0.5,
|
|
138
|
+
help='confidence margin to count as a mislabeled candidate')
|
|
139
|
+
parser.add_argument(
|
|
140
|
+
'--splits', nargs='+', choices=['train', 'val', 'test'],
|
|
141
|
+
help='which splits to identify mislabeled candidates on')
|
|
142
|
+
parser.add_argument(
|
|
143
|
+
'-d', '--include-dataset-in-filename', action='store_true',
|
|
144
|
+
help='save filename as <dataset_name>/<blob_name>')
|
|
145
|
+
return parser.parse_args()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
if __name__ == '__main__':
|
|
149
|
+
|
|
150
|
+
args = _parse_args()
|
|
151
|
+
main(logdir=args.logdir, splits=args.splits, margin=args.margin,
|
|
152
|
+
include_dataset_in_filename=args.include_dataset_in_filename)
|