megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
write_html_image_list.py
|
|
4
|
+
|
|
5
|
+
Given a list of image file names, writes an HTML file that
|
|
6
|
+
shows all those images, with optional one-line headers above each.
|
|
7
|
+
|
|
8
|
+
Each "filename" can also be a dict with elements 'filename','title',
|
|
9
|
+
'imageStyle','textStyle', 'linkTarget'
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
#%% Constants and imports
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import math
|
|
17
|
+
import urllib
|
|
18
|
+
|
|
19
|
+
from megadetector.utils import path_utils
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
#%% write_html_image_list
|
|
23
|
+
|
|
24
|
+
def write_html_image_list(filename=None,images=None,options=None):
|
|
25
|
+
"""
|
|
26
|
+
Given a list of image file names, writes an HTML file that shows all those images,
|
|
27
|
+
with optional one-line headers above each.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
filename (str, optional): the .html output file; if None, just returns a valid
|
|
31
|
+
options dict
|
|
32
|
+
images (list, optional): the images to write to the .html file; if None, just returns
|
|
33
|
+
a valid options dict. This can be a flat list of image filenames, or this can
|
|
34
|
+
be a list of dictionaries with one or more of the following fields:
|
|
35
|
+
|
|
36
|
+
- filename (image filename) (required, all other fields are optional)
|
|
37
|
+
- imageStyle (css style for this image)
|
|
38
|
+
- textStyle (css style for the title associated with this image)
|
|
39
|
+
- title (text label for this image)
|
|
40
|
+
- linkTarget (URL to which this image should link on click)
|
|
41
|
+
|
|
42
|
+
options (dict, optional): a dict with one or more of the following fields:
|
|
43
|
+
|
|
44
|
+
- f_html (file pointer to write to, used for splitting write operations over multiple calls)
|
|
45
|
+
- pageTitle (HTML page title)
|
|
46
|
+
- headerHtml (html text to include before the image list)
|
|
47
|
+
- subPageHeaderHtml (html text to include before the images when images are broken into pages)
|
|
48
|
+
- trailerHtml (html text to include after the image list)
|
|
49
|
+
- defaultImageStyle (default css style for images)
|
|
50
|
+
- defaultTextStyle (default css style for image titles)
|
|
51
|
+
- maxFiguresPerHtmlFile (max figures for a single HTML file; overflow will be handled by creating
|
|
52
|
+
multiple files and a TOC with links)
|
|
53
|
+
- urlEncodeFilenames (default True, e.g. '#' will be replaced by '%23')
|
|
54
|
+
- urlEncodeLinkTargets (default True, e.g. '#' will be replaced by '%23')
|
|
55
|
+
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
# returns an options struct
|
|
59
|
+
if options is None:
|
|
60
|
+
options = {}
|
|
61
|
+
|
|
62
|
+
if 'f_html' not in options:
|
|
63
|
+
options['f_html'] = -1
|
|
64
|
+
|
|
65
|
+
if 'pageTitle' not in options or options['pageTitle'] is None:
|
|
66
|
+
options['pageTitle'] = ''
|
|
67
|
+
|
|
68
|
+
if 'headerHtml' not in options or options['headerHtml'] is None:
|
|
69
|
+
options['headerHtml'] = ''
|
|
70
|
+
|
|
71
|
+
if 'subPageHeaderHtml' not in options or options['subPageHeaderHtml'] is None:
|
|
72
|
+
options['subPageHeaderHtml'] = ''
|
|
73
|
+
|
|
74
|
+
if 'trailerHtml' not in options or options['trailerHtml'] is None:
|
|
75
|
+
options['trailerHtml'] = ''
|
|
76
|
+
|
|
77
|
+
if 'defaultTextStyle' not in options or options['defaultTextStyle'] is None:
|
|
78
|
+
options['defaultTextStyle'] = \
|
|
79
|
+
"font-family:calibri,verdana,arial;font-weight:bold;font-size:150%;text-align:left;margin:0px;"
|
|
80
|
+
|
|
81
|
+
if 'defaultImageStyle' not in options or options['defaultImageStyle'] is None:
|
|
82
|
+
options['defaultImageStyle'] = \
|
|
83
|
+
"margin:0px;margin-top:5px;margin-bottom:5px;"
|
|
84
|
+
|
|
85
|
+
if 'urlEncodeFilenames' not in options or options['urlEncodeFilenames'] is None:
|
|
86
|
+
options['urlEncodeFilenames'] = True
|
|
87
|
+
|
|
88
|
+
if 'urlEncodeLinkTargets' not in options or options['urlEncodeLinkTargets'] is None:
|
|
89
|
+
options['urlEncodeLinkTargets'] = True
|
|
90
|
+
|
|
91
|
+
# Possibly split the html output for figures into multiple files; Chrome gets sad with
|
|
92
|
+
# thousands of images in a single tab.
|
|
93
|
+
if 'maxFiguresPerHtmlFile' not in options or options['maxFiguresPerHtmlFile'] is None:
|
|
94
|
+
options['maxFiguresPerHtmlFile'] = math.inf
|
|
95
|
+
|
|
96
|
+
if filename is None or images is None:
|
|
97
|
+
return options
|
|
98
|
+
|
|
99
|
+
# images may be a list of images or a list of image/style/title dictionaries,
|
|
100
|
+
# enforce that it's the latter to simplify downstream code
|
|
101
|
+
for i_image,image_info in enumerate(images):
|
|
102
|
+
if isinstance(image_info,str):
|
|
103
|
+
image_info = {'filename':image_info}
|
|
104
|
+
if 'filename' not in image_info:
|
|
105
|
+
image_info['filename'] = ''
|
|
106
|
+
if 'imageStyle' not in image_info:
|
|
107
|
+
image_info['imageStyle'] = options['defaultImageStyle']
|
|
108
|
+
if 'title' not in image_info:
|
|
109
|
+
image_info['title'] = ''
|
|
110
|
+
if 'linkTarget' not in image_info:
|
|
111
|
+
image_info['linkTarget'] = ''
|
|
112
|
+
if 'textStyle' not in image_info:
|
|
113
|
+
image_info['textStyle'] = options['defaultTextStyle']
|
|
114
|
+
images[i_image] = image_info
|
|
115
|
+
|
|
116
|
+
n_images = len(images)
|
|
117
|
+
|
|
118
|
+
# If we need to break this up into multiple files...
|
|
119
|
+
if n_images > options['maxFiguresPerHtmlFile']:
|
|
120
|
+
|
|
121
|
+
# You can't supply your own file handle in this case
|
|
122
|
+
if options['f_html'] != -1:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"You can't supply your own file handle if we have to page the image set")
|
|
125
|
+
|
|
126
|
+
figure_file_starting_indices = list(range(0,n_images,options['maxFiguresPerHtmlFile']))
|
|
127
|
+
|
|
128
|
+
assert len(figure_file_starting_indices) > 1
|
|
129
|
+
|
|
130
|
+
# Open the meta-output file
|
|
131
|
+
f_meta = open(filename,'w')
|
|
132
|
+
|
|
133
|
+
# Write header stuff
|
|
134
|
+
title_string = '<title>Index page</title>'
|
|
135
|
+
if len(options['pageTitle']) > 0:
|
|
136
|
+
title_string = '<title>Index page for: {}</title>'.format(options['pageTitle'])
|
|
137
|
+
f_meta.write('<html><head>{}</head><body>\n'.format(title_string))
|
|
138
|
+
f_meta.write(options['headerHtml'])
|
|
139
|
+
f_meta.write('<table border = 0 cellpadding = 2>\n')
|
|
140
|
+
|
|
141
|
+
for starting_index in figure_file_starting_indices:
|
|
142
|
+
|
|
143
|
+
i_start = starting_index
|
|
144
|
+
i_end = starting_index + options['maxFiguresPerHtmlFile'] - 1
|
|
145
|
+
if i_end >= n_images:
|
|
146
|
+
i_end = n_images-1
|
|
147
|
+
|
|
148
|
+
trailer = 'image_{:05d}_{:05d}'.format(i_start,i_end)
|
|
149
|
+
local_figures_html_filename = path_utils.insert_before_extension(filename,trailer)
|
|
150
|
+
f_meta.write('<tr><td>\n')
|
|
151
|
+
f_meta.write('<p style="padding-bottom:0px;margin-bottom:0px;text-align:left;font-family:''segoe ui'',calibri,arial;font-size:100%;text-decoration:none;font-weight:bold;">') # noqa
|
|
152
|
+
f_meta.write('<a href="{}">Figures for images {} through {}</a></p></td></tr>\n'.format(
|
|
153
|
+
os.path.basename(local_figures_html_filename),i_start,i_end))
|
|
154
|
+
|
|
155
|
+
local_images = images[i_start:i_end+1]
|
|
156
|
+
|
|
157
|
+
local_options = options.copy()
|
|
158
|
+
local_options['headerHtml'] = options['subPageHeaderHtml']
|
|
159
|
+
local_options['trailerHtml'] = ''
|
|
160
|
+
|
|
161
|
+
# Make a recursive call for this image set
|
|
162
|
+
write_html_image_list(local_figures_html_filename,local_images,local_options)
|
|
163
|
+
|
|
164
|
+
# ...for each page of images
|
|
165
|
+
|
|
166
|
+
f_meta.write('</table></body>\n')
|
|
167
|
+
f_meta.write(options['trailerHtml'])
|
|
168
|
+
f_meta.write('</html>\n')
|
|
169
|
+
f_meta.close()
|
|
170
|
+
|
|
171
|
+
return options
|
|
172
|
+
|
|
173
|
+
# ...if we have to make multiple sub-pages
|
|
174
|
+
|
|
175
|
+
b_clean_up_file = False
|
|
176
|
+
|
|
177
|
+
if options['f_html'] == -1:
|
|
178
|
+
b_clean_up_file = True
|
|
179
|
+
f_html = open(filename,'w')
|
|
180
|
+
else:
|
|
181
|
+
f_html = options['f_html']
|
|
182
|
+
|
|
183
|
+
title_string = ''
|
|
184
|
+
if len(options['pageTitle']) > 0:
|
|
185
|
+
title_string = '<title>{}</title>'.format(options['pageTitle'])
|
|
186
|
+
|
|
187
|
+
f_html.write('<html><head>{}</head><body>\n'.format(title_string))
|
|
188
|
+
|
|
189
|
+
f_html.write(options['headerHtml'])
|
|
190
|
+
|
|
191
|
+
# Write out images
|
|
192
|
+
for i_image,image in enumerate(images):
|
|
193
|
+
|
|
194
|
+
title = image['title']
|
|
195
|
+
image_style = image['imageStyle']
|
|
196
|
+
text_style = image['textStyle']
|
|
197
|
+
filename = image['filename']
|
|
198
|
+
link_target = image['linkTarget']
|
|
199
|
+
|
|
200
|
+
# Remove unicode characters
|
|
201
|
+
title = title.encode('ascii','ignore').decode('ascii')
|
|
202
|
+
filename = filename.encode('ascii','ignore').decode('ascii')
|
|
203
|
+
|
|
204
|
+
filename = filename.replace('\\','/')
|
|
205
|
+
if options['urlEncodeFilenames']:
|
|
206
|
+
filename = urllib.parse.quote(filename)
|
|
207
|
+
|
|
208
|
+
if len(title) > 0:
|
|
209
|
+
f_html.write(
|
|
210
|
+
'<p style="{}">{}</p>\n'\
|
|
211
|
+
.format(text_style,title))
|
|
212
|
+
|
|
213
|
+
link_target = link_target.replace('\\','/')
|
|
214
|
+
if options['urlEncodeLinkTargets']:
|
|
215
|
+
# These are typically absolute paths, so we only want to mess with certain characters
|
|
216
|
+
link_target = urllib.parse.quote(link_target,safe=':/')
|
|
217
|
+
|
|
218
|
+
if len(link_target) > 0:
|
|
219
|
+
f_html.write('<a href="{}">'.format(link_target))
|
|
220
|
+
# image_style.append(';border:0px;')
|
|
221
|
+
|
|
222
|
+
f_html.write('<img src="{}" style="{}">\n'.format(filename,image_style))
|
|
223
|
+
|
|
224
|
+
if len(link_target) > 0:
|
|
225
|
+
f_html.write('</a>')
|
|
226
|
+
|
|
227
|
+
if i_image != len(images)-1:
|
|
228
|
+
f_html.write('<br/>')
|
|
229
|
+
|
|
230
|
+
# ...for each image we need to write
|
|
231
|
+
|
|
232
|
+
f_html.write(options['trailerHtml'])
|
|
233
|
+
|
|
234
|
+
f_html.write('</body></html>\n')
|
|
235
|
+
|
|
236
|
+
if b_clean_up_file:
|
|
237
|
+
f_html.close()
|
|
238
|
+
|
|
239
|
+
# ...function
|
|
File without changes
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
plot_utils.py
|
|
4
|
+
|
|
5
|
+
Utility functions for plotting, particularly for plotting confusion matrices
|
|
6
|
+
and precision-recall curves.
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
#%% Imports
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
# This also imports mpl.{cm, axes, colors}
|
|
15
|
+
import matplotlib.figure
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
#%% Plotting functions
|
|
19
|
+
|
|
20
|
+
def plot_confusion_matrix(matrix,
|
|
21
|
+
classes,
|
|
22
|
+
normalize=False,
|
|
23
|
+
title='Confusion matrix',
|
|
24
|
+
cmap=matplotlib.cm.Blues,
|
|
25
|
+
vmax=None,
|
|
26
|
+
use_colorbar=True,
|
|
27
|
+
y_label=True,
|
|
28
|
+
fmt= '{:.0f}',
|
|
29
|
+
fig=None):
|
|
30
|
+
"""
|
|
31
|
+
Plots a confusion matrix.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
matrix (np.ndarray): shape [num_classes, num_classes], confusion matrix
|
|
35
|
+
where rows are ground-truth classes and columns are predicted classes
|
|
36
|
+
classes (list of str): class names for each row/column
|
|
37
|
+
normalize (bool, optional): whether to perform row-wise normalization;
|
|
38
|
+
by default, assumes values in the confusion matrix are percentages
|
|
39
|
+
title (str, optional): figure title
|
|
40
|
+
cmap (matplotlib.colors.colormap, optional): colormap for cell backgrounds
|
|
41
|
+
vmax (float, optional): value corresponding to the largest value of the colormap;
|
|
42
|
+
if None, the maximum value in [matrix] will be used
|
|
43
|
+
use_colorbar (bool, optional): whether to show colorbar
|
|
44
|
+
y_label (bool, optional): whether to show class names on the y axis
|
|
45
|
+
fmt (str, optional): format string for rendering numeric values
|
|
46
|
+
fig (Figure, optional): existing figure to which we should render, otherwise
|
|
47
|
+
creates a new figure
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
matplotlib.figure.Figure: the figure we rendered to or created
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
num_classes = matrix.shape[0]
|
|
54
|
+
assert matrix.shape[1] == num_classes
|
|
55
|
+
assert len(classes) == num_classes
|
|
56
|
+
|
|
57
|
+
normalized_matrix = matrix.astype(np.float64) / (
|
|
58
|
+
matrix.sum(axis=1, keepdims=True) + 1e-7)
|
|
59
|
+
if normalize:
|
|
60
|
+
matrix = normalized_matrix
|
|
61
|
+
|
|
62
|
+
fig_h = 3 + 0.3 * num_classes
|
|
63
|
+
fig_w = fig_h
|
|
64
|
+
if use_colorbar:
|
|
65
|
+
fig_w += 0.5
|
|
66
|
+
|
|
67
|
+
if fig is None:
|
|
68
|
+
fig = matplotlib.figure.Figure(figsize=(fig_w, fig_h), tight_layout=True)
|
|
69
|
+
ax = fig.subplots(1, 1)
|
|
70
|
+
im = ax.imshow(normalized_matrix, interpolation='nearest', cmap=cmap, vmax=vmax)
|
|
71
|
+
ax.set_title(title)
|
|
72
|
+
|
|
73
|
+
if use_colorbar:
|
|
74
|
+
cbar = fig.colorbar(im, fraction=0.046, pad=0.04,
|
|
75
|
+
ticks=[0.0, 0.25, 0.5, 0.75, 1.0])
|
|
76
|
+
cbar.set_ticklabels(['0%', '25%', '50%', '75%', '100%'])
|
|
77
|
+
|
|
78
|
+
tick_marks = np.arange(num_classes)
|
|
79
|
+
ax.set_xticks(tick_marks)
|
|
80
|
+
ax.set_yticks(tick_marks)
|
|
81
|
+
ax.set_xticklabels(classes, rotation=90)
|
|
82
|
+
ax.set_xlabel('Predicted class')
|
|
83
|
+
|
|
84
|
+
if y_label:
|
|
85
|
+
ax.set_yticklabels(classes)
|
|
86
|
+
ax.set_ylabel('Ground-truth class')
|
|
87
|
+
|
|
88
|
+
for i, j in np.ndindex(matrix.shape):
|
|
89
|
+
v = matrix[i, j]
|
|
90
|
+
ax.text(j, i, fmt.format(v),
|
|
91
|
+
horizontalalignment='center',
|
|
92
|
+
verticalalignment='center',
|
|
93
|
+
color='white' if normalized_matrix[i, j] > 0.5 else 'black')
|
|
94
|
+
|
|
95
|
+
return fig
|
|
96
|
+
|
|
97
|
+
# ...def plot_confusion_matrix(...)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def plot_precision_recall_curve(precisions,
|
|
101
|
+
recalls,
|
|
102
|
+
title='Precision/recall curve',
|
|
103
|
+
xlim=(0.0,1.05),
|
|
104
|
+
ylim=(0.0,1.05)):
|
|
105
|
+
"""
|
|
106
|
+
Plots a precision/recall curve given lists of (ordered) precision
|
|
107
|
+
and recall values.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
precisions (list of float): precision for corresponding recall values,
|
|
111
|
+
should have same length as [recalls].
|
|
112
|
+
recalls (list of float): recall for corresponding precision values,
|
|
113
|
+
should have same length as [precisions].
|
|
114
|
+
title (str, optional): plot title
|
|
115
|
+
xlim (tuple, optional): x-axis limits as a length-2 tuple
|
|
116
|
+
ylim (tuple, optional): y-axis limits as a length-2 tuple
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
matplotlib.figure.Figure: the (new) figure
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
assert len(precisions) == len(recalls)
|
|
123
|
+
|
|
124
|
+
fig = matplotlib.figure.Figure(tight_layout=True)
|
|
125
|
+
ax = fig.subplots(1, 1)
|
|
126
|
+
ax.step(recalls, precisions, color='b', alpha=0.2, where='post')
|
|
127
|
+
ax.fill_between(recalls, precisions, alpha=0.2, color='b', step='post')
|
|
128
|
+
|
|
129
|
+
ax.set_xlabel('Recall')
|
|
130
|
+
ax.set_ylabel('Precision')
|
|
131
|
+
ax.set_title(title)
|
|
132
|
+
ax.set_xlim(xlim[0],xlim[1])
|
|
133
|
+
ax.set_ylim(ylim[0],ylim[1])
|
|
134
|
+
|
|
135
|
+
return fig
|
|
136
|
+
|
|
137
|
+
# ...def plot_precision_recall_curve(...)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def plot_stacked_bar_chart(data,
|
|
141
|
+
series_labels=None,
|
|
142
|
+
col_labels=None,
|
|
143
|
+
x_label=None,
|
|
144
|
+
y_label=None,
|
|
145
|
+
log_scale=False):
|
|
146
|
+
"""
|
|
147
|
+
Plot a stacked bar chart, for plotting e.g. species distribution across locations.
|
|
148
|
+
|
|
149
|
+
Reference: https://stackoverflow.com/q/44309507
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
data (np.ndarray or list of list): data to plot; rows (series) are species, columns
|
|
153
|
+
are locations
|
|
154
|
+
series_labels (list of str, optional): series labels, typically species names
|
|
155
|
+
col_labels (list of str, optional): column labels, typically location names
|
|
156
|
+
x_label (str, optional): x-axis label
|
|
157
|
+
y_label (str, optional): y-axis label
|
|
158
|
+
log_scale (bool, optional): whether to plot the y axis in log-scale
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
matplotlib.figure.Figure: the (new) figure
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
data = np.asarray(data)
|
|
165
|
+
num_series, num_columns = data.shape
|
|
166
|
+
ind = np.arange(num_columns)
|
|
167
|
+
|
|
168
|
+
fig = matplotlib.figure.Figure(tight_layout=True)
|
|
169
|
+
ax = fig.subplots(1, 1)
|
|
170
|
+
colors = matplotlib.cm.rainbow(np.linspace(0, 1, num_series))
|
|
171
|
+
|
|
172
|
+
# stacked bar charts are made with each segment starting from a y position
|
|
173
|
+
cumulative_size = np.zeros(num_columns)
|
|
174
|
+
for i_row, row_data in enumerate(data):
|
|
175
|
+
if series_labels is None:
|
|
176
|
+
label = 'series_{}'.format(str(i_row).zfill(2))
|
|
177
|
+
else:
|
|
178
|
+
label = series_labels[i_row]
|
|
179
|
+
ax.bar(ind, row_data, bottom=cumulative_size, label=label,
|
|
180
|
+
color=colors[i_row])
|
|
181
|
+
cumulative_size += row_data
|
|
182
|
+
|
|
183
|
+
if (col_labels is not None) and (len(col_labels) < 25):
|
|
184
|
+
ax.set_xticks(ind)
|
|
185
|
+
ax.set_xticklabels(col_labels, rotation=90)
|
|
186
|
+
elif (col_labels is not None):
|
|
187
|
+
ax.set_xticks(list(range(0, len(col_labels), 20)))
|
|
188
|
+
ax.set_xticklabels(col_labels[::20], rotation=90)
|
|
189
|
+
|
|
190
|
+
if x_label is not None:
|
|
191
|
+
ax.set_xlabel(x_label)
|
|
192
|
+
if y_label is not None:
|
|
193
|
+
ax.set_ylabel(y_label)
|
|
194
|
+
if log_scale:
|
|
195
|
+
ax.set_yscale('log')
|
|
196
|
+
|
|
197
|
+
# To fit the legend in, shrink current axis by 20%
|
|
198
|
+
box = ax.get_position()
|
|
199
|
+
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
|
200
|
+
|
|
201
|
+
# Put a legend to the right of the current axis
|
|
202
|
+
ax.legend(loc='center left', bbox_to_anchor=(0.99, 0.5), frameon=False)
|
|
203
|
+
|
|
204
|
+
return fig
|
|
205
|
+
|
|
206
|
+
# ...def plot_stacked_bar_chart(...)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def calibration_ece(true_scores, pred_scores, num_bins):
|
|
210
|
+
r"""
|
|
211
|
+
Expected calibration error (ECE) as defined in equation (3) of
|
|
212
|
+
Guo et al. "On Calibration of Modern Neural Networks." (2017).
|
|
213
|
+
|
|
214
|
+
Implementation modified from sklearn.calibration.calibration_curve()
|
|
215
|
+
in order to implement ECE calculation. See:
|
|
216
|
+
|
|
217
|
+
https://github.com/scikit-learn/scikit-learn/issues/18268
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
|
|
221
|
+
pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
|
|
222
|
+
predicted confidence that example i is positive
|
|
223
|
+
num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
tuple: a length-three tuple containing:
|
|
227
|
+
- accs: np.ndarray, shape [M], type float64, accuracy in each bin,
|
|
228
|
+
M <= num_bins because bins with no samples are not returned
|
|
229
|
+
- confs: np.ndarray, shape [M], type float64, mean model confidence in
|
|
230
|
+
each bin
|
|
231
|
+
- ece: float, expected calibration error
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
assert len(true_scores) == len(pred_scores)
|
|
235
|
+
|
|
236
|
+
bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
|
|
237
|
+
binids = np.digitize(pred_scores, bins) - 1
|
|
238
|
+
|
|
239
|
+
bin_sums = np.bincount(binids, weights=pred_scores, minlength=len(bins))
|
|
240
|
+
bin_true = np.bincount(binids, weights=true_scores, minlength=len(bins))
|
|
241
|
+
bin_total = np.bincount(binids, minlength=len(bins))
|
|
242
|
+
|
|
243
|
+
nonzero = bin_total != 0
|
|
244
|
+
accs = bin_true[nonzero] / bin_total[nonzero]
|
|
245
|
+
confs = bin_sums[nonzero] / bin_total[nonzero]
|
|
246
|
+
|
|
247
|
+
weights = bin_total[nonzero] / len(true_scores)
|
|
248
|
+
ece = np.abs(accs - confs) @ weights
|
|
249
|
+
return accs, confs, ece
|
|
250
|
+
|
|
251
|
+
# ...def calibration_ece(...)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def plot_calibration_curve(true_scores,
|
|
255
|
+
pred_scores,
|
|
256
|
+
num_bins,
|
|
257
|
+
name='calibration',
|
|
258
|
+
plot_perf=True,
|
|
259
|
+
plot_hist=True,
|
|
260
|
+
ax=None,
|
|
261
|
+
**fig_kwargs):
|
|
262
|
+
"""
|
|
263
|
+
Plots a calibration curve.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
|
|
267
|
+
pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
|
|
268
|
+
predicted confidence that example i is positive
|
|
269
|
+
num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
|
|
270
|
+
name (str, optional): label in legend for the calibration curve
|
|
271
|
+
plot_perf (bool, optional): whether to plot y=x line indicating perfect calibration
|
|
272
|
+
plot_hist (bool, optional): whether to plot histogram of counts
|
|
273
|
+
ax (Axes, optional): if given then no legend is drawn, and fig_kwargs are ignored
|
|
274
|
+
fig_kwargs (dict): only used if [ax] is None
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
matplotlib.figure.Figure: the (new) figure
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
accs, confs, ece = calibration_ece(true_scores, pred_scores, num_bins)
|
|
281
|
+
|
|
282
|
+
created_fig = False
|
|
283
|
+
if ax is None:
|
|
284
|
+
created_fig = True
|
|
285
|
+
fig = matplotlib.figure.Figure(**fig_kwargs)
|
|
286
|
+
ax = fig.subplots(1, 1)
|
|
287
|
+
ax.plot(confs, accs, 's-', label=name) # 's-': squares on line
|
|
288
|
+
ax.set(xlabel='Model confidence', ylabel='Actual accuracy',
|
|
289
|
+
title=f'Calibration plot (ECE: {ece:.02g})')
|
|
290
|
+
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05])
|
|
291
|
+
if plot_perf:
|
|
292
|
+
ax.plot([0, 1], [0, 1], color='black', label='perfect calibration')
|
|
293
|
+
ax.grid(True)
|
|
294
|
+
|
|
295
|
+
if plot_hist:
|
|
296
|
+
ax1 = ax.twinx()
|
|
297
|
+
bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
|
|
298
|
+
counts = ax1.hist(pred_scores, alpha=0.5, label='histogram of examples',
|
|
299
|
+
bins=bins, color='tab:red')[0]
|
|
300
|
+
max_count = np.max(counts)
|
|
301
|
+
ax1.set_ylim([-0.05 * max_count, 1.05 * max_count])
|
|
302
|
+
ax1.set_ylabel('Count')
|
|
303
|
+
|
|
304
|
+
if created_fig:
|
|
305
|
+
fig.legend(loc='upper left', bbox_to_anchor=(0.15, 0.85))
|
|
306
|
+
|
|
307
|
+
return ax.figure
|
|
308
|
+
|
|
309
|
+
# ...def plot_calibration_curve(...)
|