edgefirst-validator 4.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepview/modelpack/utils/argmax.py +16 -0
- edgefirst/validator/__init__.py +1 -0
- edgefirst/validator/__main__.py +375 -0
- edgefirst/validator/datasets/__init__.py +118 -0
- edgefirst/validator/datasets/cache.py +296 -0
- edgefirst/validator/datasets/core.py +250 -0
- edgefirst/validator/datasets/darknet.py +446 -0
- edgefirst/validator/datasets/database.py +1067 -0
- edgefirst/validator/datasets/instance/__init__.py +4 -0
- edgefirst/validator/datasets/instance/core.py +222 -0
- edgefirst/validator/datasets/instance/detection.py +145 -0
- edgefirst/validator/datasets/instance/multitask.py +80 -0
- edgefirst/validator/datasets/instance/segmentation.py +120 -0
- edgefirst/validator/datasets/utils/fetch.py +682 -0
- edgefirst/validator/datasets/utils/readers.py +425 -0
- edgefirst/validator/datasets/utils/transformations.py +1695 -0
- edgefirst/validator/evaluators/__init__.py +17 -0
- edgefirst/validator/evaluators/callbacks/__init__.py +3 -0
- edgefirst/validator/evaluators/callbacks/core.py +192 -0
- edgefirst/validator/evaluators/callbacks/plots.py +900 -0
- edgefirst/validator/evaluators/callbacks/studio.py +234 -0
- edgefirst/validator/evaluators/core.py +257 -0
- edgefirst/validator/evaluators/detection.py +749 -0
- edgefirst/validator/evaluators/multitask.py +270 -0
- edgefirst/validator/evaluators/parameters/__init__.py +53 -0
- edgefirst/validator/evaluators/parameters/core.py +554 -0
- edgefirst/validator/evaluators/parameters/dataset.py +239 -0
- edgefirst/validator/evaluators/parameters/model.py +338 -0
- edgefirst/validator/evaluators/parameters/validation.py +528 -0
- edgefirst/validator/evaluators/segmentation.py +729 -0
- edgefirst/validator/evaluators/utils/__init__.py +3 -0
- edgefirst/validator/evaluators/utils/classify.py +292 -0
- edgefirst/validator/evaluators/utils/match.py +262 -0
- edgefirst/validator/evaluators/utils/timer.py +132 -0
- edgefirst/validator/metrics/__init__.py +9 -0
- edgefirst/validator/metrics/data/__init__.py +7 -0
- edgefirst/validator/metrics/data/label.py +668 -0
- edgefirst/validator/metrics/data/metrics.py +759 -0
- edgefirst/validator/metrics/data/plots.py +476 -0
- edgefirst/validator/metrics/data/stats.py +507 -0
- edgefirst/validator/metrics/detection.py +595 -0
- edgefirst/validator/metrics/segmentation.py +173 -0
- edgefirst/validator/metrics/utils/math.py +717 -0
- edgefirst/validator/publishers/__init__.py +3 -0
- edgefirst/validator/publishers/console.py +147 -0
- edgefirst/validator/publishers/studio.py +128 -0
- edgefirst/validator/publishers/tensorboard.py +119 -0
- edgefirst/validator/publishers/utils/logger.py +111 -0
- edgefirst/validator/publishers/utils/table.py +403 -0
- edgefirst/validator/runners/__init__.py +8 -0
- edgefirst/validator/runners/core.py +727 -0
- edgefirst/validator/runners/deepviewrt.py +177 -0
- edgefirst/validator/runners/hailo.py +263 -0
- edgefirst/validator/runners/keras.py +150 -0
- edgefirst/validator/runners/kinara.py +265 -0
- edgefirst/validator/runners/offline.py +228 -0
- edgefirst/validator/runners/onnx.py +241 -0
- edgefirst/validator/runners/processing/decode.py +320 -0
- edgefirst/validator/runners/processing/dvapi.py +4192 -0
- edgefirst/validator/runners/processing/nms.py +637 -0
- edgefirst/validator/runners/processing/outputs.py +507 -0
- edgefirst/validator/runners/tensorrt.py +321 -0
- edgefirst/validator/runners/tflite.py +221 -0
- edgefirst/validator/validate.py +843 -0
- edgefirst/validator/visualize/__init__.py +3 -0
- edgefirst/validator/visualize/detection.py +623 -0
- edgefirst/validator/visualize/segmentation.py +281 -0
- edgefirst/validator/visualize/utils/plots.py +635 -0
- edgefirst_validator-4.2.1.dist-info/METADATA +111 -0
- edgefirst_validator-4.2.1.dist-info/RECORD +73 -0
- edgefirst_validator-4.2.1.dist-info/WHEEL +5 -0
- edgefirst_validator-4.2.1.dist-info/entry_points.txt +2 -0
- edgefirst_validator-4.2.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from typing import TYPE_CHECKING, List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import seaborn as sn
|
|
8
|
+
import matplotlib
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import matplotlib.figure
|
|
11
|
+
|
|
12
|
+
from edgefirst.validator.metrics.utils.math import batch_iou
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from edgefirst.validator.datasets import DetectionInstance
|
|
16
|
+
|
|
17
|
+
matplotlib.use('Agg')
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def figure2numpy(figure: matplotlib.figure.Figure) -> np.ndarray:
|
|
21
|
+
"""
|
|
22
|
+
Converts a matplotlib.figure.Figure into a NumPy
|
|
23
|
+
array so that it can be published to Tensorboard.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
figure: matplotlib.figure.Figure
|
|
28
|
+
This is the figure to convert to a numpy array.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
np.ndarray
|
|
33
|
+
The figure that is represented as a numpy array.
|
|
34
|
+
"""
|
|
35
|
+
io_buf = io.BytesIO()
|
|
36
|
+
figure.savefig(io_buf, format='raw')
|
|
37
|
+
io_buf.seek(0)
|
|
38
|
+
nimage = np.reshape(
|
|
39
|
+
np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
|
|
40
|
+
newshape=(int(figure.bbox.bounds[3]), int(figure.bbox.bounds[2]), -1))
|
|
41
|
+
io_buf.close()
|
|
42
|
+
return nimage
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def plot_classification_detection(
|
|
46
|
+
class_histogram_data: dict,
|
|
47
|
+
model: str = "Model",
|
|
48
|
+
) -> matplotlib.figure.Figure:
|
|
49
|
+
"""
|
|
50
|
+
Plots the bar charts showing the precision, recall, and accuracy per class.
|
|
51
|
+
It also shows the number of true positives, false positives,
|
|
52
|
+
and false negatives per class.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
class_histogram_data: dict.
|
|
57
|
+
This contains information about the metrics per class.
|
|
58
|
+
|
|
59
|
+
.. code-block:: python
|
|
60
|
+
|
|
61
|
+
{
|
|
62
|
+
'label_1': {
|
|
63
|
+
'precision': "The calculated precision at
|
|
64
|
+
IoU threshold 0.5 for the class",
|
|
65
|
+
'recall': "The calculated recall at
|
|
66
|
+
IoU threshold 0.5 for the class",
|
|
67
|
+
'accuracy': "The calculated accuracy at
|
|
68
|
+
IoU threshold 0.5 for the class",
|
|
69
|
+
'tp': "The number of true positives for the class",
|
|
70
|
+
'fn': "The number of false negatives for the class",
|
|
71
|
+
'fp': "The number of localization and
|
|
72
|
+
classification false positives for the class",
|
|
73
|
+
'gt': "The number of grounds truths for the class"
|
|
74
|
+
},
|
|
75
|
+
'label_2': ...
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
model: str
|
|
79
|
+
The name of the model.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
matplotlib.figure.Figure
|
|
84
|
+
This shows two histograms on the left that compares
|
|
85
|
+
the precision, recall, and accuracy and on the right
|
|
86
|
+
compares the number of true positives, false positives,
|
|
87
|
+
and false negatives for each class.
|
|
88
|
+
"""
|
|
89
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 10))
|
|
90
|
+
# Score = [[prec c1, prec c2, prec c3], [rec c1, rec c2, rec c3], [acc c1,
|
|
91
|
+
# acc c2, acc c3]]
|
|
92
|
+
X = np.arange(len(class_histogram_data))
|
|
93
|
+
labels, precision, recall, accuracy = list(), list(), list(), list()
|
|
94
|
+
tp, fp, fn = list(), list(), list()
|
|
95
|
+
|
|
96
|
+
for cls, value, in class_histogram_data.items():
|
|
97
|
+
labels.append(cls)
|
|
98
|
+
precision.append(round(value.get('precision') * 100, 2))
|
|
99
|
+
recall.append(round(value.get('recall') * 100, 2))
|
|
100
|
+
accuracy.append(round(value.get('accuracy') * 100, 2))
|
|
101
|
+
tp.append(value.get('tp'))
|
|
102
|
+
fn.append(value.get('fn'))
|
|
103
|
+
fp.append(value.get('fp'))
|
|
104
|
+
|
|
105
|
+
ax1.bar(X + 0.0, precision, color='m', width=0.25)
|
|
106
|
+
ax1.bar(X + 0.25, recall, color='y', width=0.25)
|
|
107
|
+
ax1.bar(X + 0.5, accuracy, color='c', width=0.25)
|
|
108
|
+
|
|
109
|
+
ax2.bar(X + 0.0, tp, color='LimeGreen', width=0.25)
|
|
110
|
+
ax2.bar(X + 0.25, fn, color='RoyalBlue', width=0.25)
|
|
111
|
+
ax2.bar(X + 0.5, fp, color='OrangeRed', width=0.25)
|
|
112
|
+
|
|
113
|
+
ax1.set_ylim(0, 100)
|
|
114
|
+
|
|
115
|
+
ax1.set_ylabel('Score (%)')
|
|
116
|
+
ax2.set_ylabel("Total Number")
|
|
117
|
+
fig.suptitle(f"{model} Evaluation Table")
|
|
118
|
+
|
|
119
|
+
ax1.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
|
|
120
|
+
ax2.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
|
|
121
|
+
|
|
122
|
+
colors = {'precision': 'm', 'recall': 'y', 'accuracy': 'c'}
|
|
123
|
+
labels = list(colors.keys())
|
|
124
|
+
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
|
|
125
|
+
for label in labels]
|
|
126
|
+
ax1.legend(handles, labels)
|
|
127
|
+
colors = {'true positives': 'green',
|
|
128
|
+
'false negatives': 'blue',
|
|
129
|
+
'false positives': 'red'}
|
|
130
|
+
labels = list(colors.keys())
|
|
131
|
+
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
|
|
132
|
+
for label in labels]
|
|
133
|
+
ax2.legend(handles, labels)
|
|
134
|
+
return fig
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def plot_classification_segmentation(
|
|
138
|
+
class_histogram_data: dict,
|
|
139
|
+
model: str = "Model"
|
|
140
|
+
) -> matplotlib.figure.Figure:
|
|
141
|
+
"""
|
|
142
|
+
Plots the bar charts showing the precision,
|
|
143
|
+
recall, and accuracy per class.
|
|
144
|
+
It also shows the number of true predictions
|
|
145
|
+
and false predictions per class.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
class_histogram_data: dict.
|
|
150
|
+
This contains information about the metrics per class.
|
|
151
|
+
|
|
152
|
+
.. code-block:: python
|
|
153
|
+
|
|
154
|
+
{
|
|
155
|
+
'label_1': {
|
|
156
|
+
'precision': "The calculated precision for the class",
|
|
157
|
+
'recall': "The calculated recall for the class",
|
|
158
|
+
'accuracy': "The calculated accuracy for the class",
|
|
159
|
+
'true_predictions': "The number of true prediction
|
|
160
|
+
pixels of the class",
|
|
161
|
+
'false_predictions': "The number of false prediction
|
|
162
|
+
pixels of the class",
|
|
163
|
+
'gt': "The number of grounds truths for the class"
|
|
164
|
+
},
|
|
165
|
+
'label_2': ...
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
model: str
|
|
169
|
+
The name of the model.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
matplotlib.figure.Figure
|
|
174
|
+
This shows two histograms on the left that compares
|
|
175
|
+
the precision, recall, and accuracy and on the right
|
|
176
|
+
compares the number of true prediction and
|
|
177
|
+
false prediction pixels for each class.
|
|
178
|
+
"""
|
|
179
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 10))
|
|
180
|
+
# Score = [[prec c1, prec c2, prec c3], [rec c1, rec c2, rec c3], [acc c1,
|
|
181
|
+
# acc c2, acc c3]]
|
|
182
|
+
X = np.arange(len(class_histogram_data))
|
|
183
|
+
labels, precision, recall, accuracy = list(), list(), list(), list()
|
|
184
|
+
true_predictions, false_predictions = list(), list()
|
|
185
|
+
|
|
186
|
+
for cls, value, in class_histogram_data.items():
|
|
187
|
+
labels.append(cls)
|
|
188
|
+
precision.append(round(value.get('precision') * 100, 2))
|
|
189
|
+
recall.append(round(value.get('recall') * 100, 2))
|
|
190
|
+
accuracy.append(round(value.get('accuracy') * 100, 2))
|
|
191
|
+
true_predictions.append(value.get('true_predictions'))
|
|
192
|
+
false_predictions.append(value.get('false_predictions'))
|
|
193
|
+
|
|
194
|
+
ax1.bar(X + 0.0, precision, color='m', width=0.25)
|
|
195
|
+
ax1.bar(X + 0.25, recall, color='y', width=0.25)
|
|
196
|
+
ax1.bar(X + 0.5, accuracy, color='c', width=0.25)
|
|
197
|
+
|
|
198
|
+
ax2.bar(X + 0.0, true_predictions, color='LimeGreen', width=0.25)
|
|
199
|
+
ax2.bar(X + 0.25, false_predictions, color='OrangeRed', width=0.25)
|
|
200
|
+
|
|
201
|
+
ax1.set_ylim(0, 100)
|
|
202
|
+
|
|
203
|
+
ax1.set_ylabel('Score (%)')
|
|
204
|
+
ax2.set_ylabel("Total Number")
|
|
205
|
+
fig.suptitle(f"{model} Evaluation Table")
|
|
206
|
+
|
|
207
|
+
ax1.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
|
|
208
|
+
ax2.xaxis.set_ticks(range(len(labels)), labels, rotation='vertical')
|
|
209
|
+
|
|
210
|
+
colors = {'precision': 'm', 'recall': 'y', 'accuracy': 'c'}
|
|
211
|
+
labels = list(colors.keys())
|
|
212
|
+
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
|
|
213
|
+
for label in labels]
|
|
214
|
+
ax1.legend(handles, labels)
|
|
215
|
+
colors = {'true predictions': 'green',
|
|
216
|
+
'false predictions': 'red'}
|
|
217
|
+
labels = list(colors.keys())
|
|
218
|
+
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label])
|
|
219
|
+
for label in labels]
|
|
220
|
+
ax2.legend(handles, labels)
|
|
221
|
+
return fig
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def plot_score_histogram(
|
|
225
|
+
tp_scores: np.ndarray,
|
|
226
|
+
fp_scores: np.ndarray,
|
|
227
|
+
model: str = "Model",
|
|
228
|
+
title: str = "Histogram of TP vs FP Scores",
|
|
229
|
+
xlabel: str = "Score",
|
|
230
|
+
ylabel: str = "Count"
|
|
231
|
+
):
|
|
232
|
+
"""
|
|
233
|
+
Create a score histogram to compare the number of true positives
|
|
234
|
+
and false positives based on the scores. This provides insight
|
|
235
|
+
on the optimal thresholds to use. Also draws count labels
|
|
236
|
+
on each histogram bar.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
tp_scores: np.ndarray
|
|
241
|
+
All the scores for the true positives.
|
|
242
|
+
fp_scores: np.ndarray
|
|
243
|
+
All the scores for the false positives.
|
|
244
|
+
model: str
|
|
245
|
+
The name of the model evaluated.
|
|
246
|
+
title: str
|
|
247
|
+
Provide the title for the plot.
|
|
248
|
+
xlabel: str
|
|
249
|
+
The x-axis label.
|
|
250
|
+
ylabel: str
|
|
251
|
+
The y-axis label.
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
matplotlib.figure.Figure
|
|
256
|
+
This shows the histogram comparing the scores of the
|
|
257
|
+
true positives and false positives.
|
|
258
|
+
"""
|
|
259
|
+
# Define histogram bins: 0.0 to 1.0 with step of 0.05
|
|
260
|
+
bins = np.arange(0, 1.05, 0.05)
|
|
261
|
+
|
|
262
|
+
# Compute histograms (counts only)
|
|
263
|
+
tp_hist, _ = np.histogram(tp_scores, bins=bins)
|
|
264
|
+
fp_hist, _ = np.histogram(fp_scores, bins=bins)
|
|
265
|
+
|
|
266
|
+
# Plot histograms
|
|
267
|
+
fig, ax = plt.subplots(1, 1, figsize=(10, 5), tight_layout=True)
|
|
268
|
+
bin_centers = (bins[:-1] + bins[1:]) / 2
|
|
269
|
+
|
|
270
|
+
tp_bars = ax.bar(bin_centers - 0.01,
|
|
271
|
+
tp_hist,
|
|
272
|
+
width=0.02,
|
|
273
|
+
label='True Positives',
|
|
274
|
+
alpha=0.7, color='green')
|
|
275
|
+
fp_bars = ax.bar(bin_centers + 0.01,
|
|
276
|
+
fp_hist,
|
|
277
|
+
width=0.02,
|
|
278
|
+
label='False Positives',
|
|
279
|
+
alpha=0.7, color='red')
|
|
280
|
+
|
|
281
|
+
# Annotate each bar with count
|
|
282
|
+
for i, (tp_bar, fp_bar) in enumerate(zip(tp_bars, fp_bars)):
|
|
283
|
+
tp_count = tp_hist[i]
|
|
284
|
+
fp_count = fp_hist[i]
|
|
285
|
+
if tp_count > 0:
|
|
286
|
+
ax.text(tp_bar.get_x() + tp_bar.get_width() / 2,
|
|
287
|
+
tp_bar.get_height() + 0.5,
|
|
288
|
+
str(tp_count),
|
|
289
|
+
ha='center', va='bottom', fontsize=8, color='green')
|
|
290
|
+
if fp_count > 0:
|
|
291
|
+
ax.text(fp_bar.get_x() + fp_bar.get_width() / 2,
|
|
292
|
+
fp_bar.get_height() + 0.5,
|
|
293
|
+
str(fp_count),
|
|
294
|
+
ha='center', va='bottom', fontsize=8, color='red')
|
|
295
|
+
|
|
296
|
+
ax.set_xlabel(xlabel)
|
|
297
|
+
ax.set_ylabel(ylabel)
|
|
298
|
+
ax.set_title(f'{model} {title}')
|
|
299
|
+
ax.set_xticks(bins)
|
|
300
|
+
ax.legend()
|
|
301
|
+
ax.grid(True)
|
|
302
|
+
return fig
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def plot_pr_curve(
|
|
306
|
+
precision: np.ndarray,
|
|
307
|
+
recall: np.ndarray,
|
|
308
|
+
ap: np.ndarray,
|
|
309
|
+
names: list = [],
|
|
310
|
+
model: str = "Model",
|
|
311
|
+
iou_threshold: float = 0.50
|
|
312
|
+
) -> matplotlib.figure.Figure:
|
|
313
|
+
"""
|
|
314
|
+
Version 2 Ploting precision and recall per class and the average metric.
|
|
315
|
+
Use this method for YoloV5 implementation of precision recall
|
|
316
|
+
curve.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
precision: (NxM) np.ndarray
|
|
321
|
+
N => number of classes and M is the number of precision values.
|
|
322
|
+
recall: (NxM) np.ndarray
|
|
323
|
+
N => number of classes and M is the number of recall values.
|
|
324
|
+
ap: (NxM) np.ndarray
|
|
325
|
+
N => number of classes, M => 10 denoting each IoU threshold
|
|
326
|
+
from (0.5 to 0.95 at 0.05 intervals).
|
|
327
|
+
names: list
|
|
328
|
+
This contains the unique string labels captured in the order
|
|
329
|
+
that respects the data for precision and recall.
|
|
330
|
+
model: str
|
|
331
|
+
The name of the model evaluated.
|
|
332
|
+
iou_threshold: float
|
|
333
|
+
The iou threshold used for the mAP calculation.
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
matplotlib.figure.Figure
|
|
338
|
+
The precision recall plot where recall is denoted
|
|
339
|
+
on the x-axis and precision is denoted
|
|
340
|
+
on the y-axis.
|
|
341
|
+
"""
|
|
342
|
+
# Precision-recall curve
|
|
343
|
+
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
344
|
+
if len(precision) == 0:
|
|
345
|
+
ax.set_xlabel("Recall")
|
|
346
|
+
ax.set_ylabel("Precision")
|
|
347
|
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
|
348
|
+
ax.set_title(f'{model} Precision-Recall Curve')
|
|
349
|
+
return fig
|
|
350
|
+
|
|
351
|
+
precision = np.stack(precision, axis=1)
|
|
352
|
+
if (0 < len(names) < 21): # display per-class legend if < 21 classes
|
|
353
|
+
for i, y in enumerate(precision.T):
|
|
354
|
+
# plot(recall, precision)
|
|
355
|
+
ax.plot(recall, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}")
|
|
356
|
+
else:
|
|
357
|
+
# plot(recall, precision)
|
|
358
|
+
ax.plot(recall, precision, linewidth=1, color="grey")
|
|
359
|
+
|
|
360
|
+
ax.plot(
|
|
361
|
+
recall,
|
|
362
|
+
precision.mean(1),
|
|
363
|
+
linewidth=3,
|
|
364
|
+
color="blue",
|
|
365
|
+
label="all classes %.3f mAP@%.2f" % (ap[:, 0].mean(), iou_threshold))
|
|
366
|
+
ax.set_xlabel("Recall")
|
|
367
|
+
ax.set_ylabel("Precision")
|
|
368
|
+
ax.set_xlim(0, 1)
|
|
369
|
+
ax.set_ylim(0, 1)
|
|
370
|
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
|
371
|
+
ax.set_title(f'{model} Precision-Recall Curve')
|
|
372
|
+
return fig
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def plot_mc_curve(
|
|
376
|
+
px: np.ndarray,
|
|
377
|
+
py: np.ndarray,
|
|
378
|
+
names: list = [],
|
|
379
|
+
model: str = "Model",
|
|
380
|
+
xlabel: str = 'Confidence',
|
|
381
|
+
ylabel: str = 'Metric'
|
|
382
|
+
) -> matplotlib.figure.Figure:
|
|
383
|
+
"""
|
|
384
|
+
This function is used for plotting either the F1-curve or the
|
|
385
|
+
precision/recall versus confidence curves.
|
|
386
|
+
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
px: (NxM) np.ndarray
|
|
390
|
+
N => number of classes.
|
|
391
|
+
py: (NxM) np.ndarray
|
|
392
|
+
This could be values for the F1, precision, or recall.
|
|
393
|
+
names: list
|
|
394
|
+
This contains the unique string labels captured in the order
|
|
395
|
+
that respects the data for precision and recall.
|
|
396
|
+
model: str
|
|
397
|
+
The name of the model evaluated.
|
|
398
|
+
xlabel: str
|
|
399
|
+
The metric on the x-axis.
|
|
400
|
+
ylabel: str
|
|
401
|
+
The metric on the y-axis.
|
|
402
|
+
|
|
403
|
+
Returns
|
|
404
|
+
-------
|
|
405
|
+
matplotlib.figure.Figure
|
|
406
|
+
The plot where recall is denoted
|
|
407
|
+
on the x-axis and either is denoted
|
|
408
|
+
on the y-axis.
|
|
409
|
+
"""
|
|
410
|
+
# Metric-confidence curve
|
|
411
|
+
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
412
|
+
|
|
413
|
+
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
|
414
|
+
for i, y in enumerate(py):
|
|
415
|
+
# plot(confidence, metric)
|
|
416
|
+
ax.plot(px, y, linewidth=1, label=f'{names[i]}')
|
|
417
|
+
else:
|
|
418
|
+
# plot(confidence, metric)
|
|
419
|
+
ax.plot(px, py.T, linewidth=1, color='grey')
|
|
420
|
+
|
|
421
|
+
y = py.mean(0)
|
|
422
|
+
ax.plot(px, y, linewidth=3, color='blue',
|
|
423
|
+
label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
|
|
424
|
+
ax.set_xlabel(xlabel)
|
|
425
|
+
ax.set_ylabel(ylabel)
|
|
426
|
+
ax.set_xlim(0, 1)
|
|
427
|
+
ax.set_ylim(0, 1)
|
|
428
|
+
ax.set_title(f'{model} {ylabel}-{xlabel} Curve')
|
|
429
|
+
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
|
430
|
+
return fig
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def plot_confusion_matrix(
|
|
434
|
+
confusion_data: np.ndarray,
|
|
435
|
+
labels: list,
|
|
436
|
+
model: str = "Model"
|
|
437
|
+
) -> matplotlib.figure.Figure:
|
|
438
|
+
"""
|
|
439
|
+
Plots the confusion matrix using the method defined below:
|
|
440
|
+
https://stackoverflow.com/questions/5821125/how-to-plot-confusion-matrix-with-string-axis-rather-than-integer-in-python/74152927#74152927
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
confusion_data: np.ndarray
|
|
445
|
+
This is a square matrix representing the confusion matrix data
|
|
446
|
+
where the rows are the predictions and the columns are the
|
|
447
|
+
ground truth.
|
|
448
|
+
labels: list
|
|
449
|
+
This contains the unique string labels in the dataset.
|
|
450
|
+
model: str
|
|
451
|
+
The name of the model being validated.
|
|
452
|
+
|
|
453
|
+
Returns
|
|
454
|
+
--------
|
|
455
|
+
matplotlib.figure.Figure
|
|
456
|
+
The confusion matrix plot.
|
|
457
|
+
"""
|
|
458
|
+
norm_conf = []
|
|
459
|
+
for i in confusion_data:
|
|
460
|
+
a = 0
|
|
461
|
+
tmp_arr = []
|
|
462
|
+
a = sum(i, 0)
|
|
463
|
+
for j in i:
|
|
464
|
+
try:
|
|
465
|
+
tmp_arr.append(float(j) / float(a))
|
|
466
|
+
except ZeroDivisionError:
|
|
467
|
+
tmp_arr.append(0.)
|
|
468
|
+
norm_conf.append(tmp_arr)
|
|
469
|
+
|
|
470
|
+
fig = plt.figure()
|
|
471
|
+
plt.clf()
|
|
472
|
+
ax = fig.add_subplot(111)
|
|
473
|
+
ax.set_aspect(1)
|
|
474
|
+
res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet,
|
|
475
|
+
interpolation='nearest')
|
|
476
|
+
width, height = confusion_data.shape
|
|
477
|
+
|
|
478
|
+
for x in range(width):
|
|
479
|
+
for y in range(height):
|
|
480
|
+
ax.annotate(str(int(confusion_data[x][y])), xy=(y, x),
|
|
481
|
+
horizontalalignment='center',
|
|
482
|
+
verticalalignment='center')
|
|
483
|
+
fig.colorbar(res)
|
|
484
|
+
plt.xticks(range(width), labels[:width], rotation="vertical")
|
|
485
|
+
plt.yticks(range(height), labels[:height])
|
|
486
|
+
plt.ylabel("Prediction")
|
|
487
|
+
plt.xlabel("Ground Truth")
|
|
488
|
+
plt.title(f"{model} Confusion Matrix")
|
|
489
|
+
return fig
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def close_figures(figures: List[matplotlib.figure.Figure]):
|
|
493
|
+
"""
|
|
494
|
+
Closes the matplotlib figures opened to prevent
|
|
495
|
+
errors such as "Fail to allocate bitmap."
|
|
496
|
+
|
|
497
|
+
Parameters
|
|
498
|
+
----------
|
|
499
|
+
figures: List[matplotlib.figure.Figure]
|
|
500
|
+
Contains matplotlib.pyplot figures to close.
|
|
501
|
+
"""
|
|
502
|
+
if len(figures) > 0:
|
|
503
|
+
for figure in figures:
|
|
504
|
+
plt.close(figure)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class ConfusionMatrix:
|
|
508
|
+
"""
|
|
509
|
+
This confusion matrix implementation was taken from YoloV7 to
|
|
510
|
+
follow their validation implementation.
|
|
511
|
+
|
|
512
|
+
Parameters
|
|
513
|
+
-----------
|
|
514
|
+
nc: int
|
|
515
|
+
The number of classes in the dataset.
|
|
516
|
+
conf: float
|
|
517
|
+
The confidence threshold for plotting.
|
|
518
|
+
iou_thres: float
|
|
519
|
+
The IoU threshold for plotting.
|
|
520
|
+
offset: int
|
|
521
|
+
If the dataset labels already contains background,
|
|
522
|
+
then this offset is 0. Otherwise the offset is +1
|
|
523
|
+
to include the background class.
|
|
524
|
+
"""
|
|
525
|
+
# Updated version of
|
|
526
|
+
# https://github.com/kaanakan/object_detection_confusion_matrix
|
|
527
|
+
|
|
528
|
+
def __init__(
|
|
529
|
+
self,
|
|
530
|
+
nc: int,
|
|
531
|
+
conf: float = 0.25,
|
|
532
|
+
iou_thres: float = 0.45,
|
|
533
|
+
offset: int = 0
|
|
534
|
+
):
|
|
535
|
+
self.matrix = np.zeros((nc + offset, nc + offset), dtype=np.int32)
|
|
536
|
+
self.nc = nc # number of classes
|
|
537
|
+
self.conf = conf
|
|
538
|
+
self.iou_thres = iou_thres
|
|
539
|
+
self.offset = offset
|
|
540
|
+
|
|
541
|
+
def process_batch(self, dt_instance: DetectionInstance,
|
|
542
|
+
gt_instance: DetectionInstance):
|
|
543
|
+
"""
|
|
544
|
+
Return intersection-over-union (Jaccard index) of boxes.
|
|
545
|
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
546
|
+
|
|
547
|
+
Parameters
|
|
548
|
+
----------
|
|
549
|
+
dt_instance: DetectionInstance
|
|
550
|
+
A prediction instance container of the boxes, labels, and scores.
|
|
551
|
+
gt_instance: DetectionInstance
|
|
552
|
+
A ground truth instance container of the boxes and the labels.
|
|
553
|
+
"""
|
|
554
|
+
|
|
555
|
+
if dt_instance is None:
|
|
556
|
+
gt_classes = gt_instance.labels.astype(np.int32)
|
|
557
|
+
for gc in gt_classes:
|
|
558
|
+
self.matrix[0, gc + self.offset] += 1 # background FN
|
|
559
|
+
return
|
|
560
|
+
|
|
561
|
+
filt = dt_instance.scores > self.conf
|
|
562
|
+
dt_boxes = dt_instance.boxes[filt]
|
|
563
|
+
dt_classes = dt_instance.labels[filt]
|
|
564
|
+
|
|
565
|
+
gt_boxes = gt_instance.boxes
|
|
566
|
+
gt_classes = gt_instance.labels.astype(np.int32)
|
|
567
|
+
dt_classes = dt_classes.astype(np.int32)
|
|
568
|
+
iou = batch_iou(gt_boxes, dt_boxes)
|
|
569
|
+
|
|
570
|
+
x = np.where(iou > self.iou_thres)
|
|
571
|
+
if x[0].shape[0]:
|
|
572
|
+
matches = np.concatenate(
|
|
573
|
+
(np.stack(x, 1), iou[x[0], x[1]][:, None]), 1)
|
|
574
|
+
if x[0].shape[0] > 1:
|
|
575
|
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
576
|
+
matches = matches[np.unique(
|
|
577
|
+
matches[:, 1], return_index=True)[1]]
|
|
578
|
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
579
|
+
matches = matches[np.unique(
|
|
580
|
+
matches[:, 0], return_index=True)[1]]
|
|
581
|
+
else:
|
|
582
|
+
matches = np.zeros((0, 3))
|
|
583
|
+
|
|
584
|
+
n = matches.shape[0] > 0
|
|
585
|
+
m0, m1, _ = matches.transpose().astype(int)
|
|
586
|
+
for i, gc in enumerate(gt_classes):
|
|
587
|
+
j = m0 == i
|
|
588
|
+
# Asserting a unique match.
|
|
589
|
+
if n and sum(j) == 1:
|
|
590
|
+
self.matrix[dt_classes[m1[j]] + self.offset,
|
|
591
|
+
gc + self.offset] += 1 # correct
|
|
592
|
+
else:
|
|
593
|
+
self.matrix[0, gc + self.offset] += 1 # true background
|
|
594
|
+
|
|
595
|
+
matched_detections = m1 if n else np.array([], dtype=int)
|
|
596
|
+
for i, dc in enumerate(dt_classes):
|
|
597
|
+
if i not in matched_detections:
|
|
598
|
+
# false positive (predicted something not matched to GT)
|
|
599
|
+
self.matrix[dc + self.offset, 0] += 1
|
|
600
|
+
|
|
601
|
+
def plot(self, names=()) -> matplotlib.figure.Figure:
|
|
602
|
+
"""
|
|
603
|
+
Plots the Confusion Matrix.
|
|
604
|
+
|
|
605
|
+
Parameters
|
|
606
|
+
----------
|
|
607
|
+
names: tuple
|
|
608
|
+
All the unique labels in the dataset.
|
|
609
|
+
|
|
610
|
+
Returns
|
|
611
|
+
-------
|
|
612
|
+
matplotlib.figure.Figure
|
|
613
|
+
The Confusion Matrix figure.
|
|
614
|
+
"""
|
|
615
|
+
array = self.matrix / \
|
|
616
|
+
(self.matrix.sum(0).reshape(1, self.nc + self.offset) + 1E-6) # normalize
|
|
617
|
+
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
|
618
|
+
|
|
619
|
+
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
|
620
|
+
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
|
|
621
|
+
labels = (0 < len(names) < 99) and len(
|
|
622
|
+
names) == self.nc # apply names to ticklabels
|
|
623
|
+
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
|
|
624
|
+
xticklabels=names if labels else "auto",
|
|
625
|
+
yticklabels=names if labels else "auto").set_facecolor((1, 1, 1))
|
|
626
|
+
fig.axes[0].set_xlabel('True')
|
|
627
|
+
fig.axes[0].set_ylabel('Predicted')
|
|
628
|
+
return fig
|
|
629
|
+
|
|
630
|
+
def print(self):
|
|
631
|
+
"""
|
|
632
|
+
Prints the Confusion Matrix.
|
|
633
|
+
"""
|
|
634
|
+
for i in range(self.nc + self.offset):
|
|
635
|
+
print(' '.join(map(str, self.matrix[i])))
|