supervisely 6.73.253__py3-none-any.whl → 6.73.255__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (60) hide show
  1. supervisely/api/file_api.py +16 -5
  2. supervisely/api/task_api.py +4 -2
  3. supervisely/app/widgets/field/field.py +10 -7
  4. supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
  5. supervisely/convert/image/sly/sly_image_converter.py +1 -1
  6. supervisely/nn/benchmark/base_benchmark.py +33 -35
  7. supervisely/nn/benchmark/base_evaluator.py +27 -1
  8. supervisely/nn/benchmark/base_visualizer.py +8 -11
  9. supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
  10. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
  11. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
  12. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
  13. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
  14. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
  15. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
  16. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
  17. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
  18. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
  19. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
  20. supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
  21. supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
  22. supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
  23. supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
  24. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
  25. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
  26. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
  27. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
  28. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
  29. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
  30. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
  31. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
  32. supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
  33. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
  34. supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
  35. supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
  36. supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
  37. supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
  38. supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
  39. supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
  40. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
  42. supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
  43. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
  44. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
  45. supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
  46. supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
  47. supervisely/nn/benchmark/visualization/renderer.py +25 -10
  48. supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
  49. supervisely/nn/inference/inference.py +1 -0
  50. supervisely/nn/training/gui/gui.py +32 -10
  51. supervisely/nn/training/gui/training_artifacts.py +145 -0
  52. supervisely/nn/training/gui/training_process.py +3 -19
  53. supervisely/nn/training/train_app.py +179 -70
  54. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/METADATA +1 -1
  55. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/RECORD +59 -47
  56. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
  57. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/LICENSE +0 -0
  58. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/WHEEL +0 -0
  59. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/entry_points.txt +0 -0
  60. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
