edgefirst-validator 4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. deepview/modelpack/utils/argmax.py +16 -0
  2. edgefirst/validator/__init__.py +1 -0
  3. edgefirst/validator/__main__.py +375 -0
  4. edgefirst/validator/datasets/__init__.py +118 -0
  5. edgefirst/validator/datasets/cache.py +296 -0
  6. edgefirst/validator/datasets/core.py +250 -0
  7. edgefirst/validator/datasets/darknet.py +446 -0
  8. edgefirst/validator/datasets/database.py +1067 -0
  9. edgefirst/validator/datasets/instance/__init__.py +4 -0
  10. edgefirst/validator/datasets/instance/core.py +222 -0
  11. edgefirst/validator/datasets/instance/detection.py +145 -0
  12. edgefirst/validator/datasets/instance/multitask.py +80 -0
  13. edgefirst/validator/datasets/instance/segmentation.py +120 -0
  14. edgefirst/validator/datasets/utils/fetch.py +682 -0
  15. edgefirst/validator/datasets/utils/readers.py +425 -0
  16. edgefirst/validator/datasets/utils/transformations.py +1695 -0
  17. edgefirst/validator/evaluators/__init__.py +17 -0
  18. edgefirst/validator/evaluators/callbacks/__init__.py +3 -0
  19. edgefirst/validator/evaluators/callbacks/core.py +192 -0
  20. edgefirst/validator/evaluators/callbacks/plots.py +900 -0
  21. edgefirst/validator/evaluators/callbacks/studio.py +234 -0
  22. edgefirst/validator/evaluators/core.py +257 -0
  23. edgefirst/validator/evaluators/detection.py +749 -0
  24. edgefirst/validator/evaluators/multitask.py +270 -0
  25. edgefirst/validator/evaluators/parameters/__init__.py +53 -0
  26. edgefirst/validator/evaluators/parameters/core.py +554 -0
  27. edgefirst/validator/evaluators/parameters/dataset.py +239 -0
  28. edgefirst/validator/evaluators/parameters/model.py +338 -0
  29. edgefirst/validator/evaluators/parameters/validation.py +528 -0
  30. edgefirst/validator/evaluators/segmentation.py +729 -0
  31. edgefirst/validator/evaluators/utils/__init__.py +3 -0
  32. edgefirst/validator/evaluators/utils/classify.py +292 -0
  33. edgefirst/validator/evaluators/utils/match.py +262 -0
  34. edgefirst/validator/evaluators/utils/timer.py +132 -0
  35. edgefirst/validator/metrics/__init__.py +9 -0
  36. edgefirst/validator/metrics/data/__init__.py +7 -0
  37. edgefirst/validator/metrics/data/label.py +668 -0
  38. edgefirst/validator/metrics/data/metrics.py +759 -0
  39. edgefirst/validator/metrics/data/plots.py +476 -0
  40. edgefirst/validator/metrics/data/stats.py +507 -0
  41. edgefirst/validator/metrics/detection.py +595 -0
  42. edgefirst/validator/metrics/segmentation.py +173 -0
  43. edgefirst/validator/metrics/utils/math.py +717 -0
  44. edgefirst/validator/publishers/__init__.py +3 -0
  45. edgefirst/validator/publishers/console.py +147 -0
  46. edgefirst/validator/publishers/studio.py +128 -0
  47. edgefirst/validator/publishers/tensorboard.py +119 -0
  48. edgefirst/validator/publishers/utils/logger.py +111 -0
  49. edgefirst/validator/publishers/utils/table.py +403 -0
  50. edgefirst/validator/runners/__init__.py +8 -0
  51. edgefirst/validator/runners/core.py +727 -0
  52. edgefirst/validator/runners/deepviewrt.py +177 -0
  53. edgefirst/validator/runners/hailo.py +263 -0
  54. edgefirst/validator/runners/keras.py +150 -0
  55. edgefirst/validator/runners/kinara.py +265 -0
  56. edgefirst/validator/runners/offline.py +228 -0
  57. edgefirst/validator/runners/onnx.py +241 -0
  58. edgefirst/validator/runners/processing/decode.py +320 -0
  59. edgefirst/validator/runners/processing/dvapi.py +4192 -0
  60. edgefirst/validator/runners/processing/nms.py +637 -0
  61. edgefirst/validator/runners/processing/outputs.py +507 -0
  62. edgefirst/validator/runners/tensorrt.py +321 -0
  63. edgefirst/validator/runners/tflite.py +221 -0
  64. edgefirst/validator/validate.py +843 -0
  65. edgefirst/validator/visualize/__init__.py +3 -0
  66. edgefirst/validator/visualize/detection.py +623 -0
  67. edgefirst/validator/visualize/segmentation.py +281 -0
  68. edgefirst/validator/visualize/utils/plots.py +635 -0
  69. edgefirst_validator-4.2.1.dist-info/METADATA +111 -0
  70. edgefirst_validator-4.2.1.dist-info/RECORD +73 -0
  71. edgefirst_validator-4.2.1.dist-info/WHEEL +5 -0
  72. edgefirst_validator-4.2.1.dist-info/entry_points.txt +2 -0
  73. edgefirst_validator-4.2.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,729 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from copy import deepcopy
