supervisely 6.73.214__py3-none-any.whl → 6.73.216__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/app/fastapi/templating.py +1 -1
- supervisely/app/widgets/report_thumbnail/report_thumbnail.py +17 -5
- supervisely/app/widgets/team_files_selector/team_files_selector.py +3 -0
- supervisely/nn/artifacts/__init__.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +32 -0
- supervisely/nn/benchmark/comparison/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/detection_visualization/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/detection_visualization/text_templates.py +437 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +27 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +125 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +224 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/explore_predicttions.py +112 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +161 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +336 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +249 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +142 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +300 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +308 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +19 -0
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +298 -0
- supervisely/nn/benchmark/comparison/model_comparison.py +84 -0
- supervisely/nn/benchmark/evaluation/coco/metric_provider.py +9 -7
- supervisely/nn/benchmark/visualization/evaluation_result.py +266 -0
- supervisely/nn/benchmark/visualization/renderer.py +100 -0
- supervisely/nn/benchmark/visualization/report_template.html +46 -0
- supervisely/nn/benchmark/visualization/visualizer.py +1 -1
- supervisely/nn/benchmark/visualization/widgets/__init__.py +17 -0
- supervisely/nn/benchmark/visualization/widgets/chart/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/chart/chart.py +72 -0
- supervisely/nn/benchmark/visualization/widgets/chart/template.html +16 -0
- supervisely/nn/benchmark/visualization/widgets/collapse/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/collapse/collapse.py +33 -0
- supervisely/nn/benchmark/visualization/widgets/container/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/container/container.py +54 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +125 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/template.html +49 -0
- supervisely/nn/benchmark/visualization/widgets/markdown/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/markdown/markdown.py +53 -0
- supervisely/nn/benchmark/visualization/widgets/notification/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/notification/notification.py +38 -0
- supervisely/nn/benchmark/visualization/widgets/sidebar/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/sidebar/sidebar.py +67 -0
- supervisely/nn/benchmark/visualization/widgets/table/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/table/table.py +116 -0
- supervisely/nn/benchmark/visualization/widgets/widget.py +22 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.216.dist-info}/METADATA +1 -1
- {supervisely-6.73.214.dist-info → supervisely-6.73.216.dist-info}/RECORD +52 -12
- {supervisely-6.73.214.dist-info → supervisely-6.73.216.dist-info}/LICENSE +0 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.216.dist-info}/WHEEL +0 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.216.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.216.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
|
|
7
|
+
BaseVisMetric,
|
|
8
|
+
)
|
|
9
|
+
from supervisely.nn.benchmark.visualization.widgets import ChartWidget
|
|
10
|
+
from supervisely.nn.task_type import TaskType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OutcomeCounts(BaseVisMetric):
|
|
14
|
+
CHART_MAIN = "chart_outcome_counts"
|
|
15
|
+
CHART_COMPARISON = "chart_outcome_counts_comparison"
|
|
16
|
+
|
|
17
|
+
def __init__(self, *args, **kwargs):
|
|
18
|
+
super().__init__(*args, **kwargs)
|
|
19
|
+
|
|
20
|
+
self.imgIds_to_anns = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
|
21
|
+
self.coco_to_sly_ids = defaultdict(lambda: defaultdict(lambda: defaultdict(tuple)))
|
|
22
|
+
|
|
23
|
+
self._initialize_ids_mapping()
|
|
24
|
+
|
|
25
|
+
self.common_and_diff_tp = self._find_common_and_diff_tp()
|
|
26
|
+
self.common_and_diff_fn = self._find_common_and_diff_fn()
|
|
27
|
+
self.common_and_diff_fp = self._find_common_and_diff_fp()
|
|
28
|
+
|
|
29
|
+
def _initialize_ids_mapping(self):
|
|
30
|
+
for idx, r in enumerate(self.eval_results):
|
|
31
|
+
l = {
|
|
32
|
+
"TP": (r.mp.cocoDt.anns, r.mp.m.tp_matches, r.click_data.outcome_counts),
|
|
33
|
+
"FN": (r.mp.cocoGt.anns, r.mp.m.fn_matches, r.click_data.outcome_counts),
|
|
34
|
+
"FP": (r.mp.cocoDt.anns, r.mp.m.fp_matches, r.click_data.outcome_counts),
|
|
35
|
+
}
|
|
36
|
+
for outcome, (coco_anns, matches_data, sly_data) in l.items():
|
|
37
|
+
for m, sly_m in zip(matches_data, sly_data[outcome]):
|
|
38
|
+
key = m["dt_id"] if outcome != "FN" else m["gt_id"]
|
|
39
|
+
gt_key = m["gt_id"] if outcome != "FP" else m["dt_id"]
|
|
40
|
+
ann = coco_anns[key]
|
|
41
|
+
self.imgIds_to_anns[idx][outcome][key].append(ann)
|
|
42
|
+
self.coco_to_sly_ids[idx][outcome][gt_key] = sly_m
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def chart_widget_main(self) -> ChartWidget:
|
|
46
|
+
chart = ChartWidget(name=self.CHART_MAIN, figure=self.get_main_figure())
|
|
47
|
+
chart.set_click_data(
|
|
48
|
+
gallery_id=self.explore_modal_table.id,
|
|
49
|
+
click_data=self.get_main_click_data(),
|
|
50
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].y}${'_'}${payload.points[0].data.name}`,",
|
|
51
|
+
)
|
|
52
|
+
return chart
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def chart_widget_comparison(self) -> ChartWidget:
|
|
56
|
+
chart = ChartWidget(
|
|
57
|
+
name=self.CHART_COMPARISON,
|
|
58
|
+
figure=self.get_comparison_figure(),
|
|
59
|
+
)
|
|
60
|
+
chart.set_click_data(
|
|
61
|
+
gallery_id=self.explore_modal_table.id,
|
|
62
|
+
click_data=self.get_comparison_click_data(),
|
|
63
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].y}${'_'}${payload.points[0].data.name}`,",
|
|
64
|
+
)
|
|
65
|
+
return chart
|
|
66
|
+
|
|
67
|
+
def _update_figure_layout(self, fig):
|
|
68
|
+
fig.update_layout(
|
|
69
|
+
barmode="stack",
|
|
70
|
+
width=600,
|
|
71
|
+
height=300,
|
|
72
|
+
)
|
|
73
|
+
fig.update_xaxes(title_text="Count (objects)")
|
|
74
|
+
fig.update_yaxes(tickangle=-90)
|
|
75
|
+
|
|
76
|
+
fig.update_layout(
|
|
77
|
+
dragmode=False,
|
|
78
|
+
modebar=dict(
|
|
79
|
+
remove=[
|
|
80
|
+
"zoom2d",
|
|
81
|
+
"pan2d",
|
|
82
|
+
"select2d",
|
|
83
|
+
"lasso2d",
|
|
84
|
+
"zoomIn2d",
|
|
85
|
+
"zoomOut2d",
|
|
86
|
+
"autoScale2d",
|
|
87
|
+
"resetScale2d",
|
|
88
|
+
]
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
return fig
|
|
92
|
+
|
|
93
|
+
def get_main_figure(self): # -> Optional[go.Figure]
|
|
94
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
95
|
+
|
|
96
|
+
fig = go.Figure()
|
|
97
|
+
tp_counts = [eval_result.mp.TP_count for eval_result in self.eval_results][::-1]
|
|
98
|
+
fn_counts = [eval_result.mp.FN_count for eval_result in self.eval_results][::-1]
|
|
99
|
+
fp_counts = [eval_result.mp.FP_count for eval_result in self.eval_results][::-1]
|
|
100
|
+
model_names = [f"Model {idx}" for idx in range(1, len(self.eval_results) + 1)][::-1]
|
|
101
|
+
counts = [tp_counts, fn_counts, fp_counts]
|
|
102
|
+
names = ["TP", "FN", "FP"]
|
|
103
|
+
colors = ["#8ACAA1", "#dd3f3f", "#F7ADAA"]
|
|
104
|
+
|
|
105
|
+
for metric, values, color in zip(names, counts, colors):
|
|
106
|
+
fig.add_trace(
|
|
107
|
+
go.Bar(
|
|
108
|
+
x=values,
|
|
109
|
+
y=model_names,
|
|
110
|
+
name=metric,
|
|
111
|
+
orientation="h",
|
|
112
|
+
marker=dict(color=color),
|
|
113
|
+
hovertemplate=f"{metric}: %{{x}} objects<extra></extra>",
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
fig = self._update_figure_layout(fig)
|
|
118
|
+
return fig
|
|
119
|
+
|
|
120
|
+
def get_comparison_figure(self): # -> Optional[go.Figure]
|
|
121
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
122
|
+
|
|
123
|
+
fig = go.Figure()
|
|
124
|
+
|
|
125
|
+
colors = ["#8ACAA1", "#dd3f3f", "#F7ADAA"]
|
|
126
|
+
model_names = [f"Model {idx}" for idx in range(1, len(self.eval_results) + 1)][::-1]
|
|
127
|
+
model_names.append("Common")
|
|
128
|
+
|
|
129
|
+
diff_tps, common_tps = self.common_and_diff_tp
|
|
130
|
+
diff_fns, common_fns = self.common_and_diff_fn
|
|
131
|
+
diff_fps, common_fps = self.common_and_diff_fp
|
|
132
|
+
tps_cnt = [len(x) for x in diff_tps[::-1]] + [len(common_tps)]
|
|
133
|
+
fns_cnt = [len(x) for x in diff_fns[::-1]] + [len(common_fns)]
|
|
134
|
+
fps_cnt = [len(x) for x in diff_fps[::-1]] + [len(common_fps)]
|
|
135
|
+
|
|
136
|
+
for metric, values, color in zip(["TP", "FN", "FP"], [tps_cnt, fns_cnt, fps_cnt], colors):
|
|
137
|
+
fig.add_trace(
|
|
138
|
+
go.Bar(
|
|
139
|
+
x=values,
|
|
140
|
+
y=model_names,
|
|
141
|
+
name=metric,
|
|
142
|
+
orientation="h",
|
|
143
|
+
marker=dict(color=color),
|
|
144
|
+
hovertemplate=f"{metric}: %{{x}} objects<extra></extra>",
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
fig = self._update_figure_layout(fig)
|
|
149
|
+
return fig
|
|
150
|
+
|
|
151
|
+
def _find_common_and_diff_fn(self) -> List[int]:
|
|
152
|
+
ids = [
|
|
153
|
+
dict([(x["gt_obj_id"], x["gt_img_id"]) for x in r.click_data.outcome_counts["FN"]])
|
|
154
|
+
for r in self.eval_results
|
|
155
|
+
]
|
|
156
|
+
same = set.intersection(*[set(x.keys()) for x in ids])
|
|
157
|
+
diffs = [set(x.keys()) - same for x in ids]
|
|
158
|
+
|
|
159
|
+
same = {i: ids[0][i] for i in same}
|
|
160
|
+
diffs = [{i: s[i] for i in d} for s, d in zip(ids, diffs)]
|
|
161
|
+
|
|
162
|
+
return diffs, same
|
|
163
|
+
|
|
164
|
+
def _get_coco_key_name(self):
|
|
165
|
+
task_type_to_key_name = {
|
|
166
|
+
TaskType.OBJECT_DETECTION: "bbox",
|
|
167
|
+
TaskType.INSTANCE_SEGMENTATION: "segmentation",
|
|
168
|
+
TaskType.SEMANTIC_SEGMENTATION: "segmentation",
|
|
169
|
+
}
|
|
170
|
+
key_name = task_type_to_key_name.get(self.eval_results[0].cv_task)
|
|
171
|
+
if key_name is None:
|
|
172
|
+
raise NotImplementedError("Not implemented for this task type")
|
|
173
|
+
return key_name
|
|
174
|
+
|
|
175
|
+
def _find_common_and_diff_fp(self) -> List[int]:
|
|
176
|
+
from pycocotools import mask as maskUtils # pylint: disable=import-error
|
|
177
|
+
|
|
178
|
+
iouThr = 0.75
|
|
179
|
+
key_name = self._get_coco_key_name()
|
|
180
|
+
|
|
181
|
+
imgIds_to_anns = [self.imgIds_to_anns[idx]["FP"] for idx in range(len(self.eval_results))]
|
|
182
|
+
sly_ids_list = [
|
|
183
|
+
{x["dt_obj_id"]: x["dt_img_id"] for x in r.click_data.outcome_counts["FP"]}
|
|
184
|
+
for r in self.eval_results
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
same_fp_matches = []
|
|
188
|
+
for img_id in imgIds_to_anns[0]:
|
|
189
|
+
anns_list = [imgIds[img_id] for imgIds in imgIds_to_anns]
|
|
190
|
+
geoms_list = [[x[key_name] for x in anns] for anns in anns_list]
|
|
191
|
+
|
|
192
|
+
if any(len(geoms) == 0 for geoms in geoms_list):
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
ious_list = [
|
|
196
|
+
maskUtils.iou(geoms_list[0], geoms, [0] * len(geoms)) for geoms in geoms_list[1:]
|
|
197
|
+
]
|
|
198
|
+
if any(len(ious) == 0 for ious in ious_list):
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
indxs_list = [np.nonzero(ious > iouThr) for ious in ious_list]
|
|
202
|
+
if any(len(indxs[0]) == 0 for indxs in indxs_list):
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
indxs_list = [list(zip(*indxs)) for indxs in indxs_list]
|
|
206
|
+
indxs_list = [
|
|
207
|
+
sorted(indxs, key=lambda x: ious[x[0], x[1]], reverse=True)
|
|
208
|
+
for indxs, ious in zip(indxs_list, ious_list)
|
|
209
|
+
]
|
|
210
|
+
|
|
211
|
+
id_sets = [set(idxs[0]) for idxs in indxs_list]
|
|
212
|
+
common_ids = set.intersection(*id_sets)
|
|
213
|
+
if not common_ids:
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
for i, j in indxs_list[0]:
|
|
217
|
+
if i in common_ids:
|
|
218
|
+
same_fp_matches.append((anns_list[0][i], [anns[j] for anns in anns_list[1:]]))
|
|
219
|
+
common_ids.remove(i)
|
|
220
|
+
|
|
221
|
+
# Find different FP matches for each model
|
|
222
|
+
same_fp_ids = set(x[0]["id"] for x in same_fp_matches)
|
|
223
|
+
diff_fp_matches = [
|
|
224
|
+
set([x["dt_id"] for x in eval_result.mp.m.fp_matches]) - same_fp_ids
|
|
225
|
+
for eval_result in self.eval_results
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
diff_fp_matches_dicts = []
|
|
229
|
+
for idx, diff_fp in enumerate(diff_fp_matches):
|
|
230
|
+
diff_fp_dict = {}
|
|
231
|
+
for x in diff_fp:
|
|
232
|
+
obj_id = self.coco_to_sly_ids[idx]["FP"][x]["dt_obj_id"]
|
|
233
|
+
img_id = sly_ids_list[idx][obj_id]
|
|
234
|
+
diff_fp_dict[obj_id] = img_id
|
|
235
|
+
diff_fp_matches_dicts.append(diff_fp_dict)
|
|
236
|
+
|
|
237
|
+
same_fp_matches_dict = {}
|
|
238
|
+
for x in same_fp_matches:
|
|
239
|
+
obj_id = self.coco_to_sly_ids[0]["FP"][x[0]["id"]]["dt_obj_id"]
|
|
240
|
+
img_id = sly_ids_list[0][obj_id]
|
|
241
|
+
same_fp_matches_dict[obj_id] = img_id
|
|
242
|
+
|
|
243
|
+
return diff_fp_matches_dicts, same_fp_matches_dict
|
|
244
|
+
|
|
245
|
+
def _find_common_and_diff_tp(self) -> tuple:
|
|
246
|
+
|
|
247
|
+
ids = [
|
|
248
|
+
dict([(x["gt_obj_id"], x) for x in r.click_data.outcome_counts["TP"]])
|
|
249
|
+
for r in self.eval_results
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
same = set.intersection(*[set(x.keys()) for x in ids])
|
|
253
|
+
diffs = [set(x.keys()) - same for x in ids]
|
|
254
|
+
|
|
255
|
+
same = {s["dt_obj_id"]: s["dt_img_id"] for s in [ids[0][i] for i in same]}
|
|
256
|
+
diffs = [{s[i]["dt_obj_id"]: s[i]["dt_img_id"] for i in d} for s, d in zip(ids, diffs)]
|
|
257
|
+
|
|
258
|
+
return diffs, same
|
|
259
|
+
|
|
260
|
+
def get_main_click_data(self):
|
|
261
|
+
res = {}
|
|
262
|
+
|
|
263
|
+
res["layoutTemplate"] = [None, None, None]
|
|
264
|
+
res["clickData"] = {}
|
|
265
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
266
|
+
model_name = f"Model {i}"
|
|
267
|
+
for outcome, matches_data in eval_result.click_data.outcome_counts.items():
|
|
268
|
+
key = f"{model_name}_{outcome}"
|
|
269
|
+
outcome_dict = res["clickData"].setdefault(key, {})
|
|
270
|
+
outcome_dict["imagesIds"] = []
|
|
271
|
+
|
|
272
|
+
img_ids = set()
|
|
273
|
+
obj_ids = set()
|
|
274
|
+
for x in matches_data:
|
|
275
|
+
img_ids.add(x["dt_img_id"] if outcome != "FN" else x["gt_img_id"])
|
|
276
|
+
obj_ids.add(x["dt_obj_id"] if outcome != "FN" else x["gt_obj_id"])
|
|
277
|
+
|
|
278
|
+
title = f"{model_name}. {outcome}: {len(obj_ids)} object{'s' if len(obj_ids) > 1 else ''}"
|
|
279
|
+
outcome_dict["title"] = title
|
|
280
|
+
outcome_dict["imagesIds"] = list(img_ids)
|
|
281
|
+
thr = eval_result.f1_optimal_conf
|
|
282
|
+
if outcome == "FN":
|
|
283
|
+
outcome_dict["filters"] = [
|
|
284
|
+
{"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
|
|
285
|
+
]
|
|
286
|
+
else:
|
|
287
|
+
outcome_dict["filters"] = [
|
|
288
|
+
{"type": "tag", "tagId": "outcome", "value": outcome},
|
|
289
|
+
{"type": "tag", "tagId": "confidence", "value": [thr, 1]},
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
return res
|
|
293
|
+
|
|
294
|
+
def get_comparison_click_data(self):
|
|
295
|
+
res = {}
|
|
296
|
+
|
|
297
|
+
res["layoutTemplate"] = [None, None, None]
|
|
298
|
+
|
|
299
|
+
res["clickData"] = {}
|
|
300
|
+
|
|
301
|
+
outcomes_ids = {
|
|
302
|
+
"TP": self.common_and_diff_tp,
|
|
303
|
+
"FN": self.common_and_diff_fn,
|
|
304
|
+
"FP": self.common_and_diff_fp,
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
def _update_outcome_dict(title, outcome, outcome_dict, ids):
|
|
308
|
+
img_ids = set()
|
|
309
|
+
obj_ids = set()
|
|
310
|
+
for obj_id, img_id in ids.items():
|
|
311
|
+
img_ids.add(img_id)
|
|
312
|
+
obj_ids.add(obj_id)
|
|
313
|
+
|
|
314
|
+
title = f"{title}. {outcome}: {len(obj_ids)} object{'s' if len(obj_ids) > 1 else ''}"
|
|
315
|
+
outcome_dict["title"] = title
|
|
316
|
+
outcome_dict["imagesIds"] = list(img_ids)
|
|
317
|
+
filters = outcome_dict.setdefault("filters", [])
|
|
318
|
+
filters.append({"type": "specific_objects", "tagId": None, "value": list(obj_ids)})
|
|
319
|
+
if outcome != "FN":
|
|
320
|
+
filters.append({"type": "tag", "tagId": "confidence", "value": [0, 1]})
|
|
321
|
+
filters.append({"type": "tag", "tagId": "outcome", "value": outcome})
|
|
322
|
+
|
|
323
|
+
for outcome, (diff_ids, common_ids) in outcomes_ids.items():
|
|
324
|
+
key = f"Common_{outcome}"
|
|
325
|
+
outcome_dict = res["clickData"].setdefault(key, {})
|
|
326
|
+
|
|
327
|
+
_update_outcome_dict("Common", outcome, outcome_dict, common_ids)
|
|
328
|
+
|
|
329
|
+
for i, diff_ids in enumerate(diff_ids, 1):
|
|
330
|
+
name = f"Model {i}"
|
|
331
|
+
key = f"{name}_{outcome}"
|
|
332
|
+
outcome_dict = res["clickData"].setdefault(key, {})
|
|
333
|
+
|
|
334
|
+
_update_outcome_dict(name, outcome, outcome_dict, diff_ids)
|
|
335
|
+
|
|
336
|
+
return res
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from supervisely._utils import abs_url
|
|
4
|
+
from supervisely.nn.benchmark.visualization.evaluation_result import EvalResult
|
|
5
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
|
|
6
|
+
BaseVisMetric,
|
|
7
|
+
)
|
|
8
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
9
|
+
ChartWidget,
|
|
10
|
+
MarkdownWidget,
|
|
11
|
+
TableWidget,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Overview(BaseVisMetric):
|
|
16
|
+
|
|
17
|
+
MARKDOWN_OVERVIEW = "markdown_overview"
|
|
18
|
+
MARKDOWN_OVERVIEW_INFO = "markdown_overview_info"
|
|
19
|
+
MARKDOWN_COMMON_OVERVIEW = "markdown_common_overview"
|
|
20
|
+
CHART = "chart_key_metrics"
|
|
21
|
+
|
|
22
|
+
def __init__(self, vis_texts, eval_results: List[EvalResult]) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Class to create widgets for the overview block
|
|
25
|
+
overview_widgets property returns list of MarkdownWidget with information about the model
|
|
26
|
+
chart_widget property returns ChartWidget with Scatterpolar chart of the base metrics with each
|
|
27
|
+
evaluation result metrics displayed
|
|
28
|
+
"""
|
|
29
|
+
super().__init__(vis_texts, eval_results)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def overview_md(self) -> List[MarkdownWidget]:
|
|
33
|
+
info = []
|
|
34
|
+
model_names = []
|
|
35
|
+
for eval_result in self.eval_results:
|
|
36
|
+
model_name = eval_result.inference_info.get("model_name") or "Custom"
|
|
37
|
+
model_name = model_name.replace("_", "\_")
|
|
38
|
+
model_names.append(model_name)
|
|
39
|
+
|
|
40
|
+
info.append(
|
|
41
|
+
[
|
|
42
|
+
eval_result.gt_project_info.id,
|
|
43
|
+
eval_result.gt_project_info.name,
|
|
44
|
+
eval_result.inference_info.get("task_type"),
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
if all([model_name == "Custom" for model_name in model_names]):
|
|
48
|
+
model_name = "Custom models"
|
|
49
|
+
elif all([model_name == model_names[0] for model_name in model_names]):
|
|
50
|
+
model_name = model_names[0]
|
|
51
|
+
else:
|
|
52
|
+
model_name = " vs. ".join(model_names)
|
|
53
|
+
|
|
54
|
+
info = [model_name] + info[0]
|
|
55
|
+
|
|
56
|
+
text_template: str = getattr(self.vis_texts, self.MARKDOWN_COMMON_OVERVIEW)
|
|
57
|
+
return MarkdownWidget(
|
|
58
|
+
name=self.MARKDOWN_COMMON_OVERVIEW,
|
|
59
|
+
title="Overview",
|
|
60
|
+
text=text_template.format(*info),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def overview_widgets(self) -> List[MarkdownWidget]:
|
|
65
|
+
self.formats = []
|
|
66
|
+
for eval_result in self.eval_results:
|
|
67
|
+
|
|
68
|
+
url = eval_result.inference_info.get("checkpoint_url")
|
|
69
|
+
link_text = eval_result.inference_info.get("custom_checkpoint_path")
|
|
70
|
+
if link_text is None:
|
|
71
|
+
link_text = url
|
|
72
|
+
link_text = link_text.replace("_", "\_")
|
|
73
|
+
|
|
74
|
+
checkpoint_name = eval_result.inference_info.get("deploy_params", {}).get(
|
|
75
|
+
"checkpoint_name", ""
|
|
76
|
+
)
|
|
77
|
+
model_name = eval_result.inference_info.get("model_name") or "Custom"
|
|
78
|
+
|
|
79
|
+
report = eval_result.api.file.get_info_by_path(
|
|
80
|
+
eval_result.team_id, eval_result.report_path
|
|
81
|
+
)
|
|
82
|
+
report_link = abs_url(f"/model-benchmark?id={report.id}")
|
|
83
|
+
|
|
84
|
+
formats = [
|
|
85
|
+
checkpoint_name,
|
|
86
|
+
model_name.replace("_", "\_"),
|
|
87
|
+
checkpoint_name.replace("_", "\_"),
|
|
88
|
+
eval_result.inference_info.get("architecture"),
|
|
89
|
+
eval_result.inference_info.get("runtime"),
|
|
90
|
+
url,
|
|
91
|
+
link_text,
|
|
92
|
+
report_link,
|
|
93
|
+
]
|
|
94
|
+
self.formats.append(formats)
|
|
95
|
+
|
|
96
|
+
text_template: str = getattr(self.vis_texts, self.MARKDOWN_OVERVIEW_INFO)
|
|
97
|
+
widgets = []
|
|
98
|
+
for formats in self.formats:
|
|
99
|
+
md = MarkdownWidget(
|
|
100
|
+
name=self.MARKDOWN_OVERVIEW_INFO,
|
|
101
|
+
title="Overview",
|
|
102
|
+
text=text_template.format(*formats),
|
|
103
|
+
)
|
|
104
|
+
md.is_info_block = True
|
|
105
|
+
widgets.append(md)
|
|
106
|
+
return widgets
|
|
107
|
+
|
|
108
|
+
def get_table_widget(self, latency, fps) -> TableWidget:
|
|
109
|
+
res = {}
|
|
110
|
+
|
|
111
|
+
metric_renames_map = {"f1": "F1-score"}
|
|
112
|
+
|
|
113
|
+
columns = ["metrics"] + [f"[{i+1}] {r.name}" for i, r in enumerate(self.eval_results)]
|
|
114
|
+
|
|
115
|
+
all_metrics = [eval_result.mp.metric_table() for eval_result in self.eval_results]
|
|
116
|
+
res["content"] = []
|
|
117
|
+
|
|
118
|
+
same_iou_thr = False
|
|
119
|
+
if len(set([r.mp.iou_threshold for r in self.eval_results])) == 1:
|
|
120
|
+
if self.eval_results[0].mp.iou_threshold is not None:
|
|
121
|
+
same_iou_thr = True
|
|
122
|
+
|
|
123
|
+
for idx, metric in enumerate(all_metrics[0].keys()):
|
|
124
|
+
if idx == 3 and not same_iou_thr:
|
|
125
|
+
continue
|
|
126
|
+
metric_name = metric_renames_map.get(metric, metric)
|
|
127
|
+
values = [m[metric] for m in all_metrics]
|
|
128
|
+
values = [v if v is not None else "―" for v in values]
|
|
129
|
+
values = [round(v, 2) if isinstance(v, float) else v for v in values]
|
|
130
|
+
row = [metric_name] + values
|
|
131
|
+
dct = {"row": row, "id": metric, "items": row}
|
|
132
|
+
res["content"].append(dct)
|
|
133
|
+
|
|
134
|
+
latency_row = ["Latency (ms)"] + latency
|
|
135
|
+
res["content"].append({"row": latency_row, "id": latency_row[0], "items": latency_row})
|
|
136
|
+
|
|
137
|
+
fps_row = ["FPS"] + fps
|
|
138
|
+
res["content"].append({"row": fps_row, "id": fps_row[0], "items": fps_row})
|
|
139
|
+
|
|
140
|
+
columns_options = [{"disableSort": True} for _ in columns]
|
|
141
|
+
|
|
142
|
+
res["columns"] = columns
|
|
143
|
+
res["columnsOptions"] = columns_options
|
|
144
|
+
|
|
145
|
+
return TableWidget(
|
|
146
|
+
name="table_key_metrics",
|
|
147
|
+
data=res,
|
|
148
|
+
show_header_controls=False,
|
|
149
|
+
fix_columns=1,
|
|
150
|
+
page_size=len(res["content"]),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def chart_widget(self) -> ChartWidget:
|
|
155
|
+
return ChartWidget(name=self.CHART, figure=self.get_figure())
|
|
156
|
+
|
|
157
|
+
def get_overview_info(self, eval_result: EvalResult):
|
|
158
|
+
classes_cnt = len(eval_result.classes_whitelist)
|
|
159
|
+
classes_str = "classes" if classes_cnt > 1 else "class"
|
|
160
|
+
classes_str = f"{classes_cnt} {classes_str}"
|
|
161
|
+
|
|
162
|
+
train_session, images_str = "", ""
|
|
163
|
+
gt_project_id = eval_result.gt_project_info.id
|
|
164
|
+
gt_dataset_ids = eval_result.gt_dataset_ids
|
|
165
|
+
gt_images_ids = eval_result.gt_images_ids
|
|
166
|
+
train_info = eval_result.train_info
|
|
167
|
+
if gt_images_ids is not None:
|
|
168
|
+
val_imgs_cnt = len(gt_images_ids)
|
|
169
|
+
elif gt_dataset_ids is not None:
|
|
170
|
+
datasets = eval_result.gt_dataset_infos
|
|
171
|
+
val_imgs_cnt = sum(ds.items_count for ds in datasets)
|
|
172
|
+
else:
|
|
173
|
+
val_imgs_cnt = eval_result.gt_project_info.items_count
|
|
174
|
+
|
|
175
|
+
if train_info:
|
|
176
|
+
train_task_id = train_info.get("app_session_id")
|
|
177
|
+
if train_task_id:
|
|
178
|
+
task_info = eval_result.api.task.get_info_by_id(int(train_task_id))
|
|
179
|
+
app_id = task_info["meta"]["app"]["id"]
|
|
180
|
+
train_session = f'- **Training dashboard**: <a href="/apps/{app_id}/sessions/{train_task_id}" target="_blank">open</a>'
|
|
181
|
+
|
|
182
|
+
train_imgs_cnt = train_info.get("images_count")
|
|
183
|
+
images_str = f", {train_imgs_cnt} images in train, {val_imgs_cnt} images in validation"
|
|
184
|
+
|
|
185
|
+
if gt_images_ids is not None:
|
|
186
|
+
images_str += f". Evaluated using subset - {val_imgs_cnt} images"
|
|
187
|
+
elif gt_dataset_ids is not None:
|
|
188
|
+
links = [
|
|
189
|
+
f'<a href="/projects/{gt_project_id}/datasets/{ds.id}" target="_blank">{ds.name}</a>'
|
|
190
|
+
for ds in datasets
|
|
191
|
+
]
|
|
192
|
+
images_str += (
|
|
193
|
+
f". Evaluated on the dataset{'s' if len(links) > 1 else ''}: {', '.join(links)}"
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
images_str += f". Evaluated on the whole project ({val_imgs_cnt} images)"
|
|
197
|
+
|
|
198
|
+
return classes_str, images_str, train_session
|
|
199
|
+
|
|
200
|
+
def get_figure(self): # -> Optional[go.Figure]
|
|
201
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
202
|
+
|
|
203
|
+
# Overall Metrics
|
|
204
|
+
fig = go.Figure()
|
|
205
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
206
|
+
name = f"[{i + 1}] {eval_result.name}"
|
|
207
|
+
base_metrics = eval_result.mp.base_metrics()
|
|
208
|
+
r = list(base_metrics.values())
|
|
209
|
+
theta = [eval_result.mp.metric_names[k] for k in base_metrics.keys()]
|
|
210
|
+
fig.add_trace(
|
|
211
|
+
go.Scatterpolar(
|
|
212
|
+
r=r + [r[0]],
|
|
213
|
+
theta=theta + [theta[0]],
|
|
214
|
+
# fill="toself",
|
|
215
|
+
name=name,
|
|
216
|
+
marker=dict(color=eval_result.color),
|
|
217
|
+
hovertemplate=name + "<br>%{theta}: %{r:.2f}<extra></extra>",
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
fig.update_layout(
|
|
221
|
+
polar=dict(
|
|
222
|
+
radialaxis=dict(
|
|
223
|
+
range=[0.0, 1.0],
|
|
224
|
+
ticks="outside",
|
|
225
|
+
),
|
|
226
|
+
angularaxis=dict(rotation=90, direction="clockwise"),
|
|
227
|
+
),
|
|
228
|
+
dragmode=False,
|
|
229
|
+
# title="Overall Metrics",
|
|
230
|
+
# width=700,
|
|
231
|
+
# height=500,
|
|
232
|
+
# autosize=False,
|
|
233
|
+
margin=dict(l=25, r=25, t=25, b=25),
|
|
234
|
+
)
|
|
235
|
+
fig.update_layout(
|
|
236
|
+
modebar=dict(
|
|
237
|
+
remove=[
|
|
238
|
+
"zoom2d",
|
|
239
|
+
"pan2d",
|
|
240
|
+
"select2d",
|
|
241
|
+
"lasso2d",
|
|
242
|
+
"zoomIn2d",
|
|
243
|
+
"zoomOut2d",
|
|
244
|
+
"autoScale2d",
|
|
245
|
+
"resetScale2d",
|
|
246
|
+
]
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
return fig
|