supervisely 6.73.214__py3-none-any.whl → 6.73.215__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/app/widgets/report_thumbnail/report_thumbnail.py +17 -5
- supervisely/app/widgets/team_files_selector/team_files_selector.py +3 -0
- supervisely/nn/benchmark/comparison/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/detection_visualization/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/detection_visualization/text_templates.py +437 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +27 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +125 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +224 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/explore_predicttions.py +112 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +161 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +336 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +249 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +142 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +300 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +308 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +19 -0
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +298 -0
- supervisely/nn/benchmark/comparison/model_comparison.py +84 -0
- supervisely/nn/benchmark/evaluation/coco/metric_provider.py +9 -7
- supervisely/nn/benchmark/visualization/evaluation_result.py +266 -0
- supervisely/nn/benchmark/visualization/renderer.py +100 -0
- supervisely/nn/benchmark/visualization/report_template.html +46 -0
- supervisely/nn/benchmark/visualization/visualizer.py +1 -1
- supervisely/nn/benchmark/visualization/widgets/__init__.py +17 -0
- supervisely/nn/benchmark/visualization/widgets/chart/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/chart/chart.py +72 -0
- supervisely/nn/benchmark/visualization/widgets/chart/template.html +16 -0
- supervisely/nn/benchmark/visualization/widgets/collapse/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/collapse/collapse.py +33 -0
- supervisely/nn/benchmark/visualization/widgets/container/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/container/container.py +54 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +125 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/template.html +49 -0
- supervisely/nn/benchmark/visualization/widgets/markdown/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/markdown/markdown.py +53 -0
- supervisely/nn/benchmark/visualization/widgets/notification/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/notification/notification.py +38 -0
- supervisely/nn/benchmark/visualization/widgets/sidebar/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/sidebar/sidebar.py +67 -0
- supervisely/nn/benchmark/visualization/widgets/table/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/table/table.py +116 -0
- supervisely/nn/benchmark/visualization/widgets/widget.py +22 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/METADATA +1 -1
- {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/RECORD +49 -10
- {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/LICENSE +0 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/WHEEL +0 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.214.dist-info → supervisely-6.73.215.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from supervisely.imaging.color import hex2rgb
|
|
4
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
|
|
5
|
+
BaseVisMetric,
|
|
6
|
+
)
|
|
7
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
8
|
+
ChartWidget,
|
|
9
|
+
CollapseWidget,
|
|
10
|
+
MarkdownWidget,
|
|
11
|
+
NotificationWidget,
|
|
12
|
+
TableWidget,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PrCurve(BaseVisMetric):
|
|
17
|
+
MARKDOWN_PR_CURVE = "markdown_pr_curve"
|
|
18
|
+
MARKDOWN_PR_TRADE_OFFS = "markdown_trade_offs"
|
|
19
|
+
MARKDOWN_WHAT_IS_PR_CURVE = "markdown_what_is_pr_curve"
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def markdown_widget(self) -> MarkdownWidget:
|
|
23
|
+
text: str = getattr(self.vis_texts, self.MARKDOWN_PR_CURVE).format(
|
|
24
|
+
self.vis_texts.definitions.f1_score
|
|
25
|
+
)
|
|
26
|
+
return MarkdownWidget(
|
|
27
|
+
name=self.MARKDOWN_PR_CURVE, title="mAP & Precision-Recall Curve", text=text
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def chart_widget(self) -> ChartWidget:
|
|
32
|
+
return ChartWidget(name="chart_pr_curve", figure=self.get_figure())
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def collapsed_widget(self) -> CollapseWidget:
|
|
36
|
+
text_pr_trade_offs = getattr(self.vis_texts, self.MARKDOWN_PR_TRADE_OFFS)
|
|
37
|
+
text_pr_curve = getattr(self.vis_texts, self.MARKDOWN_WHAT_IS_PR_CURVE).format(
|
|
38
|
+
self.vis_texts.definitions.confidence_score,
|
|
39
|
+
self.vis_texts.definitions.true_positives,
|
|
40
|
+
self.vis_texts.definitions.false_positives,
|
|
41
|
+
)
|
|
42
|
+
markdown_pr_trade_offs = MarkdownWidget(
|
|
43
|
+
name=self.MARKDOWN_PR_TRADE_OFFS,
|
|
44
|
+
title="About Trade-offs between precision and recall",
|
|
45
|
+
text=text_pr_trade_offs,
|
|
46
|
+
)
|
|
47
|
+
markdown_whatis_pr_curve = MarkdownWidget(
|
|
48
|
+
name=self.MARKDOWN_WHAT_IS_PR_CURVE,
|
|
49
|
+
title="How the PR curve is built?",
|
|
50
|
+
text=text_pr_curve,
|
|
51
|
+
)
|
|
52
|
+
return CollapseWidget(widgets=[markdown_pr_trade_offs, markdown_whatis_pr_curve])
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def table_widget(self) -> TableWidget:
|
|
56
|
+
res = {}
|
|
57
|
+
|
|
58
|
+
columns = [" ", "mAP (0.5:0.95)", "mAP (0.75)"]
|
|
59
|
+
res["content"] = []
|
|
60
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
61
|
+
value_range = round(eval_result.mp.json_metrics()["mAP"], 2)
|
|
62
|
+
value_75 = eval_result.mp.json_metrics()["AP75"] or "-"
|
|
63
|
+
value_75 = round(value_75, 2) if isinstance(value_75, float) else value_75
|
|
64
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
65
|
+
row = [model_name, value_range, value_75]
|
|
66
|
+
dct = {
|
|
67
|
+
"row": row,
|
|
68
|
+
"id": model_name,
|
|
69
|
+
"items": row,
|
|
70
|
+
}
|
|
71
|
+
res["content"].append(dct)
|
|
72
|
+
|
|
73
|
+
columns_options = [
|
|
74
|
+
{"disableSort": True},
|
|
75
|
+
{"disableSort": True},
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
res["columns"] = columns
|
|
79
|
+
res["columnsOptions"] = columns_options
|
|
80
|
+
|
|
81
|
+
return TableWidget(
|
|
82
|
+
name="table_pr_curve",
|
|
83
|
+
data=res,
|
|
84
|
+
show_header_controls=False,
|
|
85
|
+
# main_column=columns[0],
|
|
86
|
+
fix_columns=1,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def get_figure(self): # -> Optional[go.Figure]:
|
|
90
|
+
import plotly.express as px # pylint: disable=import-error
|
|
91
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
92
|
+
|
|
93
|
+
fig = go.Figure()
|
|
94
|
+
|
|
95
|
+
rec_thr = self.eval_results[0].mp.recThrs
|
|
96
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
97
|
+
pr_curve = eval_result.mp.pr_curve().copy()
|
|
98
|
+
pr_curve[pr_curve == -1] = np.nan
|
|
99
|
+
pr_curve = np.nanmean(pr_curve, axis=-1)
|
|
100
|
+
|
|
101
|
+
name = f"[{i}] {eval_result.name}"
|
|
102
|
+
color = ",".join(map(str, hex2rgb(eval_result.color))) + ",0.1"
|
|
103
|
+
line = go.Scatter(
|
|
104
|
+
x=eval_result.mp.recThrs,
|
|
105
|
+
y=pr_curve,
|
|
106
|
+
mode="lines",
|
|
107
|
+
name=name,
|
|
108
|
+
fill="tozeroy",
|
|
109
|
+
fillcolor=f"rgba({color})",
|
|
110
|
+
line=dict(color=eval_result.color),
|
|
111
|
+
hovertemplate=name + "<br>Recall: %{x:.2f}<br>Precision: %{y:.2f}<extra></extra>",
|
|
112
|
+
showlegend=True,
|
|
113
|
+
)
|
|
114
|
+
fig.add_trace(line)
|
|
115
|
+
|
|
116
|
+
fig.add_trace(
|
|
117
|
+
go.Scatter(
|
|
118
|
+
x=rec_thr,
|
|
119
|
+
y=[1] * len(rec_thr),
|
|
120
|
+
name="Perfect",
|
|
121
|
+
line=dict(color="orange", dash="dash"),
|
|
122
|
+
showlegend=True,
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
fig.update_layout(
|
|
126
|
+
dragmode=False,
|
|
127
|
+
modebar=dict(
|
|
128
|
+
remove=[
|
|
129
|
+
"zoom2d",
|
|
130
|
+
"pan2d",
|
|
131
|
+
"select2d",
|
|
132
|
+
"lasso2d",
|
|
133
|
+
"zoomIn2d",
|
|
134
|
+
"zoomOut2d",
|
|
135
|
+
"autoScale2d",
|
|
136
|
+
"resetScale2d",
|
|
137
|
+
]
|
|
138
|
+
),
|
|
139
|
+
xaxis_title="Recall",
|
|
140
|
+
yaxis_title="Precision",
|
|
141
|
+
)
|
|
142
|
+
return fig
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
|
|
2
|
+
BaseVisMetric,
|
|
3
|
+
)
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
|
+
ChartWidget,
|
|
6
|
+
CollapseWidget,
|
|
7
|
+
MarkdownWidget,
|
|
8
|
+
NotificationWidget,
|
|
9
|
+
TableWidget,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PrecisionRecallF1(BaseVisMetric):
|
|
14
|
+
MARKDOWN = "markdown_PRF1"
|
|
15
|
+
MARKDOWN_PRECISION_TITLE = "markdown_precision_per_class_title"
|
|
16
|
+
MARKDOWN_RECALL_TITLE = "markdown_recall_per_class_title"
|
|
17
|
+
MARKDOWN_F1_TITLE = "markdown_f1_per_class_title"
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def markdown_widget(self) -> MarkdownWidget:
|
|
21
|
+
text: str = getattr(self.vis_texts, self.MARKDOWN).format(
|
|
22
|
+
self.vis_texts.definitions.f1_score
|
|
23
|
+
)
|
|
24
|
+
return MarkdownWidget(name=self.MARKDOWN, title="Precision, Recall, F1-score", text=text)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def precision_per_class_title_md(self) -> MarkdownWidget:
|
|
28
|
+
text: str = getattr(self.vis_texts, self.MARKDOWN_PRECISION_TITLE)
|
|
29
|
+
text += self.vis_texts.clickable_label
|
|
30
|
+
return MarkdownWidget(
|
|
31
|
+
name="markdown_precision_per_class", title="Precision by Class", text=text
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def recall_per_class_title_md(self) -> MarkdownWidget:
|
|
36
|
+
text: str = getattr(self.vis_texts, self.MARKDOWN_RECALL_TITLE)
|
|
37
|
+
text += self.vis_texts.clickable_label
|
|
38
|
+
return MarkdownWidget(name="markdown_recall_per_class", title="Recall by Class", text=text)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def f1_per_class_title_md(self) -> MarkdownWidget:
|
|
42
|
+
text: str = getattr(self.vis_texts, self.MARKDOWN_F1_TITLE)
|
|
43
|
+
text += self.vis_texts.clickable_label
|
|
44
|
+
return MarkdownWidget(name="markdown_f1_per_class", title="F1-score by Class", text=text)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def chart_main_widget(self) -> ChartWidget:
|
|
48
|
+
chart = ChartWidget(name="chart_PRF1", figure=self.get_main_figure())
|
|
49
|
+
chart.set_click_data(
|
|
50
|
+
gallery_id=self.explore_modal_table.id,
|
|
51
|
+
click_data=self.get_click_data_main(),
|
|
52
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}`,",
|
|
53
|
+
)
|
|
54
|
+
return chart
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def chart_recall_per_class_widget(self) -> ChartWidget:
|
|
58
|
+
chart = ChartWidget(
|
|
59
|
+
name="chart_recall_per_class",
|
|
60
|
+
figure=self.get_recall_per_class_figure(),
|
|
61
|
+
)
|
|
62
|
+
chart.set_click_data(
|
|
63
|
+
gallery_id=self.explore_modal_table.id,
|
|
64
|
+
click_data=self.get_per_class_click_data(),
|
|
65
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}${'_'}${payload.points[0].label}`,",
|
|
66
|
+
)
|
|
67
|
+
return chart
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def chart_precision_per_class_widget(self) -> ChartWidget:
|
|
71
|
+
chart = ChartWidget(
|
|
72
|
+
name="chart_precision_per_class",
|
|
73
|
+
figure=self.get_precision_per_class_figure(),
|
|
74
|
+
)
|
|
75
|
+
chart.set_click_data(
|
|
76
|
+
gallery_id=self.explore_modal_table.id,
|
|
77
|
+
click_data=self.get_per_class_click_data(),
|
|
78
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}${'_'}${payload.points[0].label}`,",
|
|
79
|
+
)
|
|
80
|
+
return chart
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def chart_f1_per_class_widget(self) -> ChartWidget:
|
|
84
|
+
chart = ChartWidget(name="chart_f1_per_class", figure=self.get_f1_per_class_figure())
|
|
85
|
+
chart.set_click_data(
|
|
86
|
+
gallery_id=self.explore_modal_table.id,
|
|
87
|
+
click_data=self.get_per_class_click_data(),
|
|
88
|
+
chart_click_extra="'getKey': (payload) => `${payload.points[0].curveNumber}${'_'}${payload.points[0].label}`,",
|
|
89
|
+
)
|
|
90
|
+
return chart
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def table_widget(self) -> TableWidget:
|
|
94
|
+
res = {}
|
|
95
|
+
|
|
96
|
+
columns = [" ", "Precision", "Recall", "F1-score"]
|
|
97
|
+
res["content"] = []
|
|
98
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
99
|
+
precision = round(eval_result.mp.json_metrics()["precision"], 2)
|
|
100
|
+
recall = round(eval_result.mp.json_metrics()["recall"], 2)
|
|
101
|
+
f1 = round(eval_result.mp.json_metrics()["f1"], 2)
|
|
102
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
103
|
+
row = [model_name, precision, recall, f1]
|
|
104
|
+
dct = {
|
|
105
|
+
"row": row,
|
|
106
|
+
"id": model_name,
|
|
107
|
+
"items": row,
|
|
108
|
+
}
|
|
109
|
+
res["content"].append(dct)
|
|
110
|
+
|
|
111
|
+
columns_options = [
|
|
112
|
+
{"disableSort": True},
|
|
113
|
+
{"disableSort": True},
|
|
114
|
+
{"disableSort": True},
|
|
115
|
+
{"disableSort": True},
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
res["columns"] = columns
|
|
119
|
+
res["columnsOptions"] = columns_options
|
|
120
|
+
|
|
121
|
+
return TableWidget(
|
|
122
|
+
name="table_precision_recall_f1",
|
|
123
|
+
data=res,
|
|
124
|
+
show_header_controls=False,
|
|
125
|
+
# main_column=columns[0],
|
|
126
|
+
fix_columns=1,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def get_main_figure(self): # -> Optional[go.Figure]:
|
|
130
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
131
|
+
|
|
132
|
+
fig = go.Figure()
|
|
133
|
+
|
|
134
|
+
classes_cnt = len(self.eval_results[0].mp.cat_names)
|
|
135
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
136
|
+
precision = eval_result.mp.json_metrics()["precision"]
|
|
137
|
+
recall = eval_result.mp.json_metrics()["recall"]
|
|
138
|
+
f1 = eval_result.mp.json_metrics()["f1"]
|
|
139
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
140
|
+
fig.add_trace(
|
|
141
|
+
go.Bar(
|
|
142
|
+
x=["Precision", "Recall", "F1-score"],
|
|
143
|
+
y=[precision, recall, f1],
|
|
144
|
+
name=model_name,
|
|
145
|
+
width=0.2 if classes_cnt >= 5 else None,
|
|
146
|
+
marker=dict(color=eval_result.color),
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
fig.update_layout(
|
|
151
|
+
barmode="group",
|
|
152
|
+
xaxis_title="Metric",
|
|
153
|
+
yaxis_title="Value",
|
|
154
|
+
yaxis=dict(range=[0, 1.1]),
|
|
155
|
+
width=700 if classes_cnt < 5 else None,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return fig
|
|
159
|
+
|
|
160
|
+
def get_recall_per_class_figure(self):
|
|
161
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
162
|
+
|
|
163
|
+
fig = go.Figure()
|
|
164
|
+
classes_cnt = len(self.eval_results[0].mp.cat_names)
|
|
165
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
166
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
167
|
+
sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
168
|
+
|
|
169
|
+
fig.add_trace(
|
|
170
|
+
go.Bar(
|
|
171
|
+
y=sorted_by_f1["recall"],
|
|
172
|
+
x=sorted_by_f1["category"],
|
|
173
|
+
name=f"{model_name} Recall",
|
|
174
|
+
width=0.2 if classes_cnt < 5 else None,
|
|
175
|
+
marker=dict(color=eval_result.color),
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
fig.update_layout(
|
|
180
|
+
barmode="group",
|
|
181
|
+
bargap=0.15,
|
|
182
|
+
bargroupgap=0.05,
|
|
183
|
+
width=700 if classes_cnt < 5 else None,
|
|
184
|
+
)
|
|
185
|
+
fig.update_xaxes(title_text="Class")
|
|
186
|
+
fig.update_yaxes(title_text="Recall", range=[0, 1])
|
|
187
|
+
return fig
|
|
188
|
+
|
|
189
|
+
def get_per_class_click_data(self):
|
|
190
|
+
res = {}
|
|
191
|
+
res["layoutTemplate"] = [None, None, None]
|
|
192
|
+
res["clickData"] = {}
|
|
193
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
194
|
+
model_name = f"Model [{i + 1}] {eval_result.name}"
|
|
195
|
+
for key, v in eval_result.click_data.objects_by_class.items():
|
|
196
|
+
click_data = res["clickData"].setdefault(f"{i}_{key}", {})
|
|
197
|
+
img_ids, obj_ids = set(), set()
|
|
198
|
+
title = f"{model_name}. Class {key}: {len(v)} object{'s' if len(v) > 1 else ''}"
|
|
199
|
+
click_data["title"] = title
|
|
200
|
+
|
|
201
|
+
for x in v:
|
|
202
|
+
img_ids.add(x["dt_img_id"])
|
|
203
|
+
obj_ids.add(x["dt_obj_id"])
|
|
204
|
+
|
|
205
|
+
click_data["imagesIds"] = list(img_ids)
|
|
206
|
+
click_data["filters"] = [
|
|
207
|
+
{
|
|
208
|
+
"type": "tag",
|
|
209
|
+
"tagId": "confidence",
|
|
210
|
+
"value": [eval_result.f1_optimal_conf, 1],
|
|
211
|
+
},
|
|
212
|
+
{"type": "tag", "tagId": "outcome", "value": "TP"},
|
|
213
|
+
{"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
|
|
214
|
+
]
|
|
215
|
+
return res
|
|
216
|
+
|
|
217
|
+
def get_precision_per_class_figure(self):
|
|
218
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
219
|
+
|
|
220
|
+
fig = go.Figure()
|
|
221
|
+
classes_cnt = len(self.eval_results[0].mp.cat_names)
|
|
222
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
223
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
224
|
+
sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
225
|
+
|
|
226
|
+
fig.add_trace(
|
|
227
|
+
go.Bar(
|
|
228
|
+
y=sorted_by_f1["precision"],
|
|
229
|
+
x=sorted_by_f1["category"],
|
|
230
|
+
name=f"{model_name} Precision",
|
|
231
|
+
width=0.2 if classes_cnt < 5 else None,
|
|
232
|
+
marker=dict(color=eval_result.color),
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
fig.update_layout(
|
|
237
|
+
barmode="group",
|
|
238
|
+
bargap=0.15,
|
|
239
|
+
bargroupgap=0.05,
|
|
240
|
+
width=700 if classes_cnt < 5 else None,
|
|
241
|
+
)
|
|
242
|
+
fig.update_xaxes(title_text="Class")
|
|
243
|
+
fig.update_yaxes(title_text="Precision", range=[0, 1])
|
|
244
|
+
return fig
|
|
245
|
+
|
|
246
|
+
def get_f1_per_class_figure(self):
|
|
247
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
248
|
+
|
|
249
|
+
fig = go.Figure()
|
|
250
|
+
classes_cnt = len(self.eval_results[0].mp.cat_names)
|
|
251
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
252
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
253
|
+
sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
254
|
+
|
|
255
|
+
fig.add_trace(
|
|
256
|
+
go.Bar(
|
|
257
|
+
y=sorted_by_f1["f1"],
|
|
258
|
+
x=sorted_by_f1["category"],
|
|
259
|
+
name=f"{model_name} F1-score",
|
|
260
|
+
width=0.2 if classes_cnt < 5 else None,
|
|
261
|
+
marker=dict(color=eval_result.color),
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
fig.update_layout(
|
|
266
|
+
barmode="group",
|
|
267
|
+
bargap=0.15,
|
|
268
|
+
bargroupgap=0.05,
|
|
269
|
+
width=700 if classes_cnt < 5 else None,
|
|
270
|
+
)
|
|
271
|
+
fig.update_xaxes(title_text="Class")
|
|
272
|
+
fig.update_yaxes(title_text="F1-score", range=[0, 1])
|
|
273
|
+
return fig
|
|
274
|
+
|
|
275
|
+
def get_click_data_main(self):
|
|
276
|
+
res = {}
|
|
277
|
+
res["layoutTemplate"] = [None, None, None]
|
|
278
|
+
res["clickData"] = {}
|
|
279
|
+
|
|
280
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
281
|
+
model_name = f"Model [{i + 1}] {eval_result.name}"
|
|
282
|
+
click_data = res["clickData"].setdefault(i, {})
|
|
283
|
+
img_ids, obj_ids = set(), set()
|
|
284
|
+
objects_cnt = 0
|
|
285
|
+
for outcome, matched_obj in eval_result.click_data.outcome_counts.items():
|
|
286
|
+
if outcome == "TP": # TODO: check if this is correct
|
|
287
|
+
objects_cnt += len(matched_obj)
|
|
288
|
+
for x in matched_obj:
|
|
289
|
+
img_ids.add(x["dt_img_id"])
|
|
290
|
+
obj_ids.add(x["dt_obj_id"])
|
|
291
|
+
|
|
292
|
+
click_data["title"] = f"{model_name}, {objects_cnt} objects"
|
|
293
|
+
click_data["imagesIds"] = list(img_ids)
|
|
294
|
+
click_data["filters"] = [
|
|
295
|
+
{"type": "tag", "tagId": "confidence", "value": [eval_result.f1_optimal_conf, 1]},
|
|
296
|
+
{"type": "tag", "tagId": "outcome", "value": "TP"},
|
|
297
|
+
{"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
return res
|