megadetector 5.0.10__py3-none-any.whl → 5.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (226) hide show
  1. {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/LICENSE +0 -0
  2. {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/METADATA +12 -11
  3. megadetector-5.0.11.dist-info/RECORD +5 -0
  4. megadetector-5.0.11.dist-info/top_level.txt +1 -0
  5. api/__init__.py +0 -0
  6. api/batch_processing/__init__.py +0 -0
  7. api/batch_processing/api_core/__init__.py +0 -0
  8. api/batch_processing/api_core/batch_service/__init__.py +0 -0
  9. api/batch_processing/api_core/batch_service/score.py +0 -439
  10. api/batch_processing/api_core/server.py +0 -294
  11. api/batch_processing/api_core/server_api_config.py +0 -98
  12. api/batch_processing/api_core/server_app_config.py +0 -55
  13. api/batch_processing/api_core/server_batch_job_manager.py +0 -220
  14. api/batch_processing/api_core/server_job_status_table.py +0 -152
  15. api/batch_processing/api_core/server_orchestration.py +0 -360
  16. api/batch_processing/api_core/server_utils.py +0 -92
  17. api/batch_processing/api_core_support/__init__.py +0 -0
  18. api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
  19. api/batch_processing/api_support/__init__.py +0 -0
  20. api/batch_processing/api_support/summarize_daily_activity.py +0 -152
  21. api/batch_processing/data_preparation/__init__.py +0 -0
  22. api/batch_processing/data_preparation/manage_local_batch.py +0 -2391
  23. api/batch_processing/data_preparation/manage_video_batch.py +0 -327
  24. api/batch_processing/integration/digiKam/setup.py +0 -6
  25. api/batch_processing/integration/digiKam/xmp_integration.py +0 -465
  26. api/batch_processing/integration/eMammal/test_scripts/config_template.py +0 -5
  27. api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +0 -126
  28. api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +0 -55
  29. api/batch_processing/postprocessing/__init__.py +0 -0
  30. api/batch_processing/postprocessing/add_max_conf.py +0 -64
  31. api/batch_processing/postprocessing/categorize_detections_by_size.py +0 -163
  32. api/batch_processing/postprocessing/combine_api_outputs.py +0 -249
  33. api/batch_processing/postprocessing/compare_batch_results.py +0 -958
  34. api/batch_processing/postprocessing/convert_output_format.py +0 -397
  35. api/batch_processing/postprocessing/load_api_results.py +0 -195
  36. api/batch_processing/postprocessing/md_to_coco.py +0 -310
  37. api/batch_processing/postprocessing/md_to_labelme.py +0 -330
  38. api/batch_processing/postprocessing/merge_detections.py +0 -401
  39. api/batch_processing/postprocessing/postprocess_batch_results.py +0 -1904
  40. api/batch_processing/postprocessing/remap_detection_categories.py +0 -170
  41. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +0 -661
  42. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +0 -211
  43. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +0 -82
  44. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +0 -1631
  45. api/batch_processing/postprocessing/separate_detections_into_folders.py +0 -731
  46. api/batch_processing/postprocessing/subset_json_detector_output.py +0 -696
  47. api/batch_processing/postprocessing/top_folders_to_bottom.py +0 -223
  48. api/synchronous/__init__.py +0 -0
  49. api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  50. api/synchronous/api_core/animal_detection_api/api_backend.py +0 -152
  51. api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -266
  52. api/synchronous/api_core/animal_detection_api/config.py +0 -35
  53. api/synchronous/api_core/animal_detection_api/data_management/annotations/annotation_constants.py +0 -47
  54. api/synchronous/api_core/animal_detection_api/detection/detector_training/copy_checkpoints.py +0 -43
  55. api/synchronous/api_core/animal_detection_api/detection/detector_training/model_main_tf2.py +0 -114
  56. api/synchronous/api_core/animal_detection_api/detection/process_video.py +0 -543
  57. api/synchronous/api_core/animal_detection_api/detection/pytorch_detector.py +0 -304
  58. api/synchronous/api_core/animal_detection_api/detection/run_detector.py +0 -627
  59. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +0 -1029
  60. api/synchronous/api_core/animal_detection_api/detection/run_inference_with_yolov5_val.py +0 -581
  61. api/synchronous/api_core/animal_detection_api/detection/run_tiled_inference.py +0 -754
  62. api/synchronous/api_core/animal_detection_api/detection/tf_detector.py +0 -165
  63. api/synchronous/api_core/animal_detection_api/detection/video_utils.py +0 -495
  64. api/synchronous/api_core/animal_detection_api/md_utils/azure_utils.py +0 -174
  65. api/synchronous/api_core/animal_detection_api/md_utils/ct_utils.py +0 -262
  66. api/synchronous/api_core/animal_detection_api/md_utils/directory_listing.py +0 -251
  67. api/synchronous/api_core/animal_detection_api/md_utils/matlab_porting_tools.py +0 -97
  68. api/synchronous/api_core/animal_detection_api/md_utils/path_utils.py +0 -416
  69. api/synchronous/api_core/animal_detection_api/md_utils/process_utils.py +0 -110
  70. api/synchronous/api_core/animal_detection_api/md_utils/sas_blob_utils.py +0 -509
  71. api/synchronous/api_core/animal_detection_api/md_utils/string_utils.py +0 -59
  72. api/synchronous/api_core/animal_detection_api/md_utils/url_utils.py +0 -144
  73. api/synchronous/api_core/animal_detection_api/md_utils/write_html_image_list.py +0 -226
  74. api/synchronous/api_core/animal_detection_api/md_visualization/visualization_utils.py +0 -841
  75. api/synchronous/api_core/tests/__init__.py +0 -0
  76. api/synchronous/api_core/tests/load_test.py +0 -110
  77. classification/__init__.py +0 -0
  78. classification/aggregate_classifier_probs.py +0 -108
  79. classification/analyze_failed_images.py +0 -227
  80. classification/cache_batchapi_outputs.py +0 -198
  81. classification/create_classification_dataset.py +0 -627
  82. classification/crop_detections.py +0 -516
  83. classification/csv_to_json.py +0 -226
  84. classification/detect_and_crop.py +0 -855
  85. classification/efficientnet/__init__.py +0 -9
  86. classification/efficientnet/model.py +0 -415
  87. classification/efficientnet/utils.py +0 -610
  88. classification/evaluate_model.py +0 -520
  89. classification/identify_mislabeled_candidates.py +0 -152
  90. classification/json_to_azcopy_list.py +0 -63
  91. classification/json_validator.py +0 -695
  92. classification/map_classification_categories.py +0 -276
  93. classification/merge_classification_detection_output.py +0 -506
  94. classification/prepare_classification_script.py +0 -194
  95. classification/prepare_classification_script_mc.py +0 -228
  96. classification/run_classifier.py +0 -286
  97. classification/save_mislabeled.py +0 -110
  98. classification/train_classifier.py +0 -825
  99. classification/train_classifier_tf.py +0 -724
  100. classification/train_utils.py +0 -322
  101. data_management/__init__.py +0 -0
  102. data_management/annotations/__init__.py +0 -0
  103. data_management/annotations/annotation_constants.py +0 -34
  104. data_management/camtrap_dp_to_coco.py +0 -238
  105. data_management/cct_json_utils.py +0 -395
  106. data_management/cct_to_md.py +0 -176
  107. data_management/cct_to_wi.py +0 -289
  108. data_management/coco_to_labelme.py +0 -272
  109. data_management/coco_to_yolo.py +0 -662
  110. data_management/databases/__init__.py +0 -0
  111. data_management/databases/add_width_and_height_to_db.py +0 -33
  112. data_management/databases/combine_coco_camera_traps_files.py +0 -206
  113. data_management/databases/integrity_check_json_db.py +0 -477
  114. data_management/databases/subset_json_db.py +0 -115
  115. data_management/generate_crops_from_cct.py +0 -149
  116. data_management/get_image_sizes.py +0 -188
  117. data_management/importers/add_nacti_sizes.py +0 -52
  118. data_management/importers/add_timestamps_to_icct.py +0 -79
  119. data_management/importers/animl_results_to_md_results.py +0 -158
  120. data_management/importers/auckland_doc_test_to_json.py +0 -372
  121. data_management/importers/auckland_doc_to_json.py +0 -200
  122. data_management/importers/awc_to_json.py +0 -189
  123. data_management/importers/bellevue_to_json.py +0 -273
  124. data_management/importers/cacophony-thermal-importer.py +0 -796
  125. data_management/importers/carrizo_shrubfree_2018.py +0 -268
  126. data_management/importers/carrizo_trail_cam_2017.py +0 -287
  127. data_management/importers/cct_field_adjustments.py +0 -57
  128. data_management/importers/channel_islands_to_cct.py +0 -913
  129. data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
  130. data_management/importers/eMammal/eMammal_helpers.py +0 -249
  131. data_management/importers/eMammal/make_eMammal_json.py +0 -223
  132. data_management/importers/ena24_to_json.py +0 -275
  133. data_management/importers/filenames_to_json.py +0 -385
  134. data_management/importers/helena_to_cct.py +0 -282
  135. data_management/importers/idaho-camera-traps.py +0 -1407
  136. data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
  137. data_management/importers/jb_csv_to_json.py +0 -150
  138. data_management/importers/mcgill_to_json.py +0 -250
  139. data_management/importers/missouri_to_json.py +0 -489
  140. data_management/importers/nacti_fieldname_adjustments.py +0 -79
  141. data_management/importers/noaa_seals_2019.py +0 -181
  142. data_management/importers/pc_to_json.py +0 -365
  143. data_management/importers/plot_wni_giraffes.py +0 -123
  144. data_management/importers/prepare-noaa-fish-data-for-lila.py +0 -359
  145. data_management/importers/prepare_zsl_imerit.py +0 -131
  146. data_management/importers/rspb_to_json.py +0 -356
  147. data_management/importers/save_the_elephants_survey_A.py +0 -320
  148. data_management/importers/save_the_elephants_survey_B.py +0 -332
  149. data_management/importers/snapshot_safari_importer.py +0 -758
  150. data_management/importers/snapshot_safari_importer_reprise.py +0 -665
  151. data_management/importers/snapshot_serengeti_lila.py +0 -1067
  152. data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
  153. data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
  154. data_management/importers/sulross_get_exif.py +0 -65
  155. data_management/importers/timelapse_csv_set_to_json.py +0 -490
  156. data_management/importers/ubc_to_json.py +0 -399
  157. data_management/importers/umn_to_json.py +0 -507
  158. data_management/importers/wellington_to_json.py +0 -263
  159. data_management/importers/wi_to_json.py +0 -441
  160. data_management/importers/zamba_results_to_md_results.py +0 -181
  161. data_management/labelme_to_coco.py +0 -548
  162. data_management/labelme_to_yolo.py +0 -272
  163. data_management/lila/__init__.py +0 -0
  164. data_management/lila/add_locations_to_island_camera_traps.py +0 -97
  165. data_management/lila/add_locations_to_nacti.py +0 -147
  166. data_management/lila/create_lila_blank_set.py +0 -557
  167. data_management/lila/create_lila_test_set.py +0 -151
  168. data_management/lila/create_links_to_md_results_files.py +0 -106
  169. data_management/lila/download_lila_subset.py +0 -177
  170. data_management/lila/generate_lila_per_image_labels.py +0 -515
  171. data_management/lila/get_lila_annotation_counts.py +0 -170
  172. data_management/lila/get_lila_image_counts.py +0 -111
  173. data_management/lila/lila_common.py +0 -300
  174. data_management/lila/test_lila_metadata_urls.py +0 -132
  175. data_management/ocr_tools.py +0 -874
  176. data_management/read_exif.py +0 -681
  177. data_management/remap_coco_categories.py +0 -84
  178. data_management/remove_exif.py +0 -66
  179. data_management/resize_coco_dataset.py +0 -189
  180. data_management/wi_download_csv_to_coco.py +0 -246
  181. data_management/yolo_output_to_md_output.py +0 -441
  182. data_management/yolo_to_coco.py +0 -676
  183. detection/__init__.py +0 -0
  184. detection/detector_training/__init__.py +0 -0
  185. detection/detector_training/model_main_tf2.py +0 -114
  186. detection/process_video.py +0 -703
  187. detection/pytorch_detector.py +0 -337
  188. detection/run_detector.py +0 -779
  189. detection/run_detector_batch.py +0 -1219
  190. detection/run_inference_with_yolov5_val.py +0 -917
  191. detection/run_tiled_inference.py +0 -935
  192. detection/tf_detector.py +0 -188
  193. detection/video_utils.py +0 -606
  194. docs/source/conf.py +0 -43
  195. md_utils/__init__.py +0 -0
  196. md_utils/azure_utils.py +0 -174
  197. md_utils/ct_utils.py +0 -612
  198. md_utils/directory_listing.py +0 -246
  199. md_utils/md_tests.py +0 -968
  200. md_utils/path_utils.py +0 -1044
  201. md_utils/process_utils.py +0 -157
  202. md_utils/sas_blob_utils.py +0 -509
  203. md_utils/split_locations_into_train_val.py +0 -228
  204. md_utils/string_utils.py +0 -92
  205. md_utils/url_utils.py +0 -323
  206. md_utils/write_html_image_list.py +0 -225
  207. md_visualization/__init__.py +0 -0
  208. md_visualization/plot_utils.py +0 -293
  209. md_visualization/render_images_with_thumbnails.py +0 -275
  210. md_visualization/visualization_utils.py +0 -1537
  211. md_visualization/visualize_db.py +0 -551
  212. md_visualization/visualize_detector_output.py +0 -406
  213. megadetector-5.0.10.dist-info/RECORD +0 -224
  214. megadetector-5.0.10.dist-info/top_level.txt +0 -8
  215. taxonomy_mapping/__init__.py +0 -0
  216. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +0 -491
  217. taxonomy_mapping/map_new_lila_datasets.py +0 -154
  218. taxonomy_mapping/prepare_lila_taxonomy_release.py +0 -142
  219. taxonomy_mapping/preview_lila_taxonomy.py +0 -591
  220. taxonomy_mapping/retrieve_sample_image.py +0 -71
  221. taxonomy_mapping/simple_image_download.py +0 -218
  222. taxonomy_mapping/species_lookup.py +0 -834
  223. taxonomy_mapping/taxonomy_csv_checker.py +0 -159
  224. taxonomy_mapping/taxonomy_graph.py +0 -346
  225. taxonomy_mapping/validate_lila_category_mappings.py +0 -83
  226. {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/WHEEL +0 -0
@@ -1,225 +0,0 @@
1
- """
2
-
3
- write_html_image_list.py
4
-
5
- Given a list of image file names, writes an HTML file that
6
- shows all those images, with optional one-line headers above each.
7
-
8
- Each "filename" can also be a dict with elements 'filename','title',
9
- 'imageStyle','textStyle', 'linkTarget'
10
-
11
- """
12
-
13
- #%% Constants and imports
14
-
15
- import os
16
- import math
17
- import urllib
18
-
19
- from md_utils import path_utils
20
-
21
-
22
- #%% write_html_image_list
23
-
24
- def write_html_image_list(filename=None,images=None,options=None):
25
- """
26
- Given a list of image file names, writes an HTML file that shows all those images,
27
- with optional one-line headers above each.
28
-
29
- Args:
30
- filename (str, optional): the .html output file; if None, just returns a valid
31
- options dict
32
- images (list, optional): the images to write to the .html file; if None, just returns
33
- a valid options dict. This can be a flat list of image filenames, or this can
34
- be a list of dictionaries with one or more of the following fields:
35
-
36
- - filename (image filename) (required, all other fields are optional)
37
- - imageStyle (css style for this image)
38
- - textStyle (css style for the title associated with this image)
39
- - title (text label for this image)
40
- - linkTarget (URL to which this image should link on click)
41
-
42
- options (dict, optional): a dict with one or more of the following fields:
43
-
44
- - fHtml (file pointer to write to, used for splitting write operations over multiple calls)
45
- - headerHtml (html text to include before the image list)
46
- - trailerHtml (html text to include after the image list)
47
- - defaultImageStyle (default css style for images)
48
- - defaultTextStyle (default css style for image titles)
49
- - maxFiguresPerHtmlFile (max figures for a single HTML file; overflow will be handled by creating
50
- multiple files and a TOC with links)
51
- - urlEncodeFilenames (default True, e.g. '#' will be replaced by '%23')
52
- - urlEncodeLinkTargets (default True, e.g. '#' will be replaced by '%23')
53
-
54
- """
55
-
56
- # returns an options struct
57
- if options is None:
58
- options = {}
59
-
60
- if 'fHtml' not in options:
61
- options['fHtml'] = -1
62
-
63
- if 'headerHtml' not in options or options['headerHtml'] is None:
64
- options['headerHtml'] = ''
65
-
66
- if 'trailerHtml' not in options or options['trailerHtml'] is None:
67
- options['trailerHtml'] = ''
68
-
69
- if 'defaultTextStyle' not in options or options['defaultTextStyle'] is None:
70
- options['defaultTextStyle'] = \
71
- "font-family:calibri,verdana,arial;font-weight:bold;font-size:150%;text-align:left;margin:0px;"
72
-
73
- if 'defaultImageStyle' not in options or options['defaultImageStyle'] is None:
74
- options['defaultImageStyle'] = \
75
- "margin:0px;margin-top:5px;margin-bottom:5px;"
76
-
77
- if 'urlEncodeFilenames' not in options or options['urlEncodeFilenames'] is None:
78
- options['urlEncodeFilenames'] = True
79
-
80
- if 'urlEncodeLinkTargets' not in options or options['urlEncodeLinkTargets'] is None:
81
- options['urlEncodeLinkTargets'] = True
82
-
83
- # Possibly split the html output for figures into multiple files; Chrome gets sad with
84
- # thousands of images in a single tab.
85
- if 'maxFiguresPerHtmlFile' not in options or options['maxFiguresPerHtmlFile'] is None:
86
- options['maxFiguresPerHtmlFile'] = math.inf
87
-
88
- if filename is None or images is None:
89
- return options
90
-
91
- # images may be a list of images or a list of image/style/title dictionaries,
92
- # enforce that it's the latter to simplify downstream code
93
- for iImage,imageInfo in enumerate(images):
94
- if isinstance(imageInfo,str):
95
- imageInfo = {'filename':imageInfo}
96
- if 'filename' not in imageInfo:
97
- imageInfo['filename'] = ''
98
- if 'imageStyle' not in imageInfo:
99
- imageInfo['imageStyle'] = options['defaultImageStyle']
100
- if 'title' not in imageInfo:
101
- imageInfo['title'] = ''
102
- if 'linkTarget' not in imageInfo:
103
- imageInfo['linkTarget'] = ''
104
- if 'textStyle' not in imageInfo:
105
- textStyle = options['defaultTextStyle']
106
- imageInfo['textStyle'] = options['defaultTextStyle']
107
- images[iImage] = imageInfo
108
-
109
- nImages = len(images)
110
-
111
- # If we need to break this up into multiple files...
112
- if nImages > options['maxFiguresPerHtmlFile']:
113
-
114
- # You can't supply your own file handle in this case
115
- if options['fHtml'] != -1:
116
- raise ValueError(
117
- 'You can''t supply your own file handle if we have to page the image set')
118
-
119
- figureFileStartingIndices = list(range(0,nImages,options['maxFiguresPerHtmlFile']))
120
-
121
- assert len(figureFileStartingIndices) > 1
122
-
123
- # Open the meta-output file
124
- fMeta = open(filename,'w')
125
-
126
- # Write header stuff
127
- fMeta.write('<html><body>\n')
128
- fMeta.write(options['headerHtml'])
129
- fMeta.write('<table border = 0 cellpadding = 2>\n')
130
-
131
- for startingIndex in figureFileStartingIndices:
132
-
133
- iStart = startingIndex
134
- iEnd = startingIndex+options['maxFiguresPerHtmlFile']-1;
135
- if iEnd >= nImages:
136
- iEnd = nImages-1
137
-
138
- trailer = 'image_{:05d}_{:05d}'.format(iStart,iEnd)
139
- localFiguresHtmlFilename = path_utils.insert_before_extension(filename,trailer)
140
- fMeta.write('<tr><td>\n')
141
- fMeta.write('<p style="padding-bottom:0px;margin-bottom:0px;text-align:left;font-family:''segoe ui'',calibri,arial;font-size:100%;text-decoration:none;font-weight:bold;">')
142
- fMeta.write('<a href="{}">Figures for images {} through {}</a></p></td></tr>\n'.format(
143
- os.path.basename(localFiguresHtmlFilename),iStart,iEnd))
144
-
145
- localImages = images[iStart:iEnd+1]
146
-
147
- localOptions = options.copy();
148
- localOptions['headerHtml'] = '';
149
- localOptions['trailerHtml'] = '';
150
-
151
- # Make a recursive call for this image set
152
- write_html_image_list(localFiguresHtmlFilename,localImages,localOptions)
153
-
154
- # ...for each page of images
155
-
156
- fMeta.write('</table></body>\n')
157
- fMeta.write(options['trailerHtml'])
158
- fMeta.write('</html>\n')
159
- fMeta.close()
160
-
161
- return options
162
-
163
- # ...if we have to make multiple sub-pages
164
-
165
- bCleanupFile = False
166
-
167
- if options['fHtml'] == -1:
168
- bCleanupFile = True;
169
- fHtml = open(filename,'w')
170
- else:
171
- fHtml = options['fHtml']
172
-
173
- fHtml.write('<html><body>\n')
174
-
175
- fHtml.write(options['headerHtml'])
176
-
177
- # Write out images
178
- for iImage,image in enumerate(images):
179
-
180
- title = image['title']
181
- imageStyle = image['imageStyle']
182
- textStyle = image['textStyle']
183
- filename = image['filename']
184
- linkTarget = image['linkTarget']
185
-
186
- # Remove unicode characters
187
- title = title.encode('ascii','ignore').decode('ascii')
188
- filename = filename.encode('ascii','ignore').decode('ascii')
189
-
190
- filename = filename.replace('\\','/')
191
- if options['urlEncodeFilenames']:
192
- filename = urllib.parse.quote(filename)
193
-
194
- if len(title) > 0:
195
- fHtml.write(
196
- '<p style="{}">{}</p>\n'\
197
- .format(textStyle,title))
198
-
199
- linkTarget = linkTarget.replace('\\','/')
200
- if options['urlEncodeLinkTargets']:
201
- # These are typically absolute paths, so we only want to mess with certain characters
202
- linkTarget = urllib.parse.quote(linkTarget,safe=':/')
203
-
204
- if len(linkTarget) > 0:
205
- fHtml.write('<a href="{}">'.format(linkTarget))
206
- # imageStyle.append(';border:0px;')
207
-
208
- fHtml.write('<img src="{}" style="{}">\n'.format(filename,imageStyle))
209
-
210
- if len(linkTarget) > 0:
211
- fHtml.write('</a>')
212
-
213
- if iImage != len(images)-1:
214
- fHtml.write('<br/>')
215
-
216
- # ...for each image we need to write
217
-
218
- fHtml.write(options['trailerHtml'])
219
-
220
- fHtml.write('</body></html>\n')
221
-
222
- if bCleanupFile:
223
- fHtml.close()
224
-
225
- # ...function
File without changes
@@ -1,293 +0,0 @@
1
- """
2
-
3
- plot_utils.py
4
-
5
- Utility functions for plotting, particularly for plotting confusion matrices
6
- and precision-recall curves.
7
-
8
- """
9
-
10
- #%% Imports
11
-
12
- import numpy as np
13
-
14
- # This also imports mpl.{cm, axes, colors}
15
- import matplotlib.figure
16
-
17
-
18
- #%% Plotting functions
19
-
20
- def plot_confusion_matrix(matrix,
21
- classes,
22
- normalize=False,
23
- title='Confusion matrix',
24
- cmap=matplotlib.cm.Blues,
25
- vmax=None,
26
- use_colorbar=True,
27
- y_label=True,
28
- fmt= '{:.0f}',
29
- fig=None):
30
- """
31
- Plots a confusion matrix.
32
-
33
- Args:
34
- matrix (np.ndarray): shape [num_classes, num_classes], confusion matrix
35
- where rows are ground-truth classes and columns are predicted classes
36
- classes (list of str): class names for each row/column
37
- normalize (bool, optional): whether to perform row-wise normalization;
38
- by default, assumes values in the confusion matrix are percentages
39
- title (str, optional): figure title
40
- cmap (matplotlib.colors.colormap): colormap for cell backgrounds
41
- vmax (float, optional), value corresponding to the largest value of the colormap;
42
- if None, the maximum value in [matrix] will be used
43
- use_colorbar (bool, optional): whether to show colorbar
44
- y_label (bool, optional): whether to show class names on the y axis
45
- fmt (str): format string for rendering numeric values
46
- fig (Figure): existing figure to which we should render, otherwise creates
47
- a new figure
48
-
49
- Returns:
50
- matplotlib.figure.Figure: the figure we rendered to or created
51
- """
52
-
53
- num_classes = matrix.shape[0]
54
- assert matrix.shape[1] == num_classes
55
- assert len(classes) == num_classes
56
-
57
- normalized_matrix = matrix.astype(np.float64) / (
58
- matrix.sum(axis=1, keepdims=True) + 1e-7)
59
- if normalize:
60
- matrix = normalized_matrix
61
-
62
- fig_h = 3 + 0.3 * num_classes
63
- fig_w = fig_h
64
- if use_colorbar:
65
- fig_w += 0.5
66
-
67
- if fig is None:
68
- fig = matplotlib.figure.Figure(figsize=(fig_w, fig_h), tight_layout=True)
69
- ax = fig.subplots(1, 1)
70
- im = ax.imshow(normalized_matrix, interpolation='nearest', cmap=cmap, vmax=vmax)
71
- ax.set_title(title)
72
-
73
- if use_colorbar:
74
- cbar = fig.colorbar(im, fraction=0.046, pad=0.04,
75
- ticks=[0.0, 0.25, 0.5, 0.75, 1.0])
76
- cbar.set_ticklabels(['0%', '25%', '50%', '75%', '100%'])
77
-
78
- tick_marks = np.arange(num_classes)
79
- ax.set_xticks(tick_marks)
80
- ax.set_yticks(tick_marks)
81
- ax.set_xticklabels(classes, rotation=90)
82
- ax.set_xlabel('Predicted class')
83
-
84
- if y_label:
85
- ax.set_yticklabels(classes)
86
- ax.set_ylabel('Ground-truth class')
87
-
88
- for i, j in np.ndindex(matrix.shape):
89
- v = matrix[i, j]
90
- ax.text(j, i, fmt.format(v),
91
- horizontalalignment='center',
92
- verticalalignment='center',
93
- color='white' if normalized_matrix[i, j] > 0.5 else 'black')
94
-
95
- return fig
96
-
97
- # ...def plot_confusion_matrix(...)
98
-
99
-
100
- def plot_precision_recall_curve(precisions,
101
- recalls,
102
- title='Precision/recall curve',
103
- xlim=(0.0,1.05),
104
- ylim=(0.0,1.05)):
105
- """
106
- Plots a precision/recall curve given lists of (ordered) precision
107
- and recall values.
108
-
109
- Args:
110
- precisions (list of float): precision for corresponding recall values,
111
- should have same length as [recalls].
112
- recalls (list of float): recall for corresponding precision values,
113
- should have same length as [precisions].
114
- title (str, optional): plot title
115
- xlim (tuple, optional): x-axis limits as a length-2 tuple
116
- ylim (tuple, optional): y-axis limits as a length-2 tuple
117
-
118
- Returns:
119
- matplotlib.figure.Figure: the (new) figure
120
- """
121
-
122
- assert len(precisions) == len(recalls)
123
-
124
- fig = matplotlib.figure.Figure(tight_layout=True)
125
- ax = fig.subplots(1, 1)
126
- ax.step(recalls, precisions, color='b', alpha=0.2, where='post')
127
- ax.fill_between(recalls, precisions, alpha=0.2, color='b', step='post')
128
-
129
- try:
130
- ax.set(x_label='Recall', y_label='Precision', title=title)
131
- ax.set(x_lim=xlim, y_lim=ylim)
132
- #
133
- except Exception:
134
- ax.set_xlabel('Recall')
135
- ax.set_ylabel('Precision')
136
- ax.set_title(title)
137
- ax.set_xlim(xlim[0],xlim[1])
138
- ax.set_ylim(ylim[0],ylim[1])
139
-
140
- return fig
141
-
142
-
143
- def plot_stacked_bar_chart(data, series_labels=None, col_labels=None,
144
- x_label=None, y_label=None, log_scale=False):
145
- """
146
- Plot a stacked bar chart, for plotting e.g. species distribution across locations.
147
-
148
- Reference: https://stackoverflow.com/q/44309507
149
-
150
- Args:
151
- data (np.ndarray or list of list): data to plot; rows (series) are species, columns
152
- are locations
153
- series_labels (list of str, optional): series labels, typically species names
154
- col_labels (list of str, optional): column labels, typically location names
155
- x_label (str, optional): x-axis label
156
- y_label (str, optional): y-axis label
157
- log_scale (bool, optional) whether to plot the y axis in log-scale
158
-
159
- Returns:
160
- matplotlib.figure.Figure: the (new) figure
161
- """
162
-
163
- data = np.asarray(data)
164
- num_series, num_columns = data.shape
165
- ind = np.arange(num_columns)
166
-
167
- fig = matplotlib.figure.Figure(tight_layout=True)
168
- ax = fig.subplots(1, 1)
169
- colors = matplotlib.cm.rainbow(np.linspace(0, 1, num_series))
170
-
171
- # stacked bar charts are made with each segment starting from a y position
172
- cumulative_size = np.zeros(num_columns)
173
- for i, row_data in enumerate(data):
174
- ax.bar(ind, row_data, bottom=cumulative_size, label=series_labels[i],
175
- color=colors[i])
176
- cumulative_size += row_data
177
-
178
- if col_labels and len(col_labels) < 25:
179
- ax.set_xticks(ind)
180
- ax.set_xticklabels(col_labels, rotation=90)
181
- elif col_labels:
182
- ax.set_xticks(list(range(0, len(col_labels), 20)))
183
- ax.set_xticklabels(col_labels, rotation=90)
184
-
185
- if x_label is not None:
186
- ax.set_xlabel(x_label)
187
- if y_label is not None:
188
- ax.set_ylabel(y_label)
189
- if log_scale:
190
- ax.set_yscale('log')
191
-
192
- # To fit the legend in, shrink current axis by 20%
193
- box = ax.get_position()
194
- ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
195
-
196
- # Put a legend to the right of the current axis
197
- ax.legend(loc='center left', bbox_to_anchor=(0.99, 0.5), frameon=False)
198
-
199
- return fig
200
-
201
-
202
- def calibration_ece(true_scores, pred_scores, num_bins):
203
- r"""
204
- Expected calibration error (ECE) as defined in equation (3) of
205
- Guo et al. "On Calibration of Modern Neural Networks." (2017).
206
-
207
- Implementation modified from sklearn.calibration.calibration_curve()
208
- in order to implement ECE calculation. See:
209
-
210
- https://github.com/scikit-learn/scikit-learn/issues/18268
211
-
212
- Args:
213
- true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
214
- pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
215
- predicted confidence that example i is positive
216
- num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
217
-
218
- Returns:
219
- tuple: a length-three tuple containing:
220
- - accs: np.ndarray, shape [M], type float64, accuracy in each bin,
221
- M <= num_bins because bins with no samples are not returned
222
- - confs: np.ndarray, shape [M], type float64, mean model confidence in
223
- each bin
224
- - ece: float, expected calibration error
225
- """
226
-
227
- assert len(true_scores) == len(pred_scores)
228
-
229
- bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
230
- binids = np.digitize(pred_scores, bins) - 1
231
-
232
- bin_sums = np.bincount(binids, weights=pred_scores, minlength=len(bins))
233
- bin_true = np.bincount(binids, weights=true_scores, minlength=len(bins))
234
- bin_total = np.bincount(binids, minlength=len(bins))
235
-
236
- nonzero = bin_total != 0
237
- accs = bin_true[nonzero] / bin_total[nonzero]
238
- confs = bin_sums[nonzero] / bin_total[nonzero]
239
-
240
- weights = bin_total[nonzero] / len(true_scores)
241
- ece = np.abs(accs - confs) @ weights
242
- return accs, confs, ece
243
-
244
-
245
- def plot_calibration_curve(true_scores, pred_scores, num_bins,
246
- name='calibration', plot_perf=True, plot_hist=True,
247
- ax=None, **fig_kwargs):
248
- """
249
- Plots a calibration curve.
250
-
251
- Args:
252
- true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos)
253
- pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the
254
- predicted confidence that example i is positive
255
- num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017)
256
- name (str, optional): label in legend for the calibration curve
257
- plot_perf (bool, optional): whether to plot y=x line indicating perfect calibration
258
- plot_hist (bool, optional): whether to plot histogram of counts
259
- ax (Axes, optional): if given then no legend is drawn, and fig_kwargs are ignored
260
- fig_kwargs (dict, optional): only used if [ax] is None
261
-
262
- Returns:
263
- matplotlib.figure.Figure: the (new) figure
264
- """
265
-
266
- accs, confs, ece = calibration_ece(true_scores, pred_scores, num_bins)
267
-
268
- created_fig = False
269
- if ax is None:
270
- created_fig = True
271
- fig = matplotlib.figure.Figure(**fig_kwargs)
272
- ax = fig.subplots(1, 1)
273
- ax.plot(confs, accs, 's-', label=name) # 's-': squares on line
274
- ax.set(xlabel='Model confidence', ylabel='Actual accuracy',
275
- title=f'Calibration plot (ECE: {ece:.02g})')
276
- ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05])
277
- if plot_perf:
278
- ax.plot([0, 1], [0, 1], color='black', label='perfect calibration')
279
- ax.grid(True)
280
-
281
- if plot_hist:
282
- ax1 = ax.twinx()
283
- bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1)
284
- counts = ax1.hist(pred_scores, alpha=0.5, label='histogram of examples',
285
- bins=bins, color='tab:red')[0]
286
- max_count = np.max(counts)
287
- ax1.set_ylim([-0.05 * max_count, 1.05 * max_count])
288
- ax1.set_ylabel('Count')
289
-
290
- if created_fig:
291
- fig.legend(loc='upper left', bbox_to_anchor=(0.15, 0.85))
292
-
293
- return ax.figure