edgefirst-validator 4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. deepview/modelpack/utils/argmax.py +16 -0
  2. edgefirst/validator/__init__.py +1 -0
  3. edgefirst/validator/__main__.py +375 -0
  4. edgefirst/validator/datasets/__init__.py +118 -0
  5. edgefirst/validator/datasets/cache.py +296 -0
  6. edgefirst/validator/datasets/core.py +250 -0
  7. edgefirst/validator/datasets/darknet.py +446 -0
  8. edgefirst/validator/datasets/database.py +1067 -0
  9. edgefirst/validator/datasets/instance/__init__.py +4 -0
  10. edgefirst/validator/datasets/instance/core.py +222 -0
  11. edgefirst/validator/datasets/instance/detection.py +145 -0
  12. edgefirst/validator/datasets/instance/multitask.py +80 -0
  13. edgefirst/validator/datasets/instance/segmentation.py +120 -0
  14. edgefirst/validator/datasets/utils/fetch.py +682 -0
  15. edgefirst/validator/datasets/utils/readers.py +425 -0
  16. edgefirst/validator/datasets/utils/transformations.py +1695 -0
  17. edgefirst/validator/evaluators/__init__.py +17 -0
  18. edgefirst/validator/evaluators/callbacks/__init__.py +3 -0
  19. edgefirst/validator/evaluators/callbacks/core.py +192 -0
  20. edgefirst/validator/evaluators/callbacks/plots.py +900 -0
  21. edgefirst/validator/evaluators/callbacks/studio.py +234 -0
  22. edgefirst/validator/evaluators/core.py +257 -0
  23. edgefirst/validator/evaluators/detection.py +749 -0
  24. edgefirst/validator/evaluators/multitask.py +270 -0
  25. edgefirst/validator/evaluators/parameters/__init__.py +53 -0
  26. edgefirst/validator/evaluators/parameters/core.py +554 -0
  27. edgefirst/validator/evaluators/parameters/dataset.py +239 -0
  28. edgefirst/validator/evaluators/parameters/model.py +338 -0
  29. edgefirst/validator/evaluators/parameters/validation.py +528 -0
  30. edgefirst/validator/evaluators/segmentation.py +729 -0
  31. edgefirst/validator/evaluators/utils/__init__.py +3 -0
  32. edgefirst/validator/evaluators/utils/classify.py +292 -0
  33. edgefirst/validator/evaluators/utils/match.py +262 -0
  34. edgefirst/validator/evaluators/utils/timer.py +132 -0
  35. edgefirst/validator/metrics/__init__.py +9 -0
  36. edgefirst/validator/metrics/data/__init__.py +7 -0
  37. edgefirst/validator/metrics/data/label.py +668 -0
  38. edgefirst/validator/metrics/data/metrics.py +759 -0
  39. edgefirst/validator/metrics/data/plots.py +476 -0
  40. edgefirst/validator/metrics/data/stats.py +507 -0
  41. edgefirst/validator/metrics/detection.py +595 -0
  42. edgefirst/validator/metrics/segmentation.py +173 -0
  43. edgefirst/validator/metrics/utils/math.py +717 -0
  44. edgefirst/validator/publishers/__init__.py +3 -0
  45. edgefirst/validator/publishers/console.py +147 -0
  46. edgefirst/validator/publishers/studio.py +128 -0
  47. edgefirst/validator/publishers/tensorboard.py +119 -0
  48. edgefirst/validator/publishers/utils/logger.py +111 -0
  49. edgefirst/validator/publishers/utils/table.py +403 -0
  50. edgefirst/validator/runners/__init__.py +8 -0
  51. edgefirst/validator/runners/core.py +727 -0
  52. edgefirst/validator/runners/deepviewrt.py +177 -0
  53. edgefirst/validator/runners/hailo.py +263 -0
  54. edgefirst/validator/runners/keras.py +150 -0
  55. edgefirst/validator/runners/kinara.py +265 -0
  56. edgefirst/validator/runners/offline.py +228 -0
  57. edgefirst/validator/runners/onnx.py +241 -0
  58. edgefirst/validator/runners/processing/decode.py +320 -0
  59. edgefirst/validator/runners/processing/dvapi.py +4192 -0
  60. edgefirst/validator/runners/processing/nms.py +637 -0
  61. edgefirst/validator/runners/processing/outputs.py +507 -0
  62. edgefirst/validator/runners/tensorrt.py +321 -0
  63. edgefirst/validator/runners/tflite.py +221 -0
  64. edgefirst/validator/validate.py +843 -0
  65. edgefirst/validator/visualize/__init__.py +3 -0
  66. edgefirst/validator/visualize/detection.py +623 -0
  67. edgefirst/validator/visualize/segmentation.py +281 -0
  68. edgefirst/validator/visualize/utils/plots.py +635 -0
  69. edgefirst_validator-4.2.1.dist-info/METADATA +111 -0
  70. edgefirst_validator-4.2.1.dist-info/RECORD +73 -0
  71. edgefirst_validator-4.2.1.dist-info/WHEEL +5 -0
  72. edgefirst_validator-4.2.1.dist-info/entry_points.txt +2 -0
  73. edgefirst_validator-4.2.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,900 @@
