megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,520 @@
1
+ """
2
+
3
+ evaluate_model.py
4
+
5
+ Evaluate a species classifier.
6
+
7
+ Currently the implementation of multi-label multi-class classification is
8
+ non-functional.
9
+
10
+ Outputs the following files:
11
+
12
+ 1) outputs_{split}.csv, one file per split, contains columns:
13
+ - 'path': str, path to cropped image
14
+ - 'label': str
15
+ - 'weight': float
16
+ - [label names]: float, confidence in each label
17
+
18
+ 2) overall_metrics.csv, contains columns:
19
+ - 'split': str
20
+ - 'loss': float, mean per-example loss over entire epoch
21
+ - 'acc_top{k}': float, accuracy@k over the entire epoch
22
+ - 'loss_weighted' and 'acc_weighted_top{k}': float, weighted versions
23
+
24
+ 3) confusion_matrices.npz
25
+ - keys ['train', 'val', 'test']
26
+ - values are np.ndarray, confusion matrices
27
+
28
+ 4) label_stats.csv, per-label statistics, columns
29
+ - 'split': str
30
+ - 'label': str
31
+ - 'precision': float
32
+ - 'recall': float
33
+
34
+ """
35
+
36
+ #%% Imports and constants
37
+
38
+ from __future__ import annotations
39
+
40
+ import argparse
41
+ from collections.abc import Mapping, Sequence
42
+ import json
43
+ import os
44
+ from pprint import pprint
45
+ from typing import Any, Optional
46
+
47
+ import numpy as np
48
+ import pandas as pd
49
+ import sklearn.metrics
50
+ import torch
51
+ import torchvision
52
+ import tqdm
53
+
54
+ from megadetector.classification import efficientnet, train_classifier
55
+ from megadetector.utils import ct_utils
56
+
57
+
58
+ #%% Example usage
59
+
60
+ """
61
+ python evaluate_model.py \
62
+ $BASE_LOGDIR/$LOGDIR/params.json \
63
+ $BASE_LOGDIR/$LOGDIR/ckpt_XX.pt \
64
+ --output-dir $BASE_LOGDIR/$LOGDIR \
65
+ --splits train val test \
66
+ --batch-size 256
67
+ """
68
+
69
+ SPLITS = ['train', 'val', 'test']
70
+
71
+
72
+ #%% Support functions
73
+
74
+ def check_override(params: Mapping[str, Any], key: str,
75
+ override: Optional[Any]) -> Any:
76
+ """
77
+ Return desired value, with optional override.
78
+ """
79
+
80
+ if override is None:
81
+ return params[key]
82
+ saved = params.get(key, None)
83
+ print(f'Overriding saved {key}. Saved: {saved}. Override with: {override}.')
84
+ return override
85
+
86
+
87
+ def trace_model(model_name: str, ckpt_path: str, num_classes: int,
88
+ img_size: int) -> str:
89
+ """
90
+ Use TorchScript tracing to compile trained model into standalone file.
91
+
92
+ For now, we have to use tracing instead of scripting. See
93
+ https://github.com/lukemelas/EfficientNet-PyTorch/issues/89
94
+ https://github.com/lukemelas/EfficientNet-PyTorch/issues/218
95
+
96
+ Args:
97
+ model_name: str
98
+ ckpt_path: str, path to checkpoint file
99
+ num_labels: int, number of classification classes
100
+ img_size: int, size of input image, used for tracing
101
+
102
+ Returns: str, name of file for compiled model. For example, if ckpt_path is
103
+ '/path/to/ckpt_16.pt', then the returned path is
104
+ '/path/to/ckpt_16_compiled.pt'.
105
+ """
106
+
107
+ root, ext = os.path.splitext(ckpt_path)
108
+ compiled_path = root + '_compiled' + ext
109
+ if os.path.exists(compiled_path):
110
+ return compiled_path
111
+
112
+ model = train_classifier.build_model(model_name, num_classes=num_classes,
113
+ pretrained=ckpt_path, finetune=False)
114
+ if 'efficientnet' in model_name:
115
+ model.set_swish(memory_efficient=False)
116
+ model.eval()
117
+
118
+ ex_img = torch.rand(1, 3, img_size, img_size)
119
+ scripted_model = torch.jit.trace(model, (ex_img,))
120
+
121
+ scripted_model.save(compiled_path)
122
+ print('Saved TorchScript compiled model to', compiled_path)
123
+ return compiled_path
124
+
125
+
126
+ def calc_per_label_stats(cm: np.ndarray, label_names: Sequence[str]
127
+ ) -> pd.DataFrame:
128
+ """
129
+ Args:
130
+ cm: np.ndarray, type int, confusion matrix C such that C[i,j] is the #
131
+ of observations from group i that are predicted to be in group j
132
+ label_names: list of str, label names in order of label id
133
+
134
+ Returns: pd.DataFrame, index 'label', columns ['precision', 'recall']
135
+ precision values are in [0, 1]
136
+ recall values are in [0, 1], or np.nan if that label had 0 ground-truth
137
+ observations
138
+ """
139
+
140
+ tp = np.diag(cm) # true positives
141
+
142
+ predicted_positives = cm.sum(axis=0, dtype=np.float64) # tp + fp
143
+ predicted_positives[predicted_positives == 0] += 1e-8
144
+
145
+ all_positives = cm.sum(axis=1, dtype=np.float64) # tp + fn
146
+ all_positives[all_positives == 0] = np.nan
147
+
148
+ df = pd.DataFrame()
149
+ df['label'] = label_names
150
+ df['precision'] = tp / predicted_positives
151
+ df['recall'] = tp / all_positives
152
+ df.set_index('label', inplace=True)
153
+ return df
154
+
155
+
156
+ def test_epoch(model: torch.nn.Module,
157
+ loader: torch.utils.data.DataLoader,
158
+ weighted: bool,
159
+ device: torch.device,
160
+ label_names: Sequence[str],
161
+ top: Sequence[int] = (1, 3),
162
+ loss_fn: Optional[torch.nn.Module] = None,
163
+ target_mapping: Mapping[int, Sequence[int]] = None
164
+ ) -> tuple[pd.DataFrame, pd.Series, np.ndarray]:
165
+ """
166
+ Runs for 1 epoch.
167
+
168
+ Args:
169
+ model: torch.nn.Module
170
+ loader: torch.utils.data.DataLoader
171
+ weighted: bool, whether to calculate weighted accuracy statistics
172
+ device: torch.device
173
+ label_names: list of str, label names in order of label id
174
+ top: tuple of int, list of values of k for calculating top-K accuracy
175
+ loss_fn: optional loss function, calculates per-example loss
176
+ target_mapping: optional dict, label_id => list of ids from classifier
177
+ that should map to the label_id
178
+
179
+ Returns:
180
+ df: pd.DataFrame, columns ['img_file', 'label', 'weight', label_names]
181
+ metrics: pd.Series, type float, index includes:
182
+ 'loss': mean per-example loss over entire epoch,
183
+ only included if loss_fn is not None
184
+ 'acc_top{k}': accuracy@k over the entire epoch
185
+ 'loss_weighted' and 'acc_weighted_top{k}': weighted versions, only
186
+ included if weighted=True
187
+ cm: np.ndarray, confusion matrix C such that C[i,j] is the # of
188
+ observations known to be in group i and predicted to be in group j
189
+ """
190
+
191
+ # set dropout and BN layers to eval mode
192
+ model.eval()
193
+
194
+ if loss_fn is not None:
195
+ losses = train_classifier.AverageMeter()
196
+ accuracies_topk = {k: train_classifier.AverageMeter() for k in top} # acc@k
197
+ if weighted:
198
+ accs_weighted = {k: train_classifier.AverageMeter() for k in top}
199
+ losses_weighted = train_classifier.AverageMeter()
200
+
201
+ num_examples = len(loader.dataset)
202
+ num_labels = len(label_names)
203
+
204
+ all_img_files = []
205
+ all_probs = np.zeros([num_examples, len(label_names)], dtype=np.float32)
206
+ all_labels = np.zeros(num_examples, dtype=np.int32)
207
+ if weighted:
208
+ all_weights = np.zeros(num_examples, dtype=np.float32)
209
+
210
+ batch_slice = slice(0, 0)
211
+ tqdm_loader = tqdm.tqdm(loader)
212
+ with torch.no_grad():
213
+ for batch in tqdm_loader:
214
+ if weighted:
215
+ inputs, labels, img_files, weights = batch
216
+ else:
217
+ # even if batch contains sample weights, don't use them
218
+ inputs, labels, img_files = batch[0:3]
219
+ weights = None
220
+
221
+ all_img_files.append(img_files)
222
+
223
+ batch_size = labels.size(0)
224
+ batch_slice = slice(batch_slice.stop, batch_slice.stop + batch_size)
225
+ all_labels[batch_slice] = labels
226
+ if weighted:
227
+ all_weights[batch_slice] = weights
228
+ weights = weights.to(device, non_blocking=True)
229
+
230
+ inputs = inputs.to(device, non_blocking=True)
231
+ labels = labels.to(device, non_blocking=True)
232
+ outputs = model(inputs)
233
+
234
+ # Do target mapping on the outputs (unnormalized logits) instead of
235
+ # the normalized (softmax) probabilities, because the loss function
236
+ # uses unnormalized logits. Summing probabilities is equivalent to
237
+ # log-sum-exp of unnormalized logits.
238
+ if target_mapping is not None:
239
+ outputs_mapped = torch.zeros(
240
+ [batch_size, num_labels], dtype=outputs.dtype,
241
+ device=outputs.device)
242
+ for target, cols in target_mapping.items():
243
+ outputs_mapped[:, target] = torch.logsumexp(
244
+ outputs[:, cols], dim=1)
245
+ outputs = outputs_mapped
246
+
247
+ probs = torch.nn.functional.softmax(outputs, dim=1).cpu()
248
+ all_probs[batch_slice] = probs
249
+
250
+ desc = []
251
+ if loss_fn is not None:
252
+ loss = loss_fn(outputs, labels)
253
+ losses.update(loss.mean().item(), n=batch_size)
254
+ desc.append(f'Loss {losses.val:.3f} ({losses.avg:.3f})')
255
+ if weights is not None:
256
+ loss_weighted = (loss * weights).mean()
257
+ losses_weighted.update(loss_weighted.item(), n=batch_size)
258
+
259
+ top_correct = train_classifier.correct(
260
+ outputs, labels, weights=None, top=top)
261
+ for k, acc in accuracies_topk.items():
262
+ acc.update(top_correct[k] * (100. / batch_size), n=batch_size)
263
+ desc.append(f'Acc@{k} {acc.val:.2f} ({acc.avg:.2f})')
264
+
265
+ if weighted:
266
+ top_correct = train_classifier.correct(
267
+ outputs, labels, weights=weights, top=top)
268
+ for k, acc in accs_weighted.items():
269
+ acc.update(top_correct[k] * (100. / batch_size),
270
+ n=batch_size)
271
+ desc.append(f'Acc_w@{k} {acc.val:.2f} ({acc.avg:.2f})')
272
+
273
+ tqdm_loader.set_description(' '.join(desc))
274
+
275
+ # a confusion matrix C is such that C[i,j] is the # of observations known to
276
+ # be in group i and predicted to be in group j.
277
+ all_preds = all_probs.argmax(axis=1)
278
+ cm = sklearn.metrics.confusion_matrix(
279
+ y_true=all_labels, y_pred=all_preds, labels=np.arange(num_labels))
280
+
281
+ df = pd.DataFrame()
282
+ df['path'] = np.concatenate(all_img_files)
283
+ df['label'] = list(map(label_names.__getitem__, all_labels))
284
+ df['weight'] = all_weights
285
+ df[label_names] = all_probs
286
+
287
+ metrics = {}
288
+ if loss_fn is not None:
289
+ metrics['loss'] = losses.avg
290
+ if weighted:
291
+ metrics['loss_weighted'] = losses_weighted.avg
292
+ for k, acc in accuracies_topk.items():
293
+ metrics[f'acc_top{k}'] = acc.avg
294
+ if weighted:
295
+ for k, acc in accs_weighted.items():
296
+ metrics[f'acc_weighted_top{k}'] = acc.avg
297
+ return df, pd.Series(metrics), cm
298
+
299
+
300
+ #%% Main function
301
+
302
+ def main(params_json_path: str, ckpt_path: str, output_dir: str,
303
+ splits: Sequence[str], target_mapping_json_path: Optional[str] = None,
304
+ label_index_json_path: Optional[str] = None,
305
+ **kwargs: Any) -> None:
306
+
307
+ # input validation
308
+ assert os.path.exists(params_json_path)
309
+ assert os.path.exists(ckpt_path)
310
+ assert (target_mapping_json_path is None) == (label_index_json_path is None)
311
+ if target_mapping_json_path is not None:
312
+ assert label_index_json_path is not None
313
+ assert os.path.exists(target_mapping_json_path)
314
+ assert os.path.exists(label_index_json_path)
315
+
316
+ # Evaluating with accimage is much faster than Pillow or Pillow-SIMD, but accimage
317
+ # is Linux-only.
318
+ try:
319
+ import accimage # noqa
320
+ torchvision.set_image_backend('accimage')
321
+ except:
322
+ print('Warning: could not start accimage backend (ignore this if you\'re not using Linux)')
323
+
324
+ # create output directory
325
+ if not os.path.exists(output_dir):
326
+ print('Creating output directory:', output_dir)
327
+ os.makedirs(output_dir, exist_ok=True)
328
+
329
+ with open(params_json_path, 'r') as f:
330
+ params = json.load(f)
331
+ pprint(params)
332
+
333
+ # override saved params with kwargs
334
+ for key, new in kwargs.items():
335
+ if new is None:
336
+ continue
337
+ if key in params:
338
+ saved = params[key]
339
+ print(f'Overriding saved {key}. Saved: {saved}. '
340
+ f'Override with: {new}.')
341
+ else:
342
+ print(f'Did not find {key} in saved params. Using value {new}.')
343
+ params[key] = new
344
+
345
+ model_name: str = params['model_name']
346
+ dataset_dir: str = params['dataset_dir']
347
+
348
+ if 'efficientnet' in model_name:
349
+ img_size = efficientnet.EfficientNet.get_image_size(model_name)
350
+ else:
351
+ img_size = 224
352
+
353
+ # For now, we don't weight crops by detection confidence during
354
+ # evaluation. But consider changing this.
355
+ print('Creating dataloaders')
356
+ loaders, label_names = train_classifier.create_dataloaders(
357
+ dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
358
+ label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
359
+ splits_json_path=os.path.join(dataset_dir, 'splits.json'),
360
+ cropped_images_dir=params['cropped_images_dir'],
361
+ img_size=img_size,
362
+ multilabel=params['multilabel'],
363
+ label_weighted=params['label_weighted'],
364
+ weight_by_detection_conf=False,
365
+ batch_size=params['batch_size'],
366
+ num_workers=params['num_workers'],
367
+ augment_train=False)
368
+ num_labels = len(label_names)
369
+
370
+ # create model, compile with TorchScript if given checkpoint is not compiled
371
+ print('Loading model from checkpoint')
372
+ try:
373
+ model = torch.jit.load(ckpt_path, map_location='cpu')
374
+ except RuntimeError:
375
+ compiled_path = trace_model(model_name, ckpt_path, num_labels, img_size)
376
+ model = torch.jit.load(compiled_path, map_location='cpu')
377
+ model, device = train_classifier.prep_device(model)
378
+
379
+ if len(splits) == 0:
380
+ print('No splits given! Exiting.')
381
+ return
382
+
383
+ target_cols_map = None
384
+ if target_mapping_json_path is not None:
385
+ assert label_index_json_path is not None
386
+
387
+ # verify that target names matches original "label names" from dataset
388
+ with open(target_mapping_json_path, 'r') as f:
389
+ target_names_map = json.load(f)
390
+ target_names = set(target_names_map.keys())
391
+
392
+ # if the dataset does not already have a 'other' category, then the
393
+ # 'other' category must come last in label_names to avoid conflicting
394
+ # with an existing label_id
395
+ if target_names != set(label_names):
396
+ assert target_names == set(label_names) | {'other'}
397
+ label_names.append('other')
398
+
399
+ ct_utils.write_json(os.path.join(output_dir, 'label_index.json'), dict(enumerate(label_names)), indent=None)
400
+
401
+ with open(label_index_json_path, 'r') as f:
402
+ idx_to_label = json.load(f)
403
+ classifier_name_to_idx = {
404
+ idx_to_label[str(k)]: k for k in range(len(idx_to_label))
405
+ }
406
+
407
+ target_cols_map = {}
408
+ for i_target, label_name in enumerate(label_names):
409
+ classifier_names = target_names_map[label_name]
410
+ target_cols_map[i_target] = [
411
+ classifier_name_to_idx[classifier_name]
412
+ for classifier_name in classifier_names
413
+ ]
414
+
415
+ # define loss function (criterion)
416
+ loss_fn: torch.nn.Module
417
+ if params['multilabel']:
418
+ loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device)
419
+ else:
420
+ loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device)
421
+
422
+ split_metrics = {}
423
+ split_label_stats = {}
424
+ cms = {}
425
+ for split in splits:
426
+ print(f'Evaluating {split}...')
427
+ df, metrics, cm = test_epoch(
428
+ model, loaders[split], weighted=True, device=device,
429
+ label_names=label_names, loss_fn=loss_fn,
430
+ target_mapping=target_cols_map)
431
+
432
+ # this file ends up being huge, so we GZIP compress it
433
+ output_csv_path = os.path.join(output_dir, f'outputs_{split}.csv.gz')
434
+ df.to_csv(output_csv_path, index=False, compression='gzip')
435
+
436
+ split_metrics[split] = metrics
437
+ cms[split] = cm
438
+ split_label_stats[split] = calc_per_label_stats(cm, label_names)
439
+
440
+ # double check that the accuracy metrics are computed properly
441
+ preds = df[label_names].to_numpy().argmax(axis=1)
442
+ preds = np.asarray(label_names)[preds]
443
+ assert np.isclose(metrics['acc_top1'] / 100.,
444
+ sum(preds == df['label']) / len(df))
445
+ assert np.isclose(metrics['acc_weighted_top1'] / 100.,
446
+ sum((preds == df['label']) * df['weight']) / len(df))
447
+
448
+ metrics_df = pd.concat(split_metrics, names=['split']).unstack(level=1)
449
+ metrics_df.to_csv(os.path.join(output_dir, 'overall_metrics.csv'))
450
+
451
+ # save the confusion matrices to .npz
452
+ npz_path = os.path.join(output_dir, 'confusion_matrices.npz')
453
+ np.savez_compressed(npz_path, **cms)
454
+
455
+ # save per-label statistics
456
+ label_stats_df = pd.concat(
457
+ split_label_stats, names=['split', 'label']).reset_index()
458
+ label_stats_csv_path = os.path.join(output_dir, 'label_stats.csv')
459
+ label_stats_df.to_csv(label_stats_csv_path, index=False)
460
+
461
+
462
+ #%% Command-line driver
463
+
464
+ def _parse_args() -> argparse.Namespace:
465
+
466
+ parser = argparse.ArgumentParser(
467
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
468
+ description='Evaluate trained model.')
469
+ parser.add_argument(
470
+ 'params_json',
471
+ help='path to params.json')
472
+ parser.add_argument(
473
+ 'ckpt_path',
474
+ help='path to checkpoint file (normal or TorchScript-compiled)')
475
+ parser.add_argument(
476
+ '-o', '--output-dir', required=True,
477
+ help='(required) path to output directory')
478
+ parser.add_argument(
479
+ '--splits', nargs='*', choices=SPLITS, default=[],
480
+ help='which splits to evaluate model on. If no splits are given, then '
481
+ 'only compiles normal checkpoint file using TorchScript without '
482
+ 'actually running the model.')
483
+
484
+ other_classifier = parser.add_argument_group(
485
+ 'arguments for evaluating a model (e.g., MegaClassifier) on a '
486
+ 'different set of labels than what the model was trained on')
487
+ other_classifier.add_argument(
488
+ '--target-mapping',
489
+ help='path to JSON file mapping target categories to classifier labels')
490
+ other_classifier.add_argument(
491
+ '--label-index',
492
+ help='path to label index JSON file for classifier')
493
+
494
+ override_group = parser.add_argument_group(
495
+ 'optional arguments to override values in params_json')
496
+ override_group.add_argument(
497
+ '--model-name',
498
+ help='which EfficientNet or Resnet model')
499
+ override_group.add_argument(
500
+ '--batch-size', type=int,
501
+ help='batch size for evaluating model')
502
+ override_group.add_argument(
503
+ '--num-workers', type=int,
504
+ help='number of workers for data loading')
505
+ override_group.add_argument(
506
+ '--dataset-dir',
507
+ help='path to directory containing classification_ds.csv, '
508
+ 'label_index.json, and splits.json')
509
+ return parser.parse_args()
510
+
511
+
512
+ if __name__ == '__main__':
513
+
514
+ args = _parse_args()
515
+ main(params_json_path=args.params_json, ckpt_path=args.ckpt_path,
516
+ output_dir=args.output_dir, splits=args.splits,
517
+ target_mapping_json_path=args.target_mapping,
518
+ label_index_json_path=args.label_index,
519
+ model_name=args.model_name, batch_size=args.batch_size,
520
+ num_workers=args.num_workers, dataset_dir=args.dataset_dir)
@@ -0,0 +1,152 @@
1
+ """
2
+
3
+ identify_mislabeled_candidates.py
4
+
5
+ Identify images that may have been mislabeled.
6
+
7
+ A "mislabeled candidate" is defined as an image meeting both criteria:
8
+
9
+ * according to the ground-truth label, the model made an incorrect prediction
10
+
11
+ * the model's prediction confidence exceeds its confidence for the ground-truth
12
+ label by at least <margin>
13
+
14
+ This script outputs for each dataset a text file containing the filenames of
15
+ mislabeled candidates, one per line. The text files are saved to:
16
+
17
+ <logdir>/mislabeled_candidates_{split}_{dataset}.txt
18
+
19
+ To this list of files can then be passed to AzCopy to be downloaded:
20
+
21
+ ""
22
+ azcopy cp "http://<url_of_container>?<sas_token>" "/save/files/here" \
23
+ --list-of-files "/path/to/mislabeled_candidates_{split}_{dataset}.txt"
24
+ ""
25
+
26
+ To save the filename as <dataset_name>/<blob_name> (instead of just <blob_name>
27
+ by default), pass the --include-dataset-in-filename flag. Then, the images can
28
+ be downloaded with:
29
+
30
+ ""
31
+ python data_management/megadb/download_images.py txt \
32
+ "/path/to/mislabeled_candidates_{split}_{dataset}.txt" \
33
+ /save/files/here \
34
+ --threads 50
35
+ ""
36
+
37
+ Assumes the following directory layout:
38
+ <base_logdir>/
39
+ label_index.json
40
+ <logdir>/
41
+ outputs_{split}.csv.gz
42
+
43
+ """
44
+
45
+ #%% Imports
46
+
47
+ from __future__ import annotations
48
+
49
+ import argparse
50
+ from collections import defaultdict
51
+ from collections.abc import Iterable, Sequence
52
+ import json
53
+ import os
54
+
55
+ import pandas as pd
56
+ from tqdm import tqdm
57
+
58
+
59
+ #%% Example usage
60
+
61
+ """
62
+ python identify_mislabeled_candidates.py <base_logdir>/<logdir> \
63
+ --margin 0.5 --splits val test
64
+ """
65
+
66
+
67
+ #%% Main function
68
+
69
+ def main(logdir: str, splits: Iterable[str], margin: float,
70
+ include_dataset_in_filename: bool) -> None:
71
+
72
+ # load files
73
+ logdir = os.path.normpath(logdir) # removes any trailing slash
74
+ base_logdir = os.path.dirname(logdir)
75
+ idx_to_label_json_path = os.path.join(base_logdir, 'label_index.json')
76
+ with open(idx_to_label_json_path, 'r') as f:
77
+ idx_to_label = json.load(f)
78
+ label_names = [idx_to_label[str(i)] for i in range(len(idx_to_label))]
79
+
80
+ for split in splits:
81
+ outputs_csv_path = os.path.join(logdir, f'outputs_{split}.csv.gz')
82
+ candidates_df = get_candidates_df(outputs_csv_path, label_names, margin)
83
+
84
+ # dataset => set of img_file
85
+ candidate_image_files: defaultdict[str, set[str]] = defaultdict(set)
86
+
87
+ for crop_path in tqdm(candidates_df['path']):
88
+ # crop_path: <dataset>/<img_file>___cropXX_mdvY.Y.jpg
89
+ # [----<img_path>----]
90
+ img_path = crop_path.split('___crop')[0]
91
+ ds, img_file = img_path.split('/', maxsplit=1)
92
+ if include_dataset_in_filename:
93
+ candidate_image_files[ds].add(img_path)
94
+ else:
95
+ candidate_image_files[ds].add(img_file)
96
+
97
+ for ds in sorted(candidate_image_files.keys()):
98
+ img_files = candidate_image_files[ds]
99
+ print(f'{ds} contains {len(img_files)} mislabeled candidates.')
100
+ save_path = os.path.join(
101
+ logdir, f'mislabeled_candidates_{split}_{ds}.txt')
102
+ with open(save_path, 'w') as f:
103
+ for img_file in sorted(img_files):
104
+ f.write(img_file + '\n')
105
+
106
+
107
+ #%% Support functions
108
+
109
+ def get_candidates_df(outputs_csv_path: str, label_names: Sequence[str],
110
+ margin: float) -> pd.DataFrame:
111
+ """
112
+ Returns a DataFrame containing crops only from mislabeled candidate
113
+ images.
114
+ """
115
+
116
+ df = pd.read_csv(outputs_csv_path, float_precision='high')
117
+ all_rows = range(len(df))
118
+ df['pred'] = df[label_names].idxmax(axis=1)
119
+ df['pred_conf'] = df.lookup(row_labels=all_rows, col_labels=df['pred'])
120
+ df['label_conf'] = df.lookup(row_labels=all_rows, col_labels=df['label'])
121
+ candidate_mask = df['pred_conf'] >= df['label_conf'] + margin
122
+ candidates_df = df[candidate_mask].copy()
123
+ return candidates_df
124
+
125
+
126
+ #%% Command-line driver
127
+
128
+ def _parse_args() -> argparse.Namespace:
129
+
130
+ parser = argparse.ArgumentParser(
131
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
132
+ description='Identify mislabeled candidate images.')
133
+ parser.add_argument(
134
+ 'logdir',
135
+ help='folder inside <base_logdir> containing `outputs_<split>.csv.gz`')
136
+ parser.add_argument(
137
+ '--margin', type=float, default=0.5,
138
+ help='confidence margin to count as a mislabeled candidate')
139
+ parser.add_argument(
140
+ '--splits', nargs='+', choices=['train', 'val', 'test'],
141
+ help='which splits to identify mislabeled candidates on')
142
+ parser.add_argument(
143
+ '-d', '--include-dataset-in-filename', action='store_true',
144
+ help='save filename as <dataset_name>/<blob_name>')
145
+ return parser.parse_args()
146
+
147
+
148
+ if __name__ == '__main__':
149
+
150
+ args = _parse_args()
151
+ main(logdir=args.logdir, splits=args.splits, margin=args.margin,
152
+ include_dataset_in_filename=args.include_dataset_in_filename)