megadetector 5.0.11__py3-none-any.whl → 5.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (203) hide show
  1. megadetector/api/__init__.py +0 -0
  2. megadetector/api/batch_processing/__init__.py +0 -0
  3. megadetector/api/batch_processing/api_core/__init__.py +0 -0
  4. megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
  5. megadetector/api/batch_processing/api_core/batch_service/score.py +439 -0
  6. megadetector/api/batch_processing/api_core/server.py +294 -0
  7. megadetector/api/batch_processing/api_core/server_api_config.py +97 -0
  8. megadetector/api/batch_processing/api_core/server_app_config.py +55 -0
  9. megadetector/api/batch_processing/api_core/server_batch_job_manager.py +220 -0
  10. megadetector/api/batch_processing/api_core/server_job_status_table.py +149 -0
  11. megadetector/api/batch_processing/api_core/server_orchestration.py +360 -0
  12. megadetector/api/batch_processing/api_core/server_utils.py +88 -0
  13. megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
  14. megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +46 -0
  15. megadetector/api/batch_processing/api_support/__init__.py +0 -0
  16. megadetector/api/batch_processing/api_support/summarize_daily_activity.py +152 -0
  17. megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
  18. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  19. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  20. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  21. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  22. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  23. megadetector/api/synchronous/__init__.py +0 -0
  24. megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  25. megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +152 -0
  26. megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +263 -0
  27. megadetector/api/synchronous/api_core/animal_detection_api/config.py +35 -0
  28. megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
  29. megadetector/api/synchronous/api_core/tests/load_test.py +110 -0
  30. megadetector/classification/__init__.py +0 -0
  31. megadetector/classification/aggregate_classifier_probs.py +108 -0
  32. megadetector/classification/analyze_failed_images.py +227 -0
  33. megadetector/classification/cache_batchapi_outputs.py +198 -0
  34. megadetector/classification/create_classification_dataset.py +627 -0
  35. megadetector/classification/crop_detections.py +516 -0
  36. megadetector/classification/csv_to_json.py +226 -0
  37. megadetector/classification/detect_and_crop.py +855 -0
  38. megadetector/classification/efficientnet/__init__.py +9 -0
  39. megadetector/classification/efficientnet/model.py +415 -0
  40. megadetector/classification/efficientnet/utils.py +607 -0
  41. megadetector/classification/evaluate_model.py +520 -0
  42. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  43. megadetector/classification/json_to_azcopy_list.py +63 -0
  44. megadetector/classification/json_validator.py +699 -0
  45. megadetector/classification/map_classification_categories.py +276 -0
  46. megadetector/classification/merge_classification_detection_output.py +506 -0
  47. megadetector/classification/prepare_classification_script.py +194 -0
  48. megadetector/classification/prepare_classification_script_mc.py +228 -0
  49. megadetector/classification/run_classifier.py +287 -0
  50. megadetector/classification/save_mislabeled.py +110 -0
  51. megadetector/classification/train_classifier.py +827 -0
  52. megadetector/classification/train_classifier_tf.py +725 -0
  53. megadetector/classification/train_utils.py +323 -0
  54. megadetector/data_management/__init__.py +0 -0
  55. megadetector/data_management/annotations/__init__.py +0 -0
  56. megadetector/data_management/annotations/annotation_constants.py +34 -0
  57. megadetector/data_management/camtrap_dp_to_coco.py +237 -0
  58. megadetector/data_management/cct_json_utils.py +404 -0
  59. megadetector/data_management/cct_to_md.py +176 -0
  60. megadetector/data_management/cct_to_wi.py +289 -0
  61. megadetector/data_management/coco_to_labelme.py +283 -0
  62. megadetector/data_management/coco_to_yolo.py +662 -0
  63. megadetector/data_management/databases/__init__.py +0 -0
  64. megadetector/data_management/databases/add_width_and_height_to_db.py +33 -0
  65. megadetector/data_management/databases/combine_coco_camera_traps_files.py +206 -0
  66. megadetector/data_management/databases/integrity_check_json_db.py +493 -0
  67. megadetector/data_management/databases/subset_json_db.py +115 -0
  68. megadetector/data_management/generate_crops_from_cct.py +149 -0
  69. megadetector/data_management/get_image_sizes.py +189 -0
  70. megadetector/data_management/importers/add_nacti_sizes.py +52 -0
  71. megadetector/data_management/importers/add_timestamps_to_icct.py +79 -0
  72. megadetector/data_management/importers/animl_results_to_md_results.py +158 -0
  73. megadetector/data_management/importers/auckland_doc_test_to_json.py +373 -0
  74. megadetector/data_management/importers/auckland_doc_to_json.py +201 -0
  75. megadetector/data_management/importers/awc_to_json.py +191 -0
  76. megadetector/data_management/importers/bellevue_to_json.py +273 -0
  77. megadetector/data_management/importers/cacophony-thermal-importer.py +793 -0
  78. megadetector/data_management/importers/carrizo_shrubfree_2018.py +269 -0
  79. megadetector/data_management/importers/carrizo_trail_cam_2017.py +289 -0
  80. megadetector/data_management/importers/cct_field_adjustments.py +58 -0
  81. megadetector/data_management/importers/channel_islands_to_cct.py +913 -0
  82. megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +180 -0
  83. megadetector/data_management/importers/eMammal/eMammal_helpers.py +249 -0
  84. megadetector/data_management/importers/eMammal/make_eMammal_json.py +223 -0
  85. megadetector/data_management/importers/ena24_to_json.py +276 -0
  86. megadetector/data_management/importers/filenames_to_json.py +386 -0
  87. megadetector/data_management/importers/helena_to_cct.py +283 -0
  88. megadetector/data_management/importers/idaho-camera-traps.py +1407 -0
  89. megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +294 -0
  90. megadetector/data_management/importers/jb_csv_to_json.py +150 -0
  91. megadetector/data_management/importers/mcgill_to_json.py +250 -0
  92. megadetector/data_management/importers/missouri_to_json.py +490 -0
  93. megadetector/data_management/importers/nacti_fieldname_adjustments.py +79 -0
  94. megadetector/data_management/importers/noaa_seals_2019.py +181 -0
  95. megadetector/data_management/importers/pc_to_json.py +365 -0
  96. megadetector/data_management/importers/plot_wni_giraffes.py +123 -0
  97. megadetector/data_management/importers/prepare-noaa-fish-data-for-lila.py +359 -0
  98. megadetector/data_management/importers/prepare_zsl_imerit.py +131 -0
  99. megadetector/data_management/importers/rspb_to_json.py +356 -0
  100. megadetector/data_management/importers/save_the_elephants_survey_A.py +320 -0
  101. megadetector/data_management/importers/save_the_elephants_survey_B.py +329 -0
  102. megadetector/data_management/importers/snapshot_safari_importer.py +758 -0
  103. megadetector/data_management/importers/snapshot_safari_importer_reprise.py +665 -0
  104. megadetector/data_management/importers/snapshot_serengeti_lila.py +1067 -0
  105. megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +150 -0
  106. megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +153 -0
  107. megadetector/data_management/importers/sulross_get_exif.py +65 -0
  108. megadetector/data_management/importers/timelapse_csv_set_to_json.py +490 -0
  109. megadetector/data_management/importers/ubc_to_json.py +399 -0
  110. megadetector/data_management/importers/umn_to_json.py +507 -0
  111. megadetector/data_management/importers/wellington_to_json.py +263 -0
  112. megadetector/data_management/importers/wi_to_json.py +442 -0
  113. megadetector/data_management/importers/zamba_results_to_md_results.py +181 -0
  114. megadetector/data_management/labelme_to_coco.py +547 -0
  115. megadetector/data_management/labelme_to_yolo.py +272 -0
  116. megadetector/data_management/lila/__init__.py +0 -0
  117. megadetector/data_management/lila/add_locations_to_island_camera_traps.py +97 -0
  118. megadetector/data_management/lila/add_locations_to_nacti.py +147 -0
  119. megadetector/data_management/lila/create_lila_blank_set.py +558 -0
  120. megadetector/data_management/lila/create_lila_test_set.py +152 -0
  121. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  122. megadetector/data_management/lila/download_lila_subset.py +178 -0
  123. megadetector/data_management/lila/generate_lila_per_image_labels.py +516 -0
  124. megadetector/data_management/lila/get_lila_annotation_counts.py +170 -0
  125. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  126. megadetector/data_management/lila/lila_common.py +300 -0
  127. megadetector/data_management/lila/test_lila_metadata_urls.py +132 -0
  128. megadetector/data_management/ocr_tools.py +870 -0
  129. megadetector/data_management/read_exif.py +809 -0
  130. megadetector/data_management/remap_coco_categories.py +84 -0
  131. megadetector/data_management/remove_exif.py +66 -0
  132. megadetector/data_management/rename_images.py +187 -0
  133. megadetector/data_management/resize_coco_dataset.py +189 -0
  134. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  135. megadetector/data_management/yolo_output_to_md_output.py +446 -0
  136. megadetector/data_management/yolo_to_coco.py +676 -0
  137. megadetector/detection/__init__.py +0 -0
  138. megadetector/detection/detector_training/__init__.py +0 -0
  139. megadetector/detection/detector_training/model_main_tf2.py +114 -0
  140. megadetector/detection/process_video.py +846 -0
  141. megadetector/detection/pytorch_detector.py +355 -0
  142. megadetector/detection/run_detector.py +779 -0
  143. megadetector/detection/run_detector_batch.py +1219 -0
  144. megadetector/detection/run_inference_with_yolov5_val.py +1087 -0
  145. megadetector/detection/run_tiled_inference.py +934 -0
  146. megadetector/detection/tf_detector.py +192 -0
  147. megadetector/detection/video_utils.py +698 -0
  148. megadetector/postprocessing/__init__.py +0 -0
  149. megadetector/postprocessing/add_max_conf.py +64 -0
  150. megadetector/postprocessing/categorize_detections_by_size.py +165 -0
  151. megadetector/postprocessing/classification_postprocessing.py +716 -0
  152. megadetector/postprocessing/combine_api_outputs.py +249 -0
  153. megadetector/postprocessing/compare_batch_results.py +966 -0
  154. megadetector/postprocessing/convert_output_format.py +396 -0
  155. megadetector/postprocessing/load_api_results.py +195 -0
  156. megadetector/postprocessing/md_to_coco.py +310 -0
  157. megadetector/postprocessing/md_to_labelme.py +330 -0
  158. megadetector/postprocessing/merge_detections.py +412 -0
  159. megadetector/postprocessing/postprocess_batch_results.py +1908 -0
  160. megadetector/postprocessing/remap_detection_categories.py +170 -0
  161. megadetector/postprocessing/render_detection_confusion_matrix.py +660 -0
  162. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +211 -0
  163. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +83 -0
  164. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1635 -0
  165. megadetector/postprocessing/separate_detections_into_folders.py +730 -0
  166. megadetector/postprocessing/subset_json_detector_output.py +700 -0
  167. megadetector/postprocessing/top_folders_to_bottom.py +223 -0
  168. megadetector/taxonomy_mapping/__init__.py +0 -0
  169. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  170. megadetector/taxonomy_mapping/map_new_lila_datasets.py +150 -0
  171. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +142 -0
  172. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +588 -0
  173. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  174. megadetector/taxonomy_mapping/simple_image_download.py +219 -0
  175. megadetector/taxonomy_mapping/species_lookup.py +834 -0
  176. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  177. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  178. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  179. megadetector/utils/__init__.py +0 -0
  180. megadetector/utils/azure_utils.py +178 -0
  181. megadetector/utils/ct_utils.py +613 -0
  182. megadetector/utils/directory_listing.py +246 -0
  183. megadetector/utils/md_tests.py +1164 -0
  184. megadetector/utils/path_utils.py +1045 -0
  185. megadetector/utils/process_utils.py +160 -0
  186. megadetector/utils/sas_blob_utils.py +509 -0
  187. megadetector/utils/split_locations_into_train_val.py +228 -0
  188. megadetector/utils/string_utils.py +92 -0
  189. megadetector/utils/url_utils.py +323 -0
  190. megadetector/utils/write_html_image_list.py +225 -0
  191. megadetector/visualization/__init__.py +0 -0
  192. megadetector/visualization/plot_utils.py +293 -0
  193. megadetector/visualization/render_images_with_thumbnails.py +275 -0
  194. megadetector/visualization/visualization_utils.py +1536 -0
  195. megadetector/visualization/visualize_db.py +552 -0
  196. megadetector/visualization/visualize_detector_output.py +405 -0
  197. {megadetector-5.0.11.dist-info → megadetector-5.0.13.dist-info}/LICENSE +0 -0
  198. {megadetector-5.0.11.dist-info → megadetector-5.0.13.dist-info}/METADATA +2 -2
  199. megadetector-5.0.13.dist-info/RECORD +201 -0
  200. megadetector-5.0.13.dist-info/top_level.txt +1 -0
  201. megadetector-5.0.11.dist-info/RECORD +0 -5
  202. megadetector-5.0.11.dist-info/top_level.txt +0 -1
  203. {megadetector-5.0.11.dist-info → megadetector-5.0.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,225 @@
1
+ """
2
+
3
+ write_html_image_list.py
4
+
5
+ Given a list of image file names, writes an HTML file that
6
+ shows all those images, with optional one-line headers above each.
7
+
8
+ Each "filename" can also be a dict with elements 'filename','title',
9
+ 'imageStyle','textStyle', 'linkTarget'
10
+
11
+ """
12
+
13
+ #%% Constants and imports
14
+
15
+ import os
16
+ import math
17
+ import urllib
18
+
19
+ from megadetector.utils import path_utils
20
+
21
+
22
+ #%% write_html_image_list
23
+
24
+ def write_html_image_list(filename=None,images=None,options=None):
25
+ """
26
+ Given a list of image file names, writes an HTML file that shows all those images,
27
+ with optional one-line headers above each.
28
+
29
+ Args:
30
+ filename (str, optional): the .html output file; if None, just returns a valid
31
+ options dict
32
+ images (list, optional): the images to write to the .html file; if None, just returns
33
+ a valid options dict. This can be a flat list of image filenames, or this can
34
+ be a list of dictionaries with one or more of the following fields:
35
+
36
+ - filename (image filename) (required, all other fields are optional)
37
+ - imageStyle (css style for this image)
38
+ - textStyle (css style for the title associated with this image)
39
+ - title (text label for this image)
40
+ - linkTarget (URL to which this image should link on click)
41
+
42
+ options (dict, optional): a dict with one or more of the following fields:
43
+
44
+ - fHtml (file pointer to write to, used for splitting write operations over multiple calls)
45
+ - headerHtml (html text to include before the image list)
46
+ - trailerHtml (html text to include after the image list)
47
+ - defaultImageStyle (default css style for images)
48
+ - defaultTextStyle (default css style for image titles)
49
+ - maxFiguresPerHtmlFile (max figures for a single HTML file; overflow will be handled by creating
50
+ multiple files and a TOC with links)
51
+ - urlEncodeFilenames (default True, e.g. '#' will be replaced by '%23')
52
+ - urlEncodeLinkTargets (default True, e.g. '#' will be replaced by '%23')
53
+
54
+ """
55
+
56
+ # returns an options struct
57
+ if options is None:
58
+ options = {}
59
+
60
+ if 'fHtml' not in options:
61
+ options['fHtml'] = -1
62
+
63
+ if 'headerHtml' not in options or options['headerHtml'] is None:
64
+ options['headerHtml'] = ''
65
+
66
+ if 'trailerHtml' not in options or options['trailerHtml'] is None:
67
+ options['trailerHtml'] = ''
68
+
69
+ if 'defaultTextStyle' not in options or options['defaultTextStyle'] is None:
70
+ options['defaultTextStyle'] = \
71
+ "font-family:calibri,verdana,arial;font-weight:bold;font-size:150%;text-align:left;margin:0px;"
72
+
73
+ if 'defaultImageStyle' not in options or options['defaultImageStyle'] is None:
74
+ options['defaultImageStyle'] = \
75
+ "margin:0px;margin-top:5px;margin-bottom:5px;"
76
+
77
+ if 'urlEncodeFilenames' not in options or options['urlEncodeFilenames'] is None:
78
+ options['urlEncodeFilenames'] = True
79
+
80
+ if 'urlEncodeLinkTargets' not in options or options['urlEncodeLinkTargets'] is None:
81
+ options['urlEncodeLinkTargets'] = True
82
+
83
+ # Possibly split the html output for figures into multiple files; Chrome gets sad with
84
+ # thousands of images in a single tab.
85
+ if 'maxFiguresPerHtmlFile' not in options or options['maxFiguresPerHtmlFile'] is None:
86
+ options['maxFiguresPerHtmlFile'] = math.inf
87
+
88
+ if filename is None or images is None:
89
+ return options
90
+
91
+ # images may be a list of images or a list of image/style/title dictionaries,
92
+ # enforce that it's the latter to simplify downstream code
93
+ for iImage,imageInfo in enumerate(images):
94
+ if isinstance(imageInfo,str):
95
+ imageInfo = {'filename':imageInfo}
96
+ if 'filename' not in imageInfo:
97
+ imageInfo['filename'] = ''
98
+ if 'imageStyle' not in imageInfo:
99
+ imageInfo['imageStyle'] = options['defaultImageStyle']
100
+ if 'title' not in imageInfo:
101
+ imageInfo['title'] = ''
102
+ if 'linkTarget' not in imageInfo:
103
+ imageInfo['linkTarget'] = ''
104
+ if 'textStyle' not in imageInfo:
105
+ textStyle = options['defaultTextStyle']
106
+ imageInfo['textStyle'] = options['defaultTextStyle']
107
+ images[iImage] = imageInfo
108
+
109
+ nImages = len(images)
110
+
111
+ # If we need to break this up into multiple files...
112
+ if nImages > options['maxFiguresPerHtmlFile']:
113
+
114
+ # You can't supply your own file handle in this case
115
+ if options['fHtml'] != -1:
116
+ raise ValueError(
117
+ 'You can''t supply your own file handle if we have to page the image set')
118
+
119
+ figureFileStartingIndices = list(range(0,nImages,options['maxFiguresPerHtmlFile']))
120
+
121
+ assert len(figureFileStartingIndices) > 1
122
+
123
+ # Open the meta-output file
124
+ fMeta = open(filename,'w')
125
+
126
+ # Write header stuff
127
+ fMeta.write('<html><body>\n')
128
+ fMeta.write(options['headerHtml'])
129
+ fMeta.write('<table border = 0 cellpadding = 2>\n')
130
+
131
+ for startingIndex in figureFileStartingIndices:
132
+
133
+ iStart = startingIndex
134
+ iEnd = startingIndex+options['maxFiguresPerHtmlFile']-1;
135
+ if iEnd >= nImages:
136
+ iEnd = nImages-1
137
+
138
+ trailer = 'image_{:05d}_{:05d}'.format(iStart,iEnd)
139
+ localFiguresHtmlFilename = path_utils.insert_before_extension(filename,trailer)
140
+ fMeta.write('<tr><td>\n')
141
+ fMeta.write('<p style="padding-bottom:0px;margin-bottom:0px;text-align:left;font-family:''segoe ui'',calibri,arial;font-size:100%;text-decoration:none;font-weight:bold;">')
142
+ fMeta.write('<a href="{}">Figures for images {} through {}</a></p></td></tr>\n'.format(
143
+ os.path.basename(localFiguresHtmlFilename),iStart,iEnd))
144
+
145
+ localImages = images[iStart:iEnd+1]
146
+
147
+ localOptions = options.copy();
148
+ localOptions['headerHtml'] = '';
149
+ localOptions['trailerHtml'] = '';
150
+
151
+ # Make a recursive call for this image set
152
+ write_html_image_list(localFiguresHtmlFilename,localImages,localOptions)
153
+
154
+ # ...for each page of images
155
+
156
+ fMeta.write('</table></body>\n')
157
+ fMeta.write(options['trailerHtml'])
158
+ fMeta.write('</html>\n')
159
+ fMeta.close()
160
+
161
+ return options
162
+
163
+ # ...if we have to make multiple sub-pages
164
+
165
+ bCleanupFile = False
166
+
167
+ if options['fHtml'] == -1:
168
+ bCleanupFile = True;
169
+ fHtml = open(filename,'w')
170
+ else:
171
+ fHtml = options['fHtml']
172
+
173
+ fHtml.write('<html><body>\n')
174
+
175
+ fHtml.write(options['headerHtml'])
176
+
177
+ # Write out images
178
+ for iImage,image in enumerate(images):
179
+
180
+ title = image['title']
181
+ imageStyle = image['imageStyle']
182
+ textStyle = image['textStyle']
183
+ filename = image['filename']
184
+ linkTarget = image['linkTarget']
185
+
186
+ # Remove unicode characters
187
+ title = title.encode('ascii','ignore').decode('ascii')
188
+ filename = filename.encode('ascii','ignore').decode('ascii')
189
+
190
+ filename = filename.replace('\\','/')
191
+ if options['urlEncodeFilenames']:
192
+ filename = urllib.parse.quote(filename)
193
+
194
+ if len(title) > 0:
195
+ fHtml.write(
196
+ '<p style="{}">{}</p>\n'\
197
+ .format(textStyle,title))
198
+
199
+ linkTarget = linkTarget.replace('\\','/')
200
+ if options['urlEncodeLinkTargets']:
201
+ # These are typically absolute paths, so we only want to mess with certain characters
202
+ linkTarget = urllib.parse.quote(linkTarget,safe=':/')
203
+
204
+ if len(linkTarget) > 0:
205
+ fHtml.write('<a href="{}">'.format(linkTarget))
206
+ # imageStyle.append(';border:0px;')
207
+
208
+ fHtml.write('<img src="{}" style="{}">\n'.format(filename,imageStyle))
209
+
210
+ if len(linkTarget) > 0:
211
+ fHtml.write('</a>')
212
+
213
+ if iImage != len(images)-1:
214
+ fHtml.write('<br/>')
215
+
216
+ # ...for each image we need to write
217
+
218
+ fHtml.write(options['trailerHtml'])
219
+
220
+ fHtml.write('</body></html>\n')
221
+
222
+ if bCleanupFile:
223
+ fHtml.close()
224
+
225
+ # ...function
File without changes
@@ -0,0 +1,293 @@
1
+ """
2
+
3
+ plot_utils.py
4
+
5
+ Utility functions for plotting, particularly for plotting confusion matrices
6
+ and precision-recall curves.
7
+
8
+ """
9
+
10
+ #%% Imports
11
+
12
+ import numpy as np
13
+
14
+ # This also imports mpl.{cm, axes, colors}
15
+ import matplotlib.figure
16
+
17
+
18
+ #%% Plotting functions
19
+
20
+ def plot_confusion_matrix(matrix,
21
+ classes,
22
+ normalize=False,
23
+ title='Confusion matrix',
24
+ cmap=matplotlib.cm.Blues,
25
+ vmax=None,
26
+ use_colorbar=True,
27
+ y_label=True,
28
+ fmt= '{:.0f}',
29
+ fig=None):
30
+ """
31
+ Plots a confusion matrix.
32
+
33
+ Args:
34
+ matrix (np.ndarray): shape [num_classes, num_classes], confusion matrix
35
+ where rows are ground-truth classes and columns are predicted classes
36
+ classes (list of str): class names for each row/column
37
+ normalize (bool, optional): whether to perform row-wise normalization;
38
+ by default, assumes values in the confusion matrix are percentages
39
+ title (str, optional): figure title
40
+ cmap (matplotlib.colors.colormap): colormap for cell backgrounds
41
+ vmax (float, optional), value corresponding to the largest value of the colormap;
42
+ if None, the maximum value in [matrix] will be used
43
+ use_colorbar (bool, optional): whether to show colorbar
44
+ y_label (bool, optional): whether to show class names on the y axis
45
+ fmt (str): format string for rendering numeric values
46
+ fig (Figure): existing figure to which we should render, otherwise creates
47
+ a new figure
48
+
49
+ Returns:
50
+ matplotlib.figure.Figure: the figure we rendered to or created
51
+ """
52
+
53
+ num_classes = matrix.shape[0]
54
+ assert matrix.shape[1] == num_classes
55
+ assert len(classes) == num_classes
56
+
57
+ normalized_matrix = matrix.astype(np.float64) / (
58
+ matrix.sum(axis=1, keepdims=True) + 1e-7)
59
+ if normalize:
60
+ matrix = normalized_matrix
61
+
62
+ fig_h = 3 + 0.3 * num_classes
63
+ fig_w = fig_h
64
+ if use_colorbar:
65
+ fig_w += 0.5
66
+
67
+ if fig is None:
68
+ fig = matplotlib.figure.Figure(figsize=(fig_w, fig_h), tight_layout=True)
69
+ ax = fig.subplots(1, 1)
70
+ im = ax.imshow(normalized_matrix, interpolation='nearest', cmap=cmap, vmax=vmax)
71
+ ax.set_title(title)
72
+
73
+ if use_colorbar:
74
+ cbar = fig.colorbar(im, fraction=0.046, pad=0.04,
75
+ ticks=[0.0, 0.25, 0.5, 0.75, 1.0])
76
+ cbar.set_ticklabels(['0%', '25%', '50%', '75%', '100%'])
77
+
78
+ tick_marks = np.arange(num_classes)
79
+ ax.set_xticks(tick_marks)
80
+ ax.set_yticks(tick_marks)
81
+ ax.set_xticklabels(classes, rotation=90)
82
+ ax.set_xlabel('Predicted class')
83
+
84
+ if y_label:
85
+ ax.set_yticklabels(classes)
86
+ ax.set_ylabel('Ground-truth class')
87
+
88
+ for i, j in np.ndindex(matrix.shape):
89
+ v = matrix[i, j]
90
+ ax.text(j, i, fmt.format(v),
91
+ horizontalalignment='center',
92
+ verticalalignment='center',
93
+ color='white' if normalized_matrix[i, j] > 0.5 else 'black')
94
+
95
+ return fig
96
+
97
+ # ...def plot_confusion_matrix(...)
98
+
99
+
100
+ def plot_precision_recall_curve(precisions,
101
+ recalls,
102
+ title='Precision/recall curve',
103
+ xlim=(0.0,1.05),
104
+ ylim=(0.0,1.05)):
105
+ """
106
+ Plots a precision/recall curve given lists of (ordered) precision
107
+ and recall values.
108
+
109
+ Args:
110
+ precisions (list of float): precision for corresponding recall values,
111
+ should have same length as [recalls].
112
+ recalls (list of float): recall for corresponding precision values,
113
+ should have same length as [precisions].
114
+ title (str, optional): plot title
115
+ xlim (tuple, optional): x-axis limits as a length-2 tuple
116
+ ylim (tuple, optional): y-axis limits as a length-2 tuple
117
+
118
+ Returns:
119
+ matplotlib.figure.Figure: the (new) figure
120
+ """
121
+
122
+ assert len(precisions) == len(recalls)
123
+
124
+ fig = matplotlib.figure.Figure(tight_layout=True)
125
+ ax = fig.subplots(1, 1)
126
+ ax.step(recalls, precisions, color='b', alpha=0.2, where='post')
127
+ ax.fill_between(recalls, precisions, alpha=0.2, color='b', step='post')
128
+
129
+ try:
130
+ ax.set(x_label='Recall', y_label='Precision', title=title)
131
+ ax.set(x_lim=xlim, y_lim=ylim)
132
+ #
133
+ except Exception:
134
+ ax.set_xlabel('Recall')
135
+ ax.set_ylabel('Precision')
136
+ ax.set_title(title)
137
+ ax.set_xlim(xlim[0],xlim[1])
138
+ ax.set_ylim(ylim[0],ylim[1])
139
+
140
+ return fig
141
+
142
+
143
+ def plot_stacked_bar_chart(data, series_labels=None, col_labels=None,
144
+ x_label=None, y_label=None, log_scale=False):
145
+ """
146
+ Plot a stacked bar chart, for plotting e.g. species distribution across locations.
147
+
148
+ Reference: https://stackoverflow.com/q/44309507
149
+
150
+ Args:
151
+ data (np.ndarray or list of list): data to plot; rows (series) are species, columns
152
+ are locations
153
+ series_labels (list of str, optional): series labels, typically species names
154
+ col_labels (list of str, optional): column labels, typically location names
155
+ x_label (str, optional): x-axis label
156
+ y_label (str, optional): y-axis label
157
+ log_scale (bool, optional) whether to plot the y axis in log-scale
158
+
159
+ Returns:
160
+ matplotlib.figure.Figure: the (new) figure
161
+ """
162
+
163
+ data = np.asarray(data)
164
+ num_series, num_columns = data.shape
165
+ ind = np.arange(num_columns)
166
+
167
+ fig = matplotlib.figure.Figure(tight_layout=True)
168
+ ax = fig.subplots(1, 1)
169
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, num_series))
170
+
171
+ # stacked bar charts are made with each segment starting from a y position
172
+ cumulative_size = np.zeros(num_columns)
173
+ for i, row_data in enumerate(data):
174
+ ax.bar(ind, row_data, bottom=cumulative_size, label=series_labels[i],
175
+ color=colors[i])
176
+ cumulative_size += row_data
177
+
178
+ if col_labels and len(col_labels) < 25:
179
+ ax.set_xticks(ind)
180
+ ax.set_xticklabels(col_labels, rotation=90)
181
+ elif col_labels:
182
+ ax.set_xticks(list(range(0, len(col_labels), 20)))
183
+ ax.set_xticklabels(col_labels, rotation=90)
184
+
185
+ if x_label is not None:
186
+ ax.set_xlabel(x_label)
187
+ if y_label is not None:
188
+ ax.set_ylabel(y_label)
189
+ if log_scale:
190
+ ax.set_yscale('log')
191
+
192
+ # To fit the legend in, shrink current axis by 20%
193
+ box = ax.get_position()
194
+ ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
195
+
196
+ # Put a legend to the right of the current axis
197
+ ax.legend(loc='center left', bbox_to_anchor=(0.99, 0.5), frameon=False)
198
+
199
+ return fig
200
+
201
+
202
+ def calibration_ece(true_scores, pred_scores, num_bins):
203
+ r"""
204
+ Expected calibration error (ECE) as defined in equation (3) of
205
+ Guo et al. "On Calibration of Modern Neural Networks." (2017).
206
+
207
+ Implementation modified from sklearn.calibration.calibration_curve()
208
+ in order to implement ECE calculation. See:
209
+
210
+ https://github.com/scikit-learn/scikit-learn/issues/18268
211
+
212
+ Args:
213
+ true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
214
+ pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
215
+ predicted confidence that example i is positive
216
+ num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
217
+
218
+ Returns:
219
+ tuple: a length-three tuple containing:
220
+ - accs: np.ndarray, shape [M], type float64, accuracy in each bin,
221
+ M <= num_bins because bins with no samples are not returned
222
+ - confs: np.ndarray, shape [M], type float64, mean model confidence in
223
+ each bin
224
+ - ece: float, expected calibration error
225
+ """
226
+
227
+ assert len(true_scores) == len(pred_scores)
228
+
229
+ bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
230
+ binids = np.digitize(pred_scores, bins) - 1
231
+
232
+ bin_sums = np.bincount(binids, weights=pred_scores, minlength=len(bins))
233
+ bin_true = np.bincount(binids, weights=true_scores, minlength=len(bins))
234
+ bin_total = np.bincount(binids, minlength=len(bins))
235
+
236
+ nonzero = bin_total != 0
237
+ accs = bin_true[nonzero] / bin_total[nonzero]
238
+ confs = bin_sums[nonzero] / bin_total[nonzero]
239
+
240
+ weights = bin_total[nonzero] / len(true_scores)
241
+ ece = np.abs(accs - confs) @ weights
242
+ return accs, confs, ece
243
+
244
+
245
+ def plot_calibration_curve(true_scores, pred_scores, num_bins,
246
+ name='calibration', plot_perf=True, plot_hist=True,
247
+ ax=None, **fig_kwargs):
248
+ """
249
+ Plots a calibration curve.
250
+
251
+ Args:
252
+ true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
253
+ pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
254
+ predicted confidence that example i is positive
255
+ num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
256
+ name (str, optional): label in legend for the calibration curve
257
+ plot_perf (bool, optional): whether to plot y=x line indicating perfect calibration
258
+ plot_hist (bool, optional): whether to plot histogram of counts
259
+ ax (Axes, optional): if given then no legend is drawn, and fig_kwargs are ignored
260
+ fig_kwargs (dict, optional): only used if [ax] is None
261
+
262
+ Returns:
263
+ matplotlib.figure.Figure: the (new) figure
264
+ """
265
+
266
+ accs, confs, ece = calibration_ece(true_scores, pred_scores, num_bins)
267
+
268
+ created_fig = False
269
+ if ax is None:
270
+ created_fig = True
271
+ fig = matplotlib.figure.Figure(**fig_kwargs)
272
+ ax = fig.subplots(1, 1)
273
+ ax.plot(confs, accs, 's-', label=name) # 's-': squares on line
274
+ ax.set(xlabel='Model confidence', ylabel='Actual accuracy',
275
+ title=f'Calibration plot (ECE: {ece:.02g})')
276
+ ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05])
277
+ if plot_perf:
278
+ ax.plot([0, 1], [0, 1], color='black', label='perfect calibration')
279
+ ax.grid(True)
280
+
281
+ if plot_hist:
282
+ ax1 = ax.twinx()
283
+ bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
284
+ counts = ax1.hist(pred_scores, alpha=0.5, label='histogram of examples',
285
+ bins=bins, color='tab:red')[0]
286
+ max_count = np.max(counts)
287
+ ax1.set_ylim([-0.05 * max_count, 1.05 * max_count])
288
+ ax1.set_ylabel('Count')
289
+
290
+ if created_fig:
291
+ fig.legend(loc='upper left', bbox_to_anchor=(0.15, 0.85))
292
+
293
+ return ax.figure