1
+ from __future__ import annotations
2
+
3
+ import collections.abc
4
+ from typing import TYPE_CHECKING, Tuple
5
+
6
+ import numpy as np
7
+
8
+ from edgefirst.validator.datasets.utils.transformations import convert_to_serializable
9
+ from edgefirst.validator.evaluators.callbacks import Callback
10
+
11
+ if TYPE_CHECKING:
12
+ from edgefirst.validator.publishers import StudioPublisher
13
+ from edgefirst.validator.evaluators import CombinedParameters
14
+ from edgefirst.validator.metrics import Metrics, Plots
15
+
16
+
17
+ class PlotsCallback(Callback):
18
+ """
19
+ Generates the plots compatible for ApexCharts
20
+ and saves as JSON files to be published to EdgeFirst Studio.
21
+
22
+ Parameters
23
+ -----------
24
+ studio_publisher: StudioPublisher
25
+ Publishes metrics, timings, plots, and
26
+ progress to EdgeFirst Studio.
27
+ parameters: CombinedParameters
28
+ These are the model, dataset, and validation parameters
29
+ set from the command line.
30
+ stage: str
31
+ The current stage to update for the progress in Studio.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ studio_publisher: StudioPublisher,
37
+ parameters: CombinedParameters,
38
+ stage: str = "validate"
39
+ ):
40
+ super(PlotsCallback, self).__init__(studio_publisher=studio_publisher,
41
+ parameters=parameters,
42
+ stage=stage)
43
+
44
+ def create_apexchart_bar(
45
+ self,
46
+ series: list,
47
+ title: str,
48
+ categories: list,
49
+ xlabel: str = None,
50
+ ylabel: str = None,
51
+ enabled_labels: bool = True
52
+ ) -> dict:
53
+ """
54
+ Create a bar chart config dictionary for ApexCharts.
55
+
56
+ Parameters
57
+ ----------
58
+ series : list
59
+ Data series for the bar chart.
60
+ title : str
61
+ Title of the chart.
62
+ categories : list
63
+ X-axis categories.
64
+ xlabel : str, optional
65
+ Label for the x-axis.
66
+ ylabel : str, optional
67
+ Label for the y-axis.
68
+ enabled_labels : bool, default=True
69
+ Whether to show data labels.
70
+
71
+ Returns
72
+ -------
73
+ dict
74
+ Configuration dictionary for an ApexCharts bar chart.
75
+ """
76
+ chart = {
77
+ "series": series,
78
+ "chart": {"type": "bar"},
79
+ "title": {"text": title},
80
+ "dataLabels": {
81
+ "enabled": enabled_labels,
82
+ "style": {
83
+ "colors": ['#000000']
84
+ },
85
+ },
86
+ }
87
+
88
+ if xlabel is not None:
89
+ chart["xaxis"] = {
90
+ "categories": categories,
91
+ "title": {
92
+ "text": xlabel
93
+ }
94
+ }
95
+ else:
96
+ chart["xaxis"] = {"categories": categories}
97
+
98
+ if ylabel is not None:
99
+ chart["yaxis"] = {
100
+ "title": {
101
+ "text": ylabel
102
+ }
103
+ }
104
+ return chart
105
+
106
+ def create_apexchart_pie(
107
+ self,
108
+ series: list,
109
+ title: str,
110
+ categories: list
111
+ ) -> dict:
112
+ """
113
+ Creates a pie chart config dictionary for ApexCharts.
114
+
115
+ Parameters
116
+ ----------
117
+ series: list
118
+ A list of values to display as a pie chart. These
119
+ values will automatically be converted into percentages.
120
+ title: str
121
+ Specify the title for the chart.
122
+ categories: list
123
+ Specify the categories for each value in the series.
124
+
125
+ Returns
126
+ -------
127
+ dict
128
+ Configuration dictionary for an ApexCharts pie chart.
129
+ """
130
+ chart = {
131
+ "series": series,
132
+ "chart": {"type": "pie"},
133
+ "title": {"text": title},
134
+ "labels": categories,
135
+ }
136
+ return chart
137
+
138
+ def create_apexchart_grid(
139
+ self,
140
+ data: dict,
141
+ labels: list,
142
+ title: str
143
+ ) -> dict:
144
+ """
145
+ Create a heatmap chart config for ApexCharts from grid data.
146
+
147
+ Parameters
148
+ ----------
149
+ data : dict
150
+ Mapping of label to data rows (2D array-like).
151
+ labels : list
152
+ Class labels for axes.
153
+ title : str
154
+ Title of the chart.
155
+
156
+ Returns
157
+ -------
158
+ dict
159
+ Configuration dictionary for an ApexCharts heatmap chart.
160
+ """
161
+ chart = {
162
+ "series": [{"name": label, "data": row.tolist()}
163
+ for row, label in zip(data, labels)],
164
+ "chart": {"type": "heatmap"},
165
+ "xaxis": {
166
+ "categories": labels,
167
+ "title": {"text": "Ground Truth"}
168
+ },
169
+ "yaxis": {
170
+ "categories": labels,
171
+ "title": {"text": "Predictions"}
172
+ },
173
+ "title": {"text": title}
174
+ }
175
+
176
+ return chart
177
+
178
+ def create_apexchart_lines(
179
+ self,
180
+ x: np.ndarray,
181
+ data: np.ndarray,
182
+ labels: list,
183
+ title: str,
184
+ xlabel: str = "Recall",
185
+ ylabel: str = "Precision"
186
+ ) -> dict:
187
+ """
188
+ Create a line chart config for precision-recall visualization.
189
+
190
+ Parameters
191
+ ----------
192
+ x : np.ndarray
193
+ X-axis values (e.g., recall).
194
+ data : np.ndarray
195
+ Y-axis values (e.g., precision) per label.
196
+ labels : list
197
+ List of label names for the lines.
198
+ title : str
199
+ Title of the chart.
200
+ xlabel : str
201
+ The x-axis label.
202
+ ylabel : str
203
+ The y-axis label.
204
+
205
+ Returns
206
+ -------
207
+ dict
208
+ Configuration dictionary for an ApexCharts line chart.
209
+ """
210
+ lines = []
211
+ labels = labels.tolist() if isinstance(labels, np.ndarray) else labels
212
+ if len(x):
213
+ for i, row in enumerate(data):
214
+ if isinstance(row, collections.abc.Iterable):
215
+ lines.append({
216
+ "name": labels[i],
217
+ "data": np.concatenate([np.round(x[:, None], 2),
218
+ np.round(row[:, None], 2)],
219
+ axis=1).tolist()
220
+ })
221
+
222
+ chart = {
223
+ "series": lines,
224
+ "chart": {"type": "line"},
225
+ "xaxis": {
226
+ "type": "numeric",
227
+ "title": {"text": xlabel},
228
+ "min": 0.0,
229
+ "max": 1.0
230
+ },
231
+ "yaxis": {
232
+ "type": "numeric",
233
+ "title": {"text": ylabel},
234
+ "min": 0.0,
235
+ "max": 1.0
236
+ },
237
+ "title": {"text": title}
238
+ }
239
+
240
+ return chart
241
+
242
+ @staticmethod
243
+ def create_histogram(
244
+ data: np.ndarray,
245
+ num_bins: int = None
246
+ ) -> Tuple[list, list]:
247
+ """
248
+ Create histogram bin counts and edges from data.
249
+
250
+ Parameters
251
+ ----------
252
+ data : np.ndarray
253
+ 1D array of numeric values.
254
+ num_bins : int, optional
255
+ Number of bins to use (default uses Sturges' formula).
256
+
257
+ Returns
258
+ -------
259
+ couns: list
260
+ List of bin counts.
261
+ edges: list
262
+ List of bin edge values (as ints).
263
+ """
264
+ min_value = np.min(data)
265
+ max_value = np.max(data)
266
+
267
+ # Use Sturges' formula if number of bins is not provided.
268
+ if num_bins is None:
269
+ num_bins = int(np.ceil(np.log2(len(data)) + 1))
270
+
271
+ bin_edges = np.linspace(min_value, max_value, num_bins + 1)
272
+ bin_edges_int = np.floor(bin_edges).astype(int)
273
+ counts, _ = np.histogram(data, bins=bin_edges)
274
+
275
+ return counts.tolist(), bin_edges_int.tolist()
276
+
277
+ def save_detection_metrics(self, metrics: Metrics, plots: Plots):
278
+ """
279
+ Save detection charts and metrics as ApexChart JSON files.
280
+
281
+ Parameters
282
+ ----------
283
+ metrics : Metrics
284
+ Detection evaluation metrics.
285
+ plots : Plots
286
+ Curves and confusion matrix data for plotting.
287
+ """
288
+ # Save Confusion Matrix
289
+ chart = self.create_apexchart_grid(plots.confusion_matrix,
290
+ plots.confusion_labels,
291
+ title="Confusion Matrix [Detection]")
292
+ self.studio_publisher.save_json(
293
+ filename="detection_confusion_matrix.json",
294
+ plot=chart
295
+ )
296
+
297
+ # Save Precision vs. Recall Curve
298
+ precision = plots.py
299
+ recall = plots.px
300
+
301
+ x = np.linspace(0.0, 1.0, recall.shape[0])
302
+ x_downsampled = np.linspace(0.0, 1.0, 100)
303
+ r_downsampled = np.interp(x_downsampled, x, recall)
304
+ p_downsampled = []
305
+
306
+ for p in precision:
307
+ p = np.interp(x_downsampled, x, p)
308
+ p_downsampled.append(p)
309
+
310
+ p_downsampled = np.array(p_downsampled)
311
+
312
+ chart = self.create_apexchart_lines(
313
+ r_downsampled,
314
+ p_downsampled,
315
+ plots.curve_labels,
316
+ title="Precision vs. Recall [Detection]"
317
+ )
318
+
319
+ p_mean = p_downsampled.mean(0) if len(p_downsampled) else None
320
+ if p_mean is not None:
321
+ chart["series"].append({
322
+ "name": "all classes",
323
+ "data": np.concatenate(
324
+ [np.round(r_downsampled[:, None], 2),
325
+ np.round(p_mean[:, None], 2)], axis=1).tolist()
326
+ })
327
+
328
+ self.studio_publisher.save_json(
329
+ filename="detection_precision_recall.json",
330
+ plot=chart
331
+ )
332
+
333
+ if self.parameters.validation.method in ["ultralytics", "yolov7"]:
334
+ # Save Ultralytics Metrics
335
+ categories = ["Mean Precision", "Mean Recall", "F1",
336
+ "mAP@0.50", "mAP@0.75", "mAP@0.50:0.95"]
337
+ series = [{"name": "Score",
338
+ "data": [round(metrics.precision["mean"], 4),
339
+ round(metrics.recall["mean"], 4),
340
+ round(metrics.f1["mean"], 4),
341
+ round(metrics.precision["map"]["0.50"], 4),
342
+ round(metrics.precision["map"]["0.75"], 4),
343
+ round(metrics.precision["map"]["0.50:0.95"], 4)]}]
344
+
345
+ chart = self.create_apexchart_bar(
346
+ series=series,
347
+ title="Detection Metrics",
348
+ categories=categories
349
+ )
350
+ self.studio_publisher.save_json(
351
+ filename="detection_metrics.json",
352
+ plot=chart
353
+ )
354
+
355
+ # Save F1 Curve
356
+ f1_downsampled = []
357
+ for f1 in plots.f1:
358
+ f1 = np.interp(x_downsampled, x, f1)
359
+ f1_downsampled.append(f1)
360
+ f1_downsampled = np.array(f1_downsampled)
361
+
362
+ chart = self.create_apexchart_lines(
363
+ r_downsampled,
364
+ f1_downsampled,
365
+ plots.curve_labels,
366
+ title="F1 vs. Confidence [Detection]",
367
+ xlabel="Confidence",
368
+ ylabel="F1"
369
+ )
370
+ self.studio_publisher.save_json(
371
+ filename="detection_f1_curve.json",
372
+ plot=chart
373
+ )
374
+
375
+ # Save Precision Curve
376
+ precision_downsampled = []
377
+ for p in plots.precision:
378
+ p = np.interp(x_downsampled, x, p)
379
+ precision_downsampled.append(p)
380
+ precision_downsampled = np.array(precision_downsampled)
381
+
382
+ chart = self.create_apexchart_lines(
383
+ r_downsampled,
384
+ precision_downsampled,
385
+ plots.curve_labels,
386
+ title="Precision vs. Confidence [Detection]",
387
+ xlabel="Confidence",
388
+ )
389
+ self.studio_publisher.save_json(
390
+ filename="detection_precision_curve.json",
391
+ plot=chart
392
+ )
393
+
394
+ # Save Recall Curve
395
+ recall_downsampled = []
396
+ for r in plots.recall:
397
+ r = np.interp(x_downsampled, x, r)
398
+ recall_downsampled.append(r)
399
+ recall_downsampled = np.array(recall_downsampled)
400
+
401
+ chart = self.create_apexchart_lines(
402
+ r_downsampled,
403
+ recall_downsampled,
404
+ plots.curve_labels,
405
+ title="Recall vs. Confidence [Detection]",
406
+ xlabel="Confidence",
407
+ ylabel="Recall"
408
+ )
409
+ self.studio_publisher.save_json(
410
+ filename="detection_recall_curve.json",
411
+ plot=chart
412
+ )
413
+
414
+ else:
415
+ # Save EdgeFirst Overall Metrics
416
+ categories = ["accuracy", "precision", "recall"]
417
+ series = [{"data": [
418
+ round(metrics.accuracy["overall"], 4),
419
+ round(metrics.precision["overall"], 4),
420
+ round(metrics.recall["overall"], 4)
421
+ ]}]
422
+ chart = self.create_apexchart_bar(
423
+ series=series,
424
+ title="Overall Metrics [Detection]",
425
+ categories=categories
426
+ )
427
+ self.studio_publisher.save_json(
428
+ filename="detection_overall_metrics.json",
429
+ plot=chart
430
+ )
431
+
432
+ # Save EdgeFirst Mean Average Metrics
433
+ categories = ["mACC", "mAP", "mAR"]
434
+ series = []
435
+ for key in ["0.50", "0.75", "0.50:0.95"]:
436
+ series.append({"data": [
437
+ round(metrics.accuracy["macc"].get(key, 0), 4),
438
+ round(metrics.precision["map"].get(key, 0), 4),
439
+ round(metrics.recall["mar"].get(key, 0), 4),
440
+ ], "name": "IoU threshold @ %s" % (key)})
441
+
442
+ chart = self.create_apexchart_bar(
443
+ series=series,
444
+ title="Mean Average Metrics [Detection]",
445
+ categories=categories
446
+ )
447
+ self.studio_publisher.save_json(
448
+ filename="detection_metrics.json",
449
+ plot=chart
450
+ )
451
+
452
+ # Save Raw Classifications
453
+ categories = ["True Positives", "False Negatives",
454
+ "Classification False Positives",
455
+ "Localization False Positives"]
456
+ series = [{"data": [metrics.tp,
457
+ metrics.fn,
458
+ metrics.cfp,
459
+ metrics.lfp]}]
460
+ chart = self.create_apexchart_bar(
461
+ series=series,
462
+ title="Prediction Classifications",
463
+ categories=categories
464
+ )
465
+ self.studio_publisher.save_json(
466
+ filename="prediction_classifications.json",
467
+ plot=chart
468
+ )
469
+
470
+ # Save Class Histogram
471
+ # Only save this chart if there are multiple classes.
472
+ if len(plots.class_histogram_data.keys()) > 1:
473
+ series = []
474
+ categories = ["accuracy", "precision", "recall"]
475
+ for key, item in plots.class_histogram_data.items():
476
+ series.append({"data": [round(item.get('accuracy', 0), 4),
477
+ round(item.get('precision', 0), 4),
478
+ round(item.get('recall', 0), 4),
479
+ ], "name": key})
480
+
481
+ chart = self.create_apexchart_bar(
482
+ series=series,
483
+ title="Class Metrics [Detection]",
484
+ categories=categories,
485
+ enabled_labels=False
486
+ )
487
+ self.studio_publisher.save_json(
488
+ filename="detection_class_metrics.json",
489
+ plot=chart
490
+ )
491
+
492
+ # Save TP and FP scores Histogram
493
+ bins = np.arange(0, 1.05, 0.05) # 0.0 to 1.0 with step 0.05
494
+ tp_scores = np.concatenate(plots.tp_scores, axis=0)
495
+ fp_scores = np.concatenate(plots.fp_scores, axis=0)
496
+
497
+ tp_hist, _ = np.histogram(tp_scores, bins=bins)
498
+ fp_hist, _ = np.histogram(fp_scores, bins=bins)
499
+
500
+ # Convert bin ranges to readable category labels
501
+ categories = [
502
+ f"{bins[i]:.2f}-{bins[i+1]:.2f}" for i in range(len(bins) - 1)]
503
+
504
+ series = [
505
+ {
506
+ "name": "True Positives",
507
+ "data": tp_hist.tolist(),
508
+ "color": "#00FF00" # Green
509
+ },
510
+ {
511
+ "name": "False Positives",
512
+ "data": fp_hist.tolist(),
513
+ "color": "#FF0000" # Red
514
+ }
515
+ ]
516
+
517
+ chart = self.create_apexchart_bar(
518
+ series=series,
519
+ title="Histogram of True Positive vs False Positive Scores",
520
+ categories=categories,
521
+ xlabel="Score",
522
+ ylabel="Count",
523
+ enabled_labels=True
524
+ )
525
+
526
+ self.studio_publisher.save_json(
527
+ filename="tp_fp_scores.json",
528
+ plot=chart
529
+ )
530
+
531
+ # Save TP and FP IoU Histogram
532
+ tp_ious = np.concatenate(plots.tp_ious, axis=0)
533
+ fp_ious = np.concatenate(plots.fp_ious, axis=0)
534
+
535
+ tp_hist, _ = np.histogram(tp_ious, bins=bins)
536
+ fp_hist, _ = np.histogram(fp_ious, bins=bins)
537
+
538
+ series = [
539
+ {
540
+ "name": "True Positives",
541
+ "data": tp_hist.tolist(),
542
+ "color": "#00FF00" # Green
543
+ },
544
+ {
545
+ "name": "False Positives",
546
+ "data": fp_hist.tolist(),
547
+ "color": "#FF0000" # Red
548
+ }
549
+ ]
550
+
551
+ chart = self.create_apexchart_bar(
552
+ series=series,
553
+ title="Histogram of True Positive vs False Positive IoUs",
554
+ categories=categories,
555
+ xlabel="IoU",
556
+ ylabel="Count",
557
+ enabled_labels=True
558
+ )
559
+
560
+ self.studio_publisher.save_json(
561
+ filename="tp_fp_ious.json",
562
+ plot=chart
563
+ )
564
+
565
+ def save_segmentation_metrics(self, metrics: Metrics, plots: Plots):
566
+ """
567
+ Save segmentation metrics and class-wise histogram charts.
568
+
569
+ Parameters
570
+ ----------
571
+ metrics : Metrics
572
+ Segmentation evaluation metrics.
573
+ plots : Plots
574
+ Class histogram and plot data.
575
+ """
576
+ if (not self.parameters.model.common.semantic and
577
+ self.parameters.validation.method in ["ultralytics", "yolov7"]):
578
+ # Save Precision vs. Recall Curve
579
+ precision = plots.py
580
+ recall = plots.px
581
+
582
+ x = np.linspace(0.0, 1.0, recall.shape[0])
583
+ x_downsampled = np.linspace(0.0, 1.0, 100)
584
+ if len(recall):
585
+ r_downsampled = np.interp(x_downsampled, x, recall)
586
+ else:
587
+ r_downsampled = []
588
+ p_downsampled = []
589
+
590
+ for p in precision:
591
+ p = np.interp(x_downsampled, x, p)
592
+ p_downsampled.append(p)
593
+
594
+ p_downsampled = np.array(p_downsampled)
595
+
596
+ chart = self.create_apexchart_lines(
597
+ r_downsampled,
598
+ p_downsampled,
599
+ plots.curve_labels,
600
+ title="Precision vs. Recall [Segmentation]"
601
+ )
602
+
603
+ p_mean = p_downsampled.mean(0) if len(p_downsampled) else None
604
+ if p_mean is not None:
605
+ chart["series"].append({
606
+ "name": "all classes",
607
+ "data": np.concatenate(
608
+ [np.round(r_downsampled[:, None], 2),
609
+ np.round(p_mean[:, None], 2)], axis=1).tolist()
610
+ })
611
+
612
+ self.studio_publisher.save_json(
613
+ filename="segmentation_precision_recall.json",
614
+ plot=chart
615
+ )
616
+
617
+ # Save Ultralytics Metrics
618
+ categories = ["Mean Precision", "Mean Recall", "F1",
619
+ "mAP@0.50", "mAP@0.75", "mAP@0.50:0.95"]
620
+ series = [{"name": "Score",
621
+ "data": [round(metrics.precision["mean"], 4),
622
+ round(metrics.recall["mean"], 4),
623
+ round(metrics.f1["mean"], 4),
624
+ round(metrics.precision["map"]["0.50"], 4),
625
+ round(metrics.precision["map"]["0.75"], 4),
626
+ round(metrics.precision["map"]["0.50:0.95"], 4)]}]
627
+
628
+ chart = self.create_apexchart_bar(
629
+ series=series,
630
+ title="Instance Segmentation Metrics",
631
+ categories=categories
632
+ )
633
+ self.studio_publisher.save_json(
634
+ filename="instance_segmentation_metrics.json",
635
+ plot=chart
636
+ )
637
+
638
+ # Save F1 Curve
639
+ f1_downsampled = []
640
+ for f1 in plots.f1:
641
+ f1 = np.interp(x_downsampled, x, f1)
642
+ f1_downsampled.append(f1)
643
+ f1_downsampled = np.array(f1_downsampled)
644
+
645
+ chart = self.create_apexchart_lines(
646
+ r_downsampled,
647
+ f1_downsampled,
648
+ plots.curve_labels,
649
+ title="F1 vs. Confidence [Segmentation]",
650
+ xlabel="Confidence",
651
+ ylabel="F1"
652
+ )
653
+ self.studio_publisher.save_json(
654
+ filename="segmentation_f1_curve.json",
655
+ plot=chart
656
+ )
657
+
658
+ # Save Precision Curve
659
+ precision_downsampled = []
660
+ for p in plots.precision:
661
+ p = np.interp(x_downsampled, x, p)
662
+ precision_downsampled.append(p)
663
+ precision_downsampled = np.array(precision_downsampled)
664
+
665
+ chart = self.create_apexchart_lines(
666
+ r_downsampled,
667
+ precision_downsampled,
668
+ plots.curve_labels,
669
+ title="Precision vs. Confidence [Segmentation]",
670
+ xlabel="Confidence"
671
+ )
672
+ self.studio_publisher.save_json(
673
+ filename="segmentation_precision_curve.json",
674
+ plot=chart
675
+ )
676
+
677
+ # Save Recall Curve
678
+ recall_downsampled = []
679
+ for r in plots.recall:
680
+ r = np.interp(x_downsampled, x, r)
681
+ recall_downsampled.append(r)
682
+ recall_downsampled = np.array(recall_downsampled)
683
+
684
+ chart = self.create_apexchart_lines(
685
+ r_downsampled,
686
+ recall_downsampled,
687
+ plots.curve_labels,
688
+ title="Recall vs. Confidence [Segmentation]",
689
+ xlabel="Confidence",
690
+ ylabel="Recall"
691
+ )
692
+ self.studio_publisher.save_json(
693
+ filename="segmentation_recall_curve.json",
694
+ plot=chart
695
+ )
696
+
697
+ else:
698
+ # Save Segmentation Metrics
699
+ series = [{"data": [round(metrics.accuracy["overall"], 4),
700
+ round(metrics.f1["overall"], 4),
701
+ round(metrics.iou["mean"], 4),
702
+ round(metrics.precision["mean"], 4),
703
+ round(metrics.recall["mean"], 4)]}]
704
+ categories = ["Accuracy", "F1", "Mean IoU",
705
+ "Mean Precision", "Mean Recall"]
706
+
707
+ chart = self.create_apexchart_bar(
708
+ series=series,
709
+ title='Semantic Segmentation Metrics',
710
+ categories=categories
711
+ )
712
+
713
+ self.studio_publisher.save_json(
714
+ filename="semantic_segmentation_metrics.json",
715
+ plot=chart
716
+ )
717
+
718
+ # Save Class Histogram
719
+ # Only save this chart if there are multiple classes.
720
+ if len(plots.class_histogram_data.keys()) > 1:
721
+ series = []
722
+ for key, item in plots.class_histogram_data.items():
723
+ series.append({"data": [round(item.get('accuracy', 0), 4),
724
+ round(item.get('precision', 0), 4),
725
+ round(item.get('recall', 0), 4),
726
+ ], "name": key})
727
+
728
+ chart = self.create_apexchart_bar(
729
+ series=series,
730
+ title="Segmentation Class Metrics",
731
+ categories=categories,
732
+ enabled_labels=False
733
+ )
734
+ self.studio_publisher.save_json(
735
+ filename="segmentation_class_metrics.json",
736
+ plot=chart
737
+ )
738
+
739
+ def save_timings(self, timings: dict):
740
+ """
741
+ Save model timing metrics for input, inference, and output stages.
742
+
743
+ Parameters
744
+ ----------
745
+ timings : dict
746
+ Timing stats (min, max, avg) in milliseconds.
747
+ """
748
+ categories = ["Input Time", "Inference Time", "Output Time"]
749
+ keys = ["input_time", "inference_time", "output_time"]
750
+
751
+ # Create a bar chart of the timings.
752
+ series = []
753
+ for name in ["Min", "Max", "Avg"]:
754
+ data = []
755
+ for key in keys:
756
+ data.append(
757
+ round(float(timings.get(f"{name.lower()}_{key}")), 2))
758
+ series.append({"data": data, "name": name})
759
+
760
+ chart = self.create_apexchart_bar(
761
+ series=series,
762
+ title='Timings (ms)',
763
+ categories=categories
764
+ )
765
+
766
+ self.studio_publisher.save_json(
767
+ filename="timings.json",
768
+ plot=chart
769
+ )
770
+
771
+ # Create a pie chart of the timings.
772
+ series = []
773
+ for key in ["avg_input_time", "avg_inference_time",
774
+ "avg_output_time"]:
775
+ series.append(round(float(timings.get(key, 2))))
776
+
777
+ chart = self.create_apexchart_pie(
778
+ series=series,
779
+ title="Distribution of the Average Timings",
780
+ categories=categories
781
+ )
782
+
783
+ self.studio_publisher.save_json(
784
+ filename="average_timings.json",
785
+ plot=chart
786
+ )
787
+
788
+ def post_metrics(self, logs=None):
789
+ """
790
+ Post the final metrics to EdgeFirst Studio.
791
+
792
+ Parameters
793
+ ----------
794
+ logs: dict, optional
795
+ This is a container of the final metrics.
796
+ """
797
+ metrics = dict()
798
+ if "multitask" in logs.keys():
799
+ metrics = logs.get("multitask")
800
+ metrics = metrics.to_dict(method=self.parameters.validation.method)
801
+
802
+ elif "detection" in logs.keys():
803
+ metrics = logs.get("detection")
804
+ metrics = metrics.to_dict(with_boxes=True,
805
+ method=self.parameters.validation.method)
806
+
807
+ elif "segmentation" in logs.keys():
808
+ metrics = logs.get("segmentation")
809
+ metrics = metrics.to_dict(with_boxes=False,
810
+ method=self.parameters.validation.method)
811
+
812
+ parameters = self.parameters.to_dict()
813
+ metrics["parameters"] = parameters
814
+
815
+ self.studio_publisher.post_metrics(convert_to_serializable(metrics))
816
+
817
+ def on_test_batch_end(self, step: int, logs=None):
818
+ """
819
+ Update progress status at the end of a validation batch.
820
+
821
+ Parameters
822
+ ----------
823
+ step : int
824
+ Current validation batch index.
825
+ logs : dict, optional
826
+ Contains total number of steps for percentage calculation.
827
+ """
828
+
829
+ percentage = 0
830
+ if logs is not None:
831
+ total = logs.get("total")
832
+ if total > 0:
833
+ percentage = int((step / total) * 100)
834
+
835
+ if percentage % 5 == 0:
836
+ self.studio_publisher.update_stage(
837
+ stage=self.stage,
838
+ status="running",
839
+ message=self.message,
840
+ percentage=percentage
841
+ )
842
+
843
+ def on_test_error(self, step: int, error, logs=None):
844
+ """
845
+ Report an error during validation and update the progress.
846
+
847
+ Parameters
848
+ ----------
849
+ step : int
850
+ Batch step at which the error occurred.
851
+ error : Exception
852
+ The exception raised during validation.
853
+ logs : dict, optional
854
+ Contains total number of steps for percentage calculation.
855
+ """
856
+ percentage = 0
857
+ if logs is not None:
858
+ total = logs.get("total")
859
+ if total > 0:
860
+ percentage = int((step / total) * 100)
861
+
862
+ self.studio_publisher.update_stage(
863
+ stage=self.stage,
864
+ status="error",
865
+ message=str(error),
866
+ percentage=percentage
867
+ )
868
+
869
+ def on_test_end(self, logs=None):
870
+ """
871
+ Report the final stages of validation
872
+ and post the metrics.
873
+
874
+ Parameters
875
+ ----------
876
+ logs : dict, optional
877
+ Contains the metrics.
878
+ """
879
+ plots = logs.get("plots")
880
+ timings = logs.get("timings")
881
+
882
+ if "multitask" in logs.keys():
883
+ self.save_detection_metrics(
884
+ logs.get("multitask").detection_metrics, plots.detection_plots)
885
+ self.save_segmentation_metrics(
886
+ logs.get("multitask").segmentation_metrics, plots.segmentation_plots)
887
+ elif "detection" in logs.keys():
888
+ self.save_detection_metrics(logs.get("detection"), plots)
889
+ elif "segmentation" in logs.keys():
890
+ self.save_segmentation_metrics(logs.get("segmentation"), plots)
891
+ self.save_timings(timings)
892
+ self.post_metrics(logs)
893
+
894
+ self.studio_publisher.update_stage(
895
+ stage=self.stage,
896
+ status="complete",
897
+ message=self.message,
898
+ percentage=100
899
+ )
900
+ self.studio_publisher.post_plots()