5
+ from typing import TYPE_CHECKING, List, Tuple
6
+
7
+ import numpy as np
8
+ import matplotlib.figure
9
+
10
+ from edgefirst.validator.metrics.utils.math import mask_iou
11
+ from edgefirst.validator.publishers.utils.logger import logger
12
+ from edgefirst.validator.visualize.utils.plots import (figure2numpy,
13
+ plot_pr_curve,
14
+ plot_mc_curve,
15
+ close_figures,
16
+ plot_classification_segmentation)
17
+ from edgefirst.validator.evaluators.utils.classify import classify_mask
18
+ from edgefirst.validator.datasets.utils.transformations import (labels2string,
19
+ create_mask_class,
20
+ create_mask_background)
21
+ from edgefirst.validator.evaluators import Evaluator, YOLOValidator
22
+ from edgefirst.validator.visualize import SegmentationDrawer, DetectionDrawer
23
+ from edgefirst.validator.metrics import (SegmentationStats, SegmentationMetrics,
24
+ MultitaskMetrics, MultitaskPlots,
25
+ YOLOStats, DetectionMetrics)
26
+ from edgefirst.validator.datasets import SegmentationInstance, MultitaskInstance
27
+
28
+ if TYPE_CHECKING:
29
+ from edgefirst.validator.evaluators import CombinedParameters
30
+ from edgefirst.validator.datasets import Dataset
31
+ from edgefirst.validator.runners import Runner
32
+
33
+
34
+ class YOLOSegmentationValidator(YOLOValidator):
35
+ """
36
+ Reproduce the validation methods implemented in Ultralytics for
37
+ segmentation.
38
+
39
+ Parameters
40
+ ----------
41
+ parameters: CombinedParameters
42
+ This is a container for the model, dataset, and validation parameters
43
+ set from the command line.
44
+ runner: Runner
45
+ A type of model runner object responsible for running the model
46
+ for inference provided with an input image to produce bounding boxes.
47
+ dataset: Dataset
48
+ A type of dataset object responsible for reading different types
49
+ of datasets such as Darknet, TFRecords, or EdgeFirst Datasets.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ parameters: CombinedParameters,
55
+ runner: Runner = None,
56
+ dataset: Dataset = None
57
+ ):
58
+ super(YOLOSegmentationValidator, self).__init__(
59
+ parameters=parameters, runner=runner, dataset=dataset)
60
+
61
+ self.segmentation_stats = YOLOStats()
62
+ # Segmentation in Ultralytics uses base detection metrics.
63
+ self.segmentation_metrics = DetectionMetrics(
64
+ parameters=self.parameters.validation,
65
+ detection_stats=self.segmentation_stats,
66
+ model_name=self.model_name,
67
+ dataset_name=self.dataset_name,
68
+ save_path=self.save_path,
69
+ labels=self.parameters.dataset.labels
70
+ )
71
+
72
+ self.detection_drawer = DetectionDrawer()
73
+ self.segmentation_drawer = SegmentationDrawer()
74
+
75
+ # Store both detection and segmentation metric results.
76
+ self.multi_metrics = MultitaskMetrics(
77
+ detection_metrics=self.metrics.metrics,
78
+ segmentation_metrics=self.segmentation_metrics.metrics
79
+ )
80
+ self.multi_plots = MultitaskPlots(
81
+ detection_plots=self.metrics.plots,
82
+ segmentation_plots=self.segmentation_metrics.plots
83
+ )
84
+
85
+ def instance_collector(self):
86
+ """
87
+ Collects the instances from the ground truth and runs
88
+ model inference on a single image to collect the instance for
89
+ the model predictions.
90
+
91
+ Yields
92
+ ------
93
+ dict
94
+ This yields one image instance from the ground truth
95
+ and model predictions with keys "gt_instance" and "dt_instance".
96
+ """
97
+
98
+ gt_instance: MultitaskInstance
99
+ for gt_instance in self.dataset:
100
+ detections = self.runner.run_single_instance(
101
+ image=gt_instance.image,
102
+ )
103
+ self.filter_gt(gt_instance)
104
+
105
+ if detections is None:
106
+ yield {
107
+ "gt_instance": gt_instance,
108
+ "dt_instance": None
109
+ }
110
+
111
+ dt_instance = MultitaskInstance(gt_instance.image_path)
112
+ boxes, labels, scores, mask = detections
113
+ dt_instance.height = gt_instance.height
114
+ dt_instance.width = gt_instance.width
115
+ dt_instance.boxes = boxes
116
+ dt_instance.labels = labels
117
+ dt_instance.scores = scores
118
+ dt_instance.mask = mask
119
+ self.filter_dt(dt_instance)
120
+
121
+ yield {
122
+ "gt_instance": gt_instance,
123
+ "dt_instance": dt_instance,
124
+ }
125
+
126
+ def process_seg_batch_v5(
127
+ self,
128
+ dt_instance: MultitaskInstance,
129
+ gt_instance: MultitaskInstance,
130
+ ) -> np.ndarray:
131
+ """
132
+ Processes predicted and ground truth masks to compute IoU matches.
133
+
134
+ Parameters
135
+ ----------
136
+ dt_instance : MultitaskInstance
137
+ A prediction instance container with predicted
138
+ bounding boxes and masks.
139
+ gt_instance : MultitaskInstance
140
+ A ground truth instance contaienr with ground truth
141
+ bounding boxes and masks.
142
+
143
+ Returns
144
+ -------
145
+ np.ndarray
146
+ Boolean array indicating correct matches per IoU threshold.
147
+ """
148
+ niou = len(self.segmentation_stats.ious)
149
+ gt_cls = gt_instance.labels
150
+ pred_cls = dt_instance.labels
151
+
152
+ gt_masks = gt_instance.mask
153
+ pred_masks = dt_instance.mask
154
+
155
+ if len(gt_cls) == 0 or len(pred_cls) == 0:
156
+ correct = np.zeros((len(pred_cls), niou), dtype=bool)
157
+ else:
158
+ # Handle 2-dimension masks, but typically this is
159
+ # 3D with shape (n, h, w).
160
+ gt_masks = (gt_masks.reshape(1, -1).astype(np.float32)
161
+ if gt_masks.ndim == 2 else
162
+ gt_masks.reshape(gt_masks.shape[0], -1).astype(
163
+ np.float32))
164
+ pred_masks = (pred_masks.reshape(1, -1).astype(np.float32)
165
+ if pred_masks.ndim == 2 else
166
+ pred_masks.reshape(pred_masks.shape[0], -1).astype(
167
+ np.float32))
168
+ iou = mask_iou(gt_masks, pred_masks)
169
+ correct = self.match_predictions(pred_classes=pred_cls,
170
+ true_classes=gt_cls,
171
+ iou=iou)
172
+ return correct
173
+
174
+ def process_seg_batch_v7(
175
+ self,
176
+ dt_instance: MultitaskInstance,
177
+ gt_instance: MultitaskInstance
178
+ ) -> np.ndarray:
179
+ """
180
+ Placeholder for YOLOv7 segmentation evaluation support.
181
+
182
+ Parameters
183
+ ----------
184
+ dt_instance : MultitaskInstance
185
+ A prediction instance container with predicted
186
+ bounding boxes and masks.
187
+ gt_instance : MultitaskInstance
188
+ A ground truth instance contaienr with ground truth
189
+ bounding boxes and masks.
190
+
191
+ Returns
192
+ -------
193
+ np.ndarray
194
+ Boolean array indicating correct matches per IoU threshold.
195
+ """
196
+ logger("Validation with YOLOv7 is not yet supported for segmentation. " +
197
+ "Falling back to use Ultralytics.", code="WARNING")
198
+ return self.process_seg_batch_v5(
199
+ dt_instance=dt_instance,
200
+ gt_instance=gt_instance
201
+ )
202
+
203
+ def evaluate(self, instance: dict):
204
+ """
205
+ Evaluates a segmentation prediction instance and updates metrics.
206
+
207
+ Parameters
208
+ ----------
209
+ instance: dict
210
+ This contains the ground truth and model prediction instances
211
+ with keys "gt_instance", "dt_instance".
212
+ """
213
+ super().evaluate(instance=instance)
214
+
215
+ gt_instance: MultitaskInstance = instance.get("gt_instance")
216
+ dt_instance: MultitaskInstance = instance.get("dt_instance")
217
+
218
+ if self.parameters.validation.method == "ultralytics":
219
+ correct = self.process_seg_batch_v5(dt_instance=dt_instance,
220
+ gt_instance=gt_instance)
221
+ elif self.parameters.validation.method == "yolov7":
222
+ correct = self.process_seg_batch_v7(dt_instance=dt_instance,
223
+ gt_instance=gt_instance)
224
+ else:
225
+ correct = np.zeros((0, len(self.segmentation_stats.ious)),
226
+ dtype=bool)
227
+
228
+ self.segmentation_stats.stats["tp"].append(correct)
229
+ self.segmentation_stats.stats["conf"].append(dt_instance.scores)
230
+ self.segmentation_stats.stats["pred_cls"].append(dt_instance.labels)
231
+ self.segmentation_stats.stats["target_cls"].append(gt_instance.labels)
232
+
233
+ def visualize(
234
+ self,
235
+ gt_instance: MultitaskInstance,
236
+ dt_instance: MultitaskInstance,
237
+ epoch: int = 0
238
+ ):
239
+ """
240
+ Visualizes predicted and ground truth bounding
241
+ boxes and masks on an image.
242
+
243
+ Parameters
244
+ ----------
245
+ gt_instance: DetectionInstance
246
+ This is the ground truth instance which contains bounding
247
+ boxes and labels to draw.
248
+ dt_instance: DetectionInstance
249
+ This is the model detection instance which contains the
250
+ bounding boxes, labels, and confidence scores to draw.
251
+ epoch: int
252
+ This is the training epoch number. This
253
+ parameter is internal for ModelPack usage.
254
+ Standalone validation does not use this parameter.
255
+ """
256
+ if "background" not in self.parameters.dataset.labels:
257
+ label_offset = 1
258
+ else:
259
+ label_offset = 0
260
+
261
+ # Separate results for the ground truth and detection.
262
+ dt_instance.visual_image = gt_instance.visual_image.copy()
263
+
264
+ # Draw the ground truth boxes on the image.
265
+ image = self.detection_drawer.draw_2d_gt_boxes(
266
+ image=gt_instance.visual_image,
267
+ gt_instance=gt_instance,
268
+ method="ultralytics",
269
+ labels=self.parameters.dataset.labels,
270
+ )
271
+ gt_instance.visual_image = np.asarray(image)
272
+
273
+ # Filter to visualize only confident scores.
274
+ filt = dt_instance.scores >= 0.25
275
+ dt_instance.boxes = dt_instance.boxes[filt]
276
+ dt_instance.labels = dt_instance.labels[filt]
277
+ dt_instance.scores = dt_instance.scores[filt]
278
+
279
+ if (len(dt_instance.mask.shape) >= len(filt) and
280
+ not self.parameters.model.common.semantic):
281
+ dt_instance.mask = dt_instance.mask[filt]
282
+
283
+ # Draw the prediction boxes on the image.
284
+ image = self.detection_drawer.draw_2d_dt_boxes(
285
+ image=dt_instance.visual_image,
286
+ dt_instance=dt_instance,
287
+ method="ultralytics",
288
+ labels=self.parameters.dataset.labels,
289
+ )
290
+ dt_instance.visual_image = np.asarray(image)
291
+
292
+ gt_instance.labels = np.array([
293
+ self.parameters.dataset.labels.index(label)
294
+ for label in gt_instance.labels], dtype=np.int32) + label_offset
295
+ dt_instance.labels = np.array([
296
+ self.parameters.dataset.labels.index(label)
297
+ for label in dt_instance.labels], dtype=np.int32) + label_offset
298
+
299
+ # Draw the masks prediction and ground truth boxes on the image.
300
+ image = self.segmentation_drawer.mask2maskimage(
301
+ gt_instance=gt_instance,
302
+ dt_instance=dt_instance,
303
+ semantic=self.parameters.model.common.semantic
304
+ )
305
+
306
+ if self.parameters.validation.visualize:
307
+ image.save(os.path.join(self.parameters.validation.visualize,
308
+ os.path.basename(gt_instance.image_path)))
309
+ elif self.tensorboard_writer:
310
+ self.tensorboard_writer(
311
+ np.asarray(image), gt_instance.image_path, step=epoch)
312
+
313
+ def end(
314
+ self,
315
+ epoch: int = 0,
316
+ reset: bool = True
317
+ ) -> Tuple[MultitaskMetrics, MultitaskPlots]:
318
+ """
319
+ Computes the final metrics from Ultralytics for detection and
320
+ segmentation and generates the validation plots to save the
321
+ results in disk or publishes to Tensorboard.
322
+
323
+ Parameters
324
+ ----------
325
+ epoch: int
326
+ This is the training epoch number. This
327
+ parameter is internal for ModelPack usage.
328
+ Standalone validation does not use this parameter.
329
+ reset: bool
330
+ This is an optional parameter that controls the reset state.
331
+ By default, it will reset at the end of validation to erase
332
+ the data in the containers.
333
+
334
+ Returns
335
+ -------
336
+ metrics: MultitaskMetrics
337
+ This is a container for the detection and segmentation metrics.
338
+ plots: MultitaskPlots
339
+ This is a container for the validation data for plotting.
340
+ """
341
+ metrics, plots = super().end(epoch=epoch, reset=reset, publish=False)
342
+
343
+ self.multi_metrics.detection_metrics = metrics
344
+ self.multi_plots.detection_plots = plots
345
+ self.multi_metrics.timings = metrics.timings
346
+
347
+ self.segmentation_metrics.run_metrics()
348
+ self.multi_metrics.segmentation_metrics = deepcopy(
349
+ self.segmentation_metrics.metrics)
350
+ self.multi_plots.segmentation_plots = deepcopy(
351
+ self.segmentation_metrics.plots)
352
+
353
+ # Plot Operations
354
+ if self.parameters.validation.plots:
355
+ self.segmentation_metrics.plots.curve_labels = labels2string(
356
+ self.segmentation_metrics.plots.curve_labels,
357
+ self.parameters.dataset.labels
358
+ )
359
+ self.segmentation_metrics.plots.confusion_matrix =\
360
+ self.confusion_matrix.matrix
361
+
362
+ if self.parameters.validation.visualize or self.tensorboard_writer:
363
+ plots = self.get_seg_plots()
364
+
365
+ if self.parameters.validation.visualize:
366
+ self.save_plots(plots)
367
+ elif self.tensorboard_writer:
368
+ self.publish_plots(plots, epoch)
369
+ close_figures(plots)
370
+
371
+ if self.tensorboard_writer:
372
+ self.tensorboard_writer.publish_metrics(
373
+ metrics=self.multi_metrics,
374
+ parameters=self.parameters,
375
+ step=epoch,
376
+ )
377
+ else:
378
+ table = self.console_writer(metrics=self.multi_metrics,
379
+ parameters=self.parameters)
380
+
381
+ if self.parameters.validation.visualize:
382
+ self.console_writer.save_metrics(table)
383
+
384
+ if reset:
385
+ self.segmentation_metrics.reset()
386
+ return self.multi_metrics, self.multi_plots
387
+
388
+ def get_seg_plots(self) -> List[matplotlib.figure.Figure]:
389
+ """
390
+ Reproduces the validation charts from Ultralytics.
391
+ These plots are Matplotlib figures.
392
+
393
+ Returns
394
+ -------
395
+ List[matplotlib.figure.Figure]
396
+ This contains matplotlib figures of the plots.
397
+ """
398
+ fig_confusion_matrix = self.confusion_matrix.plot(
399
+ names=self.segmentation_metrics.plots.confusion_labels
400
+ )
401
+ fig_prec_rec_curve = plot_pr_curve(
402
+ precision=self.segmentation_metrics.plots.py,
403
+ recall=self.segmentation_metrics.plots.px,
404
+ ap=self.segmentation_metrics.plots.average_precision,
405
+ names=self.parameters.dataset.labels,
406
+ model=self.segmentation_metrics.metrics.model,
407
+ iou_threshold=self.parameters.validation.iou_threshold
408
+ )
409
+ fig_f1_curve = plot_mc_curve(
410
+ px=self.segmentation_metrics.plots.px,
411
+ py=self.segmentation_metrics.plots.f1,
412
+ names=self.parameters.dataset.labels,
413
+ model=self.segmentation_metrics.metrics.model,
414
+ ylabel='F1'
415
+ )
416
+ fig_prec_curve = plot_mc_curve(
417
+ px=self.segmentation_metrics.plots.px,
418
+ py=self.segmentation_metrics.plots.precision,
419
+ names=self.parameters.dataset.labels,
420
+ model=self.segmentation_metrics.metrics.model,
421
+ ylabel='Precision'
422
+ )
423
+ fig_rec_curve = plot_mc_curve(
424
+ px=self.segmentation_metrics.plots.px,
425
+ py=self.segmentation_metrics.plots.recall,
426
+ names=self.parameters.dataset.labels,
427
+ model=self.segmentation_metrics.metrics.model,
428
+ ylabel='Recall'
429
+ )
430
+ return [fig_confusion_matrix,
431
+ fig_prec_rec_curve,
432
+ fig_f1_curve,
433
+ fig_prec_curve,
434
+ fig_rec_curve]
435
+
436
+
437
+ class SegmentationValidator(Evaluator):
438
+
439
+ """
440
+ Define the validation methods for EdgeFirst. Reproduces EdgeFirst
441
+ metrics for segmentation::
442
+
443
+ 1. Grab the ground truth and the model prediction instances per image.
444
+ 2. Create masks for both ground truth and model prediction.
445
+ 3. Classify the mask pixels as either true predictions or false predictions.
446
+ 4. Overlay the ground truth and predictions masks on the image.
447
+ 5. Calculate the metrics.
448
+
449
+ Parameters
450
+ ----------
451
+ parameters: CombinedParameters
452
+ This is a container for the model, dataset, and validation parameters
453
+ set from the command line.
454
+ runner: Runner
455
+ A type of model runner object responsible for running the model
456
+ for inference provided with an input image to produce bounding boxes.
457
+ dataset: Dataset
458
+ A type of dataset object responsible for reading different types
459
+ of datasets such as Darknet, TFRecords, or EdgeFirst Datasets.
460
+ """
461
+
462
+ def __init__(
463
+ self,
464
+ parameters: CombinedParameters,
465
+ runner: Runner = None,
466
+ dataset: Dataset = None,
467
+ ):
468
+ super(SegmentationValidator, self).__init__(
469
+ parameters=parameters, runner=runner, dataset=dataset)
470
+
471
+ self.segmentation_stats = SegmentationStats()
472
+ self.metrics = SegmentationMetrics(
473
+ parameters=self.parameters.validation,
474
+ segmentation_stats=self.segmentation_stats,
475
+ model_name=self.model_name,
476
+ dataset_name=self.dataset_name,
477
+ save_path=self.save_path
478
+ )
479
+ self.drawer = SegmentationDrawer()
480
+
481
+ def instance_collector(self):
482
+ """
483
+ Collects the instances from the ground truth and runs
484
+ model inference on a single image to collect the instance for
485
+ the model predictions.
486
+
487
+ Yields
488
+ ------
489
+ dict
490
+ This yields one image instance from the ground truth
491
+ and model predictions with keys "gt_instance" and "dt_instance".
492
+
493
+ Raises
494
+ ------
495
+ ValueError
496
+ Raised if the model labels and the
497
+ dataset labels are not matching.
498
+ """
499
+
500
+ gt_instance: SegmentationInstance
501
+ for gt_instance in self.dataset:
502
+ mask = self.runner.run_single_instance(
503
+ image=gt_instance.image
504
+ )
505
+
506
+ if mask is None:
507
+ yield {
508
+ 'gt_instance': gt_instance,
509
+ 'dt_instance': None
510
+ }
511
+
512
+ dt_instance = SegmentationInstance(gt_instance.image_path)
513
+ dt_instance.height = gt_instance.height
514
+ dt_instance.width = gt_instance.width
515
+ dt_instance.mask = self.calibrate_mask(mask,
516
+ dt_labels=dt_instance.labels)
517
+
518
+ yield {
519
+ 'gt_instance': gt_instance,
520
+ 'dt_instance': dt_instance
521
+ }
522
+
523
+ def calibrate_mask(self, mask: np.ndarray,
524
+ dt_labels: np.ndarray) -> np.ndarray:
525
+ """
526
+ Map the labels of the mask to the label order of the ground
527
+ truth labels. This ensures the prediction mask and the ground
528
+ truth mask are comparable.
529
+
530
+ Parameters
531
+ ----------
532
+ mask: np.ndarray
533
+ The prediction mask output from the model.
534
+ dt_labels: np.ndarray
535
+ A list of labels for each mask to convert instance
536
+ segmentation to semantic which is needed for Edgefirst validation.
537
+
538
+ Returns
539
+ -------
540
+ np.ndarray
541
+ The calibrated prediction mask.
542
+ """
543
+ mask = np.squeeze(mask)
544
+ # For segmentation, the background class should exist to properly
545
+ # map mask indices.
546
+ if "background" not in self.parameters.model.labels:
547
+ model_labels = ["background"] + self.parameters.model.labels
548
+ label_offset = 1
549
+ else:
550
+ model_labels = self.parameters.model.labels
551
+ label_offset = 0
552
+
553
+ if "background" not in self.parameters.dataset.labels:
554
+ dataset_labels = ["background"] + self.parameters.dataset.labels
555
+ else:
556
+ dataset_labels = self.parameters.dataset.labels
557
+
558
+ # If the model is instance segmentation, convert to semantic.
559
+ if mask.ndim == 3:
560
+ _, height, width = mask.shape
561
+ masks = np.zeros((height, width), dtype=np.int32)
562
+ labels = dt_labels + label_offset
563
+ for m, cls in zip(mask, labels):
564
+ masks[m > 0] = cls
565
+ mask = masks
566
+
567
+ # If the label orders does not match between prediction and dataset,
568
+ # map the prediction indices to the dataset indices.
569
+ if model_labels != dataset_labels:
570
+ # -1 means unmapped/missing
571
+ index_map = np.full(len(model_labels), -1, dtype=int)
572
+ for model_idx, label in enumerate(model_labels):
573
+ if label in dataset_labels:
574
+ index_map[model_idx] = dataset_labels.index(label)
575
+ else:
576
+ raise ValueError(
577
+ f"Label '{label}' not found in dataset labels.")
578
+
579
+ mask = index_map[mask]
580
+ return mask
581
+
582
+ def evaluate(self, instance: dict):
583
+ """
584
+ Run model evaluation using EdgeFirst validation methods
585
+ for segmentation.
586
+
587
+ Parameters
588
+ ----------
589
+ instance: dict
590
+ This contains the ground truth and model predictions instances
591
+ with keys "gt_instance" and "dt_instance".
592
+ """
593
+ gt_instance: SegmentationInstance = instance.get("gt_instance")
594
+ dt_instance: SegmentationInstance = instance.get("dt_instance")
595
+
596
+ class_labels = np.unique(np.append(np.unique(gt_instance.mask),
597
+ np.unique(dt_instance.mask)))
598
+ gt_mask = gt_instance.mask
599
+ dt_mask = dt_instance.mask
600
+ self.segmentation_stats.ious.append(
601
+ mask_iou(dt_mask.reshape(1, -1).astype(np.float32),
602
+ gt_mask.reshape(1, -1).astype(np.float32)))
603
+
604
+ predictions = dt_mask.flatten()
605
+ ground_truths = gt_mask.flatten()
606
+
607
+ if not self.parameters.validation.include_background:
608
+ class_labels = class_labels[class_labels != 0]
609
+ predictions = predictions[predictions != 0]
610
+ ground_truths = ground_truths[ground_truths != 0]
611
+ true_predictions, false_predictions, union_gt_dt = classify_mask(
612
+ gt_mask, dt_mask)
613
+ else:
614
+ true_predictions, false_predictions, union_gt_dt = classify_mask(
615
+ gt_mask, dt_mask, False)
616
+
617
+ self.segmentation_stats.capture_class(class_labels,
618
+ self.parameters.dataset.labels)
619
+ self.metrics.metrics.add_ground_truths(len(ground_truths))
620
+ self.metrics.metrics.add_predictions(len(predictions))
621
+ self.metrics.metrics.add_true_predictions(true_predictions)
622
+ self.metrics.metrics.add_false_predictions(false_predictions)
623
+ self.metrics.metrics.add_union(union_gt_dt)
624
+
625
+ for cl in class_labels:
626
+ gt_class_mask = create_mask_class(gt_mask, cl)
627
+ dt_class_mask = create_mask_class(dt_mask, cl)
628
+
629
+ # Evaluate background class
630
+ if cl == 0:
631
+ gt_class_mask = create_mask_background(gt_mask)
632
+ dt_class_mask = create_mask_background(dt_mask)
633
+
634
+ class_ground_truths = np.sum(gt_mask == cl)
635
+ class_predictions = np.sum(dt_mask == cl)
636
+
637
+ # Under classify_mask always exclude background because we are
638
+ # only concerned with this class.
639
+ class_true_predictions, class_false_predictions, union_gt_dt = \
640
+ classify_mask(gt_class_mask, dt_class_mask)
641
+
642
+ datalabel = self.segmentation_stats.get_label_data(
643
+ self.parameters.dataset.labels[cl]
644
+ )
645
+ datalabel.add_true_predictions(class_true_predictions)
646
+ datalabel.add_false_predictions(class_false_predictions)
647
+ datalabel.add_ground_truths(class_ground_truths)
648
+ datalabel.add_predictions(class_predictions)
649
+ datalabel.add_union(union_gt_dt)
650
+
651
+ def visualize(
652
+ self,
653
+ gt_instance: SegmentationInstance,
654
+ dt_instance: SegmentationInstance,
655
+ epoch: int = 0
656
+ ):
657
+ """
658
+ Draw segmentation mask results on the image and save
659
+ the results in disk or publish into Tensorboard.
660
+
661
+ Parameters
662
+ ----------
663
+ gt_instance: DetectionInstance
664
+ This is the ground truth instance which contains masks
665
+ and labels to draw.
666
+ dt_instance: DetectionInstance
667
+ This is the model detection instance which contains the
668
+ masks and labels to draw.
669
+ epoch: int
670
+ This is the training epoch number. This
671
+ parameter is internal for ModelPack usage.
672
+ Standalone validation does not use this parameter.
673
+ """
674
+ image = self.drawer.mask2maskimage(gt_instance, dt_instance)
675
+
676
+ if self.parameters.validation.visualize:
677
+ image.save(os.path.join(self.parameters.validation.visualize,
678
+ os.path.basename(gt_instance.image_path)))
679
+ elif self.tensorboard_writer:
680
+ self.tensorboard_writer(
681
+ np.asarray(image), gt_instance.image_path, step=epoch)
682
+
683
+ def get_plots(self) -> List[matplotlib.figure.Figure]:
684
+ """
685
+ Generatte EdgeFirst validation plots.
686
+
687
+ Returns
688
+ -------
689
+ List[matplotlib.figure.Figure]
690
+ This contains matplotlib figures of the plots.
691
+ """
692
+ fig_class_metrics = plot_classification_segmentation(
693
+ class_histogram_data=self.metrics.plots.class_histogram_data,
694
+ model=self.metrics.metrics.model
695
+ )
696
+ return [fig_class_metrics]
697
+
698
+ def save_plots(self, plots: List[matplotlib.figure.Figure]):
699
+ """
700
+ Saves the validation plots as image files in disk.
701
+
702
+ Parameters
703
+ ----------
704
+ plots: List[matplotlib.figure.Figure]
705
+ This is the list of matplotlib figures to save.
706
+ """
707
+ plots[0].savefig(
708
+ f'{self.parameters.validation.visualize}/class_scores.png',
709
+ bbox_inches="tight"
710
+ )
711
+
712
+ def publish_plots(
713
+ self, plots: List[matplotlib.figure.Figure], epoch: int = 0):
714
+ """
715
+ Publishes the validation plots into Tensorboard.
716
+
717
+ Parameters
718
+ ----------
719
+ plots: List[matplotlib.figure.Figure]
720
+ This is the list of matplotlib figures to save.
721
+ epoch: int
722
+ The training epoch number used for ModelPack training usage.
723
+ """
724
+ nimage_class = figure2numpy(plots[0])
725
+ self.tensorboard_writer(
726
+ nimage_class,
727
+ f"{self.metrics.metrics.model}_scores.png",
728
+ step=epoch
729
+ )