supervisely 6.73.253__py3-none-any.whl → 6.73.255__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/file_api.py +16 -5
- supervisely/api/task_api.py +4 -2
- supervisely/app/widgets/field/field.py +10 -7
- supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
- supervisely/convert/image/sly/sly_image_converter.py +1 -1
- supervisely/nn/benchmark/base_benchmark.py +33 -35
- supervisely/nn/benchmark/base_evaluator.py +27 -1
- supervisely/nn/benchmark/base_visualizer.py +8 -11
- supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
- supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
- supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
- supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
- supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
- supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
- supervisely/nn/benchmark/visualization/renderer.py +25 -10
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
- supervisely/nn/inference/inference.py +1 -0
- supervisely/nn/training/gui/gui.py +32 -10
- supervisely/nn/training/gui/training_artifacts.py +145 -0
- supervisely/nn/training/gui/training_process.py +3 -19
- supervisely/nn/training/train_app.py +179 -70
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/METADATA +1 -1
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/RECORD +59 -47
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/LICENSE +0 -0
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/WHEEL +0 -0
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
2
|
+
from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class IntersectionErrorOverUnion(BaseVisMetrics):
|
|
6
|
+
|
|
7
|
+
@property
|
|
8
|
+
def md(self) -> MarkdownWidget:
|
|
9
|
+
return MarkdownWidget(
|
|
10
|
+
"intersection_error_over_union",
|
|
11
|
+
"Intersection & Error Over Union",
|
|
12
|
+
text=self.vis_texts.markdown_iou,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def chart(self) -> ChartWidget:
|
|
17
|
+
return ChartWidget("intersection_error_over_union", self.get_figure())
|
|
18
|
+
|
|
19
|
+
def get_figure(self):
|
|
20
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
21
|
+
from plotly.subplots import make_subplots # pylint: disable=import-error
|
|
22
|
+
|
|
23
|
+
labels = ["mIoU", "mBoundaryEoU", "mExtentEoU", "mSegmentEoU"]
|
|
24
|
+
|
|
25
|
+
length = len(self.eval_results)
|
|
26
|
+
cols = 3 if length > 2 else 2
|
|
27
|
+
cols = 4 if length % 4 == 0 else cols
|
|
28
|
+
rows = length // cols + (1 if length % cols != 0 else 0)
|
|
29
|
+
|
|
30
|
+
fig = make_subplots(rows=rows, cols=cols, specs=[[{"type": "domain"}] * cols] * rows)
|
|
31
|
+
|
|
32
|
+
annotations = []
|
|
33
|
+
for idx, eval_result in enumerate(self.eval_results, start=1):
|
|
34
|
+
col = idx % cols + (cols if idx % cols == 0 else 0)
|
|
35
|
+
row = idx // cols + (1 if idx % cols != 0 else 0)
|
|
36
|
+
|
|
37
|
+
fig.add_trace(
|
|
38
|
+
go.Pie(
|
|
39
|
+
labels=labels,
|
|
40
|
+
values=[
|
|
41
|
+
eval_result.mp.iou,
|
|
42
|
+
eval_result.mp.boundary_eou,
|
|
43
|
+
eval_result.mp.extent_eou,
|
|
44
|
+
eval_result.mp.segment_eou,
|
|
45
|
+
],
|
|
46
|
+
hole=0.5,
|
|
47
|
+
textposition="outside",
|
|
48
|
+
textinfo="percent+label",
|
|
49
|
+
marker=dict(colors=["#8ACAA1", "#FFE4B5", "#F7ADAA", "#dd3f3f"]),
|
|
50
|
+
),
|
|
51
|
+
row=row,
|
|
52
|
+
col=col,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
text = f"[{idx}] {eval_result.name[:7]}"
|
|
56
|
+
text += "..." if len(eval_result.name) > 7 else ""
|
|
57
|
+
annotations.append(
|
|
58
|
+
dict(
|
|
59
|
+
text=text,
|
|
60
|
+
x=sum(fig.get_subplot(row, col).x) / 2,
|
|
61
|
+
y=sum(fig.get_subplot(row, col).y) / 2,
|
|
62
|
+
showarrow=False,
|
|
63
|
+
xanchor="center",
|
|
64
|
+
)
|
|
65
|
+
)
|
|
66
|
+
fig.update_layout(annotations=annotations)
|
|
67
|
+
|
|
68
|
+
return fig
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from supervisely._utils import abs_url
|
|
4
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
5
|
+
from supervisely.nn.benchmark.visualization.evaluation_result import EvalResult
|
|
6
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
7
|
+
ChartWidget,
|
|
8
|
+
MarkdownWidget,
|
|
9
|
+
TableWidget,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Overview(BaseVisMetrics):
|
|
14
|
+
|
|
15
|
+
MARKDOWN_OVERVIEW = "markdown_overview"
|
|
16
|
+
MARKDOWN_OVERVIEW_INFO = "markdown_overview_info"
|
|
17
|
+
MARKDOWN_COMMON_OVERVIEW = "markdown_common_overview"
|
|
18
|
+
CHART = "chart_key_metrics"
|
|
19
|
+
|
|
20
|
+
def __init__(self, vis_texts, eval_results: List[EvalResult]) -> None:
|
|
21
|
+
super().__init__(vis_texts, eval_results)
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def overview_md(self) -> List[MarkdownWidget]:
|
|
25
|
+
info = []
|
|
26
|
+
model_names = []
|
|
27
|
+
for eval_result in self.eval_results:
|
|
28
|
+
model_name = eval_result.name or "Custom"
|
|
29
|
+
model_name = model_name.replace("_", "\_")
|
|
30
|
+
model_names.append(model_name)
|
|
31
|
+
|
|
32
|
+
info.append(
|
|
33
|
+
[
|
|
34
|
+
eval_result.gt_project_info.id,
|
|
35
|
+
eval_result.gt_project_info.name,
|
|
36
|
+
eval_result.inference_info.get("task_type"),
|
|
37
|
+
]
|
|
38
|
+
)
|
|
39
|
+
if all([model_name == "Custom" for model_name in model_names]):
|
|
40
|
+
model_name = "Custom models"
|
|
41
|
+
elif all([model_name == model_names[0] for model_name in model_names]):
|
|
42
|
+
model_name = model_names[0]
|
|
43
|
+
else:
|
|
44
|
+
model_name = " vs. ".join(model_names)
|
|
45
|
+
|
|
46
|
+
info = [model_name] + info[0]
|
|
47
|
+
|
|
48
|
+
text_template: str = getattr(self.vis_texts, self.MARKDOWN_COMMON_OVERVIEW)
|
|
49
|
+
return MarkdownWidget(
|
|
50
|
+
name=self.MARKDOWN_COMMON_OVERVIEW,
|
|
51
|
+
title="Overview",
|
|
52
|
+
text=text_template.format(*info),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def overview_widgets(self) -> List[MarkdownWidget]:
|
|
57
|
+
all_formats = []
|
|
58
|
+
for eval_result in self.eval_results:
|
|
59
|
+
|
|
60
|
+
url = eval_result.inference_info.get("checkpoint_url")
|
|
61
|
+
link_text = eval_result.inference_info.get("custom_checkpoint_path")
|
|
62
|
+
if link_text is None:
|
|
63
|
+
link_text = url
|
|
64
|
+
link_text = link_text.replace("_", "\_")
|
|
65
|
+
|
|
66
|
+
checkpoint_name = eval_result.checkpoint_name
|
|
67
|
+
model_name = eval_result.inference_info.get("model_name") or "Custom"
|
|
68
|
+
|
|
69
|
+
report = eval_result.api.file.get_info_by_path(self.team_id, eval_result.report_path)
|
|
70
|
+
report_link = abs_url(f"/model-benchmark?id={report.id}")
|
|
71
|
+
|
|
72
|
+
formats = [
|
|
73
|
+
checkpoint_name,
|
|
74
|
+
model_name.replace("_", "\_"),
|
|
75
|
+
checkpoint_name.replace("_", "\_"),
|
|
76
|
+
eval_result.inference_info.get("architecture"),
|
|
77
|
+
eval_result.inference_info.get("runtime"),
|
|
78
|
+
url,
|
|
79
|
+
link_text,
|
|
80
|
+
report_link,
|
|
81
|
+
]
|
|
82
|
+
all_formats.append(formats)
|
|
83
|
+
|
|
84
|
+
text_template: str = getattr(self.vis_texts, self.MARKDOWN_OVERVIEW_INFO)
|
|
85
|
+
widgets = []
|
|
86
|
+
for formats in all_formats:
|
|
87
|
+
md = MarkdownWidget(
|
|
88
|
+
name=self.MARKDOWN_OVERVIEW_INFO,
|
|
89
|
+
title="Overview",
|
|
90
|
+
text=text_template.format(*formats),
|
|
91
|
+
)
|
|
92
|
+
md.is_info_block = True
|
|
93
|
+
widgets.append(md)
|
|
94
|
+
return widgets
|
|
95
|
+
|
|
96
|
+
def get_table_widget(self, latency, fps) -> TableWidget:
|
|
97
|
+
res = {}
|
|
98
|
+
|
|
99
|
+
columns = ["metrics"] + [f"[{i+1}] {r.name}" for i, r in enumerate(self.eval_results)]
|
|
100
|
+
|
|
101
|
+
all_metrics = [eval_result.mp.key_metrics() for eval_result in self.eval_results]
|
|
102
|
+
res["content"] = []
|
|
103
|
+
|
|
104
|
+
for metric in all_metrics[0].keys():
|
|
105
|
+
values = [m[metric] for m in all_metrics]
|
|
106
|
+
values = [v if v is not None else "―" for v in values]
|
|
107
|
+
values = [round(v, 2) if isinstance(v, float) else v for v in values]
|
|
108
|
+
row = [metric] + values
|
|
109
|
+
dct = {"row": row, "id": metric, "items": row}
|
|
110
|
+
res["content"].append(dct)
|
|
111
|
+
|
|
112
|
+
latency_row = ["Latency (ms)"] + latency
|
|
113
|
+
res["content"].append({"row": latency_row, "id": latency_row[0], "items": latency_row})
|
|
114
|
+
|
|
115
|
+
fps_row = ["FPS"] + fps
|
|
116
|
+
res["content"].append({"row": fps_row, "id": fps_row[0], "items": fps_row})
|
|
117
|
+
|
|
118
|
+
columns_options = [{"disableSort": True} for _ in columns]
|
|
119
|
+
|
|
120
|
+
res["columns"] = columns
|
|
121
|
+
res["columnsOptions"] = columns_options
|
|
122
|
+
|
|
123
|
+
return TableWidget(
|
|
124
|
+
name="table_key_metrics",
|
|
125
|
+
data=res,
|
|
126
|
+
show_header_controls=False,
|
|
127
|
+
fix_columns=1,
|
|
128
|
+
page_size=len(res["content"]),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def chart_widget(self) -> ChartWidget:
|
|
133
|
+
return ChartWidget(name=self.CHART, figure=self.get_figure())
|
|
134
|
+
|
|
135
|
+
def get_overview_info(self, eval_result: EvalResult):
|
|
136
|
+
classes_cnt = len(eval_result.classes_whitelist)
|
|
137
|
+
classes_str = "classes" if classes_cnt > 1 else "class"
|
|
138
|
+
classes_str = f"{classes_cnt} {classes_str}"
|
|
139
|
+
|
|
140
|
+
train_session, images_str = "", ""
|
|
141
|
+
gt_project_id = eval_result.gt_project_info.id
|
|
142
|
+
gt_dataset_ids = eval_result.gt_dataset_ids
|
|
143
|
+
gt_images_cnt = eval_result.val_images_cnt
|
|
144
|
+
train_info = eval_result.train_info
|
|
145
|
+
total_imgs_cnt = eval_result.gt_project_info.items_count
|
|
146
|
+
if gt_images_cnt is not None:
|
|
147
|
+
val_imgs_cnt = gt_images_cnt
|
|
148
|
+
elif gt_dataset_ids is not None:
|
|
149
|
+
datasets = eval_result.gt_dataset_infos
|
|
150
|
+
val_imgs_cnt = sum(ds.items_count for ds in datasets)
|
|
151
|
+
else:
|
|
152
|
+
val_imgs_cnt = eval_result.gt_project_info.items_count
|
|
153
|
+
|
|
154
|
+
if train_info:
|
|
155
|
+
train_task_id = train_info.get("app_session_id")
|
|
156
|
+
if train_task_id:
|
|
157
|
+
task_info = eval_result.api.task.get_info_by_id(int(train_task_id))
|
|
158
|
+
app_id = task_info["meta"]["app"]["id"]
|
|
159
|
+
train_session = f'- **Training dashboard**: <a href="/apps/{app_id}/sessions/{train_task_id}" target="_blank">open</a>'
|
|
160
|
+
|
|
161
|
+
train_imgs_cnt = train_info.get("images_count")
|
|
162
|
+
images_str = f", {train_imgs_cnt} images in train, {val_imgs_cnt} images in validation"
|
|
163
|
+
|
|
164
|
+
if gt_images_cnt is not None:
|
|
165
|
+
images_str += (
|
|
166
|
+
f", total {total_imgs_cnt} images. Evaluated using subset - {val_imgs_cnt} images"
|
|
167
|
+
)
|
|
168
|
+
elif gt_dataset_ids is not None:
|
|
169
|
+
links = [
|
|
170
|
+
f'<a href="/projects/{gt_project_id}/datasets/{ds.id}" target="_blank">{ds.name}</a>'
|
|
171
|
+
for ds in datasets
|
|
172
|
+
]
|
|
173
|
+
images_str += f", total {total_imgs_cnt} images. Evaluated on the dataset{'s' if len(links) > 1 else ''}: {', '.join(links)}"
|
|
174
|
+
else:
|
|
175
|
+
images_str += f", total {total_imgs_cnt} images. Evaluated on the whole project ({val_imgs_cnt} images)"
|
|
176
|
+
|
|
177
|
+
return classes_str, images_str, train_session
|
|
178
|
+
|
|
179
|
+
def get_figure(self): # -> Optional[go.Figure]
|
|
180
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
181
|
+
|
|
182
|
+
# Overall Metrics
|
|
183
|
+
fig = go.Figure()
|
|
184
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
185
|
+
name = f"[{i + 1}] {eval_result.name}"
|
|
186
|
+
base_metrics = eval_result.mp.key_metrics().copy()
|
|
187
|
+
base_metrics["mPixel accuracy"] = round(base_metrics["mPixel accuracy"] * 100, 2)
|
|
188
|
+
r = list(base_metrics.values())
|
|
189
|
+
theta = list(base_metrics.keys())
|
|
190
|
+
fig.add_trace(
|
|
191
|
+
go.Scatterpolar(
|
|
192
|
+
r=r + [r[0]],
|
|
193
|
+
theta=theta + [theta[0]],
|
|
194
|
+
name=name,
|
|
195
|
+
marker=dict(color=eval_result.color),
|
|
196
|
+
hovertemplate=name + "<br>%{theta}: %{r:.2f}<extra></extra>",
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
fig.update_layout(
|
|
200
|
+
polar=dict(
|
|
201
|
+
radialaxis=dict(
|
|
202
|
+
range=[0, 105],
|
|
203
|
+
ticks="outside",
|
|
204
|
+
),
|
|
205
|
+
angularaxis=dict(rotation=90, direction="clockwise"),
|
|
206
|
+
),
|
|
207
|
+
dragmode=False,
|
|
208
|
+
height=500,
|
|
209
|
+
margin=dict(l=25, r=25, t=25, b=25),
|
|
210
|
+
modebar=dict(
|
|
211
|
+
remove=[
|
|
212
|
+
"zoom2d",
|
|
213
|
+
"pan2d",
|
|
214
|
+
"select2d",
|
|
215
|
+
"lasso2d",
|
|
216
|
+
"zoomIn2d",
|
|
217
|
+
"zoomOut2d",
|
|
218
|
+
"autoScale2d",
|
|
219
|
+
"resetScale2d",
|
|
220
|
+
]
|
|
221
|
+
),
|
|
222
|
+
)
|
|
223
|
+
return fig
|
supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
2
|
+
from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class RenormalizedErrorOverUnion(BaseVisMetrics):
|
|
6
|
+
|
|
7
|
+
@property
|
|
8
|
+
def md(self) -> MarkdownWidget:
|
|
9
|
+
return MarkdownWidget(
|
|
10
|
+
"renormalized_error_over_union",
|
|
11
|
+
"Renormalized Error over Union",
|
|
12
|
+
text=self.vis_texts.markdown_renormalized_error_ou,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def chart(self) -> ChartWidget:
|
|
17
|
+
return ChartWidget("intersection_error_over_union", self.get_figure())
|
|
18
|
+
|
|
19
|
+
def get_figure(self):
|
|
20
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
21
|
+
|
|
22
|
+
fig = go.Figure()
|
|
23
|
+
|
|
24
|
+
labels = ["Boundary EoU", "Extent EoU", "Segment EoU"]
|
|
25
|
+
|
|
26
|
+
for idx, eval_result in enumerate(self.eval_results, 1):
|
|
27
|
+
model_name = f"[{idx}] {eval_result.short_name}"
|
|
28
|
+
|
|
29
|
+
fig.add_trace(
|
|
30
|
+
go.Bar(
|
|
31
|
+
x=labels,
|
|
32
|
+
y=[
|
|
33
|
+
eval_result.mp.boundary_renormed_eou,
|
|
34
|
+
eval_result.mp.extent_renormed_eou,
|
|
35
|
+
eval_result.mp.segment_renormed_eou,
|
|
36
|
+
],
|
|
37
|
+
name=model_name,
|
|
38
|
+
text=[
|
|
39
|
+
eval_result.mp.boundary_renormed_eou,
|
|
40
|
+
eval_result.mp.extent_renormed_eou,
|
|
41
|
+
eval_result.mp.segment_renormed_eou,
|
|
42
|
+
],
|
|
43
|
+
textposition="outside",
|
|
44
|
+
marker=dict(color=eval_result.color, line=dict(width=0.7)),
|
|
45
|
+
width=0.4,
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
fig.update_traces(hovertemplate="%{x}: %{y:.2f}<extra></extra>")
|
|
50
|
+
fig.update_layout(
|
|
51
|
+
barmode="group",
|
|
52
|
+
bargap=0.15,
|
|
53
|
+
bargroupgap=0.05,
|
|
54
|
+
width=700 if len(self.eval_results) < 4 else 1000,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return fig
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
from supervisely.imaging.color import hex2rgb
|
|
4
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
5
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
6
|
+
ChartWidget,
|
|
7
|
+
MarkdownWidget,
|
|
8
|
+
TableWidget,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Speedtest(BaseVisMetrics):
|
|
13
|
+
|
|
14
|
+
def is_empty(self) -> bool:
|
|
15
|
+
return not any(eval_result.speedtest_info for eval_result in self.eval_results)
|
|
16
|
+
|
|
17
|
+
def multiple_batche_sizes(self) -> bool:
|
|
18
|
+
for eval_result in self.eval_results:
|
|
19
|
+
if eval_result.speedtest_info is None:
|
|
20
|
+
continue
|
|
21
|
+
if len(eval_result.speedtest_info["speedtest"]) > 1:
|
|
22
|
+
return True
|
|
23
|
+
return False
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def latency(self) -> List[Union[int, str]]:
|
|
27
|
+
latency = []
|
|
28
|
+
for eval_result in self.eval_results:
|
|
29
|
+
if eval_result.speedtest_info is None:
|
|
30
|
+
latency.append("N/A")
|
|
31
|
+
else:
|
|
32
|
+
added = False
|
|
33
|
+
for test in eval_result.speedtest_info["speedtest"]:
|
|
34
|
+
if test["batch_size"] == 1:
|
|
35
|
+
latency.append(round(test["benchmark"]["total"], 2))
|
|
36
|
+
added = True
|
|
37
|
+
break
|
|
38
|
+
if not added:
|
|
39
|
+
latency.append("N/A")
|
|
40
|
+
return latency
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def fps(self) -> List[Union[int, str]]:
|
|
44
|
+
fps = []
|
|
45
|
+
for eval_result in self.eval_results:
|
|
46
|
+
if eval_result.speedtest_info is None:
|
|
47
|
+
fps.append("N/A")
|
|
48
|
+
else:
|
|
49
|
+
added = False
|
|
50
|
+
for test in eval_result.speedtest_info["speedtest"]:
|
|
51
|
+
if test["batch_size"] == 1:
|
|
52
|
+
fps.append(round(1000 / test["benchmark"]["total"], 2))
|
|
53
|
+
added = True
|
|
54
|
+
break
|
|
55
|
+
if not added:
|
|
56
|
+
fps.append("N/A")
|
|
57
|
+
return fps
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def md_intro(self) -> MarkdownWidget:
|
|
61
|
+
return MarkdownWidget(
|
|
62
|
+
name="speedtest_intro",
|
|
63
|
+
title="Inference Speed",
|
|
64
|
+
text=self.vis_texts.markdown_speedtest_intro,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def intro_table(self) -> TableWidget:
|
|
69
|
+
columns = ["Model", "Device", "Hardware", "Runtime"]
|
|
70
|
+
columns_options = [{"disableSort": True} for _ in columns]
|
|
71
|
+
content = []
|
|
72
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
73
|
+
name = f"[{i}] {eval_result.name}"
|
|
74
|
+
if eval_result.speedtest_info is None:
|
|
75
|
+
row = [name, "N/A", "N/A", "N/A"]
|
|
76
|
+
dct = {
|
|
77
|
+
"row": row,
|
|
78
|
+
"id": name,
|
|
79
|
+
"items": row,
|
|
80
|
+
}
|
|
81
|
+
content.append(dct)
|
|
82
|
+
continue
|
|
83
|
+
model_info = eval_result.speedtest_info.get("model_info", {})
|
|
84
|
+
device = model_info.get("device", "N/A")
|
|
85
|
+
hardware = model_info.get("hardware", "N/A")
|
|
86
|
+
runtime = model_info.get("runtime", "N/A")
|
|
87
|
+
row = [name, device, hardware, runtime]
|
|
88
|
+
dct = {
|
|
89
|
+
"row": row,
|
|
90
|
+
"id": name,
|
|
91
|
+
"items": row,
|
|
92
|
+
}
|
|
93
|
+
content.append(dct)
|
|
94
|
+
|
|
95
|
+
data = {
|
|
96
|
+
"columns": columns,
|
|
97
|
+
"columnsOptions": columns_options,
|
|
98
|
+
"content": content,
|
|
99
|
+
}
|
|
100
|
+
return TableWidget(
|
|
101
|
+
name="speedtest_intro_table",
|
|
102
|
+
data=data,
|
|
103
|
+
show_header_controls=False,
|
|
104
|
+
fix_columns=1,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def inference_time_md(self) -> MarkdownWidget:
|
|
109
|
+
text = self.vis_texts.markdown_speedtest_overview_ms.format(100)
|
|
110
|
+
return MarkdownWidget(
|
|
111
|
+
name="inference_time_md",
|
|
112
|
+
title="Overview",
|
|
113
|
+
text=text,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def fps_md(self) -> MarkdownWidget:
|
|
118
|
+
text = self.vis_texts.markdown_speedtest_overview_fps.format(100)
|
|
119
|
+
return MarkdownWidget(
|
|
120
|
+
name="fps_md",
|
|
121
|
+
title="FPS Table",
|
|
122
|
+
text=text,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def fps_table(self) -> TableWidget:
|
|
127
|
+
data = {}
|
|
128
|
+
batch_sizes = set()
|
|
129
|
+
max_fps = 0
|
|
130
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
131
|
+
data[i] = {}
|
|
132
|
+
if eval_result.speedtest_info is None:
|
|
133
|
+
continue
|
|
134
|
+
speedtests = eval_result.speedtest_info["speedtest"]
|
|
135
|
+
for test in speedtests:
|
|
136
|
+
batch_size = test["batch_size"]
|
|
137
|
+
fps = round(1000 / test["benchmark"]["total"] * batch_size)
|
|
138
|
+
batch_sizes.add(batch_size)
|
|
139
|
+
max_fps = max(max_fps, fps)
|
|
140
|
+
data[i][batch_size] = fps
|
|
141
|
+
|
|
142
|
+
batch_sizes = sorted(batch_sizes)
|
|
143
|
+
columns = ["Model"]
|
|
144
|
+
columns_options = [{"disableSort": True}]
|
|
145
|
+
for batch_size in batch_sizes:
|
|
146
|
+
columns.append(f"Batch size {batch_size}")
|
|
147
|
+
columns_options.append(
|
|
148
|
+
{
|
|
149
|
+
"subtitle": "imgs/sec",
|
|
150
|
+
"tooltip": "Frames (images) per second",
|
|
151
|
+
"postfix": "fps",
|
|
152
|
+
# "maxValue": max_fps,
|
|
153
|
+
}
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
content = []
|
|
157
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
158
|
+
name = f"[{i}] {eval_result.name}"
|
|
159
|
+
row = [name]
|
|
160
|
+
for batch_size in batch_sizes:
|
|
161
|
+
if batch_size in data[i]:
|
|
162
|
+
row.append(data[i][batch_size])
|
|
163
|
+
else:
|
|
164
|
+
row.append("―")
|
|
165
|
+
content.append(
|
|
166
|
+
{
|
|
167
|
+
"row": row,
|
|
168
|
+
"id": name,
|
|
169
|
+
"items": row,
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
data = {
|
|
173
|
+
"columns": columns,
|
|
174
|
+
"columnsOptions": columns_options,
|
|
175
|
+
"content": content,
|
|
176
|
+
}
|
|
177
|
+
return TableWidget(
|
|
178
|
+
name="fps_table",
|
|
179
|
+
data=data,
|
|
180
|
+
show_header_controls=False,
|
|
181
|
+
fix_columns=1,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def inference_time_table(self) -> TableWidget:
|
|
186
|
+
data = {}
|
|
187
|
+
batch_sizes = set()
|
|
188
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
189
|
+
data[i] = {}
|
|
190
|
+
if eval_result.speedtest_info is None:
|
|
191
|
+
continue
|
|
192
|
+
speedtests = eval_result.speedtest_info["speedtest"]
|
|
193
|
+
for test in speedtests:
|
|
194
|
+
batch_size = test["batch_size"]
|
|
195
|
+
ms = round(test["benchmark"]["total"], 2)
|
|
196
|
+
batch_sizes.add(batch_size)
|
|
197
|
+
data[i][batch_size] = ms
|
|
198
|
+
|
|
199
|
+
batch_sizes = sorted(batch_sizes)
|
|
200
|
+
columns = ["Model"]
|
|
201
|
+
columns_options = [{"disableSort": True}]
|
|
202
|
+
for batch_size in batch_sizes:
|
|
203
|
+
columns.extend([f"Batch size {batch_size}"])
|
|
204
|
+
columns_options.extend(
|
|
205
|
+
[
|
|
206
|
+
{"subtitle": "ms", "tooltip": "Milliseconds for batch images", "postfix": "ms"},
|
|
207
|
+
]
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
content = []
|
|
211
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
212
|
+
name = f"[{i}] {eval_result.name}"
|
|
213
|
+
row = [name]
|
|
214
|
+
for batch_size in batch_sizes:
|
|
215
|
+
if batch_size in data[i]:
|
|
216
|
+
row.append(data[i][batch_size])
|
|
217
|
+
else:
|
|
218
|
+
row.append("―")
|
|
219
|
+
content.append(
|
|
220
|
+
{
|
|
221
|
+
"row": row,
|
|
222
|
+
"id": name,
|
|
223
|
+
"items": row,
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
data = {
|
|
228
|
+
"columns": columns,
|
|
229
|
+
"columnsOptions": columns_options,
|
|
230
|
+
"content": content,
|
|
231
|
+
}
|
|
232
|
+
return TableWidget(
|
|
233
|
+
name="inference_time_md",
|
|
234
|
+
data=data,
|
|
235
|
+
show_header_controls=False,
|
|
236
|
+
fix_columns=1,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def batch_inference_md(self):
|
|
241
|
+
return MarkdownWidget(
|
|
242
|
+
name="batch_inference",
|
|
243
|
+
title="Batch Inference",
|
|
244
|
+
text=self.vis_texts.markdown_batch_inference,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def chart(self) -> ChartWidget:
|
|
249
|
+
return ChartWidget(name="speed_charts", figure=self.get_figure())
|
|
250
|
+
|
|
251
|
+
def get_figure(self): # -> Optional[go.Figure]
|
|
252
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
253
|
+
from plotly.subplots import make_subplots # pylint: disable=import-error
|
|
254
|
+
|
|
255
|
+
fig = make_subplots(cols=2)
|
|
256
|
+
|
|
257
|
+
for idx, eval_result in enumerate(self.eval_results, 1):
|
|
258
|
+
if eval_result.speedtest_info is None:
|
|
259
|
+
continue
|
|
260
|
+
temp_res = {}
|
|
261
|
+
for test in eval_result.speedtest_info["speedtest"]:
|
|
262
|
+
batch_size = test["batch_size"]
|
|
263
|
+
|
|
264
|
+
std = test["benchmark_std"]["total"]
|
|
265
|
+
ms = test["benchmark"]["total"]
|
|
266
|
+
fps = round(1000 / test["benchmark"]["total"] * batch_size)
|
|
267
|
+
|
|
268
|
+
ms_line = temp_res.setdefault("ms", {})
|
|
269
|
+
fps_line = temp_res.setdefault("fps", {})
|
|
270
|
+
ms_std_line = temp_res.setdefault("ms_std", {})
|
|
271
|
+
|
|
272
|
+
ms_line[batch_size] = ms
|
|
273
|
+
fps_line[batch_size] = fps
|
|
274
|
+
ms_std_line[batch_size] = round(std, 2)
|
|
275
|
+
|
|
276
|
+
error_color = "rgba(" + ",".join(map(str, hex2rgb(eval_result.color))) + ", 0.5)"
|
|
277
|
+
fig.add_trace(
|
|
278
|
+
go.Scatter(
|
|
279
|
+
x=list(temp_res["ms"].keys()),
|
|
280
|
+
y=list(temp_res["ms"].values()),
|
|
281
|
+
name=f"[{idx}] {eval_result.name} (ms)",
|
|
282
|
+
line=dict(color=eval_result.color),
|
|
283
|
+
customdata=list(temp_res["ms_std"].values()),
|
|
284
|
+
error_y=dict(
|
|
285
|
+
type="data",
|
|
286
|
+
array=list(temp_res["ms_std"].values()),
|
|
287
|
+
visible=True,
|
|
288
|
+
color=error_color,
|
|
289
|
+
),
|
|
290
|
+
hovertemplate="Batch Size: %{x}<br>Time: %{y:.2f} ms<br> Standard deviation: %{customdata:.2f} ms<extra></extra>",
|
|
291
|
+
),
|
|
292
|
+
col=1,
|
|
293
|
+
row=1,
|
|
294
|
+
)
|
|
295
|
+
fig.add_trace(
|
|
296
|
+
go.Scatter(
|
|
297
|
+
x=list(temp_res["fps"].keys()),
|
|
298
|
+
y=list(temp_res["fps"].values()),
|
|
299
|
+
name=f"[{idx}] {eval_result.name} (fps)",
|
|
300
|
+
line=dict(color=eval_result.color),
|
|
301
|
+
hovertemplate="Batch Size: %{x}<br>FPS: %{y:.2f}<extra></extra>", # <br> Standard deviation: %{customdata:.2f}<extra></extra>",
|
|
302
|
+
),
|
|
303
|
+
col=2,
|
|
304
|
+
row=1,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
fig.update_xaxes(title_text="Batch size", col=1, dtick=1)
|
|
308
|
+
fig.update_xaxes(title_text="Batch size", col=2, dtick=1)
|
|
309
|
+
|
|
310
|
+
fig.update_yaxes(title_text="Time (ms)", col=1)
|
|
311
|
+
fig.update_yaxes(title_text="FPS", col=2)
|
|
312
|
+
fig.update_layout(height=400)
|
|
313
|
+
|
|
314
|
+
return fig
|