megadetector 5.0.9__py3-none-any.whl → 5.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (226) hide show
  1. {megadetector-5.0.9.dist-info → megadetector-5.0.11.dist-info}/LICENSE +0 -0
  2. {megadetector-5.0.9.dist-info → megadetector-5.0.11.dist-info}/METADATA +12 -11
  3. megadetector-5.0.11.dist-info/RECORD +5 -0
  4. megadetector-5.0.11.dist-info/top_level.txt +1 -0
  5. api/__init__.py +0 -0
  6. api/batch_processing/__init__.py +0 -0
  7. api/batch_processing/api_core/__init__.py +0 -0
  8. api/batch_processing/api_core/batch_service/__init__.py +0 -0
  9. api/batch_processing/api_core/batch_service/score.py +0 -439
  10. api/batch_processing/api_core/server.py +0 -294
  11. api/batch_processing/api_core/server_api_config.py +0 -98
  12. api/batch_processing/api_core/server_app_config.py +0 -55
  13. api/batch_processing/api_core/server_batch_job_manager.py +0 -220
  14. api/batch_processing/api_core/server_job_status_table.py +0 -152
  15. api/batch_processing/api_core/server_orchestration.py +0 -360
  16. api/batch_processing/api_core/server_utils.py +0 -92
  17. api/batch_processing/api_core_support/__init__.py +0 -0
  18. api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
  19. api/batch_processing/api_support/__init__.py +0 -0
  20. api/batch_processing/api_support/summarize_daily_activity.py +0 -152
  21. api/batch_processing/data_preparation/__init__.py +0 -0
  22. api/batch_processing/data_preparation/manage_local_batch.py +0 -2391
  23. api/batch_processing/data_preparation/manage_video_batch.py +0 -327
  24. api/batch_processing/integration/digiKam/setup.py +0 -6
  25. api/batch_processing/integration/digiKam/xmp_integration.py +0 -465
  26. api/batch_processing/integration/eMammal/test_scripts/config_template.py +0 -5
  27. api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +0 -126
  28. api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +0 -55
  29. api/batch_processing/postprocessing/__init__.py +0 -0
  30. api/batch_processing/postprocessing/add_max_conf.py +0 -64
  31. api/batch_processing/postprocessing/categorize_detections_by_size.py +0 -163
  32. api/batch_processing/postprocessing/combine_api_outputs.py +0 -249
  33. api/batch_processing/postprocessing/compare_batch_results.py +0 -958
  34. api/batch_processing/postprocessing/convert_output_format.py +0 -397
  35. api/batch_processing/postprocessing/load_api_results.py +0 -195
  36. api/batch_processing/postprocessing/md_to_coco.py +0 -310
  37. api/batch_processing/postprocessing/md_to_labelme.py +0 -330
  38. api/batch_processing/postprocessing/merge_detections.py +0 -401
  39. api/batch_processing/postprocessing/postprocess_batch_results.py +0 -1904
  40. api/batch_processing/postprocessing/remap_detection_categories.py +0 -170
  41. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +0 -661
  42. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +0 -211
  43. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +0 -82
  44. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +0 -1631
  45. api/batch_processing/postprocessing/separate_detections_into_folders.py +0 -731
  46. api/batch_processing/postprocessing/subset_json_detector_output.py +0 -696
  47. api/batch_processing/postprocessing/top_folders_to_bottom.py +0 -223
  48. api/synchronous/__init__.py +0 -0
  49. api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  50. api/synchronous/api_core/animal_detection_api/api_backend.py +0 -152
  51. api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -266
  52. api/synchronous/api_core/animal_detection_api/config.py +0 -35
  53. api/synchronous/api_core/animal_detection_api/data_management/annotations/annotation_constants.py +0 -47
  54. api/synchronous/api_core/animal_detection_api/detection/detector_training/copy_checkpoints.py +0 -43
  55. api/synchronous/api_core/animal_detection_api/detection/detector_training/model_main_tf2.py +0 -114
  56. api/synchronous/api_core/animal_detection_api/detection/process_video.py +0 -543
  57. api/synchronous/api_core/animal_detection_api/detection/pytorch_detector.py +0 -304
  58. api/synchronous/api_core/animal_detection_api/detection/run_detector.py +0 -627
  59. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +0 -1029
  60. api/synchronous/api_core/animal_detection_api/detection/run_inference_with_yolov5_val.py +0 -581
  61. api/synchronous/api_core/animal_detection_api/detection/run_tiled_inference.py +0 -754
  62. api/synchronous/api_core/animal_detection_api/detection/tf_detector.py +0 -165
  63. api/synchronous/api_core/animal_detection_api/detection/video_utils.py +0 -495
  64. api/synchronous/api_core/animal_detection_api/md_utils/azure_utils.py +0 -174
  65. api/synchronous/api_core/animal_detection_api/md_utils/ct_utils.py +0 -262
  66. api/synchronous/api_core/animal_detection_api/md_utils/directory_listing.py +0 -251
  67. api/synchronous/api_core/animal_detection_api/md_utils/matlab_porting_tools.py +0 -97
  68. api/synchronous/api_core/animal_detection_api/md_utils/path_utils.py +0 -416
  69. api/synchronous/api_core/animal_detection_api/md_utils/process_utils.py +0 -110
  70. api/synchronous/api_core/animal_detection_api/md_utils/sas_blob_utils.py +0 -509
  71. api/synchronous/api_core/animal_detection_api/md_utils/string_utils.py +0 -59
  72. api/synchronous/api_core/animal_detection_api/md_utils/url_utils.py +0 -144
  73. api/synchronous/api_core/animal_detection_api/md_utils/write_html_image_list.py +0 -226
  74. api/synchronous/api_core/animal_detection_api/md_visualization/visualization_utils.py +0 -841
  75. api/synchronous/api_core/tests/__init__.py +0 -0
  76. api/synchronous/api_core/tests/load_test.py +0 -110
  77. classification/__init__.py +0 -0
  78. classification/aggregate_classifier_probs.py +0 -108
  79. classification/analyze_failed_images.py +0 -227
  80. classification/cache_batchapi_outputs.py +0 -198
  81. classification/create_classification_dataset.py +0 -627
  82. classification/crop_detections.py +0 -516
  83. classification/csv_to_json.py +0 -226
  84. classification/detect_and_crop.py +0 -855
  85. classification/efficientnet/__init__.py +0 -9
  86. classification/efficientnet/model.py +0 -415
  87. classification/efficientnet/utils.py +0 -610
  88. classification/evaluate_model.py +0 -520
  89. classification/identify_mislabeled_candidates.py +0 -152
  90. classification/json_to_azcopy_list.py +0 -63
  91. classification/json_validator.py +0 -695
  92. classification/map_classification_categories.py +0 -276
  93. classification/merge_classification_detection_output.py +0 -506
  94. classification/prepare_classification_script.py +0 -194
  95. classification/prepare_classification_script_mc.py +0 -228
  96. classification/run_classifier.py +0 -286
  97. classification/save_mislabeled.py +0 -110
  98. classification/train_classifier.py +0 -825
  99. classification/train_classifier_tf.py +0 -724
  100. classification/train_utils.py +0 -322
  101. data_management/__init__.py +0 -0
  102. data_management/annotations/__init__.py +0 -0
  103. data_management/annotations/annotation_constants.py +0 -34
  104. data_management/camtrap_dp_to_coco.py +0 -238
  105. data_management/cct_json_utils.py +0 -395
  106. data_management/cct_to_md.py +0 -176
  107. data_management/cct_to_wi.py +0 -289
  108. data_management/coco_to_labelme.py +0 -272
  109. data_management/coco_to_yolo.py +0 -662
  110. data_management/databases/__init__.py +0 -0
  111. data_management/databases/add_width_and_height_to_db.py +0 -33
  112. data_management/databases/combine_coco_camera_traps_files.py +0 -206
  113. data_management/databases/integrity_check_json_db.py +0 -477
  114. data_management/databases/subset_json_db.py +0 -115
  115. data_management/generate_crops_from_cct.py +0 -149
  116. data_management/get_image_sizes.py +0 -188
  117. data_management/importers/add_nacti_sizes.py +0 -52
  118. data_management/importers/add_timestamps_to_icct.py +0 -79
  119. data_management/importers/animl_results_to_md_results.py +0 -158
  120. data_management/importers/auckland_doc_test_to_json.py +0 -372
  121. data_management/importers/auckland_doc_to_json.py +0 -200
  122. data_management/importers/awc_to_json.py +0 -189
  123. data_management/importers/bellevue_to_json.py +0 -273
  124. data_management/importers/cacophony-thermal-importer.py +0 -796
  125. data_management/importers/carrizo_shrubfree_2018.py +0 -268
  126. data_management/importers/carrizo_trail_cam_2017.py +0 -287
  127. data_management/importers/cct_field_adjustments.py +0 -57
  128. data_management/importers/channel_islands_to_cct.py +0 -913
  129. data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
  130. data_management/importers/eMammal/eMammal_helpers.py +0 -249
  131. data_management/importers/eMammal/make_eMammal_json.py +0 -223
  132. data_management/importers/ena24_to_json.py +0 -275
  133. data_management/importers/filenames_to_json.py +0 -385
  134. data_management/importers/helena_to_cct.py +0 -282
  135. data_management/importers/idaho-camera-traps.py +0 -1407
  136. data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
  137. data_management/importers/jb_csv_to_json.py +0 -150
  138. data_management/importers/mcgill_to_json.py +0 -250
  139. data_management/importers/missouri_to_json.py +0 -489
  140. data_management/importers/nacti_fieldname_adjustments.py +0 -79
  141. data_management/importers/noaa_seals_2019.py +0 -181
  142. data_management/importers/pc_to_json.py +0 -365
  143. data_management/importers/plot_wni_giraffes.py +0 -123
  144. data_management/importers/prepare-noaa-fish-data-for-lila.py +0 -359
  145. data_management/importers/prepare_zsl_imerit.py +0 -131
  146. data_management/importers/rspb_to_json.py +0 -356
  147. data_management/importers/save_the_elephants_survey_A.py +0 -320
  148. data_management/importers/save_the_elephants_survey_B.py +0 -332
  149. data_management/importers/snapshot_safari_importer.py +0 -758
  150. data_management/importers/snapshot_safari_importer_reprise.py +0 -665
  151. data_management/importers/snapshot_serengeti_lila.py +0 -1067
  152. data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
  153. data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
  154. data_management/importers/sulross_get_exif.py +0 -65
  155. data_management/importers/timelapse_csv_set_to_json.py +0 -490
  156. data_management/importers/ubc_to_json.py +0 -399
  157. data_management/importers/umn_to_json.py +0 -507
  158. data_management/importers/wellington_to_json.py +0 -263
  159. data_management/importers/wi_to_json.py +0 -441
  160. data_management/importers/zamba_results_to_md_results.py +0 -181
  161. data_management/labelme_to_coco.py +0 -548
  162. data_management/labelme_to_yolo.py +0 -272
  163. data_management/lila/__init__.py +0 -0
  164. data_management/lila/add_locations_to_island_camera_traps.py +0 -97
  165. data_management/lila/add_locations_to_nacti.py +0 -147
  166. data_management/lila/create_lila_blank_set.py +0 -557
  167. data_management/lila/create_lila_test_set.py +0 -151
  168. data_management/lila/create_links_to_md_results_files.py +0 -106
  169. data_management/lila/download_lila_subset.py +0 -177
  170. data_management/lila/generate_lila_per_image_labels.py +0 -515
  171. data_management/lila/get_lila_annotation_counts.py +0 -170
  172. data_management/lila/get_lila_image_counts.py +0 -111
  173. data_management/lila/lila_common.py +0 -300
  174. data_management/lila/test_lila_metadata_urls.py +0 -132
  175. data_management/ocr_tools.py +0 -874
  176. data_management/read_exif.py +0 -681
  177. data_management/remap_coco_categories.py +0 -84
  178. data_management/remove_exif.py +0 -66
  179. data_management/resize_coco_dataset.py +0 -189
  180. data_management/wi_download_csv_to_coco.py +0 -246
  181. data_management/yolo_output_to_md_output.py +0 -441
  182. data_management/yolo_to_coco.py +0 -676
  183. detection/__init__.py +0 -0
  184. detection/detector_training/__init__.py +0 -0
  185. detection/detector_training/model_main_tf2.py +0 -114
  186. detection/process_video.py +0 -703
  187. detection/pytorch_detector.py +0 -337
  188. detection/run_detector.py +0 -779
  189. detection/run_detector_batch.py +0 -1219
  190. detection/run_inference_with_yolov5_val.py +0 -917
  191. detection/run_tiled_inference.py +0 -935
  192. detection/tf_detector.py +0 -188
  193. detection/video_utils.py +0 -606
  194. docs/source/conf.py +0 -43
  195. md_utils/__init__.py +0 -0
  196. md_utils/azure_utils.py +0 -174
  197. md_utils/ct_utils.py +0 -612
  198. md_utils/directory_listing.py +0 -246
  199. md_utils/md_tests.py +0 -968
  200. md_utils/path_utils.py +0 -1044
  201. md_utils/process_utils.py +0 -157
  202. md_utils/sas_blob_utils.py +0 -509
  203. md_utils/split_locations_into_train_val.py +0 -228
  204. md_utils/string_utils.py +0 -92
  205. md_utils/url_utils.py +0 -323
  206. md_utils/write_html_image_list.py +0 -225
  207. md_visualization/__init__.py +0 -0
  208. md_visualization/plot_utils.py +0 -293
  209. md_visualization/render_images_with_thumbnails.py +0 -275
  210. md_visualization/visualization_utils.py +0 -1537
  211. md_visualization/visualize_db.py +0 -551
  212. md_visualization/visualize_detector_output.py +0 -406
  213. megadetector-5.0.9.dist-info/RECORD +0 -224
  214. megadetector-5.0.9.dist-info/top_level.txt +0 -8
  215. taxonomy_mapping/__init__.py +0 -0
  216. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +0 -491
  217. taxonomy_mapping/map_new_lila_datasets.py +0 -154
  218. taxonomy_mapping/prepare_lila_taxonomy_release.py +0 -142
  219. taxonomy_mapping/preview_lila_taxonomy.py +0 -591
  220. taxonomy_mapping/retrieve_sample_image.py +0 -71
  221. taxonomy_mapping/simple_image_download.py +0 -218
  222. taxonomy_mapping/species_lookup.py +0 -834
  223. taxonomy_mapping/taxonomy_csv_checker.py +0 -159
  224. taxonomy_mapping/taxonomy_graph.py +0 -346
  225. taxonomy_mapping/validate_lila_category_mappings.py +0 -83
  226. {megadetector-5.0.9.dist-info → megadetector-5.0.11.dist-info}/WHEEL +0 -0
