supervisely 6.73.238__py3-none-any.whl → 6.73.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/annotation/annotation.py +2 -2
- supervisely/api/entity_annotation/tag_api.py +11 -4
- supervisely/nn/__init__.py +1 -0
- supervisely/nn/benchmark/__init__.py +14 -2
- supervisely/nn/benchmark/base_benchmark.py +84 -37
- supervisely/nn/benchmark/base_evaluator.py +120 -0
- supervisely/nn/benchmark/base_visualizer.py +265 -0
- supervisely/nn/benchmark/comparison/detection_visualization/text_templates.py +5 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +2 -2
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/explore_predicttions.py +39 -16
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +1 -1
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +4 -4
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +12 -11
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +1 -1
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +6 -6
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +3 -3
- supervisely/nn/benchmark/{instance_segmentation_benchmark.py → instance_segmentation/benchmark.py} +9 -3
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +58 -0
- supervisely/nn/benchmark/{visualization/text_templates/instance_segmentation_text.py → instance_segmentation/text_templates.py} +53 -69
- supervisely/nn/benchmark/instance_segmentation/visualizer.py +18 -0
- supervisely/nn/benchmark/object_detection/__init__.py +0 -0
- supervisely/nn/benchmark/object_detection/base_vis_metric.py +51 -0
- supervisely/nn/benchmark/{object_detection_benchmark.py → object_detection/benchmark.py} +4 -2
- supervisely/nn/benchmark/object_detection/evaluation_params.yaml +2 -0
- supervisely/nn/benchmark/{evaluation/object_detection_evaluator.py → object_detection/evaluator.py} +67 -9
- supervisely/nn/benchmark/{evaluation/coco → object_detection}/metric_provider.py +13 -14
- supervisely/nn/benchmark/{visualization/text_templates/object_detection_text.py → object_detection/text_templates.py} +49 -41
- supervisely/nn/benchmark/object_detection/vis_metrics/__init__.py +48 -0
- supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/confidence_distribution.py +20 -24
- supervisely/nn/benchmark/object_detection/vis_metrics/confidence_score.py +119 -0
- supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/confusion_matrix.py +34 -22
- supervisely/nn/benchmark/object_detection/vis_metrics/explore_predictions.py +129 -0
- supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/f1_score_at_different_iou.py +21 -26
- supervisely/nn/benchmark/object_detection/vis_metrics/frequently_confused.py +137 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/iou_distribution.py +106 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/key_metrics.py +136 -0
- supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/model_predictions.py +53 -49
- supervisely/nn/benchmark/object_detection/vis_metrics/outcome_counts.py +188 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/outcome_counts_per_class.py +191 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +116 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +106 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve_by_class.py +49 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +72 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/precision_avg_per_class.py +59 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +71 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +56 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/reliability_diagram.py +110 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/speedtest.py +151 -0
- supervisely/nn/benchmark/object_detection/visualizer.py +697 -0
- supervisely/nn/benchmark/semantic_segmentation/__init__.py +9 -0
- supervisely/nn/benchmark/semantic_segmentation/base_vis_metric.py +55 -0
- supervisely/nn/benchmark/semantic_segmentation/benchmark.py +32 -0
- supervisely/nn/benchmark/semantic_segmentation/evaluation_params.yaml +0 -0
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +162 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +153 -0
- supervisely/nn/benchmark/semantic_segmentation/text_templates.py +130 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/__init__.py +0 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/acknowledgement.py +15 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/classwise_error_analysis.py +57 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +92 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/explore_predictions.py +84 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +101 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/iou_eou.py +45 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +60 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/model_predictions.py +107 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +112 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/renormalized_error_ou.py +48 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/speedtest.py +178 -0
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/vis_texts.py +21 -0
- supervisely/nn/benchmark/semantic_segmentation/visualizer.py +304 -0
- supervisely/nn/benchmark/utils/__init__.py +12 -0
- supervisely/nn/benchmark/utils/detection/__init__.py +2 -0
- supervisely/nn/benchmark/{evaluation/coco → utils/detection}/calculate_metrics.py +6 -4
- supervisely/nn/benchmark/utils/detection/metric_provider.py +533 -0
- supervisely/nn/benchmark/{coco_utils → utils/detection}/sly2coco.py +4 -4
- supervisely/nn/benchmark/{coco_utils/utils.py → utils/detection/utlis.py} +11 -0
- supervisely/nn/benchmark/utils/semantic_segmentation/__init__.py +0 -0
- supervisely/nn/benchmark/utils/semantic_segmentation/calculate_metrics.py +35 -0
- supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +804 -0
- supervisely/nn/benchmark/utils/semantic_segmentation/loader.py +65 -0
- supervisely/nn/benchmark/utils/semantic_segmentation/utils.py +109 -0
- supervisely/nn/benchmark/visualization/evaluation_result.py +17 -3
- supervisely/nn/benchmark/visualization/vis_click_data.py +1 -1
- supervisely/nn/benchmark/visualization/widgets/__init__.py +3 -0
- supervisely/nn/benchmark/visualization/widgets/chart/chart.py +12 -4
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +35 -8
- supervisely/nn/benchmark/visualization/widgets/gallery/template.html +8 -4
- supervisely/nn/benchmark/visualization/widgets/markdown/markdown.py +1 -1
- supervisely/nn/benchmark/visualization/widgets/notification/notification.py +11 -7
- supervisely/nn/benchmark/visualization/widgets/radio_group/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/radio_group/radio_group.py +34 -0
- supervisely/nn/benchmark/visualization/widgets/table/table.py +9 -3
- supervisely/nn/benchmark/visualization/widgets/widget.py +4 -0
- supervisely/project/project.py +18 -6
- {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/METADATA +3 -1
- {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/RECORD +103 -81
- supervisely/nn/benchmark/coco_utils/__init__.py +0 -2
- supervisely/nn/benchmark/evaluation/__init__.py +0 -3
- supervisely/nn/benchmark/evaluation/base_evaluator.py +0 -64
- supervisely/nn/benchmark/evaluation/coco/__init__.py +0 -2
- supervisely/nn/benchmark/evaluation/instance_segmentation_evaluator.py +0 -88
- supervisely/nn/benchmark/utils.py +0 -13
- supervisely/nn/benchmark/visualization/inference_speed/__init__.py +0 -19
- supervisely/nn/benchmark/visualization/inference_speed/speedtest_batch.py +0 -161
- supervisely/nn/benchmark/visualization/inference_speed/speedtest_intro.py +0 -28
- supervisely/nn/benchmark/visualization/inference_speed/speedtest_overview.py +0 -141
- supervisely/nn/benchmark/visualization/inference_speed/speedtest_real_time.py +0 -63
- supervisely/nn/benchmark/visualization/text_templates/inference_speed_text.py +0 -23
- supervisely/nn/benchmark/visualization/vis_metric_base.py +0 -337
- supervisely/nn/benchmark/visualization/vis_metrics/__init__.py +0 -67
- supervisely/nn/benchmark/visualization/vis_metrics/classwise_error_analysis.py +0 -55
- supervisely/nn/benchmark/visualization/vis_metrics/confidence_score.py +0 -93
- supervisely/nn/benchmark/visualization/vis_metrics/explorer_grid.py +0 -144
- supervisely/nn/benchmark/visualization/vis_metrics/frequently_confused.py +0 -115
- supervisely/nn/benchmark/visualization/vis_metrics/iou_distribution.py +0 -86
- supervisely/nn/benchmark/visualization/vis_metrics/outcome_counts.py +0 -119
- supervisely/nn/benchmark/visualization/vis_metrics/outcome_counts_per_class.py +0 -148
- supervisely/nn/benchmark/visualization/vis_metrics/overall_error_analysis.py +0 -109
- supervisely/nn/benchmark/visualization/vis_metrics/overview.py +0 -189
- supervisely/nn/benchmark/visualization/vis_metrics/percision_avg_per_class.py +0 -57
- supervisely/nn/benchmark/visualization/vis_metrics/pr_curve.py +0 -101
- supervisely/nn/benchmark/visualization/vis_metrics/pr_curve_by_class.py +0 -46
- supervisely/nn/benchmark/visualization/vis_metrics/precision.py +0 -56
- supervisely/nn/benchmark/visualization/vis_metrics/recall.py +0 -54
- supervisely/nn/benchmark/visualization/vis_metrics/recall_vs_precision.py +0 -57
- supervisely/nn/benchmark/visualization/vis_metrics/reliability_diagram.py +0 -88
- supervisely/nn/benchmark/visualization/vis_metrics/what_is.py +0 -23
- supervisely/nn/benchmark/visualization/vis_templates.py +0 -241
- supervisely/nn/benchmark/visualization/vis_widgets.py +0 -128
- supervisely/nn/benchmark/visualization/visualizer.py +0 -729
- /supervisely/nn/benchmark/{visualization/text_templates → instance_segmentation}/__init__.py +0 -0
- /supervisely/nn/benchmark/{evaluation/coco → instance_segmentation}/evaluation_params.yaml +0 -0
- /supervisely/nn/benchmark/{evaluation/coco → utils/detection}/metrics.py +0 -0
- {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/LICENSE +0 -0
- {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/WHEEL +0 -0
- {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from supervisely.nn.benchmark.object_detection.base_vis_metric import DetectionVisMetric
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
|
+
ChartWidget,
|
|
6
|
+
MarkdownWidget,
|
|
7
|
+
NotificationWidget,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Precision(DetectionVisMetric):
|
|
12
|
+
MARKDOWN = "precision"
|
|
13
|
+
MARKDOWN_PER_CLASS = "precision_per_class"
|
|
14
|
+
NOTIFICATION = "precision"
|
|
15
|
+
CHART = "precision"
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
18
|
+
super().__init__(*args, **kwargs)
|
|
19
|
+
self.clickable = True
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def md(self) -> MarkdownWidget:
|
|
23
|
+
text = self.vis_texts.markdown_P
|
|
24
|
+
return MarkdownWidget(self.MARKDOWN, "Precision", text)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def notification(self) -> NotificationWidget:
|
|
28
|
+
title, desc = self.vis_texts.notification_precision.values()
|
|
29
|
+
tp_plus_fp = self.eval_result.mp.TP_count + self.eval_result.mp.FP_count
|
|
30
|
+
return NotificationWidget(
|
|
31
|
+
self.NOTIFICATION,
|
|
32
|
+
title.format(self.eval_result.mp.base_metrics()["precision"].round(2)),
|
|
33
|
+
desc.format(self.eval_result.mp.TP_count, tp_plus_fp),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def per_class_md(self) -> MarkdownWidget:
|
|
38
|
+
text = self.vis_texts.markdown_P_perclass.format(self.vis_texts.definitions.f1_score)
|
|
39
|
+
return MarkdownWidget(self.MARKDOWN_PER_CLASS, "Precision per class", text)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def chart(self) -> ChartWidget:
|
|
43
|
+
chart = ChartWidget(self.CHART, self._get_figure())
|
|
44
|
+
chart.set_click_data(
|
|
45
|
+
self.explore_modal_table.id,
|
|
46
|
+
self.get_click_data(),
|
|
47
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].label}`,",
|
|
48
|
+
)
|
|
49
|
+
return chart
|
|
50
|
+
|
|
51
|
+
def _get_figure(self): # -> go.Figure
|
|
52
|
+
import plotly.express as px # pylint: disable=import-error
|
|
53
|
+
|
|
54
|
+
sorted_by_precision = self.eval_result.mp.per_class_metrics().sort_values(by="precision")
|
|
55
|
+
fig = px.bar(
|
|
56
|
+
sorted_by_precision,
|
|
57
|
+
x="category",
|
|
58
|
+
y="precision",
|
|
59
|
+
# title="Per-class Precision (Sorted by F1)",
|
|
60
|
+
color="precision",
|
|
61
|
+
range_color=[0, 1],
|
|
62
|
+
color_continuous_scale="Plasma",
|
|
63
|
+
)
|
|
64
|
+
fig.update_traces(hovertemplate="Class: %{x}<br>Precision: %{y:.2f}<extra></extra>")
|
|
65
|
+
if len(sorted_by_precision) <= 20:
|
|
66
|
+
fig.update_traces(
|
|
67
|
+
text=sorted_by_precision.round(2),
|
|
68
|
+
textposition="outside",
|
|
69
|
+
)
|
|
70
|
+
fig.update_xaxes(title_text="Class")
|
|
71
|
+
fig.update_yaxes(title_text="Precision", range=[0, 1])
|
|
72
|
+
return fig
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from supervisely.nn.benchmark.object_detection.base_vis_metric import DetectionVisMetric
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PerClassAvgPrecision(DetectionVisMetric):
|
|
8
|
+
MARKDOWN = "per_class_avg_precision"
|
|
9
|
+
CHART = "per_class_avg_precision"
|
|
10
|
+
|
|
11
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
12
|
+
super().__init__(*args, **kwargs)
|
|
13
|
+
self.clickable = True
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def md(self) -> MarkdownWidget:
|
|
17
|
+
text = self.vis_texts.markdown_class_ap
|
|
18
|
+
text = text.format(self.vis_texts.definitions.average_precision)
|
|
19
|
+
return MarkdownWidget(self.MARKDOWN, "Average Precision by Class", text)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def chart(self) -> ChartWidget:
|
|
23
|
+
chart = ChartWidget(self.CHART, self._get_figure())
|
|
24
|
+
chart.set_click_data(
|
|
25
|
+
self.explore_modal_table.id,
|
|
26
|
+
self.get_click_data(),
|
|
27
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].theta}`,",
|
|
28
|
+
)
|
|
29
|
+
return chart
|
|
30
|
+
|
|
31
|
+
def _get_figure(self): # -> go.Figure:
|
|
32
|
+
import plotly.express as px # pylint: disable=import-error
|
|
33
|
+
|
|
34
|
+
# AP per-class
|
|
35
|
+
ap_per_class = self.eval_result.mp.coco_precision[:, :, :, 0, 2].mean(axis=(0, 1))
|
|
36
|
+
ap_per_class[ap_per_class == -1] = 0 # -1 is a placeholder for no GT
|
|
37
|
+
labels = dict(r="Average Precision", theta="Class")
|
|
38
|
+
fig = px.scatter_polar(
|
|
39
|
+
r=ap_per_class,
|
|
40
|
+
theta=self.eval_result.mp.cat_names,
|
|
41
|
+
# title="Per-class Average Precision (AP)",
|
|
42
|
+
labels=labels,
|
|
43
|
+
width=800,
|
|
44
|
+
height=800,
|
|
45
|
+
range_r=[0, 1],
|
|
46
|
+
)
|
|
47
|
+
fig.update_traces(fill="toself")
|
|
48
|
+
fig.update_layout(
|
|
49
|
+
modebar_add=["resetScale"],
|
|
50
|
+
margin=dict(l=80, r=80, t=0, b=0),
|
|
51
|
+
)
|
|
52
|
+
fig.update_traces(
|
|
53
|
+
hovertemplate=labels["theta"]
|
|
54
|
+
+ ": %{theta}<br>"
|
|
55
|
+
+ labels["r"]
|
|
56
|
+
+ ": %{r:.2f}<br>"
|
|
57
|
+
+ "<extra></extra>"
|
|
58
|
+
)
|
|
59
|
+
return fig
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from supervisely.nn.benchmark.object_detection.base_vis_metric import DetectionVisMetric
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
|
+
ChartWidget,
|
|
6
|
+
MarkdownWidget,
|
|
7
|
+
NotificationWidget,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Recall(DetectionVisMetric):
|
|
12
|
+
MARKDOWN = "recall"
|
|
13
|
+
MARKDOWN_PER_CLASS = "recall_per_class"
|
|
14
|
+
NOTIFICATION = "recall"
|
|
15
|
+
CHART = "recall"
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
18
|
+
super().__init__(*args, **kwargs)
|
|
19
|
+
self.clickable = True
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def md(self) -> MarkdownWidget:
|
|
23
|
+
text = self.vis_texts.markdown_R
|
|
24
|
+
return MarkdownWidget(self.MARKDOWN, "Recall", text)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def notification(self) -> NotificationWidget:
|
|
28
|
+
title, desc = self.vis_texts.notification_recall.values()
|
|
29
|
+
tp_plus_fn = self.eval_result.mp.TP_count + self.eval_result.mp.FN_count
|
|
30
|
+
return NotificationWidget(
|
|
31
|
+
self.NOTIFICATION,
|
|
32
|
+
title.format(self.eval_result.mp.base_metrics()["recall"].round(2)),
|
|
33
|
+
desc.format(self.eval_result.mp.TP_count, tp_plus_fn),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def per_class_md(self) -> MarkdownWidget:
|
|
38
|
+
text = self.vis_texts.markdown_R_perclass.format(self.vis_texts.definitions.f1_score)
|
|
39
|
+
return MarkdownWidget(self.MARKDOWN_PER_CLASS, "Recall per class", text)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def chart(self) -> ChartWidget:
|
|
43
|
+
chart = ChartWidget(self.CHART, self._get_figure())
|
|
44
|
+
chart.set_click_data(
|
|
45
|
+
self.explore_modal_table.id,
|
|
46
|
+
self.get_click_data(),
|
|
47
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].label}`,",
|
|
48
|
+
)
|
|
49
|
+
return chart
|
|
50
|
+
|
|
51
|
+
def _get_figure(self): # -> go.Figure
|
|
52
|
+
import plotly.express as px # pylint: disable=import-error
|
|
53
|
+
|
|
54
|
+
sorted_by_f1 = self.eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
55
|
+
fig = px.bar(
|
|
56
|
+
sorted_by_f1,
|
|
57
|
+
x="category",
|
|
58
|
+
y="recall",
|
|
59
|
+
color="recall",
|
|
60
|
+
range_color=[0, 1],
|
|
61
|
+
color_continuous_scale="Plasma",
|
|
62
|
+
)
|
|
63
|
+
fig.update_traces(hovertemplate="Class: %{x}<br>Recall: %{y:.2f}<extra></extra>")
|
|
64
|
+
if len(sorted_by_f1) <= 20:
|
|
65
|
+
fig.update_traces(
|
|
66
|
+
text=sorted_by_f1["recall"].round(2),
|
|
67
|
+
textposition="outside",
|
|
68
|
+
)
|
|
69
|
+
fig.update_xaxes(title_text="Class")
|
|
70
|
+
fig.update_yaxes(title_text="Recall", range=[0, 1])
|
|
71
|
+
return fig
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from supervisely.nn.benchmark.object_detection.base_vis_metric import DetectionVisMetric
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RecallVsPrecision(DetectionVisMetric):
|
|
8
|
+
MARKDOWN = "recall_vs_precision"
|
|
9
|
+
CHART = "recall_vs_precision"
|
|
10
|
+
|
|
11
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
12
|
+
super().__init__(*args, **kwargs)
|
|
13
|
+
self.clickable = True
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def md(self) -> MarkdownWidget:
|
|
17
|
+
text = self.vis_texts.markdown_PR.format(self.vis_texts.definitions.f1_score)
|
|
18
|
+
return MarkdownWidget(self.MARKDOWN, "Recall vs Precision", text)
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def chart(self) -> ChartWidget:
|
|
22
|
+
chart = ChartWidget(self.CHART, self._get_figure())
|
|
23
|
+
chart.set_click_data(
|
|
24
|
+
self.explore_modal_table.id,
|
|
25
|
+
self.get_click_data(),
|
|
26
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].label}`,",
|
|
27
|
+
)
|
|
28
|
+
return chart
|
|
29
|
+
|
|
30
|
+
def _get_figure(self): # -> Optional[go.Figure]
|
|
31
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
32
|
+
|
|
33
|
+
blue_color = "#1f77b4"
|
|
34
|
+
orange_color = "#ff7f0e"
|
|
35
|
+
sorted_by_f1 = self.eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
36
|
+
fig = go.Figure()
|
|
37
|
+
fig.add_trace(
|
|
38
|
+
go.Bar(
|
|
39
|
+
y=sorted_by_f1["precision"],
|
|
40
|
+
x=sorted_by_f1["category"],
|
|
41
|
+
name="Precision",
|
|
42
|
+
marker=dict(color=blue_color),
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
fig.add_trace(
|
|
46
|
+
go.Bar(
|
|
47
|
+
y=sorted_by_f1["recall"],
|
|
48
|
+
x=sorted_by_f1["category"],
|
|
49
|
+
name="Recall",
|
|
50
|
+
marker=dict(color=orange_color),
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
fig.update_layout(barmode="group")
|
|
54
|
+
fig.update_xaxes(title_text="Class")
|
|
55
|
+
fig.update_yaxes(title_text="Value", range=[0, 1])
|
|
56
|
+
return fig
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from supervisely.nn.benchmark.object_detection.base_vis_metric import DetectionVisMetric
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
|
+
ChartWidget,
|
|
6
|
+
CollapseWidget,
|
|
7
|
+
MarkdownWidget,
|
|
8
|
+
NotificationWidget,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ReliabilityDiagram(DetectionVisMetric):
|
|
13
|
+
MARKDOWN_CALIBRATION_SCORE = "calibration_score"
|
|
14
|
+
MARKDOWN_CALIBRATION_SCORE_2 = "calibration_score_2"
|
|
15
|
+
MARKDOWN_RELIABILITY_DIAGRAM = "reliability_diagram"
|
|
16
|
+
NOTIFICATION = "reliability_diagram"
|
|
17
|
+
CHART = "reliability_diagram"
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def md_calibration_score(self) -> MarkdownWidget:
|
|
21
|
+
text = self.vis_texts.markdown_calibration_score_1
|
|
22
|
+
text = text.format(self.vis_texts.definitions.confidence_score)
|
|
23
|
+
return MarkdownWidget(self.MARKDOWN_CALIBRATION_SCORE, "Calibration Score", text)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def collapse_tip(self) -> CollapseWidget:
|
|
27
|
+
md = MarkdownWidget(
|
|
28
|
+
"what_is_calibration",
|
|
29
|
+
"What is calibration?",
|
|
30
|
+
self.vis_texts.markdown_what_is_calibration,
|
|
31
|
+
)
|
|
32
|
+
return CollapseWidget([md])
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def md_calibration_score_2(self) -> MarkdownWidget:
|
|
36
|
+
return MarkdownWidget(
|
|
37
|
+
self.MARKDOWN_CALIBRATION_SCORE_2,
|
|
38
|
+
"",
|
|
39
|
+
self.vis_texts.markdown_calibration_score_2,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def md_reliability_diagram(self) -> MarkdownWidget:
|
|
44
|
+
return MarkdownWidget(
|
|
45
|
+
self.MARKDOWN_RELIABILITY_DIAGRAM,
|
|
46
|
+
"Reliability Diagram",
|
|
47
|
+
self.vis_texts.markdown_reliability_diagram,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def notification(self) -> NotificationWidget:
|
|
52
|
+
title, _ = self.vis_texts.notification_ece.values()
|
|
53
|
+
return NotificationWidget(
|
|
54
|
+
self.NOTIFICATION,
|
|
55
|
+
title.format(self.eval_result.mp.m_full.expected_calibration_error().round(4)),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def chart(self) -> ChartWidget:
|
|
60
|
+
return ChartWidget(self.CHART, self._get_figure())
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def collapse(self) -> CollapseWidget:
|
|
64
|
+
md = MarkdownWidget(
|
|
65
|
+
"markdown_calibration_curve_interpretation",
|
|
66
|
+
"How to interpret the Calibration curve",
|
|
67
|
+
self.vis_texts.markdown_calibration_curve_interpretation,
|
|
68
|
+
)
|
|
69
|
+
return CollapseWidget([md])
|
|
70
|
+
|
|
71
|
+
def _get_figure(self): # -> go.Figure:
|
|
72
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
73
|
+
|
|
74
|
+
# Calibration curve (only positive predictions)
|
|
75
|
+
true_probs, pred_probs = self.eval_result.mp.m_full.calibration_curve()
|
|
76
|
+
|
|
77
|
+
fig = go.Figure()
|
|
78
|
+
fig.add_trace(
|
|
79
|
+
go.Scatter(
|
|
80
|
+
x=pred_probs,
|
|
81
|
+
y=true_probs,
|
|
82
|
+
mode="lines+markers",
|
|
83
|
+
name="Calibration plot (Model)",
|
|
84
|
+
line=dict(color="blue"),
|
|
85
|
+
marker=dict(color="blue"),
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
fig.add_trace(
|
|
89
|
+
go.Scatter(
|
|
90
|
+
x=[0, 1],
|
|
91
|
+
y=[0, 1],
|
|
92
|
+
mode="lines",
|
|
93
|
+
name="Perfectly calibrated",
|
|
94
|
+
line=dict(color="orange", dash="dash"),
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
fig.update_layout(
|
|
99
|
+
xaxis_title="Confidence Score",
|
|
100
|
+
yaxis_title="Fraction of True Positives",
|
|
101
|
+
legend=dict(x=0.6, y=0.1),
|
|
102
|
+
xaxis=dict(range=[0, 1]),
|
|
103
|
+
yaxis=dict(range=[0, 1]),
|
|
104
|
+
width=700,
|
|
105
|
+
height=500,
|
|
106
|
+
)
|
|
107
|
+
fig.update_traces(
|
|
108
|
+
hovertemplate="Confidence Score: %{x:.2f}<br>Fraction of True Positives: %{y:.2f}<extra></extra>"
|
|
109
|
+
)
|
|
110
|
+
return fig
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from supervisely.nn.benchmark.object_detection.base_vis_metric import DetectionVisMetric
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
|
+
ChartWidget,
|
|
6
|
+
MarkdownWidget,
|
|
7
|
+
TableWidget,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Speedtest(DetectionVisMetric):
|
|
12
|
+
MARKDOWN_INTRO = "speedtest_intro"
|
|
13
|
+
TABLE_MARKDOWN = "speedtest_table"
|
|
14
|
+
TABLE = "speedtest"
|
|
15
|
+
CHART_MARKDOWN = "speedtest_chart"
|
|
16
|
+
CHART = "speedtest"
|
|
17
|
+
|
|
18
|
+
def is_empty(self) -> bool:
|
|
19
|
+
return not self.eval_result.speedtest_info
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def num_batche_sizes(self) -> int:
|
|
23
|
+
if self.is_empty():
|
|
24
|
+
return 0
|
|
25
|
+
return len(self.eval_result.speedtest_info.get("speedtest", []))
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def intro_md(self) -> MarkdownWidget:
|
|
29
|
+
text = self.vis_texts.markdown_speedtest_intro
|
|
30
|
+
text = text.format(
|
|
31
|
+
self.eval_result.speedtest_info["model_info"]["device"],
|
|
32
|
+
self.eval_result.speedtest_info["model_info"]["hardware"],
|
|
33
|
+
self.eval_result.speedtest_info["model_info"]["runtime"],
|
|
34
|
+
)
|
|
35
|
+
return MarkdownWidget(name=self.MARKDOWN_INTRO, title="Inference Speed", text=text)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def table_md(self) -> MarkdownWidget:
|
|
39
|
+
text = self.vis_texts.markdown_speedtest_table
|
|
40
|
+
text = text.format(self.eval_result.speedtest_info["speedtest"][0]["num_iterations"])
|
|
41
|
+
return MarkdownWidget(name=self.TABLE_MARKDOWN, title="Speedtest Table", text=text)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def table(self) -> TableWidget:
|
|
45
|
+
columns = [" ", "Inference time", "FPS"]
|
|
46
|
+
content = []
|
|
47
|
+
temp_res = {}
|
|
48
|
+
max_fps = 0
|
|
49
|
+
for test in self.eval_result.speedtest_info["speedtest"]:
|
|
50
|
+
batch_size = test["batch_size"]
|
|
51
|
+
|
|
52
|
+
ms = round(test["benchmark"]["total"], 2)
|
|
53
|
+
fps = round(1000 / test["benchmark"]["total"] * batch_size)
|
|
54
|
+
row = [f"Batch size {batch_size}", ms, fps]
|
|
55
|
+
temp_res[batch_size] = row
|
|
56
|
+
max_fps = max(max_fps, fps)
|
|
57
|
+
|
|
58
|
+
# sort by batch size
|
|
59
|
+
temp_res = dict(sorted(temp_res.items()))
|
|
60
|
+
for row in temp_res.values():
|
|
61
|
+
content.append({"row": row, "id": row[0], "items": row})
|
|
62
|
+
|
|
63
|
+
columns_options = [
|
|
64
|
+
{"disableSort": True}, # , "ustomCell": True},
|
|
65
|
+
{"subtitle": "ms", "tooltip": "Milliseconds for batch images", "postfix": "ms"},
|
|
66
|
+
{
|
|
67
|
+
"subtitle": "imgs/sec",
|
|
68
|
+
"tooltip": "Frames (images) per second",
|
|
69
|
+
"postfix": "fps",
|
|
70
|
+
"maxValue": max_fps,
|
|
71
|
+
},
|
|
72
|
+
]
|
|
73
|
+
data = {"columns": columns, "columnsOptions": columns_options, "content": content}
|
|
74
|
+
table = TableWidget(name=self.TABLE, data=data)
|
|
75
|
+
table.main_column = "Batch size"
|
|
76
|
+
table.fixed_columns = 1
|
|
77
|
+
table.show_header_controls = False
|
|
78
|
+
return table
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def chart_md(self) -> MarkdownWidget:
|
|
82
|
+
text = self.vis_texts.markdown_speedtest_chart
|
|
83
|
+
return MarkdownWidget(name=self.CHART_MARKDOWN, title="Speedtest Chart", text=text)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def chart(self) -> ChartWidget:
|
|
87
|
+
return ChartWidget(name=self.CHART, figure=self._get_figure())
|
|
88
|
+
|
|
89
|
+
def _get_figure(self): # -> go.Figure:
|
|
90
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
91
|
+
from plotly.subplots import make_subplots # pylint: disable=import-error
|
|
92
|
+
|
|
93
|
+
fig = make_subplots(cols=2)
|
|
94
|
+
|
|
95
|
+
ms_color = "#e377c2"
|
|
96
|
+
fps_color = "#17becf"
|
|
97
|
+
|
|
98
|
+
temp_res = {}
|
|
99
|
+
for test in self.eval_result.speedtest_info["speedtest"]:
|
|
100
|
+
batch_size = test["batch_size"]
|
|
101
|
+
|
|
102
|
+
std = test["benchmark_std"]["total"]
|
|
103
|
+
ms = test["benchmark"]["total"]
|
|
104
|
+
fps = round(1000 / test["benchmark"]["total"] * batch_size)
|
|
105
|
+
|
|
106
|
+
ms_line = temp_res.setdefault("ms", {})
|
|
107
|
+
fps_line = temp_res.setdefault("fps", {})
|
|
108
|
+
ms_std_line = temp_res.setdefault("ms_std", {})
|
|
109
|
+
|
|
110
|
+
ms_line[batch_size] = ms
|
|
111
|
+
fps_line[batch_size] = fps
|
|
112
|
+
ms_std_line[batch_size] = round(std, 2)
|
|
113
|
+
|
|
114
|
+
fig.add_trace(
|
|
115
|
+
go.Scatter(
|
|
116
|
+
x=list(temp_res["ms"].keys()),
|
|
117
|
+
y=list(temp_res["ms"].values()),
|
|
118
|
+
name="Inference time (ms)",
|
|
119
|
+
line=dict(color=ms_color),
|
|
120
|
+
customdata=list(temp_res["ms_std"].values()),
|
|
121
|
+
error_y=dict(
|
|
122
|
+
type="data",
|
|
123
|
+
array=list(temp_res["ms_std"].values()),
|
|
124
|
+
visible=True,
|
|
125
|
+
color="rgba(227, 119, 194, 0.7)",
|
|
126
|
+
),
|
|
127
|
+
hovertemplate="Batch Size: %{x}<br>Time: %{y:.2f} ms<br> Standard deviation: %{customdata:.2f} ms<extra></extra>",
|
|
128
|
+
),
|
|
129
|
+
col=1,
|
|
130
|
+
row=1,
|
|
131
|
+
)
|
|
132
|
+
fig.add_trace(
|
|
133
|
+
go.Scatter(
|
|
134
|
+
x=list(temp_res["fps"].keys()),
|
|
135
|
+
y=list(temp_res["fps"].values()),
|
|
136
|
+
name="FPS",
|
|
137
|
+
line=dict(color=fps_color),
|
|
138
|
+
hovertemplate="Batch Size: %{x}<br>FPS: %{y:.2f}<extra></extra>",
|
|
139
|
+
),
|
|
140
|
+
col=2,
|
|
141
|
+
row=1,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
fig.update_xaxes(title_text="Batch size", col=1, dtick=1)
|
|
145
|
+
fig.update_xaxes(title_text="Batch size", col=2, dtick=1)
|
|
146
|
+
|
|
147
|
+
fig.update_yaxes(title_text="Time (ms)", col=1)
|
|
148
|
+
fig.update_yaxes(title_text="FPS", col=2)
|
|
149
|
+
fig.update_layout(height=400)
|
|
150
|
+
|
|
151
|
+
return fig
|