supervisely 6.73.214__py3-none-any.whl → 6.73.215__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (49) hide show
  1. supervisely/app/widgets/report_thumbnail/report_thumbnail.py +17 -5
  2. supervisely/app/widgets/team_files_selector/team_files_selector.py +3 -0
  3. supervisely/nn/benchmark/comparison/__init__.py +0 -0
  4. supervisely/nn/benchmark/comparison/detection_visualization/__init__.py +0 -0
  5. supervisely/nn/benchmark/comparison/detection_visualization/text_templates.py +437 -0
  6. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +27 -0
  7. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +125 -0
  8. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +224 -0
  9. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/explore_predicttions.py +112 -0
  10. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +161 -0
  11. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +336 -0
  12. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +249 -0
  13. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +142 -0
  14. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +300 -0
  15. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +308 -0
  16. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +19 -0
  17. supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +298 -0
  18. supervisely/nn/benchmark/comparison/model_comparison.py +84 -0
  19. supervisely/nn/benchmark/evaluation/coco/metric_provider.py +9 -7
  20. supervisely/nn/benchmark/visualization/evaluation_result.py +266 -0
  21. supervisely/nn/benchmark/visualization/renderer.py +100 -0
  22. supervisely/nn/benchmark/visualization/report_template.html +46 -0
  23. supervisely/nn/benchmark/visualization/visualizer.py +1 -1
  24. supervisely/nn/benchmark/visualization/widgets/__init__.py +17 -0
  25. supervisely/nn/benchmark/visualization/widgets/chart/__init__.py +0 -0
  26. supervisely/nn/benchmark/visualization/widgets/chart/chart.py +72 -0
  27. supervisely/nn/benchmark/visualization/widgets/chart/template.html +16 -0
  28. supervisely/nn/benchmark/visualization/widgets/collapse/__init__.py +0 -0
  29. supervisely/nn/benchmark/visualization/widgets/collapse/collapse.py +33 -0
  30. supervisely/nn/benchmark/visualization/widgets/container/__init__.py +0 -0
  31. supervisely/nn/benchmark/visualization/widgets/container/container.py +54 -0
  32. supervisely/nn/benchmark/visualization/widgets/gallery/__init__.py +0 -0
  33. supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +125 -0
  34. supervisely/nn/benchmark/visualization/widgets/gallery/template.html +49 -0
  35. supervisely/nn/benchmark/visualization/widgets/markdown/__init__.py +0 -0
  36. supervisely/nn/benchmark/visualization/widgets/markdown/markdown.py +53 -0
  37. supervisely/nn/benchmark/visualization/widgets/notification/__init__.py +0 -0
  38. supervisely/nn/benchmark/visualization/widgets/notification/notification.py +38 -0
  39. supervisely/nn/benchmark/visualization/widgets/sidebar/__init__.py +0 -0
  40. supervisely/nn/benchmark/visualization/widgets/sidebar/sidebar.py +67 -0
  41. supervisely/nn/benchmark/visualization/widgets/table/__init__.py +0 -0
  42. supervisely/nn/benchmark/visualization/widgets/table/table.py +116 -0
  43. supervisely/nn/benchmark/visualization/widgets/widget.py +22 -0
  44. {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/METADATA +1 -1
  45. {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/RECORD +49 -10
  46. {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/LICENSE +0 -0
  47. {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/WHEEL +0 -0
  48. {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/entry_points.txt +0 -0
  49. {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ import numpy as np
2
+
3
+ from supervisely.imaging.color import hex2rgb
4
+ from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
5
+ BaseVisMetric,
6
+ )
7
+ from supervisely.nn.benchmark.visualization.widgets import (
8
+ ChartWidget,
9
+ CollapseWidget,
10
+ MarkdownWidget,
11
+ NotificationWidget,
12
+ TableWidget,
13
+ )
14
+
15
+
16
+ class PrCurve(BaseVisMetric):
17
+ MARKDOWN_PR_CURVE = "markdown_pr_curve"
18
+ MARKDOWN_PR_TRADE_OFFS = "markdown_trade_offs"
19
+ MARKDOWN_WHAT_IS_PR_CURVE = "markdown_what_is_pr_curve"
20
+
21
+ @property
22
+ def markdown_widget(self) -> MarkdownWidget:
23
+ text: str = getattr(self.vis_texts, self.MARKDOWN_PR_CURVE).format(
24
+ self.vis_texts.definitions.f1_score
25
+ )
26
+ return MarkdownWidget(
27
+ name=self.MARKDOWN_PR_CURVE, title="mAP & Precision-Recall Curve", text=text
28
+ )
29
+
30
+ @property
31
+ def chart_widget(self) -> ChartWidget:
32
+ return ChartWidget(name="chart_pr_curve", figure=self.get_figure())
33
+
34
+ @property
35
+ def collapsed_widget(self) -> CollapseWidget:
36
+ text_pr_trade_offs = getattr(self.vis_texts, self.MARKDOWN_PR_TRADE_OFFS)
37
+ text_pr_curve = getattr(self.vis_texts, self.MARKDOWN_WHAT_IS_PR_CURVE).format(
38
+ self.vis_texts.definitions.confidence_score,
39
+ self.vis_texts.definitions.true_positives,
40
+ self.vis_texts.definitions.false_positives,
41
+ )
42
+ markdown_pr_trade_offs = MarkdownWidget(
43
+ name=self.MARKDOWN_PR_TRADE_OFFS,
44
+ title="About Trade-offs between precision and recall",
45
+ text=text_pr_trade_offs,
46
+ )
47
+ markdown_whatis_pr_curve = MarkdownWidget(
48
+ name=self.MARKDOWN_WHAT_IS_PR_CURVE,
49
+ title="How the PR curve is built?",
50
+ text=text_pr_curve,
51
+ )
52
+ return CollapseWidget(widgets=[markdown_pr_trade_offs, markdown_whatis_pr_curve])
53
+
54
+ @property
55
+ def table_widget(self) -> TableWidget:
56
+ res = {}
57
+
58
+ columns = [" ", "mAP (0.5:0.95)", "mAP (0.75)"]
59
+ res["content"] = []
60
+ for i, eval_result in enumerate(self.eval_results, 1):
61
+ value_range = round(eval_result.mp.json_metrics()["mAP"], 2)
62
+ value_75 = eval_result.mp.json_metrics()["AP75"] or "-"
63
+ value_75 = round(value_75, 2) if isinstance(value_75, float) else value_75
64
+ model_name = f"[{i}] {eval_result.name}"
65
+ row = [model_name, value_range, value_75]
66
+ dct = {
67
+ "row": row,
68
+ "id": model_name,
69
+ "items": row,
70
+ }
71
+ res["content"].append(dct)
72
+
73
+ columns_options = [
74
+ {"disableSort": True},
75
+ {"disableSort": True},
76
+ ]
77
+
78
+ res["columns"] = columns
79
+ res["columnsOptions"] = columns_options
80
+
81
+ return TableWidget(
82
+ name="table_pr_curve",
83
+ data=res,
84
+ show_header_controls=False,
85
+ # main_column=columns[0],
86
+ fix_columns=1,
87
+ )
88
+
89
+ def get_figure(self): # -> Optional[go.Figure]:
90
+ import plotly.express as px # pylint: disable=import-error
91
+ import plotly.graph_objects as go # pylint: disable=import-error
92
+
93
+ fig = go.Figure()
94
+
95
+ rec_thr = self.eval_results[0].mp.recThrs
96
+ for i, eval_result in enumerate(self.eval_results, 1):
97
+ pr_curve = eval_result.mp.pr_curve().copy()
98
+ pr_curve[pr_curve == -1] = np.nan
99
+ pr_curve = np.nanmean(pr_curve, axis=-1)
100
+
101
+ name = f"[{i}] {eval_result.name}"
102
+ color = ",".join(map(str, hex2rgb(eval_result.color))) + ",0.1"
103
+ line = go.Scatter(
104
+ x=eval_result.mp.recThrs,
105
+ y=pr_curve,
106
+ mode="lines",
107
+ name=name,
108
+ fill="tozeroy",
109
+ fillcolor=f"rgba({color})",
110
+ line=dict(color=eval_result.color),
111
+ hovertemplate=name + "<br>Recall: %{x:.2f}<br>Precision: %{y:.2f}<extra></extra>",
112
+ showlegend=True,
113
+ )
114
+ fig.add_trace(line)
115
+
116
+ fig.add_trace(
117
+ go.Scatter(
118
+ x=rec_thr,
119
+ y=[1] * len(rec_thr),
120
+ name="Perfect",
121
+ line=dict(color="orange", dash="dash"),
122
+ showlegend=True,
123
+ )
124
+ )
125
+ fig.update_layout(
126
+ dragmode=False,
127
+ modebar=dict(
128
+ remove=[
129
+ "zoom2d",
130
+ "pan2d",
131
+ "select2d",
132
+ "lasso2d",
133
+ "zoomIn2d",
134
+ "zoomOut2d",
135
+ "autoScale2d",
136
+ "resetScale2d",
137
+ ]
138
+ ),
139
+ xaxis_title="Recall",
140
+ yaxis_title="Precision",
141
+ )
142
+ return fig
@@ -0,0 +1,300 @@
1
+ from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
2
+ BaseVisMetric,
3
+ )
4
+ from supervisely.nn.benchmark.visualization.widgets import (
5
+ ChartWidget,
6
+ CollapseWidget,
7
+ MarkdownWidget,
8
+ NotificationWidget,
9
+ TableWidget,
10
+ )
11
+
12
+
13
+ class PrecisionRecallF1(BaseVisMetric):
14
+ MARKDOWN = "markdown_PRF1"
15
+ MARKDOWN_PRECISION_TITLE = "markdown_precision_per_class_title"
16
+ MARKDOWN_RECALL_TITLE = "markdown_recall_per_class_title"
17
+ MARKDOWN_F1_TITLE = "markdown_f1_per_class_title"
18
+
19
+ @property
20
+ def markdown_widget(self) -> MarkdownWidget:
21
+ text: str = getattr(self.vis_texts, self.MARKDOWN).format(
22
+ self.vis_texts.definitions.f1_score
23
+ )
24
+ return MarkdownWidget(name=self.MARKDOWN, title="Precision, Recall, F1-score", text=text)
25
+
26
+ @property
27
+ def precision_per_class_title_md(self) -> MarkdownWidget:
28
+ text: str = getattr(self.vis_texts, self.MARKDOWN_PRECISION_TITLE)
29
+ text += self.vis_texts.clickable_label
30
+ return MarkdownWidget(
31
+ name="markdown_precision_per_class", title="Precision by Class", text=text
32
+ )
33
+
34
+ @property
35
+ def recall_per_class_title_md(self) -> MarkdownWidget:
36
+ text: str = getattr(self.vis_texts, self.MARKDOWN_RECALL_TITLE)
37
+ text += self.vis_texts.clickable_label
38
+ return MarkdownWidget(name="markdown_recall_per_class", title="Recall by Class", text=text)
39
+
40
+ @property
41
+ def f1_per_class_title_md(self) -> MarkdownWidget:
42
+ text: str = getattr(self.vis_texts, self.MARKDOWN_F1_TITLE)
43
+ text += self.vis_texts.clickable_label
44
+ return MarkdownWidget(name="markdown_f1_per_class", title="F1-score by Class", text=text)
45
+
46
+ @property
47
+ def chart_main_widget(self) -> ChartWidget:
48
+ chart = ChartWidget(name="chart_PRF1", figure=self.get_main_figure())
49
+ chart.set_click_data(
50
+ gallery_id=self.explore_modal_table.id,
51
+ click_data=self.get_click_data_main(),
52
+ chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}`,",
53
+ )
54
+ return chart
55
+
56
+ @property
57
+ def chart_recall_per_class_widget(self) -> ChartWidget:
58
+ chart = ChartWidget(
59
+ name="chart_recall_per_class",
60
+ figure=self.get_recall_per_class_figure(),
61
+ )
62
+ chart.set_click_data(
63
+ gallery_id=self.explore_modal_table.id,
64
+ click_data=self.get_per_class_click_data(),
65
+ chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}${'_'}${payload.points[0].label}`,",
66
+ )
67
+ return chart
68
+
69
+ @property
70
+ def chart_precision_per_class_widget(self) -> ChartWidget:
71
+ chart = ChartWidget(
72
+ name="chart_precision_per_class",
73
+ figure=self.get_precision_per_class_figure(),
74
+ )
75
+ chart.set_click_data(
76
+ gallery_id=self.explore_modal_table.id,
77
+ click_data=self.get_per_class_click_data(),
78
+ chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}${'_'}${payload.points[0].label}`,",
79
+ )
80
+ return chart
81
+
82
+ @property
83
+ def chart_f1_per_class_widget(self) -> ChartWidget:
84
+ chart = ChartWidget(name="chart_f1_per_class", figure=self.get_f1_per_class_figure())
85
+ chart.set_click_data(
86
+ gallery_id=self.explore_modal_table.id,
87
+ click_data=self.get_per_class_click_data(),
88
+ chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}${'_'}${payload.points[0].label}`,",
89
+ )
90
+ return chart
91
+
92
+ @property
93
+ def table_widget(self) -> TableWidget:
94
+ res = {}
95
+
96
+ columns = [" ", "Precision", "Recall", "F1-score"]
97
+ res["content"] = []
98
+ for i, eval_result in enumerate(self.eval_results, 1):
99
+ precision = round(eval_result.mp.json_metrics()["precision"], 2)
100
+ recall = round(eval_result.mp.json_metrics()["recall"], 2)
101
+ f1 = round(eval_result.mp.json_metrics()["f1"], 2)
102
+ model_name = f"[{i}] {eval_result.name}"
103
+ row = [model_name, precision, recall, f1]
104
+ dct = {
105
+ "row": row,
106
+ "id": model_name,
107
+ "items": row,
108
+ }
109
+ res["content"].append(dct)
110
+
111
+ columns_options = [
112
+ {"disableSort": True},
113
+ {"disableSort": True},
114
+ {"disableSort": True},
115
+ {"disableSort": True},
116
+ ]
117
+
118
+ res["columns"] = columns
119
+ res["columnsOptions"] = columns_options
120
+
121
+ return TableWidget(
122
+ name="table_precision_recall_f1",
123
+ data=res,
124
+ show_header_controls=False,
125
+ # main_column=columns[0],
126
+ fix_columns=1,
127
+ )
128
+
129
+ def get_main_figure(self): # -> Optional[go.Figure]:
130
+ import plotly.graph_objects as go # pylint: disable=import-error
131
+
132
+ fig = go.Figure()
133
+
134
+ classes_cnt = len(self.eval_results[0].mp.cat_names)
135
+ for i, eval_result in enumerate(self.eval_results, 1):
136
+ precision = eval_result.mp.json_metrics()["precision"]
137
+ recall = eval_result.mp.json_metrics()["recall"]
138
+ f1 = eval_result.mp.json_metrics()["f1"]
139
+ model_name = f"[{i}] {eval_result.name}"
140
+ fig.add_trace(
141
+ go.Bar(
142
+ x=["Precision", "Recall", "F1-score"],
143
+ y=[precision, recall, f1],
144
+ name=model_name,
145
+ width=0.2 if classes_cnt >= 5 else None,
146
+ marker=dict(color=eval_result.color),
147
+ )
148
+ )
149
+
150
+ fig.update_layout(
151
+ barmode="group",
152
+ xaxis_title="Metric",
153
+ yaxis_title="Value",
154
+ yaxis=dict(range=[0, 1.1]),
155
+ width=700 if classes_cnt < 5 else None,
156
+ )
157
+
158
+ return fig
159
+
160
+ def get_recall_per_class_figure(self):
161
+ import plotly.graph_objects as go # pylint: disable=import-error
162
+
163
+ fig = go.Figure()
164
+ classes_cnt = len(self.eval_results[0].mp.cat_names)
165
+ for i, eval_result in enumerate(self.eval_results, 1):
166
+ model_name = f"[{i}] {eval_result.name}"
167
+ sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
168
+
169
+ fig.add_trace(
170
+ go.Bar(
171
+ y=sorted_by_f1["recall"],
172
+ x=sorted_by_f1["category"],
173
+ name=f"{model_name} Recall",
174
+ width=0.2 if classes_cnt < 5 else None,
175
+ marker=dict(color=eval_result.color),
176
+ )
177
+ )
178
+
179
+ fig.update_layout(
180
+ barmode="group",
181
+ bargap=0.15,
182
+ bargroupgap=0.05,
183
+ width=700 if classes_cnt < 5 else None,
184
+ )
185
+ fig.update_xaxes(title_text="Class")
186
+ fig.update_yaxes(title_text="Recall", range=[0, 1])
187
+ return fig
188
+
189
+ def get_per_class_click_data(self):
190
+ res = {}
191
+ res["layoutTemplate"] = [None, None, None]
192
+ res["clickData"] = {}
193
+ for i, eval_result in enumerate(self.eval_results):
194
+ model_name = f"Model [{i + 1}] {eval_result.name}"
195
+ for key, v in eval_result.click_data.objects_by_class.items():
196
+ click_data = res["clickData"].setdefault(f"{i}_{key}", {})
197
+ img_ids, obj_ids = set(), set()
198
+ title = f"{model_name}. Class {key}: {len(v)} object{'s' if len(v) > 1 else ''}"
199
+ click_data["title"] = title
200
+
201
+ for x in v:
202
+ img_ids.add(x["dt_img_id"])
203
+ obj_ids.add(x["dt_obj_id"])
204
+
205
+ click_data["imagesIds"] = list(img_ids)
206
+ click_data["filters"] = [
207
+ {
208
+ "type": "tag",
209
+ "tagId": "confidence",
210
+ "value": [eval_result.f1_optimal_conf, 1],
211
+ },
212
+ {"type": "tag", "tagId": "outcome", "value": "TP"},
213
+ {"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
214
+ ]
215
+ return res
216
+
217
+ def get_precision_per_class_figure(self):
218
+ import plotly.graph_objects as go # pylint: disable=import-error
219
+
220
+ fig = go.Figure()
221
+ classes_cnt = len(self.eval_results[0].mp.cat_names)
222
+ for i, eval_result in enumerate(self.eval_results, 1):
223
+ model_name = f"[{i}] {eval_result.name}"
224
+ sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
225
+
226
+ fig.add_trace(
227
+ go.Bar(
228
+ y=sorted_by_f1["precision"],
229
+ x=sorted_by_f1["category"],
230
+ name=f"{model_name} Precision",
231
+ width=0.2 if classes_cnt < 5 else None,
232
+ marker=dict(color=eval_result.color),
233
+ )
234
+ )
235
+
236
+ fig.update_layout(
237
+ barmode="group",
238
+ bargap=0.15,
239
+ bargroupgap=0.05,
240
+ width=700 if classes_cnt < 5 else None,
241
+ )
242
+ fig.update_xaxes(title_text="Class")
243
+ fig.update_yaxes(title_text="Precision", range=[0, 1])
244
+ return fig
245
+
246
+ def get_f1_per_class_figure(self):
247
+ import plotly.graph_objects as go # pylint: disable=import-error
248
+
249
+ fig = go.Figure()
250
+ classes_cnt = len(self.eval_results[0].mp.cat_names)
251
+ for i, eval_result in enumerate(self.eval_results, 1):
252
+ model_name = f"[{i}] {eval_result.name}"
253
+ sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
254
+
255
+ fig.add_trace(
256
+ go.Bar(
257
+ y=sorted_by_f1["f1"],
258
+ x=sorted_by_f1["category"],
259
+ name=f"{model_name} F1-score",
260
+ width=0.2 if classes_cnt < 5 else None,
261
+ marker=dict(color=eval_result.color),
262
+ )
263
+ )
264
+
265
+ fig.update_layout(
266
+ barmode="group",
267
+ bargap=0.15,
268
+ bargroupgap=0.05,
269
+ width=700 if classes_cnt < 5 else None,
270
+ )
271
+ fig.update_xaxes(title_text="Class")
272
+ fig.update_yaxes(title_text="F1-score", range=[0, 1])
273
+ return fig
274
+
275
+ def get_click_data_main(self):
276
+ res = {}
277
+ res["layoutTemplate"] = [None, None, None]
278
+ res["clickData"] = {}
279
+
280
+ for i, eval_result in enumerate(self.eval_results):
281
+ model_name = f"Model [{i + 1}] {eval_result.name}"
282
+ click_data = res["clickData"].setdefault(i, {})
283
+ img_ids, obj_ids = set(), set()
284
+ objects_cnt = 0
285
+ for outcome, matched_obj in eval_result.click_data.outcome_counts.items():
286
+ if outcome == "TP": # TODO: check if this is correct
287
+ objects_cnt += len(matched_obj)
288
+ for x in matched_obj:
289
+ img_ids.add(x["dt_img_id"])
290
+ obj_ids.add(x["dt_obj_id"])
291
+
292
+ click_data["title"] = f"{model_name}, {objects_cnt} objects"
293
+ click_data["imagesIds"] = list(img_ids)
294
+ click_data["filters"] = [
295
+ {"type": "tag", "tagId": "confidence", "value": [eval_result.f1_optimal_conf, 1]},
296
+ {"type": "tag", "tagId": "outcome", "value": "TP"},
297
+ {"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
298
+ ]
299
+
300
+ return res