@@ -1,520 +0,0 @@
1
- """
2
-
3
- evaluate_model.py
4
-
5
- Evaluate a species classifier.
6
-
7
- Currently the implementation of multi-label multi-class classification is
8
- non-functional.
9
-
10
- Outputs the following files:
11
-
12
- 1) outputs_{split}.csv, one file per split, contains columns:
13
- - 'path': str, path to cropped image
14
- - 'label': str
15
- - 'weight': float
16
- - [label names]: float, confidence in each label
17
-
18
- 2) overall_metrics.csv, contains columns:
19
- - 'split': str
20
- - 'loss': float, mean per-example loss over entire epoch
21
- - 'acc_top{k}': float, accuracy@k over the entire epoch
22
- - 'loss_weighted' and 'acc_weighted_top{k}': float, weighted versions
23
-
24
- 3) confusion_matrices.npz
25
- - keys ['train', 'val', 'test']
26
- - values are np.ndarray, confusion matrices
27
-
28
- 4) label_stats.csv, per-label statistics, columns
29
- - 'split': str
30
- - 'label': str
31
- - 'precision': float
32
- - 'recall': float
33
-
34
- """
35
-
36
- #%% Imports and constants
37
-
38
- from __future__ import annotations
39
-
40
- import argparse
41
- from collections.abc import Mapping, Sequence
42
- import json
43
- import os
44
- from pprint import pprint
45
- from typing import Any, Optional
46
-
47
- import numpy as np
48
- import pandas as pd
49
- import sklearn.metrics
50
- import torch
51
- import torchvision
52
- import tqdm
53
-
54
- from classification import efficientnet, train_classifier
55
-
56
-
57
- #%% Example usage
58
-
59
- """
60
- python evaluate_model.py \
61
- $BASE_LOGDIR/$LOGDIR/params.json \
62
- $BASE_LOGDIR/$LOGDIR/ckpt_XX.pt \
63
- --output-dir $BASE_LOGDIR/$LOGDIR \
64
- --splits train val test \
65
- --batch-size 256
66
- """
67
-
68
- SPLITS = ['train', 'val', 'test']
69
-
70
-
71
- #%% Support functions
72
-
73
- def check_override(params: Mapping[str, Any], key: str,
74
- override: Optional[Any]) -> Any:
75
- """
76
- Return desired value, with optional override.
77
- """
78
-
79
- if override is None:
80
- return params[key]
81
- saved = params.get(key, None)
82
- print(f'Overriding saved {key}. Saved: {saved}. Override with: {override}.')
83
- return override
84
-
85
-
86
- def trace_model(model_name: str, ckpt_path: str, num_classes: int,
87
- img_size: int) -> str:
88
- """
89
- Use TorchScript tracing to compile trained model into standalone file.
90
-
91
- For now, we have to use tracing instead of scripting. See
92
- https://github.com/lukemelas/EfficientNet-PyTorch/issues/89
93
- https://github.com/lukemelas/EfficientNet-PyTorch/issues/218
94
-
95
- Args:
96
- model_name: str
97
- ckpt_path: str, path to checkpoint file
98
- num_labels: int, number of classification classes
99
- img_size: int, size of input image, used for tracing
100
-
101
- Returns: str, name of file for compiled model. For example, if ckpt_path is
102
- '/path/to/ckpt_16.pt', then the returned path is
103
- '/path/to/ckpt_16_compiled.pt'.
104
- """
105
-
106
- root, ext = os.path.splitext(ckpt_path)
107
- compiled_path = root + '_compiled' + ext
108
- if os.path.exists(compiled_path):
109
- return compiled_path
110
-
111
- model = train_classifier.build_model(model_name, num_classes=num_classes,
112
- pretrained=ckpt_path, finetune=False)
113
- if 'efficientnet' in model_name:
114
- model.set_swish(memory_efficient=False)
115
- model.eval()
116
-
117
- ex_img = torch.rand(1, 3, img_size, img_size)
118
- scripted_model = torch.jit.trace(model, (ex_img,))
119
-
120
- scripted_model.save(compiled_path)
121
- print('Saved TorchScript compiled model to', compiled_path)
122
- return compiled_path
123
-
124
-
125
- def calc_per_label_stats(cm: np.ndarray, label_names: Sequence[str]
126
- ) -> pd.DataFrame:
127
- """
128
- Args:
129
- cm: np.ndarray, type int, confusion matrix C such that C[i,j] is the #
130
- of observations from group i that are predicted to be in group j
131
- label_names: list of str, label names in order of label id
132
-
133
- Returns: pd.DataFrame, index 'label', columns ['precision', 'recall']
134
- precision values are in [0, 1]
135
- recall values are in [0, 1], or np.nan if that label had 0 ground-truth
136
- observations
137
- """
138
-
139
- tp = np.diag(cm) # true positives
140
-
141
- predicted_positives = cm.sum(axis=0, dtype=np.float64) # tp + fp
142
- predicted_positives[predicted_positives == 0] += 1e-8
143
-
144
- all_positives = cm.sum(axis=1, dtype=np.float64) # tp + fn
145
- all_positives[all_positives == 0] = np.nan
146
-
147
- df = pd.DataFrame()
148
- df['label'] = label_names
149
- df['precision'] = tp / predicted_positives
150
- df['recall'] = tp / all_positives
151
- df.set_index('label', inplace=True)
152
- return df
153
-
154
-
155
- def test_epoch(model: torch.nn.Module,
156
- loader: torch.utils.data.DataLoader,
157
- weighted: bool,
158
- device: torch.device,
159
- label_names: Sequence[str],
160
- top: Sequence[int] = (1, 3),
161
- loss_fn: Optional[torch.nn.Module] = None,
162
- target_mapping: Mapping[int, Sequence[int]] = None
163
- ) -> tuple[pd.DataFrame, pd.Series, np.ndarray]:
164
- """
165
- Runs for 1 epoch.
166
-
167
- Args:
168
- model: torch.nn.Module
169
- loader: torch.utils.data.DataLoader
170
- weighted: bool, whether to calculate weighted accuracy statistics
171
- device: torch.device
172
- label_names: list of str, label names in order of label id
173
- top: tuple of int, list of values of k for calculating top-K accuracy
174
- loss_fn: optional loss function, calculates per-example loss
175
- target_mapping: optional dict, label_id => list of ids from classifier
176
- that should map to the label_id
177
-
178
- Returns:
179
- df: pd.DataFrame, columns ['img_file', 'label', 'weight', label_names]
180
- metrics: pd.Series, type float, index includes:
181
- 'loss': mean per-example loss over entire epoch,
182
- only included if loss_fn is not None
183
- 'acc_top{k}': accuracy@k over the entire epoch
184
- 'loss_weighted' and 'acc_weighted_top{k}': weighted versions, only
185
- included if weighted=True
186
- cm: np.ndarray, confusion matrix C such that C[i,j] is the # of
187
- observations known to be in group i and predicted to be in group j
188
- """
189
-
190
- # set dropout and BN layers to eval mode
191
- model.eval()
192
-
193
- if loss_fn is not None:
194
- losses = train_classifier.AverageMeter()
195
- accuracies_topk = {k: train_classifier.AverageMeter() for k in top} # acc@k
196
- if weighted:
197
- accs_weighted = {k: train_classifier.AverageMeter() for k in top}
198
- losses_weighted = train_classifier.AverageMeter()
199
-
200
- num_examples = len(loader.dataset)
201
- num_labels = len(label_names)
202
-
203
- all_img_files = []
204
- all_probs = np.zeros([num_examples, len(label_names)], dtype=np.float32)
205
- all_labels = np.zeros(num_examples, dtype=np.int32)
206
- if weighted:
207
- all_weights = np.zeros(num_examples, dtype=np.float32)
208
-
209
- batch_slice = slice(0, 0)
210
- tqdm_loader = tqdm.tqdm(loader)
211
- with torch.no_grad():
212
- for batch in tqdm_loader:
213
- if weighted:
214
- inputs, labels, img_files, weights = batch
215
- else:
216
- # even if batch contains sample weights, don't use them
217
- inputs, labels, img_files = batch[0:3]
218
- weights = None
219
-
220
- all_img_files.append(img_files)
221
-
222
- batch_size = labels.size(0)
223
- batch_slice = slice(batch_slice.stop, batch_slice.stop + batch_size)
224
- all_labels[batch_slice] = labels
225
- if weighted:
226
- all_weights[batch_slice] = weights
227
- weights = weights.to(device, non_blocking=True)
228
-
229
- inputs = inputs.to(device, non_blocking=True)
230
- labels = labels.to(device, non_blocking=True)
231
- outputs = model(inputs)
232
-
233
- # Do target mapping on the outputs (unnormalized logits) instead of
234
- # the normalized (softmax) probabilities, because the loss function
235
- # uses unnormalized logits. Summing probabilities is equivalent to
236
- # log-sum-exp of unnormalized logits.
237
- if target_mapping is not None:
238
- outputs_mapped = torch.zeros(
239
- [batch_size, num_labels], dtype=outputs.dtype,
240
- device=outputs.device)
241
- for target, cols in target_mapping.items():
242
- outputs_mapped[:, target] = torch.logsumexp(
243
- outputs[:, cols], dim=1)
244
- outputs = outputs_mapped
245
-
246
- probs = torch.nn.functional.softmax(outputs, dim=1).cpu()
247
- all_probs[batch_slice] = probs
248
-
249
- desc = []
250
- if loss_fn is not None:
251
- loss = loss_fn(outputs, labels)
252
- losses.update(loss.mean().item(), n=batch_size)
253
- desc.append(f'Loss {losses.val:.3f} ({losses.avg:.3f})')
254
- if weights is not None:
255
- loss_weighted = (loss * weights).mean()
256
- losses_weighted.update(loss_weighted.item(), n=batch_size)
257
-
258
- top_correct = train_classifier.correct(
259
- outputs, labels, weights=None, top=top)
260
- for k, acc in accuracies_topk.items():
261
- acc.update(top_correct[k] * (100. / batch_size), n=batch_size)
262
- desc.append(f'Acc@{k} {acc.val:.2f} ({acc.avg:.2f})')
263
-
264
- if weighted:
265
- top_correct = train_classifier.correct(
266
- outputs, labels, weights=weights, top=top)
267
- for k, acc in accs_weighted.items():
268
- acc.update(top_correct[k] * (100. / batch_size),
269
- n=batch_size)
270
- desc.append(f'Acc_w@{k} {acc.val:.2f} ({acc.avg:.2f})')
271
-
272
- tqdm_loader.set_description(' '.join(desc))
273
-
274
- # a confusion matrix C is such that C[i,j] is the # of observations known to
275
- # be in group i and predicted to be in group j.
276
- all_preds = all_probs.argmax(axis=1)
277
- cm = sklearn.metrics.confusion_matrix(
278
- y_true=all_labels, y_pred=all_preds, labels=np.arange(num_labels))
279
-
280
- df = pd.DataFrame()
281
- df['path'] = np.concatenate(all_img_files)
282
- df['label'] = list(map(label_names.__getitem__, all_labels))
283
- df['weight'] = all_weights
284
- df[label_names] = all_probs
285
-
286
- metrics = {}
287
- if loss_fn is not None:
288
- metrics['loss'] = losses.avg
289
- if weighted:
290
- metrics['loss_weighted'] = losses_weighted.avg
291
- for k, acc in accuracies_topk.items():
292
- metrics[f'acc_top{k}'] = acc.avg
293
- if weighted:
294
- for k, acc in accs_weighted.items():
295
- metrics[f'acc_weighted_top{k}'] = acc.avg
296
- return df, pd.Series(metrics), cm
297
-
298
-
299
- #%% Main function
300
-
301
- def main(params_json_path: str, ckpt_path: str, output_dir: str,
302
- splits: Sequence[str], target_mapping_json_path: Optional[str] = None,
303
- label_index_json_path: Optional[str] = None,
304
- **kwargs: Any) -> None:
305
-
306
- # input validation
307
- assert os.path.exists(params_json_path)
308
- assert os.path.exists(ckpt_path)
309
- assert (target_mapping_json_path is None) == (label_index_json_path is None)
310
- if target_mapping_json_path is not None:
311
- assert label_index_json_path is not None
312
- assert os.path.exists(target_mapping_json_path)
313
- assert os.path.exists(label_index_json_path)
314
-
315
- # Evaluating with accimage is much faster than Pillow or Pillow-SIMD, but accimage
316
- # is Linux-only.
317
- try:
318
- import accimage # noqa
319
- torchvision.set_image_backend('accimage')
320
- except:
321
- print('Warning: could not start accimage backend (ignore this if you\'re not using Linux)')
322
-
323
- # create output directory
324
- if not os.path.exists(output_dir):
325
- print('Creating output directory:', output_dir)
326
- os.makedirs(output_dir, exist_ok=True)
327
-
328
- with open(params_json_path, 'r') as f:
329
- params = json.load(f)
330
- pprint(params)
331
-
332
- # override saved params with kwargs
333
- for key, new in kwargs.items():
334
- if new is None:
335
- continue
336
- if key in params:
337
- saved = params[key]
338
- print(f'Overriding saved {key}. Saved: {saved}. '
339
- f'Override with: {new}.')
340
- else:
341
- print(f'Did not find {key} in saved params. Using value {new}.')
342
- params[key] = new
343
-
344
- model_name: str = params['model_name']
345
- dataset_dir: str = params['dataset_dir']
346
-
347
- if 'efficientnet' in model_name:
348
- img_size = efficientnet.EfficientNet.get_image_size(model_name)
349
- else:
350
- img_size = 224
351
-
352
- # For now, we don't weight crops by detection confidence during
353
- # evaluation. But consider changing this.
354
- print('Creating dataloaders')
355
- loaders, label_names = train_classifier.create_dataloaders(
356
- dataset_csv_path=os.path.join(dataset_dir, 'classification_ds.csv'),
357
- label_index_json_path=os.path.join(dataset_dir, 'label_index.json'),
358
- splits_json_path=os.path.join(dataset_dir, 'splits.json'),
359
- cropped_images_dir=params['cropped_images_dir'],
360
- img_size=img_size,
361
- multilabel=params['multilabel'],
362
- label_weighted=params['label_weighted'],
363
- weight_by_detection_conf=False,
364
- batch_size=params['batch_size'],
365
- num_workers=params['num_workers'],
366
- augment_train=False)
367
- num_labels = len(label_names)
368
-
369
- # create model, compile with TorchScript if given checkpoint is not compiled
370
- print('Loading model from checkpoint')
371
- try:
372
- model = torch.jit.load(ckpt_path, map_location='cpu')
373
- except RuntimeError:
374
- compiled_path = trace_model(model_name, ckpt_path, num_labels, img_size)
375
- model = torch.jit.load(compiled_path, map_location='cpu')
376
- model, device = train_classifier.prep_device(model)
377
-
378
- if len(splits) == 0:
379
- print('No splits given! Exiting.')
380
- return
381
-
382
- target_cols_map = None
383
- if target_mapping_json_path is not None:
384
- assert label_index_json_path is not None
385
-
386
- # verify that target names matches original "label names" from dataset
387
- with open(target_mapping_json_path, 'r') as f:
388
- target_names_map = json.load(f)
389
- target_names = set(target_names_map.keys())
390
-
391
- # if the dataset does not already have a 'other' category, then the
392
- # 'other' category must come last in label_names to avoid conflicting
393
- # with an existing label_id
394
- if target_names != set(label_names):
395
- assert target_names == set(label_names) | {'other'}
396
- label_names.append('other')
397
-
398
- with open(os.path.join(output_dir, 'label_index.json'), 'w') as f:
399
- json.dump(dict(enumerate(label_names)), f)
400
-
401
- with open(label_index_json_path, 'r') as f:
402
- idx_to_label = json.load(f)
403
- classifier_name_to_idx = {
404
- idx_to_label[str(k)]: k for k in range(len(idx_to_label))
405
- }
406
-
407
- target_cols_map = {}
408
- for i_target, label_name in enumerate(label_names):
409
- classifier_names = target_names_map[label_name]
410
- target_cols_map[i_target] = [
411
- classifier_name_to_idx[classifier_name]
412
- for classifier_name in classifier_names
413
- ]
414
-
415
- # define loss function (criterion)
416
- loss_fn: torch.nn.Module
417
- if params['multilabel']:
418
- loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none').to(device)
419
- else:
420
- loss_fn = torch.nn.CrossEntropyLoss(reduction='none').to(device)
421
-
422
- split_metrics = {}
423
- split_label_stats = {}
424
- cms = {}
425
- for split in splits:
426
- print(f'Evaluating {split}...')
427
- df, metrics, cm = test_epoch(
428
- model, loaders[split], weighted=True, device=device,
429
- label_names=label_names, loss_fn=loss_fn,
430
- target_mapping=target_cols_map)
431
-
432
- # this file ends up being huge, so we GZIP compress it
433
- output_csv_path = os.path.join(output_dir, f'outputs_{split}.csv.gz')
434
- df.to_csv(output_csv_path, index=False, compression='gzip')
435
-
436
- split_metrics[split] = metrics
437
- cms[split] = cm
438
- split_label_stats[split] = calc_per_label_stats(cm, label_names)
439
-
440
- # double check that the accuracy metrics are computed properly
441
- preds = df[label_names].to_numpy().argmax(axis=1)
442
- preds = np.asarray(label_names)[preds]
443
- assert np.isclose(metrics['acc_top1'] / 100.,
444
- sum(preds == df['label']) / len(df))
445
- assert np.isclose(metrics['acc_weighted_top1'] / 100.,
446
- sum((preds == df['label']) * df['weight']) / len(df))
447
-
448
- metrics_df = pd.concat(split_metrics, names=['split']).unstack(level=1)
449
- metrics_df.to_csv(os.path.join(output_dir, 'overall_metrics.csv'))
450
-
451
- # save the confusion matrices to .npz
452
- npz_path = os.path.join(output_dir, 'confusion_matrices.npz')
453
- np.savez_compressed(npz_path, **cms)
454
-
455
- # save per-label statistics
456
- label_stats_df = pd.concat(
457
- split_label_stats, names=['split', 'label']).reset_index()
458
- label_stats_csv_path = os.path.join(output_dir, 'label_stats.csv')
459
- label_stats_df.to_csv(label_stats_csv_path, index=False)
460
-
461
-
462
- #%% Command-line driver
463
-
464
- def _parse_args() -> argparse.Namespace:
465
-
466
- parser = argparse.ArgumentParser(
467
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
468
- description='Evaluate trained model.')
469
- parser.add_argument(
470
- 'params_json',
471
- help='path to params.json')
472
- parser.add_argument(
473
- 'ckpt_path',
474
- help='path to checkpoint file (normal or TorchScript-compiled)')
475
- parser.add_argument(
476
- '-o', '--output-dir', required=True,
477
- help='(required) path to output directory')
478
- parser.add_argument(
479
- '--splits', nargs='*', choices=SPLITS, default=[],
480
- help='which splits to evaluate model on. If no splits are given, then '
481
- 'only compiles normal checkpoint file using TorchScript without '
482
- 'actually running the model.')
483
-
484
- other_classifier = parser.add_argument_group(
485
- 'arguments for evaluating a model (e.g., MegaClassifier) on a '
486
- 'different set of labels than what the model was trained on')
487
- other_classifier.add_argument(
488
- '--target-mapping',
489
- help='path to JSON file mapping target categories to classifier labels')
490
- other_classifier.add_argument(
491
- '--label-index',
492
- help='path to label index JSON file for classifier')
493
-
494
- override_group = parser.add_argument_group(
495
- 'optional arguments to override values in params_json')
496
- override_group.add_argument(
497
- '--model-name',
498
- help='which EfficientNet or Resnet model')
499
- override_group.add_argument(
500
- '--batch-size', type=int,
501
- help='batch size for evaluating model')
502
- override_group.add_argument(
503
- '--num-workers', type=int,
504
- help='number of workers for data loading')
505
- override_group.add_argument(
506
- '--dataset-dir',
507
- help='path to directory containing classification_ds.csv, '
508
- 'label_index.json, and splits.json')
509
- return parser.parse_args()
510
-
511
-
512
- if __name__ == '__main__':
513
-
514
- args = _parse_args()
515
- main(params_json_path=args.params_json, ckpt_path=args.ckpt_path,
516
- output_dir=args.output_dir, splits=args.splits,
517
- target_mapping_json_path=args.target_mapping,
518
- label_index_json_path=args.label_index,
519
- model_name=args.model_name, batch_size=args.batch_size,
520
- num_workers=args.num_workers, dataset_dir=args.dataset_dir)
@@ -1,152 +0,0 @@
1
- """
2
-
3
- identify_mislabeled_candidates.py
4
-
5
- Identify images that may have been mislabeled.
6
-
7
- A "mislabeled candidate" is defined as an image meeting both criteria:
8
-
9
- * according to the ground-truth label, the model made an incorrect prediction
10
-
11
- * the model's prediction confidence exceeds its confidence for the ground-truth
12
- label by at least <margin>
13
-
14
- This script outputs for each dataset a text file containing the filenames of
15
- mislabeled candidates, one per line. The text files are saved to:
16
-
17
- <logdir>/mislabeled_candidates_{split}_{dataset}.txt
18
-
19
- To this list of files can then be passed to AzCopy to be downloaded:
20
-
21
- ""
22
- azcopy cp "http://<url_of_container>?<sas_token>" "/save/files/here" \
23
- --list-of-files "/path/to/mislabeled_candidates_{split}_{dataset}.txt"
24
- ""
25
-
26
- To save the filename as <dataset_name>/<blob_name> (instead of just <blob_name>
27
- by default), pass the --include-dataset-in-filename flag. Then, the images can
28
- be downloaded with:
29
-
30
- ""
31
- python data_management/megadb/download_images.py txt \
32
- "/path/to/mislabeled_candidates_{split}_{dataset}.txt" \
33
- /save/files/here \
34
- --threads 50
35
- ""
36
-
37
- Assumes the following directory layout:
38
- <base_logdir>/
39
- label_index.json
40
- <logdir>/
41
- outputs_{split}.csv.gz
42
-
43
- """
44
-
45
- #%% Imports
46
-
47
- from __future__ import annotations
48
-
49
- import argparse
50
- from collections import defaultdict
51
- from collections.abc import Iterable, Sequence
52
- import json
53
- import os
54
-
55
- import pandas as pd
56
- from tqdm import tqdm
57
-
58
-
59
- #%% Example usage
60
-
61
- """
62
- python identify_mislabeled_candidates.py <base_logdir>/<logdir> \
63
- --margin 0.5 --splits val test
64
- """
65
-
66
-
67
- #%% Main function
68
-
69
- def main(logdir: str, splits: Iterable[str], margin: float,
70
- include_dataset_in_filename: bool) -> None:
71
-
72
- # load files
73
- logdir = os.path.normpath(logdir) # removes any trailing slash
74
- base_logdir = os.path.dirname(logdir)
75
- idx_to_label_json_path = os.path.join(base_logdir, 'label_index.json')
76
- with open(idx_to_label_json_path, 'r') as f:
77
- idx_to_label = json.load(f)
78
- label_names = [idx_to_label[str(i)] for i in range(len(idx_to_label))]
79
-
80
- for split in splits:
81
- outputs_csv_path = os.path.join(logdir, f'outputs_{split}.csv.gz')
82
- candidates_df = get_candidates_df(outputs_csv_path, label_names, margin)
83
-
84
- # dataset => set of img_file
85
- candidate_image_files: defaultdict[str, set[str]] = defaultdict(set)
86
-
87
- for crop_path in tqdm(candidates_df['path']):
88
- # crop_path: <dataset>/<img_file>___cropXX_mdvY.Y.jpg
89
- # [----<img_path>----]
90
- img_path = crop_path.split('___crop')[0]
91
- ds, img_file = img_path.split('/', maxsplit=1)
92
- if include_dataset_in_filename:
93
- candidate_image_files[ds].add(img_path)
94
- else:
95
- candidate_image_files[ds].add(img_file)
96
-
97
- for ds in sorted(candidate_image_files.keys()):
98
- img_files = candidate_image_files[ds]
99
- print(f'{ds} contains {len(img_files)} mislabeled candidates.')
100
- save_path = os.path.join(
101
- logdir, f'mislabeled_candidates_{split}_{ds}.txt')
102
- with open(save_path, 'w') as f:
103
- for img_file in sorted(img_files):
104
- f.write(img_file + '\n')
105
-
106
-
107
- #%% Support functions
108
-
109
- def get_candidates_df(outputs_csv_path: str, label_names: Sequence[str],
110
- margin: float) -> pd.DataFrame:
111
- """
112
- Returns a DataFrame containing crops only from mislabeled candidate
113
- images.
114
- """
115
-
116
- df = pd.read_csv(outputs_csv_path, float_precision='high')
117
- all_rows = range(len(df))
118
- df['pred'] = df[label_names].idxmax(axis=1)
119
- df['pred_conf'] = df.lookup(row_labels=all_rows, col_labels=df['pred'])
120
- df['label_conf'] = df.lookup(row_labels=all_rows, col_labels=df['label'])
121
- candidate_mask = df['pred_conf'] >= df['label_conf'] + margin
122
- candidates_df = df[candidate_mask].copy()
123
- return candidates_df
124
-
125
-
126
- #%% Command-line driver
127
-
128
- def _parse_args() -> argparse.Namespace:
129
-
130
- parser = argparse.ArgumentParser(
131
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
132
- description='Identify mislabeled candidate images.')
133
- parser.add_argument(
134
- 'logdir',
135
- help='folder inside <base_logdir> containing `outputs_<split>.csv.gz`')
136
- parser.add_argument(
137
- '--margin', type=float, default=0.5,
138
- help='confidence margin to count as a mislabeled candidate')
139
- parser.add_argument(
140
- '--splits', nargs='+', choices=['train', 'val', 'test'],
141
- help='which splits to identify mislabeled candidates on')
142
- parser.add_argument(
143
- '-d', '--include-dataset-in-filename', action='store_true',
144
- help='save filename as <dataset_name>/<blob_name>')
145
- return parser.parse_args()
146
-
147
-
148
- if __name__ == '__main__':
149
-
150
- args = _parse_args()
151
- main(logdir=args.logdir, splits=args.splits, margin=args.margin,
152
- include_dataset_in_filename=args.include_dataset_in_filename)