megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
1
+ """
2
+
3
+ write_html_image_list.py
4
+
5
+ Given a list of image file names, writes an HTML file that
6
+ shows all those images, with optional one-line headers above each.
7
+
8
+ Each "filename" can also be a dict with elements 'filename','title',
9
+ 'imageStyle','textStyle', 'linkTarget'
10
+
11
+ """
12
+
13
+ #%% Constants and imports
14
+
15
+ import os
16
+ import math
17
+ import urllib
18
+
19
+ from megadetector.utils import path_utils
20
+
21
+
22
+ #%% write_html_image_list
23
+
24
+ def write_html_image_list(filename=None,images=None,options=None):
25
+ """
26
+ Given a list of image file names, writes an HTML file that shows all those images,
27
+ with optional one-line headers above each.
28
+
29
+ Args:
30
+ filename (str, optional): the .html output file; if None, just returns a valid
31
+ options dict
32
+ images (list, optional): the images to write to the .html file; if None, just returns
33
+ a valid options dict. This can be a flat list of image filenames, or this can
34
+ be a list of dictionaries with one or more of the following fields:
35
+
36
+ - filename (image filename) (required, all other fields are optional)
37
+ - imageStyle (css style for this image)
38
+ - textStyle (css style for the title associated with this image)
39
+ - title (text label for this image)
40
+ - linkTarget (URL to which this image should link on click)
41
+
42
+ options (dict, optional): a dict with one or more of the following fields:
43
+
44
+ - f_html (file pointer to write to, used for splitting write operations over multiple calls)
45
+ - pageTitle (HTML page title)
46
+ - headerHtml (html text to include before the image list)
47
+ - subPageHeaderHtml (html text to include before the images when images are broken into pages)
48
+ - trailerHtml (html text to include after the image list)
49
+ - defaultImageStyle (default css style for images)
50
+ - defaultTextStyle (default css style for image titles)
51
+ - maxFiguresPerHtmlFile (max figures for a single HTML file; overflow will be handled by creating
52
+ multiple files and a TOC with links)
53
+ - urlEncodeFilenames (default True, e.g. '#' will be replaced by '%23')
54
+ - urlEncodeLinkTargets (default True, e.g. '#' will be replaced by '%23')
55
+
56
+ """
57
+
58
+ # returns an options struct
59
+ if options is None:
60
+ options = {}
61
+
62
+ if 'f_html' not in options:
63
+ options['f_html'] = -1
64
+
65
+ if 'pageTitle' not in options or options['pageTitle'] is None:
66
+ options['pageTitle'] = ''
67
+
68
+ if 'headerHtml' not in options or options['headerHtml'] is None:
69
+ options['headerHtml'] = ''
70
+
71
+ if 'subPageHeaderHtml' not in options or options['subPageHeaderHtml'] is None:
72
+ options['subPageHeaderHtml'] = ''
73
+
74
+ if 'trailerHtml' not in options or options['trailerHtml'] is None:
75
+ options['trailerHtml'] = ''
76
+
77
+ if 'defaultTextStyle' not in options or options['defaultTextStyle'] is None:
78
+ options['defaultTextStyle'] = \
79
+ "font-family:calibri,verdana,arial;font-weight:bold;font-size:150%;text-align:left;margin:0px;"
80
+
81
+ if 'defaultImageStyle' not in options or options['defaultImageStyle'] is None:
82
+ options['defaultImageStyle'] = \
83
+ "margin:0px;margin-top:5px;margin-bottom:5px;"
84
+
85
+ if 'urlEncodeFilenames' not in options or options['urlEncodeFilenames'] is None:
86
+ options['urlEncodeFilenames'] = True
87
+
88
+ if 'urlEncodeLinkTargets' not in options or options['urlEncodeLinkTargets'] is None:
89
+ options['urlEncodeLinkTargets'] = True
90
+
91
+ # Possibly split the html output for figures into multiple files; Chrome gets sad with
92
+ # thousands of images in a single tab.
93
+ if 'maxFiguresPerHtmlFile' not in options or options['maxFiguresPerHtmlFile'] is None:
94
+ options['maxFiguresPerHtmlFile'] = math.inf
95
+
96
+ if filename is None or images is None:
97
+ return options
98
+
99
+ # images may be a list of images or a list of image/style/title dictionaries,
100
+ # enforce that it's the latter to simplify downstream code
101
+ for i_image,image_info in enumerate(images):
102
+ if isinstance(image_info,str):
103
+ image_info = {'filename':image_info}
104
+ if 'filename' not in image_info:
105
+ image_info['filename'] = ''
106
+ if 'imageStyle' not in image_info:
107
+ image_info['imageStyle'] = options['defaultImageStyle']
108
+ if 'title' not in image_info:
109
+ image_info['title'] = ''
110
+ if 'linkTarget' not in image_info:
111
+ image_info['linkTarget'] = ''
112
+ if 'textStyle' not in image_info:
113
+ image_info['textStyle'] = options['defaultTextStyle']
114
+ images[i_image] = image_info
115
+
116
+ n_images = len(images)
117
+
118
+ # If we need to break this up into multiple files...
119
+ if n_images > options['maxFiguresPerHtmlFile']:
120
+
121
+ # You can't supply your own file handle in this case
122
+ if options['f_html'] != -1:
123
+ raise ValueError(
124
+ "You can't supply your own file handle if we have to page the image set")
125
+
126
+ figure_file_starting_indices = list(range(0,n_images,options['maxFiguresPerHtmlFile']))
127
+
128
+ assert len(figure_file_starting_indices) > 1
129
+
130
+ # Open the meta-output file
131
+ f_meta = open(filename,'w')
132
+
133
+ # Write header stuff
134
+ title_string = '<title>Index page</title>'
135
+ if len(options['pageTitle']) > 0:
136
+ title_string = '<title>Index page for: {}</title>'.format(options['pageTitle'])
137
+ f_meta.write('<html><head>{}</head><body>\n'.format(title_string))
138
+ f_meta.write(options['headerHtml'])
139
+ f_meta.write('<table border = 0 cellpadding = 2>\n')
140
+
141
+ for starting_index in figure_file_starting_indices:
142
+
143
+ i_start = starting_index
144
+ i_end = starting_index + options['maxFiguresPerHtmlFile'] - 1
145
+ if i_end >= n_images:
146
+ i_end = n_images-1
147
+
148
+ trailer = 'image_{:05d}_{:05d}'.format(i_start,i_end)
149
+ local_figures_html_filename = path_utils.insert_before_extension(filename,trailer)
150
+ f_meta.write('<tr><td>\n')
151
+ f_meta.write('<p style="padding-bottom:0px;margin-bottom:0px;text-align:left;font-family:''segoe ui'',calibri,arial;font-size:100%;text-decoration:none;font-weight:bold;">') # noqa
152
+ f_meta.write('<a href="{}">Figures for images {} through {}</a></p></td></tr>\n'.format(
153
+ os.path.basename(local_figures_html_filename),i_start,i_end))
154
+
155
+ local_images = images[i_start:i_end+1]
156
+
157
+ local_options = options.copy()
158
+ local_options['headerHtml'] = options['subPageHeaderHtml']
159
+ local_options['trailerHtml'] = ''
160
+
161
+ # Make a recursive call for this image set
162
+ write_html_image_list(local_figures_html_filename,local_images,local_options)
163
+
164
+ # ...for each page of images
165
+
166
+ f_meta.write('</table></body>\n')
167
+ f_meta.write(options['trailerHtml'])
168
+ f_meta.write('</html>\n')
169
+ f_meta.close()
170
+
171
+ return options
172
+
173
+ # ...if we have to make multiple sub-pages
174
+
175
+ b_clean_up_file = False
176
+
177
+ if options['f_html'] == -1:
178
+ b_clean_up_file = True
179
+ f_html = open(filename,'w')
180
+ else:
181
+ f_html = options['f_html']
182
+
183
+ title_string = ''
184
+ if len(options['pageTitle']) > 0:
185
+ title_string = '<title>{}</title>'.format(options['pageTitle'])
186
+
187
+ f_html.write('<html><head>{}</head><body>\n'.format(title_string))
188
+
189
+ f_html.write(options['headerHtml'])
190
+
191
+ # Write out images
192
+ for i_image,image in enumerate(images):
193
+
194
+ title = image['title']
195
+ image_style = image['imageStyle']
196
+ text_style = image['textStyle']
197
+ filename = image['filename']
198
+ link_target = image['linkTarget']
199
+
200
+ # Remove unicode characters
201
+ title = title.encode('ascii','ignore').decode('ascii')
202
+ filename = filename.encode('ascii','ignore').decode('ascii')
203
+
204
+ filename = filename.replace('\\','/')
205
+ if options['urlEncodeFilenames']:
206
+ filename = urllib.parse.quote(filename)
207
+
208
+ if len(title) > 0:
209
+ f_html.write(
210
+ '<p style="{}">{}</p>\n'\
211
+ .format(text_style,title))
212
+
213
+ link_target = link_target.replace('\\','/')
214
+ if options['urlEncodeLinkTargets']:
215
+ # These are typically absolute paths, so we only want to mess with certain characters
216
+ link_target = urllib.parse.quote(link_target,safe=':/')
217
+
218
+ if len(link_target) > 0:
219
+ f_html.write('<a href="{}">'.format(link_target))
220
+ # image_style.append(';border:0px;')
221
+
222
+ f_html.write('<img src="{}" style="{}">\n'.format(filename,image_style))
223
+
224
+ if len(link_target) > 0:
225
+ f_html.write('</a>')
226
+
227
+ if i_image != len(images)-1:
228
+ f_html.write('<br/>')
229
+
230
+ # ...for each image we need to write
231
+
232
+ f_html.write(options['trailerHtml'])
233
+
234
+ f_html.write('</body></html>\n')
235
+
236
+ if b_clean_up_file:
237
+ f_html.close()
238
+
239
+ # ...function
File without changes
@@ -0,0 +1,309 @@
1
+ """
2
+
3
+ plot_utils.py
4
+
5
+ Utility functions for plotting, particularly for plotting confusion matrices
6
+ and precision-recall curves.
7
+
8
+ """
9
+
10
+ #%% Imports
11
+
12
+ import numpy as np
13
+
14
+ # This also imports mpl.{cm, axes, colors}
15
+ import matplotlib.figure
16
+
17
+
18
+ #%% Plotting functions
19
+
20
+ def plot_confusion_matrix(matrix,
21
+ classes,
22
+ normalize=False,
23
+ title='Confusion matrix',
24
+ cmap=matplotlib.cm.Blues,
25
+ vmax=None,
26
+ use_colorbar=True,
27
+ y_label=True,
28
+ fmt= '{:.0f}',
29
+ fig=None):
30
+ """
31
+ Plots a confusion matrix.
32
+
33
+ Args:
34
+ matrix (np.ndarray): shape [num_classes, num_classes], confusion matrix
35
+ where rows are ground-truth classes and columns are predicted classes
36
+ classes (list of str): class names for each row/column
37
+ normalize (bool, optional): whether to perform row-wise normalization;
38
+ by default, assumes values in the confusion matrix are percentages
39
+ title (str, optional): figure title
40
+ cmap (matplotlib.colors.colormap, optional): colormap for cell backgrounds
41
+ vmax (float, optional): value corresponding to the largest value of the colormap;
42
+ if None, the maximum value in [matrix] will be used
43
+ use_colorbar (bool, optional): whether to show colorbar
44
+ y_label (bool, optional): whether to show class names on the y axis
45
+ fmt (str, optional): format string for rendering numeric values
46
+ fig (Figure, optional): existing figure to which we should render, otherwise
47
+ creates a new figure
48
+
49
+ Returns:
50
+ matplotlib.figure.Figure: the figure we rendered to or created
51
+ """
52
+
53
+ num_classes = matrix.shape[0]
54
+ assert matrix.shape[1] == num_classes
55
+ assert len(classes) == num_classes
56
+
57
+ normalized_matrix = matrix.astype(np.float64) / (
58
+ matrix.sum(axis=1, keepdims=True) + 1e-7)
59
+ if normalize:
60
+ matrix = normalized_matrix
61
+
62
+ fig_h = 3 + 0.3 * num_classes
63
+ fig_w = fig_h
64
+ if use_colorbar:
65
+ fig_w += 0.5
66
+
67
+ if fig is None:
68
+ fig = matplotlib.figure.Figure(figsize=(fig_w, fig_h), tight_layout=True)
69
+ ax = fig.subplots(1, 1)
70
+ im = ax.imshow(normalized_matrix, interpolation='nearest', cmap=cmap, vmax=vmax)
71
+ ax.set_title(title)
72
+
73
+ if use_colorbar:
74
+ cbar = fig.colorbar(im, fraction=0.046, pad=0.04,
75
+ ticks=[0.0, 0.25, 0.5, 0.75, 1.0])
76
+ cbar.set_ticklabels(['0%', '25%', '50%', '75%', '100%'])
77
+
78
+ tick_marks = np.arange(num_classes)
79
+ ax.set_xticks(tick_marks)
80
+ ax.set_yticks(tick_marks)
81
+ ax.set_xticklabels(classes, rotation=90)
82
+ ax.set_xlabel('Predicted class')
83
+
84
+ if y_label:
85
+ ax.set_yticklabels(classes)
86
+ ax.set_ylabel('Ground-truth class')
87
+
88
+ for i, j in np.ndindex(matrix.shape):
89
+ v = matrix[i, j]
90
+ ax.text(j, i, fmt.format(v),
91
+ horizontalalignment='center',
92
+ verticalalignment='center',
93
+ color='white' if normalized_matrix[i, j] > 0.5 else 'black')
94
+
95
+ return fig
96
+
97
+ # ...def plot_confusion_matrix(...)
98
+
99
+
100
+ def plot_precision_recall_curve(precisions,
101
+ recalls,
102
+ title='Precision/recall curve',
103
+ xlim=(0.0,1.05),
104
+ ylim=(0.0,1.05)):
105
+ """
106
+ Plots a precision/recall curve given lists of (ordered) precision
107
+ and recall values.
108
+
109
+ Args:
110
+ precisions (list of float): precision for corresponding recall values,
111
+ should have same length as [recalls].
112
+ recalls (list of float): recall for corresponding precision values,
113
+ should have same length as [precisions].
114
+ title (str, optional): plot title
115
+ xlim (tuple, optional): x-axis limits as a length-2 tuple
116
+ ylim (tuple, optional): y-axis limits as a length-2 tuple
117
+
118
+ Returns:
119
+ matplotlib.figure.Figure: the (new) figure
120
+ """
121
+
122
+ assert len(precisions) == len(recalls)
123
+
124
+ fig = matplotlib.figure.Figure(tight_layout=True)
125
+ ax = fig.subplots(1, 1)
126
+ ax.step(recalls, precisions, color='b', alpha=0.2, where='post')
127
+ ax.fill_between(recalls, precisions, alpha=0.2, color='b', step='post')
128
+
129
+ ax.set_xlabel('Recall')
130
+ ax.set_ylabel('Precision')
131
+ ax.set_title(title)
132
+ ax.set_xlim(xlim[0],xlim[1])
133
+ ax.set_ylim(ylim[0],ylim[1])
134
+
135
+ return fig
136
+
137
+ # ...def plot_precision_recall_curve(...)
138
+
139
+
140
+ def plot_stacked_bar_chart(data,
141
+ series_labels=None,
142
+ col_labels=None,
143
+ x_label=None,
144
+ y_label=None,
145
+ log_scale=False):
146
+ """
147
+ Plot a stacked bar chart, for plotting e.g. species distribution across locations.
148
+
149
+ Reference: https://stackoverflow.com/q/44309507
150
+
151
+ Args:
152
+ data (np.ndarray or list of list): data to plot; rows (series) are species, columns
153
+ are locations
154
+ series_labels (list of str, optional): series labels, typically species names
155
+ col_labels (list of str, optional): column labels, typically location names
156
+ x_label (str, optional): x-axis label
157
+ y_label (str, optional): y-axis label
158
+ log_scale (bool, optional): whether to plot the y axis in log-scale
159
+
160
+ Returns:
161
+ matplotlib.figure.Figure: the (new) figure
162
+ """
163
+
164
+ data = np.asarray(data)
165
+ num_series, num_columns = data.shape
166
+ ind = np.arange(num_columns)
167
+
168
+ fig = matplotlib.figure.Figure(tight_layout=True)
169
+ ax = fig.subplots(1, 1)
170
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, num_series))
171
+
172
+ # stacked bar charts are made with each segment starting from a y position
173
+ cumulative_size = np.zeros(num_columns)
174
+ for i_row, row_data in enumerate(data):
175
+ if series_labels is None:
176
+ label = 'series_{}'.format(str(i_row).zfill(2))
177
+ else:
178
+ label = series_labels[i_row]
179
+ ax.bar(ind, row_data, bottom=cumulative_size, label=label,
180
+ color=colors[i_row])
181
+ cumulative_size += row_data
182
+
183
+ if (col_labels is not None) and (len(col_labels) < 25):
184
+ ax.set_xticks(ind)
185
+ ax.set_xticklabels(col_labels, rotation=90)
186
+ elif (col_labels is not None):
187
+ ax.set_xticks(list(range(0, len(col_labels), 20)))
188
+ ax.set_xticklabels(col_labels[::20], rotation=90)
189
+
190
+ if x_label is not None:
191
+ ax.set_xlabel(x_label)
192
+ if y_label is not None:
193
+ ax.set_ylabel(y_label)
194
+ if log_scale:
195
+ ax.set_yscale('log')
196
+
197
+ # To fit the legend in, shrink current axis by 20%
198
+ box = ax.get_position()
199
+ ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
200
+
201
+ # Put a legend to the right of the current axis
202
+ ax.legend(loc='center left', bbox_to_anchor=(0.99, 0.5), frameon=False)
203
+
204
+ return fig
205
+
206
+ # ...def plot_stacked_bar_chart(...)
207
+
208
+
209
+ def calibration_ece(true_scores, pred_scores, num_bins):
210
+ r"""
211
+ Expected calibration error (ECE) as defined in equation (3) of
212
+ Guo et al. "On Calibration of Modern Neural Networks." (2017).
213
+
214
+ Implementation modified from sklearn.calibration.calibration_curve()
215
+ in order to implement ECE calculation. See:
216
+
217
+ https://github.com/scikit-learn/scikit-learn/issues/18268
218
+
219
+ Args:
220
+ true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
221
+ pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
222
+ predicted confidence that example i is positive
223
+ num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
224
+
225
+ Returns:
226
+ tuple: a length-three tuple containing:
227
+ - accs: np.ndarray, shape [M], type float64, accuracy in each bin,
228
+ M <= num_bins because bins with no samples are not returned
229
+ - confs: np.ndarray, shape [M], type float64, mean model confidence in
230
+ each bin
231
+ - ece: float, expected calibration error
232
+ """
233
+
234
+ assert len(true_scores) == len(pred_scores)
235
+
236
+ bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
237
+ binids = np.digitize(pred_scores, bins) - 1
238
+
239
+ bin_sums = np.bincount(binids, weights=pred_scores, minlength=len(bins))
240
+ bin_true = np.bincount(binids, weights=true_scores, minlength=len(bins))
241
+ bin_total = np.bincount(binids, minlength=len(bins))
242
+
243
+ nonzero = bin_total != 0
244
+ accs = bin_true[nonzero] / bin_total[nonzero]
245
+ confs = bin_sums[nonzero] / bin_total[nonzero]
246
+
247
+ weights = bin_total[nonzero] / len(true_scores)
248
+ ece = np.abs(accs - confs) @ weights
249
+ return accs, confs, ece
250
+
251
+ # ...def calibration_ece(...)
252
+
253
+
254
+ def plot_calibration_curve(true_scores,
255
+ pred_scores,
256
+ num_bins,
257
+ name='calibration',
258
+ plot_perf=True,
259
+ plot_hist=True,
260
+ ax=None,
261
+ **fig_kwargs):
262
+ """
263
+ Plots a calibration curve.
264
+
265
+ Args:
266
+ true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
267
+ pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
268
+ predicted confidence that example i is positive
269
+ num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
270
+ name (str, optional): label in legend for the calibration curve
271
+ plot_perf (bool, optional): whether to plot y=x line indicating perfect calibration
272
+ plot_hist (bool, optional): whether to plot histogram of counts
273
+ ax (Axes, optional): if given then no legend is drawn, and fig_kwargs are ignored
274
+ fig_kwargs (dict): only used if [ax] is None
275
+
276
+ Returns:
277
+ matplotlib.figure.Figure: the (new) figure
278
+ """
279
+
280
+ accs, confs, ece = calibration_ece(true_scores, pred_scores, num_bins)
281
+
282
+ created_fig = False
283
+ if ax is None:
284
+ created_fig = True
285
+ fig = matplotlib.figure.Figure(**fig_kwargs)
286
+ ax = fig.subplots(1, 1)
287
+ ax.plot(confs, accs, 's-', label=name) # 's-': squares on line
288
+ ax.set(xlabel='Model confidence', ylabel='Actual accuracy',
289
+ title=f'Calibration plot (ECE: {ece:.02g})')
290
+ ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05])
291
+ if plot_perf:
292
+ ax.plot([0, 1], [0, 1], color='black', label='perfect calibration')
293
+ ax.grid(True)
294
+
295
+ if plot_hist:
296
+ ax1 = ax.twinx()
297
+ bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
298
+ counts = ax1.hist(pred_scores, alpha=0.5, label='histogram of examples',
299
+ bins=bins, color='tab:red')[0]
300
+ max_count = np.max(counts)
301
+ ax1.set_ylim([-0.05 * max_count, 1.05 * max_count])
302
+ ax1.set_ylabel('Count')
303
+
304
+ if created_fig:
305
+ fig.legend(loc='upper left', bbox_to_anchor=(0.15, 0.85))
306
+
307
+ return ax.figure
308
+
309
+ # ...def plot_calibration_curve(...)