2
+ from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
3
+
4
+
5
+ class IntersectionErrorOverUnion(BaseVisMetrics):
6
+
7
+ @property
8
+ def md(self) -> MarkdownWidget:
9
+ return MarkdownWidget(
10
+ "intersection_error_over_union",
11
+ "Intersection & Error Over Union",
12
+ text=self.vis_texts.markdown_iou,
13
+ )
14
+
15
+ @property
16
+ def chart(self) -> ChartWidget:
17
+ return ChartWidget("intersection_error_over_union", self.get_figure())
18
+
19
+ def get_figure(self):
20
+ import plotly.graph_objects as go # pylint: disable=import-error
21
+ from plotly.subplots import make_subplots # pylint: disable=import-error
22
+
23
+ labels = ["mIoU", "mBoundaryEoU", "mExtentEoU", "mSegmentEoU"]
24
+
25
+ length = len(self.eval_results)
26
+ cols = 3 if length > 2 else 2
27
+ cols = 4 if length % 4 == 0 else cols
28
+ rows = length // cols + (1 if length % cols != 0 else 0)
29
+
30
+ fig = make_subplots(rows=rows, cols=cols, specs=[[{"type": "domain"}] * cols] * rows)
31
+
32
+ annotations = []
33
+ for idx, eval_result in enumerate(self.eval_results, start=1):
34
+ col = idx % cols + (cols if idx % cols == 0 else 0)
35
+ row = idx // cols + (1 if idx % cols != 0 else 0)
36
+
37
+ fig.add_trace(
38
+ go.Pie(
39
+ labels=labels,
40
+ values=[
41
+ eval_result.mp.iou,
42
+ eval_result.mp.boundary_eou,
43
+ eval_result.mp.extent_eou,
44
+ eval_result.mp.segment_eou,
45
+ ],
46
+ hole=0.5,
47
+ textposition="outside",
48
+ textinfo="percent+label",
49
+ marker=dict(colors=["#8ACAA1", "#FFE4B5", "#F7ADAA", "#dd3f3f"]),
50
+ ),
51
+ row=row,
52
+ col=col,
53
+ )
54
+
55
+ text = f"[{idx}] {eval_result.name[:7]}"
56
+ text += "..." if len(eval_result.name) > 7 else ""
57
+ annotations.append(
58
+ dict(
59
+ text=text,
60
+ x=sum(fig.get_subplot(row, col).x) / 2,
61
+ y=sum(fig.get_subplot(row, col).y) / 2,
62
+ showarrow=False,
63
+ xanchor="center",
64
+ )
65
+ )
66
+ fig.update_layout(annotations=annotations)
67
+
68
+ return fig
@@ -0,0 +1,223 @@
1
+ from typing import List
2
+
3
+ from supervisely._utils import abs_url
4
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
5
+ from supervisely.nn.benchmark.visualization.evaluation_result import EvalResult
6
+ from supervisely.nn.benchmark.visualization.widgets import (
7
+ ChartWidget,
8
+ MarkdownWidget,
9
+ TableWidget,
10
+ )
11
+
12
+
13
+ class Overview(BaseVisMetrics):
14
+
15
+ MARKDOWN_OVERVIEW = "markdown_overview"
16
+ MARKDOWN_OVERVIEW_INFO = "markdown_overview_info"
17
+ MARKDOWN_COMMON_OVERVIEW = "markdown_common_overview"
18
+ CHART = "chart_key_metrics"
19
+
20
+ def __init__(self, vis_texts, eval_results: List[EvalResult]) -> None:
21
+ super().__init__(vis_texts, eval_results)
22
+
23
+ @property
24
+ def overview_md(self) -> List[MarkdownWidget]:
25
+ info = []
26
+ model_names = []
27
+ for eval_result in self.eval_results:
28
+ model_name = eval_result.name or "Custom"
29
+ model_name = model_name.replace("_", "\_")
30
+ model_names.append(model_name)
31
+
32
+ info.append(
33
+ [
34
+ eval_result.gt_project_info.id,
35
+ eval_result.gt_project_info.name,
36
+ eval_result.inference_info.get("task_type"),
37
+ ]
38
+ )
39
+ if all([model_name == "Custom" for model_name in model_names]):
40
+ model_name = "Custom models"
41
+ elif all([model_name == model_names[0] for model_name in model_names]):
42
+ model_name = model_names[0]
43
+ else:
44
+ model_name = " vs. ".join(model_names)
45
+
46
+ info = [model_name] + info[0]
47
+
48
+ text_template: str = getattr(self.vis_texts, self.MARKDOWN_COMMON_OVERVIEW)
49
+ return MarkdownWidget(
50
+ name=self.MARKDOWN_COMMON_OVERVIEW,
51
+ title="Overview",
52
+ text=text_template.format(*info),
53
+ )
54
+
55
+ @property
56
+ def overview_widgets(self) -> List[MarkdownWidget]:
57
+ all_formats = []
58
+ for eval_result in self.eval_results:
59
+
60
+ url = eval_result.inference_info.get("checkpoint_url")
61
+ link_text = eval_result.inference_info.get("custom_checkpoint_path")
62
+ if link_text is None:
63
+ link_text = url
64
+ link_text = link_text.replace("_", "\_")
65
+
66
+ checkpoint_name = eval_result.checkpoint_name
67
+ model_name = eval_result.inference_info.get("model_name") or "Custom"
68
+
69
+ report = eval_result.api.file.get_info_by_path(self.team_id, eval_result.report_path)
70
+ report_link = abs_url(f"/model-benchmark?id={report.id}")
71
+
72
+ formats = [
73
+ checkpoint_name,
74
+ model_name.replace("_", "\_"),
75
+ checkpoint_name.replace("_", "\_"),
76
+ eval_result.inference_info.get("architecture"),
77
+ eval_result.inference_info.get("runtime"),
78
+ url,
79
+ link_text,
80
+ report_link,
81
+ ]
82
+ all_formats.append(formats)
83
+
84
+ text_template: str = getattr(self.vis_texts, self.MARKDOWN_OVERVIEW_INFO)
85
+ widgets = []
86
+ for formats in all_formats:
87
+ md = MarkdownWidget(
88
+ name=self.MARKDOWN_OVERVIEW_INFO,
89
+ title="Overview",
90
+ text=text_template.format(*formats),
91
+ )
92
+ md.is_info_block = True
93
+ widgets.append(md)
94
+ return widgets
95
+
96
+ def get_table_widget(self, latency, fps) -> TableWidget:
97
+ res = {}
98
+
99
+ columns = ["metrics"] + [f"[{i+1}] {r.name}" for i, r in enumerate(self.eval_results)]
100
+
101
+ all_metrics = [eval_result.mp.key_metrics() for eval_result in self.eval_results]
102
+ res["content"] = []
103
+
104
+ for metric in all_metrics[0].keys():
105
+ values = [m[metric] for m in all_metrics]
106
+ values = [v if v is not None else "―" for v in values]
107
+ values = [round(v, 2) if isinstance(v, float) else v for v in values]
108
+ row = [metric] + values
109
+ dct = {"row": row, "id": metric, "items": row}
110
+ res["content"].append(dct)
111
+
112
+ latency_row = ["Latency (ms)"] + latency
113
+ res["content"].append({"row": latency_row, "id": latency_row[0], "items": latency_row})
114
+
115
+ fps_row = ["FPS"] + fps
116
+ res["content"].append({"row": fps_row, "id": fps_row[0], "items": fps_row})
117
+
118
+ columns_options = [{"disableSort": True} for _ in columns]
119
+
120
+ res["columns"] = columns
121
+ res["columnsOptions"] = columns_options
122
+
123
+ return TableWidget(
124
+ name="table_key_metrics",
125
+ data=res,
126
+ show_header_controls=False,
127
+ fix_columns=1,
128
+ page_size=len(res["content"]),
129
+ )
130
+
131
+ @property
132
+ def chart_widget(self) -> ChartWidget:
133
+ return ChartWidget(name=self.CHART, figure=self.get_figure())
134
+
135
+ def get_overview_info(self, eval_result: EvalResult):
136
+ classes_cnt = len(eval_result.classes_whitelist)
137
+ classes_str = "classes" if classes_cnt > 1 else "class"
138
+ classes_str = f"{classes_cnt} {classes_str}"
139
+
140
+ train_session, images_str = "", ""
141
+ gt_project_id = eval_result.gt_project_info.id
142
+ gt_dataset_ids = eval_result.gt_dataset_ids
143
+ gt_images_cnt = eval_result.val_images_cnt
144
+ train_info = eval_result.train_info
145
+ total_imgs_cnt = eval_result.gt_project_info.items_count
146
+ if gt_images_cnt is not None:
147
+ val_imgs_cnt = gt_images_cnt
148
+ elif gt_dataset_ids is not None:
149
+ datasets = eval_result.gt_dataset_infos
150
+ val_imgs_cnt = sum(ds.items_count for ds in datasets)
151
+ else:
152
+ val_imgs_cnt = eval_result.gt_project_info.items_count
153
+
154
+ if train_info:
155
+ train_task_id = train_info.get("app_session_id")
156
+ if train_task_id:
157
+ task_info = eval_result.api.task.get_info_by_id(int(train_task_id))
158
+ app_id = task_info["meta"]["app"]["id"]
159
+ train_session = f'- **Training dashboard**: <a href="/apps/{app_id}/sessions/{train_task_id}" target="_blank">open</a>'
160
+
161
+ train_imgs_cnt = train_info.get("images_count")
162
+ images_str = f", {train_imgs_cnt} images in train, {val_imgs_cnt} images in validation"
163
+
164
+ if gt_images_cnt is not None:
165
+ images_str += (
166
+ f", total {total_imgs_cnt} images. Evaluated using subset - {val_imgs_cnt} images"
167
+ )
168
+ elif gt_dataset_ids is not None:
169
+ links = [
170
+ f'<a href="/projects/{gt_project_id}/datasets/{ds.id}" target="_blank">{ds.name}</a>'
171
+ for ds in datasets
172
+ ]
173
+ images_str += f", total {total_imgs_cnt} images. Evaluated on the dataset{'s' if len(links) > 1 else ''}: {', '.join(links)}"
174
+ else:
175
+ images_str += f", total {total_imgs_cnt} images. Evaluated on the whole project ({val_imgs_cnt} images)"
176
+
177
+ return classes_str, images_str, train_session
178
+
179
+ def get_figure(self): # -> Optional[go.Figure]
180
+ import plotly.graph_objects as go # pylint: disable=import-error
181
+
182
+ # Overall Metrics
183
+ fig = go.Figure()
184
+ for i, eval_result in enumerate(self.eval_results):
185
+ name = f"[{i + 1}] {eval_result.name}"
186
+ base_metrics = eval_result.mp.key_metrics().copy()
187
+ base_metrics["mPixel accuracy"] = round(base_metrics["mPixel accuracy"] * 100, 2)
188
+ r = list(base_metrics.values())
189
+ theta = list(base_metrics.keys())
190
+ fig.add_trace(
191
+ go.Scatterpolar(
192
+ r=r + [r[0]],
193
+ theta=theta + [theta[0]],
194
+ name=name,
195
+ marker=dict(color=eval_result.color),
196
+ hovertemplate=name + "<br>%{theta}: %{r:.2f}<extra></extra>",
197
+ )
198
+ )
199
+ fig.update_layout(
200
+ polar=dict(
201
+ radialaxis=dict(
202
+ range=[0, 105],
203
+ ticks="outside",
204
+ ),
205
+ angularaxis=dict(rotation=90, direction="clockwise"),
206
+ ),
207
+ dragmode=False,
208
+ height=500,
209
+ margin=dict(l=25, r=25, t=25, b=25),
210
+ modebar=dict(
211
+ remove=[
212
+ "zoom2d",
213
+ "pan2d",
214
+ "select2d",
215
+ "lasso2d",
216
+ "zoomIn2d",
217
+ "zoomOut2d",
218
+ "autoScale2d",
219
+ "resetScale2d",
220
+ ]
221
+ ),
222
+ )
223
+ return fig
@@ -0,0 +1,57 @@
1
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
2
+ from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
3
+
4
+
5
+ class RenormalizedErrorOverUnion(BaseVisMetrics):
6
+
7
+ @property
8
+ def md(self) -> MarkdownWidget:
9
+ return MarkdownWidget(
10
+ "renormalized_error_over_union",
11
+ "Renormalized Error over Union",
12
+ text=self.vis_texts.markdown_renormalized_error_ou,
13
+ )
14
+
15
+ @property
16
+ def chart(self) -> ChartWidget:
17
+ return ChartWidget("intersection_error_over_union", self.get_figure())
18
+
19
+ def get_figure(self):
20
+ import plotly.graph_objects as go # pylint: disable=import-error
21
+
22
+ fig = go.Figure()
23
+
24
+ labels = ["Boundary EoU", "Extent EoU", "Segment EoU"]
25
+
26
+ for idx, eval_result in enumerate(self.eval_results, 1):
27
+ model_name = f"[{idx}] {eval_result.short_name}"
28
+
29
+ fig.add_trace(
30
+ go.Bar(
31
+ x=labels,
32
+ y=[
33
+ eval_result.mp.boundary_renormed_eou,
34
+ eval_result.mp.extent_renormed_eou,
35
+ eval_result.mp.segment_renormed_eou,
36
+ ],
37
+ name=model_name,
38
+ text=[
39
+ eval_result.mp.boundary_renormed_eou,
40
+ eval_result.mp.extent_renormed_eou,
41
+ eval_result.mp.segment_renormed_eou,
42
+ ],
43
+ textposition="outside",
44
+ marker=dict(color=eval_result.color, line=dict(width=0.7)),
45
+ width=0.4,
46
+ )
47
+ )
48
+
49
+ fig.update_traces(hovertemplate="%{x}: %{y:.2f}<extra></extra>")
50
+ fig.update_layout(
51
+ barmode="group",
52
+ bargap=0.15,
53
+ bargroupgap=0.05,
54
+ width=700 if len(self.eval_results) < 4 else 1000,
55
+ )
56
+
57
+ return fig
@@ -0,0 +1,314 @@
1
+ from typing import List, Union
2
+
3
+ from supervisely.imaging.color import hex2rgb
4
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
5
+ from supervisely.nn.benchmark.visualization.widgets import (
6
+ ChartWidget,
7
+ MarkdownWidget,
8
+ TableWidget,
9
+ )
10
+
11
+
12
+ class Speedtest(BaseVisMetrics):
13
+
14
+ def is_empty(self) -> bool:
15
+ return not any(eval_result.speedtest_info for eval_result in self.eval_results)
16
+
17
+ def multiple_batche_sizes(self) -> bool:
18
+ for eval_result in self.eval_results:
19
+ if eval_result.speedtest_info is None:
20
+ continue
21
+ if len(eval_result.speedtest_info["speedtest"]) > 1:
22
+ return True
23
+ return False
24
+
25
+ @property
26
+ def latency(self) -> List[Union[int, str]]:
27
+ latency = []
28
+ for eval_result in self.eval_results:
29
+ if eval_result.speedtest_info is None:
30
+ latency.append("N/A")
31
+ else:
32
+ added = False
33
+ for test in eval_result.speedtest_info["speedtest"]:
34
+ if test["batch_size"] == 1:
35
+ latency.append(round(test["benchmark"]["total"], 2))
36
+ added = True
37
+ break
38
+ if not added:
39
+ latency.append("N/A")
40
+ return latency
41
+
42
+ @property
43
+ def fps(self) -> List[Union[int, str]]:
44
+ fps = []
45
+ for eval_result in self.eval_results:
46
+ if eval_result.speedtest_info is None:
47
+ fps.append("N/A")
48
+ else:
49
+ added = False
50
+ for test in eval_result.speedtest_info["speedtest"]:
51
+ if test["batch_size"] == 1:
52
+ fps.append(round(1000 / test["benchmark"]["total"], 2))
53
+ added = True
54
+ break
55
+ if not added:
56
+ fps.append("N/A")
57
+ return fps
58
+
59
+ @property
60
+ def md_intro(self) -> MarkdownWidget:
61
+ return MarkdownWidget(
62
+ name="speedtest_intro",
63
+ title="Inference Speed",
64
+ text=self.vis_texts.markdown_speedtest_intro,
65
+ )
66
+
67
+ @property
68
+ def intro_table(self) -> TableWidget:
69
+ columns = ["Model", "Device", "Hardware", "Runtime"]
70
+ columns_options = [{"disableSort": True} for _ in columns]
71
+ content = []
72
+ for i, eval_result in enumerate(self.eval_results, 1):
73
+ name = f"[{i}] {eval_result.name}"
74
+ if eval_result.speedtest_info is None:
75
+ row = [name, "N/A", "N/A", "N/A"]
76
+ dct = {
77
+ "row": row,
78
+ "id": name,
79
+ "items": row,
80
+ }
81
+ content.append(dct)
82
+ continue
83
+ model_info = eval_result.speedtest_info.get("model_info", {})
84
+ device = model_info.get("device", "N/A")
85
+ hardware = model_info.get("hardware", "N/A")
86
+ runtime = model_info.get("runtime", "N/A")
87
+ row = [name, device, hardware, runtime]
88
+ dct = {
89
+ "row": row,
90
+ "id": name,
91
+ "items": row,
92
+ }
93
+ content.append(dct)
94
+
95
+ data = {
96
+ "columns": columns,
97
+ "columnsOptions": columns_options,
98
+ "content": content,
99
+ }
100
+ return TableWidget(
101
+ name="speedtest_intro_table",
102
+ data=data,
103
+ show_header_controls=False,
104
+ fix_columns=1,
105
+ )
106
+
107
+ @property
108
+ def inference_time_md(self) -> MarkdownWidget:
109
+ text = self.vis_texts.markdown_speedtest_overview_ms.format(100)
110
+ return MarkdownWidget(
111
+ name="inference_time_md",
112
+ title="Overview",
113
+ text=text,
114
+ )
115
+
116
+ @property
117
+ def fps_md(self) -> MarkdownWidget:
118
+ text = self.vis_texts.markdown_speedtest_overview_fps.format(100)
119
+ return MarkdownWidget(
120
+ name="fps_md",
121
+ title="FPS Table",
122
+ text=text,
123
+ )
124
+
125
+ @property
126
+ def fps_table(self) -> TableWidget:
127
+ data = {}
128
+ batch_sizes = set()
129
+ max_fps = 0
130
+ for i, eval_result in enumerate(self.eval_results, 1):
131
+ data[i] = {}
132
+ if eval_result.speedtest_info is None:
133
+ continue
134
+ speedtests = eval_result.speedtest_info["speedtest"]
135
+ for test in speedtests:
136
+ batch_size = test["batch_size"]
137
+ fps = round(1000 / test["benchmark"]["total"] * batch_size)
138
+ batch_sizes.add(batch_size)
139
+ max_fps = max(max_fps, fps)
140
+ data[i][batch_size] = fps
141
+
142
+ batch_sizes = sorted(batch_sizes)
143
+ columns = ["Model"]
144
+ columns_options = [{"disableSort": True}]
145
+ for batch_size in batch_sizes:
146
+ columns.append(f"Batch size {batch_size}")
147
+ columns_options.append(
148
+ {
149
+ "subtitle": "imgs/sec",
150
+ "tooltip": "Frames (images) per second",
151
+ "postfix": "fps",
152
+ # "maxValue": max_fps,
153
+ }
154
+ )
155
+
156
+ content = []
157
+ for i, eval_result in enumerate(self.eval_results, 1):
158
+ name = f"[{i}] {eval_result.name}"
159
+ row = [name]
160
+ for batch_size in batch_sizes:
161
+ if batch_size in data[i]:
162
+ row.append(data[i][batch_size])
163
+ else:
164
+ row.append("―")
165
+ content.append(
166
+ {
167
+ "row": row,
168
+ "id": name,
169
+ "items": row,
170
+ }
171
+ )
172
+ data = {
173
+ "columns": columns,
174
+ "columnsOptions": columns_options,
175
+ "content": content,
176
+ }
177
+ return TableWidget(
178
+ name="fps_table",
179
+ data=data,
180
+ show_header_controls=False,
181
+ fix_columns=1,
182
+ )
183
+
184
+ @property
185
+ def inference_time_table(self) -> TableWidget:
186
+ data = {}
187
+ batch_sizes = set()
188
+ for i, eval_result in enumerate(self.eval_results, 1):
189
+ data[i] = {}
190
+ if eval_result.speedtest_info is None:
191
+ continue
192
+ speedtests = eval_result.speedtest_info["speedtest"]
193
+ for test in speedtests:
194
+ batch_size = test["batch_size"]
195
+ ms = round(test["benchmark"]["total"], 2)
196
+ batch_sizes.add(batch_size)
197
+ data[i][batch_size] = ms
198
+
199
+ batch_sizes = sorted(batch_sizes)
200
+ columns = ["Model"]
201
+ columns_options = [{"disableSort": True}]
202
+ for batch_size in batch_sizes:
203
+ columns.extend([f"Batch size {batch_size}"])
204
+ columns_options.extend(
205
+ [
206
+ {"subtitle": "ms", "tooltip": "Milliseconds for batch images", "postfix": "ms"},
207
+ ]
208
+ )
209
+
210
+ content = []
211
+ for i, eval_result in enumerate(self.eval_results, 1):
212
+ name = f"[{i}] {eval_result.name}"
213
+ row = [name]
214
+ for batch_size in batch_sizes:
215
+ if batch_size in data[i]:
216
+ row.append(data[i][batch_size])
217
+ else:
218
+ row.append("―")
219
+ content.append(
220
+ {
221
+ "row": row,
222
+ "id": name,
223
+ "items": row,
224
+ }
225
+ )
226
+
227
+ data = {
228
+ "columns": columns,
229
+ "columnsOptions": columns_options,
230
+ "content": content,
231
+ }
232
+ return TableWidget(
233
+ name="inference_time_md",
234
+ data=data,
235
+ show_header_controls=False,
236
+ fix_columns=1,
237
+ )
238
+
239
+ @property
240
+ def batch_inference_md(self):
241
+ return MarkdownWidget(
242
+ name="batch_inference",
243
+ title="Batch Inference",
244
+ text=self.vis_texts.markdown_batch_inference,
245
+ )
246
+
247
+ @property
248
+ def chart(self) -> ChartWidget:
249
+ return ChartWidget(name="speed_charts", figure=self.get_figure())
250
+
251
+ def get_figure(self): # -> Optional[go.Figure]
252
+ import plotly.graph_objects as go # pylint: disable=import-error
253
+ from plotly.subplots import make_subplots # pylint: disable=import-error
254
+
255
+ fig = make_subplots(cols=2)
256
+
257
+ for idx, eval_result in enumerate(self.eval_results, 1):
258
+ if eval_result.speedtest_info is None:
259
+ continue
260
+ temp_res = {}
261
+ for test in eval_result.speedtest_info["speedtest"]:
262
+ batch_size = test["batch_size"]
263
+
264
+ std = test["benchmark_std"]["total"]
265
+ ms = test["benchmark"]["total"]
266
+ fps = round(1000 / test["benchmark"]["total"] * batch_size)
267
+
268
+ ms_line = temp_res.setdefault("ms", {})
269
+ fps_line = temp_res.setdefault("fps", {})
270
+ ms_std_line = temp_res.setdefault("ms_std", {})
271
+
272
+ ms_line[batch_size] = ms
273
+ fps_line[batch_size] = fps
274
+ ms_std_line[batch_size] = round(std, 2)
275
+
276
+ error_color = "rgba(" + ",".join(map(str, hex2rgb(eval_result.color))) + ", 0.5)"
277
+ fig.add_trace(
278
+ go.Scatter(
279
+ x=list(temp_res["ms"].keys()),
280
+ y=list(temp_res["ms"].values()),
281
+ name=f"[{idx}] {eval_result.name} (ms)",
282
+ line=dict(color=eval_result.color),
283
+ customdata=list(temp_res["ms_std"].values()),
284
+ error_y=dict(
285
+ type="data",
286
+ array=list(temp_res["ms_std"].values()),
287
+ visible=True,
288
+ color=error_color,
289
+ ),
290
+ hovertemplate="Batch Size: %{x}<br>Time: %{y:.2f} ms<br> Standard deviation: %{customdata:.2f} ms<extra></extra>",
291
+ ),
292
+ col=1,
293
+ row=1,
294
+ )
295
+ fig.add_trace(
296
+ go.Scatter(
297
+ x=list(temp_res["fps"].keys()),
298
+ y=list(temp_res["fps"].values()),
299
+ name=f"[{idx}] {eval_result.name} (fps)",
300
+ line=dict(color=eval_result.color),
301
+ hovertemplate="Batch Size: %{x}<br>FPS: %{y:.2f}<extra></extra>", # <br> Standard deviation: %{customdata:.2f}<extra></extra>",
302
+ ),
303
+ col=2,
304
+ row=1,
305
+ )
306
+
307
+ fig.update_xaxes(title_text="Batch size", col=1, dtick=1)
308
+ fig.update_xaxes(title_text="Batch size", col=2, dtick=1)
309
+
310
+ fig.update_yaxes(title_text="Time (ms)", col=1)
311
+ fig.update_yaxes(title_text="FPS", col=2)
312
+ fig.update_layout(height=400)
313
+
314
+ return fig