edgefirst-validator 4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. deepview/modelpack/utils/argmax.py +16 -0
  2. edgefirst/validator/__init__.py +1 -0
  3. edgefirst/validator/__main__.py +375 -0
  4. edgefirst/validator/datasets/__init__.py +118 -0
  5. edgefirst/validator/datasets/cache.py +296 -0
  6. edgefirst/validator/datasets/core.py +250 -0
  7. edgefirst/validator/datasets/darknet.py +446 -0
  8. edgefirst/validator/datasets/database.py +1067 -0
  9. edgefirst/validator/datasets/instance/__init__.py +4 -0
  10. edgefirst/validator/datasets/instance/core.py +222 -0
  11. edgefirst/validator/datasets/instance/detection.py +145 -0
  12. edgefirst/validator/datasets/instance/multitask.py +80 -0
  13. edgefirst/validator/datasets/instance/segmentation.py +120 -0
  14. edgefirst/validator/datasets/utils/fetch.py +682 -0
  15. edgefirst/validator/datasets/utils/readers.py +425 -0
  16. edgefirst/validator/datasets/utils/transformations.py +1695 -0
  17. edgefirst/validator/evaluators/__init__.py +17 -0
  18. edgefirst/validator/evaluators/callbacks/__init__.py +3 -0
  19. edgefirst/validator/evaluators/callbacks/core.py +192 -0
  20. edgefirst/validator/evaluators/callbacks/plots.py +900 -0
  21. edgefirst/validator/evaluators/callbacks/studio.py +234 -0
  22. edgefirst/validator/evaluators/core.py +257 -0
  23. edgefirst/validator/evaluators/detection.py +749 -0
  24. edgefirst/validator/evaluators/multitask.py +270 -0
  25. edgefirst/validator/evaluators/parameters/__init__.py +53 -0
  26. edgefirst/validator/evaluators/parameters/core.py +554 -0
  27. edgefirst/validator/evaluators/parameters/dataset.py +239 -0
  28. edgefirst/validator/evaluators/parameters/model.py +338 -0
  29. edgefirst/validator/evaluators/parameters/validation.py +528 -0
  30. edgefirst/validator/evaluators/segmentation.py +729 -0
  31. edgefirst/validator/evaluators/utils/__init__.py +3 -0
  32. edgefirst/validator/evaluators/utils/classify.py +292 -0
  33. edgefirst/validator/evaluators/utils/match.py +262 -0
  34. edgefirst/validator/evaluators/utils/timer.py +132 -0
  35. edgefirst/validator/metrics/__init__.py +9 -0
  36. edgefirst/validator/metrics/data/__init__.py +7 -0
  37. edgefirst/validator/metrics/data/label.py +668 -0
  38. edgefirst/validator/metrics/data/metrics.py +759 -0
  39. edgefirst/validator/metrics/data/plots.py +476 -0
  40. edgefirst/validator/metrics/data/stats.py +507 -0
  41. edgefirst/validator/metrics/detection.py +595 -0
  42. edgefirst/validator/metrics/segmentation.py +173 -0
  43. edgefirst/validator/metrics/utils/math.py +717 -0
  44. edgefirst/validator/publishers/__init__.py +3 -0
  45. edgefirst/validator/publishers/console.py +147 -0
  46. edgefirst/validator/publishers/studio.py +128 -0
  47. edgefirst/validator/publishers/tensorboard.py +119 -0
  48. edgefirst/validator/publishers/utils/logger.py +111 -0
  49. edgefirst/validator/publishers/utils/table.py +403 -0
  50. edgefirst/validator/runners/__init__.py +8 -0
  51. edgefirst/validator/runners/core.py +727 -0
  52. edgefirst/validator/runners/deepviewrt.py +177 -0
  53. edgefirst/validator/runners/hailo.py +263 -0
  54. edgefirst/validator/runners/keras.py +150 -0
  55. edgefirst/validator/runners/kinara.py +265 -0
  56. edgefirst/validator/runners/offline.py +228 -0
  57. edgefirst/validator/runners/onnx.py +241 -0
  58. edgefirst/validator/runners/processing/decode.py +320 -0
  59. edgefirst/validator/runners/processing/dvapi.py +4192 -0
  60. edgefirst/validator/runners/processing/nms.py +637 -0
  61. edgefirst/validator/runners/processing/outputs.py +507 -0
  62. edgefirst/validator/runners/tensorrt.py +321 -0
  63. edgefirst/validator/runners/tflite.py +221 -0
  64. edgefirst/validator/validate.py +843 -0
  65. edgefirst/validator/visualize/__init__.py +3 -0
  66. edgefirst/validator/visualize/detection.py +623 -0
  67. edgefirst/validator/visualize/segmentation.py +281 -0
  68. edgefirst/validator/visualize/utils/plots.py +635 -0
  69. edgefirst_validator-4.2.1.dist-info/METADATA +111 -0
  70. edgefirst_validator-4.2.1.dist-info/RECORD +73 -0
  71. edgefirst_validator-4.2.1.dist-info/WHEEL +5 -0
  72. edgefirst_validator-4.2.1.dist-info/entry_points.txt +2 -0
  73. edgefirst_validator-4.2.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,635 @@
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ from typing import TYPE_CHECKING, List
5
+
6
+ import numpy as np
7
+ import seaborn as sn
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.figure
11
+
12
+ from edgefirst.validator.metrics.utils.math import batch_iou
13
+
14
+ if TYPE_CHECKING:
15
+ from edgefirst.validator.datasets import DetectionInstance
16
+
17
+ matplotlib.use('Agg')
18
+
19
+
20
+ def figure2numpy(figure: matplotlib.figure.Figure) -> np.ndarray:
21
+ """
22
+ Converts a matplotlib.figure.Figure into a NumPy
23
+ array so that it can be published to Tensorboard.
24
+
25
+ Parameters
26
+ ----------
27
+ figure: matplotlib.figure.Figure
28
+ This is the figure to convert to a numpy array.
29
+
30
+ Returns
31
+ -------
32
+ np.ndarray
33
+ The figure that is represented as a numpy array.
34
+ """
35
+ io_buf = io.BytesIO()
36
+ figure.savefig(io_buf, format='raw')
37
+ io_buf.seek(0)
38
+ nimage = np.reshape(
39
+ np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
40
+ newshape=(int(figure.bbox.bounds[3]), int(figure.bbox.bounds[2]), -1))
41
+ io_buf.close()
42
+ return nimage
43
+
44
+
45
+ def plot_classification_detection(
46
+ class_histogram_data: dict,
47
+ model: str = "Model",
48
+ ) -> matplotlib.figure.Figure:
49
+ """
50
+ Plots the bar charts showing the precision, recall, and accuracy per class.
51
+ It also shows the number of true positives, false positives,
52
+ and false negatives per class.
53
+
54
+ Parameters
55
+ ----------
56
+ class_histogram_data: dict.
57
+ This contains information about the metrics per class.
58
+
59
+ .. code-block:: python
60
+
61
+ {
62
+ 'label_1': {
63
+ 'precision': "The calculated precision at
64
+ IoU threshold 0.5 for the class",
65
+ 'recall': "The calculated recall at
66
+ IoU threshold 0.5 for the class",
67
+ 'accuracy': "The calculated accuracy at
68
+ IoU threshold 0.5 for the class",
69
+ 'tp': "The number of true positives for the class",
70
+ 'fn': "The number of false negatives for the class",
71
+ 'fp': "The number of localization and
72
+ classification false positives for the class",
73
+ 'gt': "The number of grounds truths for the class"
74
+ },
75
+ 'label_2': ...
76
+ }
77
+
78
+ model: str
79
+ The name of the model.
80
+
81
+ Returns
82
+ -------
83
+ matplotlib.figure.Figure
84
+ This shows two histograms on the left that compares
85
+ the precision, recall, and accuracy and on the right
86
+ compares the number of true positives, false positives,
87
+ and false negatives for each class.
88
+ """
89
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 10))
90
+ # Score = [[prec c1, prec c2, prec c3], [rec c1, rec c2, rec c3], [acc c1,
91
+ # acc c2, acc c3]]
92
+ X = np.arange(len(class_histogram_data))
93
+ labels, precision, recall, accuracy = list(), list(), list(), list()
94
+ tp, fp, fn = list(), list(), list()
95
+
96
+ for cls, value, in class_histogram_data.items():
97
+ labels.append(cls)
98
+ precision.append(round(value.get('precision') * 100, 2))
99
+ recall.append(round(value.get('recall') * 100, 2))
100
+ accuracy.append(round(value.get('accuracy') * 100, 2))
101
+ tp.append(value.get('tp'))
102
+ fn.append(value.get('fn'))
103
+ fp.append(value.get('fp'))
104
+
105
+ ax1.bar(X + 0.0, precision, color='m', width=0.25)
106
+ ax1.bar(X + 0.25, recall, color='y', width=0.25)
107
+ ax1.bar(X + 0.5, accuracy, color='c', width=0.25)
108
+
109
+ ax2.bar(X + 0.0, tp, color='LimeGreen', width=0.25)
110
+ ax2.bar(X + 0.25, fn, color='RoyalBlue', width=0.25)
111
+ ax2.bar(X + 0.5, fp, color='OrangeRed', width=0.25)
112
+
113
+ ax1.set_ylim(0, 100)
114
+
115
+ ax1.set_ylabel('Score (%)')
116
+ ax2.set_ylabel("Total Number")
117
+ fig.suptitle(f"{model} Evaluation Table")
118
+
119
+ ax1.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
120
+ ax2.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
121
+
122
+ colors = {'precision': 'm', 'recall': 'y', 'accuracy': 'c'}
123
+ labels = list(colors.keys())
124
+ handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
125
+ for label in labels]
126
+ ax1.legend(handles, labels)
127
+ colors = {'true positives': 'green',
128
+ 'false negatives': 'blue',
129
+ 'false positives': 'red'}
130
+ labels = list(colors.keys())
131
+ handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
132
+ for label in labels]
133
+ ax2.legend(handles, labels)
134
+ return fig
135
+
136
+
137
+ def plot_classification_segmentation(
138
+ class_histogram_data: dict,
139
+ model: str = "Model"
140
+ ) -> matplotlib.figure.Figure:
141
+ """
142
+ Plots the bar charts showing the precision,
143
+ recall, and accuracy per class.
144
+ It also shows the number of true predictions
145
+ and false predictions per class.
146
+
147
+ Parameters
148
+ ----------
149
+ class_histogram_data: dict.
150
+ This contains information about the metrics per class.
151
+
152
+ .. code-block:: python
153
+
154
+ {
155
+ 'label_1': {
156
+ 'precision': "The calculated precision for the class",
157
+ 'recall': "The calculated recall for the class",
158
+ 'accuracy': "The calculated accuracy for the class",
159
+ 'true_predictions': "The number of true prediction
160
+ pixels of the class",
161
+ 'false_predictions': "The number of false prediction
162
+ pixels of the class",
163
+ 'gt': "The number of grounds truths for the class"
164
+ },
165
+ 'label_2': ...
166
+ }
167
+
168
+ model: str
169
+ The name of the model.
170
+
171
+ Returns
172
+ -------
173
+ matplotlib.figure.Figure
174
+ This shows two histograms on the left that compares
175
+ the precision, recall, and accuracy and on the right
176
+ compares the number of true prediction and
177
+ false prediction pixels for each class.
178
+ """
179
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 10))
180
+ # Score = [[prec c1, prec c2, prec c3], [rec c1, rec c2, rec c3], [acc c1,
181
+ # acc c2, acc c3]]
182
+ X = np.arange(len(class_histogram_data))
183
+ labels, precision, recall, accuracy = list(), list(), list(), list()
184
+ true_predictions, false_predictions = list(), list()
185
+
186
+ for cls, value, in class_histogram_data.items():
187
+ labels.append(cls)
188
+ precision.append(round(value.get('precision') * 100, 2))
189
+ recall.append(round(value.get('recall') * 100, 2))
190
+ accuracy.append(round(value.get('accuracy') * 100, 2))
191
+ true_predictions.append(value.get('true_predictions'))
192
+ false_predictions.append(value.get('false_predictions'))
193
+
194
+ ax1.bar(X + 0.0, precision, color='m', width=0.25)
195
+ ax1.bar(X + 0.25, recall, color='y', width=0.25)
196
+ ax1.bar(X + 0.5, accuracy, color='c', width=0.25)
197
+
198
+ ax2.bar(X + 0.0, true_predictions, color='LimeGreen', width=0.25)
199
+ ax2.bar(X + 0.25, false_predictions, color='OrangeRed', width=0.25)
200
+
201
+ ax1.set_ylim(0, 100)
202
+
203
+ ax1.set_ylabel('Score (%)')
204
+ ax2.set_ylabel("Total Number")
205
+ fig.suptitle(f"{model} Evaluation Table")
206
+
207
+ ax1.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
208
+ ax2.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
209
+
210
+ colors = {'precision': 'm', 'recall': 'y', 'accuracy': 'c'}
211
+ labels = list(colors.keys())
212
+ handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
213
+ for label in labels]
214
+ ax1.legend(handles, labels)
215
+ colors = {'true predictions': 'green',
216
+ 'false predictions': 'red'}
217
+ labels = list(colors.keys())
218
+ handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
219
+ for label in labels]
220
+ ax2.legend(handles, labels)
221
+ return fig
222
+
223
+
224
+ def plot_score_histogram(
225
+ tp_scores: np.ndarray,
226
+ fp_scores: np.ndarray,
227
+ model: str = "Model",
228
+ title: str = "Histogram of TP vs FP Scores",
229
+ xlabel: str = "Score",
230
+ ylabel: str = "Count"
231
+ ):
232
+ """
233
+ Create a score histogram to compare the number of true positives
234
+ and false positives based on the scores. This provides insight
235
+ on the optimal thresholds to use. Also draws count labels
236
+ on each histogram bar.
237
+
238
+ Parameters
239
+ ----------
240
+ tp_scores: np.ndarray
241
+ All the scores for the true positives.
242
+ fp_scores: np.ndarray
243
+ All the scores for the false positives.
244
+ model: str
245
+ The name of the model evaluated.
246
+ title: str
247
+ Provide the title for the plot.
248
+ xlabel: str
249
+ The x-axis label.
250
+ ylabel: str
251
+ The y-axis label.
252
+
253
+ Returns
254
+ -------
255
+ matplotlib.figure.Figure
256
+ This shows the histogram comparing the scores of the
257
+ true positives and false positives.
258
+ """
259
+ # Define histogram bins: 0.0 to 1.0 with step of 0.05
260
+ bins = np.arange(0, 1.05, 0.05)
261
+
262
+ # Compute histograms (counts only)
263
+ tp_hist, _ = np.histogram(tp_scores, bins=bins)
264
+ fp_hist, _ = np.histogram(fp_scores, bins=bins)
265
+
266
+ # Plot histograms
267
+ fig, ax = plt.subplots(1, 1, figsize=(10, 5), tight_layout=True)
268
+ bin_centers = (bins[:-1] + bins[1:]) / 2
269
+
270
+ tp_bars = ax.bar(bin_centers - 0.01,
271
+ tp_hist,
272
+ width=0.02,
273
+ label='True Positives',
274
+ alpha=0.7, color='green')
275
+ fp_bars = ax.bar(bin_centers + 0.01,
276
+ fp_hist,
277
+ width=0.02,
278
+ label='False Positives',
279
+ alpha=0.7, color='red')
280
+
281
+ # Annotate each bar with count
282
+ for i, (tp_bar, fp_bar) in enumerate(zip(tp_bars, fp_bars)):
283
+ tp_count = tp_hist[i]
284
+ fp_count = fp_hist[i]
285
+ if tp_count > 0:
286
+ ax.text(tp_bar.get_x() + tp_bar.get_width() / 2,
287
+ tp_bar.get_height() + 0.5,
288
+ str(tp_count),
289
+ ha='center', va='bottom', fontsize=8, color='green')
290
+ if fp_count > 0:
291
+ ax.text(fp_bar.get_x() + fp_bar.get_width() / 2,
292
+ fp_bar.get_height() + 0.5,
293
+ str(fp_count),
294
+ ha='center', va='bottom', fontsize=8, color='red')
295
+
296
+ ax.set_xlabel(xlabel)
297
+ ax.set_ylabel(ylabel)
298
+ ax.set_title(f'{model} {title}')
299
+ ax.set_xticks(bins)
300
+ ax.legend()
301
+ ax.grid(True)
302
+ return fig
303
+
304
+
305
+ def plot_pr_curve(
306
+ precision: np.ndarray,
307
+ recall: np.ndarray,
308
+ ap: np.ndarray,
309
+ names: list = [],
310
+ model: str = "Model",
311
+ iou_threshold: float = 0.50
312
+ ) -> matplotlib.figure.Figure:
313
+ """
314
+ Version 2 Ploting precision and recall per class and the average metric.
315
+ Use this method for YoloV5 implementation of precision recall
316
+ curve.
317
+
318
+ Parameters
319
+ ----------
320
+ precision: (NxM) np.ndarray
321
+ N => number of classes and M is the number of precision values.
322
+ recall: (NxM) np.ndarray
323
+ N => number of classes and M is the number of recall values.
324
+ ap: (NxM) np.ndarray
325
+ N => number of classes, M => 10 denoting each IoU threshold
326
+ from (0.5 to 0.95 at 0.05 intervals).
327
+ names: list
328
+ This contains the unique string labels captured in the order
329
+ that respects the data for precision and recall.
330
+ model: str
331
+ The name of the model evaluated.
332
+ iou_threshold: float
333
+ The iou threshold used for the mAP calculation.
334
+
335
+ Returns
336
+ -------
337
+ matplotlib.figure.Figure
338
+ The precision recall plot where recall is denoted
339
+ on the x-axis and precision is denoted
340
+ on the y-axis.
341
+ """
342
+ # Precision-recall curve
343
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
344
+ if len(precision) == 0:
345
+ ax.set_xlabel("Recall")
346
+ ax.set_ylabel("Precision")
347
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
348
+ ax.set_title(f'{model} Precision-Recall Curve')
349
+ return fig
350
+
351
+ precision = np.stack(precision, axis=1)
352
+ if (0 < len(names) < 21): # display per-class legend if < 21 classes
353
+ for i, y in enumerate(precision.T):
354
+ # plot(recall, precision)
355
+ ax.plot(recall, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}")
356
+ else:
357
+ # plot(recall, precision)
358
+ ax.plot(recall, precision, linewidth=1, color="grey")
359
+
360
+ ax.plot(
361
+ recall,
362
+ precision.mean(1),
363
+ linewidth=3,
364
+ color="blue",
365
+ label="all classes %.3f mAP@%.2f" % (ap[:, 0].mean(), iou_threshold))
366
+ ax.set_xlabel("Recall")
367
+ ax.set_ylabel("Precision")
368
+ ax.set_xlim(0, 1)
369
+ ax.set_ylim(0, 1)
370
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
371
+ ax.set_title(f'{model} Precision-Recall Curve')
372
+ return fig
373
+
374
+
375
+ def plot_mc_curve(
376
+ px: np.ndarray,
377
+ py: np.ndarray,
378
+ names: list = [],
379
+ model: str = "Model",
380
+ xlabel: str = 'Confidence',
381
+ ylabel: str = 'Metric'
382
+ ) -> matplotlib.figure.Figure:
383
+ """
384
+ This function is used for plotting either the F1-curve or the
385
+ precision/recall versus confidence curves.
386
+
387
+ Parameters
388
+ ----------
389
+ px: (NxM) np.ndarray
390
+ N => number of classes.
391
+ py: (NxM) np.ndarray
392
+ This could be values for the F1, precision, or recall.
393
+ names: list
394
+ This contains the unique string labels captured in the order
395
+ that respects the data for precision and recall.
396
+ model: str
397
+ The name of the model evaluated.
398
+ xlabel: str
399
+ The metric on the x-axis.
400
+ ylabel: str
401
+ The metric on the y-axis.
402
+
403
+ Returns
404
+ -------
405
+ matplotlib.figure.Figure
406
+ The plot where recall is denoted
407
+ on the x-axis and either is denoted
408
+ on the y-axis.
409
+ """
410
+ # Metric-confidence curve
411
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
412
+
413
+ if 0 < len(names) < 21: # display per-class legend if < 21 classes
414
+ for i, y in enumerate(py):
415
+ # plot(confidence, metric)
416
+ ax.plot(px, y, linewidth=1, label=f'{names[i]}')
417
+ else:
418
+ # plot(confidence, metric)
419
+ ax.plot(px, py.T, linewidth=1, color='grey')
420
+
421
+ y = py.mean(0)
422
+ ax.plot(px, y, linewidth=3, color='blue',
423
+ label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
424
+ ax.set_xlabel(xlabel)
425
+ ax.set_ylabel(ylabel)
426
+ ax.set_xlim(0, 1)
427
+ ax.set_ylim(0, 1)
428
+ ax.set_title(f'{model} {ylabel}-{xlabel} Curve')
429
+ plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
430
+ return fig
431
+
432
+
433
+ def plot_confusion_matrix(
434
+ confusion_data: np.ndarray,
435
+ labels: list,
436
+ model: str = "Model"
437
+ ) -> matplotlib.figure.Figure:
438
+ """
439
+ Plots the confusion matrix using the method defined below:
440
+ https://stackoverflow.com/questions/5821125/how-to-plot-confusion-matrix-with-string-axis-rather-than-integer-in-python/74152927#74152927
441
+
442
+ Parameters
443
+ ----------
444
+ confusion_data: np.ndarray
445
+ This is a square matrix representing the confusion matrix data
446
+ where the rows are the predictions and the columns are the
447
+ ground truth.
448
+ labels: list
449
+ This contains the unique string labels in the dataset.
450
+ model: str
451
+ The name of the model being validated.
452
+
453
+ Returns
454
+ --------
455
+ matplotlib.figure.Figure
456
+ The confusion matrix plot.
457
+ """
458
+ norm_conf = []
459
+ for i in confusion_data:
460
+ a = 0
461
+ tmp_arr = []
462
+ a = sum(i, 0)
463
+ for j in i:
464
+ try:
465
+ tmp_arr.append(float(j) / float(a))
466
+ except ZeroDivisionError:
467
+ tmp_arr.append(0.)
468
+ norm_conf.append(tmp_arr)
469
+
470
+ fig = plt.figure()
471
+ plt.clf()
472
+ ax = fig.add_subplot(111)
473
+ ax.set_aspect(1)
474
+ res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet,
475
+ interpolation='nearest')
476
+ width, height = confusion_data.shape
477
+
478
+ for x in range(width):
479
+ for y in range(height):
480
+ ax.annotate(str(int(confusion_data[x][y])), xy=(y, x),
481
+ horizontalalignment='center',
482
+ verticalalignment='center')
483
+ fig.colorbar(res)
484
+ plt.xticks(range(width), labels[:width], rotation="vertical")
485
+ plt.yticks(range(height), labels[:height])
486
+ plt.ylabel("Prediction")
487
+ plt.xlabel("Ground Truth")
488
+ plt.title(f"{model} Confusion Matrix")
489
+ return fig
490
+
491
+
492
+ def close_figures(figures: List[matplotlib.figure.Figure]):
493
+ """
494
+ Closes the matplotlib figures opened to prevent
495
+ errors such as "Fail to allocate bitmap."
496
+
497
+ Parameters
498
+ ----------
499
+ figures: List[matplotlib.figure.Figure]
500
+ Contains matplotlib.pyplot figures to close.
501
+ """
502
+ if len(figures) > 0:
503
+ for figure in figures:
504
+ plt.close(figure)
505
+
506
+
507
+ class ConfusionMatrix:
508
+ """
509
+ This confusion matrix implementation was taken from YoloV7 to
510
+ follow their validation implementation.
511
+
512
+ Parameters
513
+ -----------
514
+ nc: int
515
+ The number of classes in the dataset.
516
+ conf: float
517
+ The confidence threshold for plotting.
518
+ iou_thres: float
519
+ The IoU threshold for plotting.
520
+ offset: int
521
+ If the dataset labels already contains background,
522
+ then this offset is 0. Otherwise the offset is +1
523
+ to include the background class.
524
+ """
525
+ # Updated version of
526
+ # https://github.com/kaanakan/object_detection_confusion_matrix
527
+
528
+ def __init__(
529
+ self,
530
+ nc: int,
531
+ conf: float = 0.25,
532
+ iou_thres: float = 0.45,
533
+ offset: int = 0
534
+ ):
535
+ self.matrix = np.zeros((nc + offset, nc + offset), dtype=np.int32)
536
+ self.nc = nc # number of classes
537
+ self.conf = conf
538
+ self.iou_thres = iou_thres
539
+ self.offset = offset
540
+
541
+ def process_batch(self, dt_instance: DetectionInstance,
542
+ gt_instance: DetectionInstance):
543
+ """
544
+ Return intersection-over-union (Jaccard index) of boxes.
545
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
546
+
547
+ Parameters
548
+ ----------
549
+ dt_instance: DetectionInstance
550
+ A prediction instance container of the boxes, labels, and scores.
551
+ gt_instance: DetectionInstance
552
+ A ground truth instance container of the boxes and the labels.
553
+ """
554
+
555
+ if dt_instance is None:
556
+ gt_classes = gt_instance.labels.astype(np.int32)
557
+ for gc in gt_classes:
558
+ self.matrix[0, gc + self.offset] += 1 # background FN
559
+ return
560
+
561
+ filt = dt_instance.scores > self.conf
562
+ dt_boxes = dt_instance.boxes[filt]
563
+ dt_classes = dt_instance.labels[filt]
564
+
565
+ gt_boxes = gt_instance.boxes
566
+ gt_classes = gt_instance.labels.astype(np.int32)
567
+ dt_classes = dt_classes.astype(np.int32)
568
+ iou = batch_iou(gt_boxes, dt_boxes)
569
+
570
+ x = np.where(iou > self.iou_thres)
571
+ if x[0].shape[0]:
572
+ matches = np.concatenate(
573
+ (np.stack(x, 1), iou[x[0], x[1]][:, None]), 1)
574
+ if x[0].shape[0] > 1:
575
+ matches = matches[matches[:, 2].argsort()[::-1]]
576
+ matches = matches[np.unique(
577
+ matches[:, 1], return_index=True)[1]]
578
+ matches = matches[matches[:, 2].argsort()[::-1]]
579
+ matches = matches[np.unique(
580
+ matches[:, 0], return_index=True)[1]]
581
+ else:
582
+ matches = np.zeros((0, 3))
583
+
584
+ n = matches.shape[0] > 0
585
+ m0, m1, _ = matches.transpose().astype(int)
586
+ for i, gc in enumerate(gt_classes):
587
+ j = m0 == i
588
+ # Asserting a unique match.
589
+ if n and sum(j) == 1:
590
+ self.matrix[dt_classes[m1[j]] + self.offset,
591
+ gc + self.offset] += 1 # correct
592
+ else:
593
+ self.matrix[0, gc + self.offset] += 1 # true background
594
+
595
+ matched_detections = m1 if n else np.array([], dtype=int)
596
+ for i, dc in enumerate(dt_classes):
597
+ if i not in matched_detections:
598
+ # false positive (predicted something not matched to GT)
599
+ self.matrix[dc + self.offset, 0] += 1
600
+
601
+ def plot(self, names=()) -> matplotlib.figure.Figure:
602
+ """
603
+ Plots the Confusion Matrix.
604
+
605
+ Parameters
606
+ ----------
607
+ names: tuple
608
+ All the unique labels in the dataset.
609
+
610
+ Returns
611
+ -------
612
+ matplotlib.figure.Figure
613
+ The Confusion Matrix figure.
614
+ """
615
+ array = self.matrix / \
616
+ (self.matrix.sum(0).reshape(1, self.nc + self.offset) + 1E-6) # normalize
617
+ array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
618
+
619
+ fig = plt.figure(figsize=(12, 9), tight_layout=True)
620
+ sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
621
+ labels = (0 < len(names) < 99) and len(
622
+ names) == self.nc # apply names to ticklabels
623
+ sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
624
+ xticklabels=names if labels else "auto",
625
+ yticklabels=names if labels else "auto").set_facecolor((1, 1, 1))
626
+ fig.axes[0].set_xlabel('True')
627
+ fig.axes[0].set_ylabel('Predicted')
628
+ return fig
629
+
630
+ def print(self):
631
+ """
632
+ Prints the Confusion Matrix.
633
+ """
634
+ for i in range(self.nc + self.offset):
635
+ print(' '.join(map(str, self.matrix[i])))