supervisely 6.73.213__py3-none-any.whl → 6.73.215__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/app/widgets/report_thumbnail/report_thumbnail.py +17 -5
- supervisely/app/widgets/team_files_selector/team_files_selector.py +3 -0
- supervisely/io/network_exceptions.py +89 -32
- supervisely/nn/benchmark/comparison/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/detection_visualization/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/detection_visualization/text_templates.py +437 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +27 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +125 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +224 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/explore_predicttions.py +112 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +161 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +336 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +249 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +142 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +300 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +308 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +19 -0
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +298 -0
- supervisely/nn/benchmark/comparison/model_comparison.py +84 -0
- supervisely/nn/benchmark/evaluation/coco/metric_provider.py +9 -7
- supervisely/nn/benchmark/visualization/evaluation_result.py +266 -0
- supervisely/nn/benchmark/visualization/renderer.py +100 -0
- supervisely/nn/benchmark/visualization/report_template.html +46 -0
- supervisely/nn/benchmark/visualization/visualizer.py +1 -1
- supervisely/nn/benchmark/visualization/widgets/__init__.py +17 -0
- supervisely/nn/benchmark/visualization/widgets/chart/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/chart/chart.py +72 -0
- supervisely/nn/benchmark/visualization/widgets/chart/template.html +16 -0
- supervisely/nn/benchmark/visualization/widgets/collapse/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/collapse/collapse.py +33 -0
- supervisely/nn/benchmark/visualization/widgets/container/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/container/container.py +54 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +125 -0
- supervisely/nn/benchmark/visualization/widgets/gallery/template.html +49 -0
- supervisely/nn/benchmark/visualization/widgets/markdown/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/markdown/markdown.py +53 -0
- supervisely/nn/benchmark/visualization/widgets/notification/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/notification/notification.py +38 -0
- supervisely/nn/benchmark/visualization/widgets/sidebar/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/sidebar/sidebar.py +67 -0
- supervisely/nn/benchmark/visualization/widgets/table/__init__.py +0 -0
- supervisely/nn/benchmark/visualization/widgets/table/table.py +116 -0
- supervisely/nn/benchmark/visualization/widgets/widget.py +22 -0
- supervisely/nn/inference/cache.py +8 -5
- {supervisely-6.73.213.dist-info → supervisely-6.73.215.dist-info}/METADATA +5 -5
- {supervisely-6.73.213.dist-info → supervisely-6.73.215.dist-info}/RECORD +51 -12
- {supervisely-6.73.213.dist-info → supervisely-6.73.215.dist-info}/LICENSE +0 -0
- {supervisely-6.73.213.dist-info → supervisely-6.73.215.dist-info}/WHEEL +0 -0
- {supervisely-6.73.213.dist-info → supervisely-6.73.215.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.213.dist-info → supervisely-6.73.215.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
|
|
2
|
+
BaseVisMetric,
|
|
3
|
+
)
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
|
+
ChartWidget,
|
|
6
|
+
CollapseWidget,
|
|
7
|
+
MarkdownWidget,
|
|
8
|
+
NotificationWidget,
|
|
9
|
+
TableWidget,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CalibrationScore(BaseVisMetric):
|
|
14
|
+
@property
|
|
15
|
+
def header_md(self) -> MarkdownWidget:
|
|
16
|
+
text_template = self.vis_texts.markdown_calibration_score_1
|
|
17
|
+
text = text_template.format(self.vis_texts.definitions.confidence_score)
|
|
18
|
+
return MarkdownWidget(
|
|
19
|
+
name="markdown_calibration_score",
|
|
20
|
+
title="Calibration Score",
|
|
21
|
+
text=text,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def collapse_tip(self) -> CollapseWidget:
|
|
26
|
+
md = MarkdownWidget(
|
|
27
|
+
name="what_is_calibration",
|
|
28
|
+
title="What is calibration?",
|
|
29
|
+
text=self.vis_texts.markdown_what_is_calibration,
|
|
30
|
+
)
|
|
31
|
+
return CollapseWidget([md])
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def header_md_2(self) -> MarkdownWidget:
|
|
35
|
+
return MarkdownWidget(
|
|
36
|
+
name="markdown_calibration_score_2",
|
|
37
|
+
title="",
|
|
38
|
+
text=self.vis_texts.markdown_calibration_score_2,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def table(self) -> TableWidget:
|
|
43
|
+
columns = [" ", "confidence threshold", "ECE", "MCE"]
|
|
44
|
+
columns_options = [
|
|
45
|
+
{"disableSort": True},
|
|
46
|
+
{"disableSort": True},
|
|
47
|
+
{"disableSort": True},
|
|
48
|
+
{"disableSort": True},
|
|
49
|
+
]
|
|
50
|
+
content = []
|
|
51
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
52
|
+
name = f"[{i+1}] {eval_result.name}"
|
|
53
|
+
conf_threshold = eval_result.mp.m_full.get_f1_optimal_conf()[0] or 0.0
|
|
54
|
+
ece = eval_result.mp.m_full.expected_calibration_error()
|
|
55
|
+
mce = eval_result.mp.m_full.maximum_calibration_error()
|
|
56
|
+
row = [name, round(conf_threshold, 2), round(ece, 2), round(mce, 2)]
|
|
57
|
+
dct = {
|
|
58
|
+
"row": row,
|
|
59
|
+
"id": name,
|
|
60
|
+
"items": row,
|
|
61
|
+
}
|
|
62
|
+
content.append(dct)
|
|
63
|
+
data = {
|
|
64
|
+
"columns": columns,
|
|
65
|
+
"columnsOptions": columns_options,
|
|
66
|
+
"content": content,
|
|
67
|
+
}
|
|
68
|
+
return TableWidget(
|
|
69
|
+
"table_reliability",
|
|
70
|
+
data,
|
|
71
|
+
show_header_controls=False,
|
|
72
|
+
# main_column=columns[0],
|
|
73
|
+
fix_columns=1,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def reliability_diagram_md(self) -> MarkdownWidget:
|
|
78
|
+
return MarkdownWidget(
|
|
79
|
+
name="markdown_reliability_diagram",
|
|
80
|
+
title="Reliability Diagram",
|
|
81
|
+
text=self.vis_texts.markdown_reliability_diagram,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def reliability_chart(self) -> ChartWidget:
|
|
86
|
+
return ChartWidget(name="chart_reliability", figure=self.get_rel_figure())
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def collapse_ece(self) -> CollapseWidget:
|
|
90
|
+
md = MarkdownWidget(
|
|
91
|
+
name="markdown_calibration_curve_interpretation",
|
|
92
|
+
title="How to interpret the Calibration curve",
|
|
93
|
+
text=self.vis_texts.markdown_calibration_curve_interpretation,
|
|
94
|
+
)
|
|
95
|
+
return CollapseWidget([md])
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def confidence_score_md(self) -> MarkdownWidget:
|
|
99
|
+
text = self.vis_texts.markdown_confidence_score_1.format(
|
|
100
|
+
self.vis_texts.definitions.confidence_threshold
|
|
101
|
+
)
|
|
102
|
+
return MarkdownWidget(
|
|
103
|
+
"markdown_confidence_score_1",
|
|
104
|
+
"Confidence Score Profile",
|
|
105
|
+
text,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def confidence_chart(self) -> ChartWidget:
|
|
110
|
+
return ChartWidget(name="chart_confidence", figure=self.get_conf_figure())
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def confidence_score_md_2(self) -> MarkdownWidget:
|
|
114
|
+
return MarkdownWidget(
|
|
115
|
+
name="markdown_confidence_score_2",
|
|
116
|
+
title="",
|
|
117
|
+
text=self.vis_texts.markdown_confidence_score_2,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def collapse_conf_score(self) -> CollapseWidget:
|
|
122
|
+
md = MarkdownWidget(
|
|
123
|
+
name="markdown_plot_confidence_profile",
|
|
124
|
+
title="How to plot Confidence Profile?",
|
|
125
|
+
text=self.vis_texts.markdown_plot_confidence_profile,
|
|
126
|
+
)
|
|
127
|
+
return CollapseWidget([md])
|
|
128
|
+
|
|
129
|
+
def get_conf_figure(self):
|
|
130
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
131
|
+
|
|
132
|
+
# Create an empty figure
|
|
133
|
+
fig = go.Figure()
|
|
134
|
+
|
|
135
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
136
|
+
# Add a line trace for each eval_result
|
|
137
|
+
fig.add_trace(
|
|
138
|
+
go.Scatter(
|
|
139
|
+
x=eval_result.dfsp_down["scores"],
|
|
140
|
+
y=eval_result.dfsp_down["f1"],
|
|
141
|
+
mode="lines",
|
|
142
|
+
name=f"[{i+1}] {eval_result.name}",
|
|
143
|
+
line=dict(color=eval_result.color),
|
|
144
|
+
hovertemplate="Confidence Score: %{x:.2f}<br>Value: %{y:.2f}<extra></extra>",
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Add a vertical line and annotation for F1-optimal threshold if available
|
|
149
|
+
if eval_result.mp.f1_optimal_conf is not None and eval_result.mp.best_f1 is not None:
|
|
150
|
+
fig.add_shape(
|
|
151
|
+
type="line",
|
|
152
|
+
x0=eval_result.mp.f1_optimal_conf,
|
|
153
|
+
x1=eval_result.mp.f1_optimal_conf,
|
|
154
|
+
y0=0,
|
|
155
|
+
y1=eval_result.mp.best_f1,
|
|
156
|
+
line=dict(color="gray", width=2, dash="dash"),
|
|
157
|
+
name=f"F1-optimal threshold ({eval_result.name})",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Update the layout
|
|
161
|
+
fig.update_layout(
|
|
162
|
+
yaxis=dict(range=[0, 1], title="Scores"),
|
|
163
|
+
xaxis=dict(range=[0, 1], tick0=0, dtick=0.1, title="Confidence Score"),
|
|
164
|
+
height=500,
|
|
165
|
+
dragmode=False,
|
|
166
|
+
modebar=dict(
|
|
167
|
+
remove=[
|
|
168
|
+
"zoom2d",
|
|
169
|
+
"pan2d",
|
|
170
|
+
"select2d",
|
|
171
|
+
"lasso2d",
|
|
172
|
+
"zoomIn2d",
|
|
173
|
+
"zoomOut2d",
|
|
174
|
+
"autoScale2d",
|
|
175
|
+
"resetScale2d",
|
|
176
|
+
]
|
|
177
|
+
),
|
|
178
|
+
showlegend=True, # Show legend to differentiate between results
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
return fig
|
|
182
|
+
|
|
183
|
+
def get_rel_figure(self):
|
|
184
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
185
|
+
|
|
186
|
+
fig = go.Figure()
|
|
187
|
+
|
|
188
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
189
|
+
# Calibration curve (only positive predictions)
|
|
190
|
+
true_probs, pred_probs = eval_result.mp.m_full.calibration_curve()
|
|
191
|
+
|
|
192
|
+
fig.add_trace(
|
|
193
|
+
go.Scatter(
|
|
194
|
+
x=pred_probs,
|
|
195
|
+
y=true_probs,
|
|
196
|
+
mode="lines+markers",
|
|
197
|
+
name=f"[{i+1}] {eval_result.name}",
|
|
198
|
+
line=dict(color=eval_result.color),
|
|
199
|
+
hovertemplate=f"{eval_result.name}<br>"
|
|
200
|
+
+ "Confidence Score: %{x:.2f}<br>Fraction of True Positives: %{y:.2f}<extra></extra>",
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
fig.add_trace(
|
|
205
|
+
go.Scatter(
|
|
206
|
+
x=[0, 1],
|
|
207
|
+
y=[0, 1],
|
|
208
|
+
mode="lines",
|
|
209
|
+
name="Perfectly calibrated",
|
|
210
|
+
line=dict(color="orange", dash="dash"),
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
fig.update_layout(
|
|
215
|
+
# title="Calibration Curve (only positive predictions)",
|
|
216
|
+
xaxis_title="Confidence Score",
|
|
217
|
+
yaxis_title="Fraction of True Positives",
|
|
218
|
+
legend=dict(x=0.6, y=0.1),
|
|
219
|
+
xaxis=dict(range=[0, 1.1]),
|
|
220
|
+
yaxis=dict(range=[0, 1.1]),
|
|
221
|
+
width=700,
|
|
222
|
+
height=500,
|
|
223
|
+
)
|
|
224
|
+
return fig
|
supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/explore_predicttions.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
from supervisely.annotation.annotation import Annotation
|
|
4
|
+
from supervisely.api.image_api import ImageInfo
|
|
5
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
|
|
6
|
+
BaseVisMetric,
|
|
7
|
+
)
|
|
8
|
+
from supervisely.nn.benchmark.visualization.widgets import GalleryWidget, MarkdownWidget
|
|
9
|
+
from supervisely.project.project_meta import ProjectMeta
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExplorePredictions(BaseVisMetric):
|
|
13
|
+
|
|
14
|
+
MARKDOWN_DIFFERENCE = "markdown_explore_difference"
|
|
15
|
+
GALLERY_DIFFERENCE = "explore_difference_gallery"
|
|
16
|
+
MARKDOWN_SAME_ERRORS = "markdown_explore_same_errors"
|
|
17
|
+
GALLERY_SAME_ERRORS = "explore_same_error_gallery"
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def difference_predictions_md(self) -> MarkdownWidget:
|
|
21
|
+
text = self.vis_texts.markdown_explore_difference
|
|
22
|
+
return MarkdownWidget(self.MARKDOWN_DIFFERENCE, "Explore Predictions", text)
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def explore_gallery(self) -> GalleryWidget:
|
|
26
|
+
columns_number = len(self.eval_results) + 1
|
|
27
|
+
*data, min_conf = self._get_sample_data()
|
|
28
|
+
default_filters = [{"confidence": [min_conf, 1]}]
|
|
29
|
+
gallery = GalleryWidget(
|
|
30
|
+
self.GALLERY_DIFFERENCE, columns_number=columns_number, filters=default_filters
|
|
31
|
+
)
|
|
32
|
+
gallery.add_image_left_header("Click to explore more")
|
|
33
|
+
gallery.show_all_button = True
|
|
34
|
+
gallery.set_project_meta(self.eval_results[0].gt_project_meta)
|
|
35
|
+
gallery.set_images(*data)
|
|
36
|
+
gallery.add_on_click(
|
|
37
|
+
self.explore_modal_table.id, self.get_click_data_explore_all(), columns_number * 3
|
|
38
|
+
)
|
|
39
|
+
gallery._gallery._filters
|
|
40
|
+
gallery._gallery._update_filters()
|
|
41
|
+
|
|
42
|
+
return gallery
|
|
43
|
+
|
|
44
|
+
def _get_sample_data(self) -> Tuple[List[ImageInfo], List[Annotation], List[ProjectMeta]]:
|
|
45
|
+
images = []
|
|
46
|
+
annotations = []
|
|
47
|
+
metas = [self.eval_results[0].gt_project_meta]
|
|
48
|
+
skip_tags_filtering = []
|
|
49
|
+
api = self.eval_results[0].api
|
|
50
|
+
min_conf = float("inf")
|
|
51
|
+
for idx, eval_res in enumerate(self.eval_results):
|
|
52
|
+
if idx == 0:
|
|
53
|
+
dataset_info = api.dataset.get_list(eval_res.gt_project_id)[0]
|
|
54
|
+
image_infos = api.image.get_list(dataset_info.id, limit=5)
|
|
55
|
+
images_ids = [image_info.id for image_info in image_infos]
|
|
56
|
+
images.append(image_infos)
|
|
57
|
+
anns = api.annotation.download_batch(dataset_info.id, images_ids)
|
|
58
|
+
annotations.append(anns)
|
|
59
|
+
skip_tags_filtering.append(True)
|
|
60
|
+
metas.append(eval_res.dt_project_meta)
|
|
61
|
+
dataset_info = api.dataset.get_list(eval_res.dt_project_id)[0]
|
|
62
|
+
image_infos = eval_res.api.image.get_list(dataset_info.id, limit=5)
|
|
63
|
+
images_ids = [image_info.id for image_info in image_infos]
|
|
64
|
+
images.append(image_infos)
|
|
65
|
+
anns = eval_res.api.annotation.download_batch(dataset_info.id, images_ids)
|
|
66
|
+
annotations.append(anns)
|
|
67
|
+
skip_tags_filtering.append(False)
|
|
68
|
+
min_conf = min(min_conf, eval_res.f1_optimal_conf)
|
|
69
|
+
|
|
70
|
+
images = list(i for x in zip(*images) for i in x)
|
|
71
|
+
annotations = list(i for x in zip(*annotations) for i in x)
|
|
72
|
+
return images, annotations, metas, skip_tags_filtering, min_conf
|
|
73
|
+
|
|
74
|
+
def get_click_data_explore_all(self) -> dict:
|
|
75
|
+
res = {}
|
|
76
|
+
|
|
77
|
+
res["projectMeta"] = self.eval_results[0].gt_project_meta.to_json()
|
|
78
|
+
res["layoutTemplate"] = [None, None, None]
|
|
79
|
+
|
|
80
|
+
res["layoutTemplate"] = [{"skipObjectTagsFiltering": True, "columnTitle": "Ground Truth"}]
|
|
81
|
+
for i in range(len(self.eval_results)):
|
|
82
|
+
res["layoutTemplate"].append({"columnTitle": f"Model {i + 1}"})
|
|
83
|
+
|
|
84
|
+
click_data = res.setdefault("clickData", {})
|
|
85
|
+
explore = click_data.setdefault("explore", {})
|
|
86
|
+
explore["title"] = "Explore all predictions"
|
|
87
|
+
|
|
88
|
+
images_ids = []
|
|
89
|
+
api = self.eval_results[0].api
|
|
90
|
+
min_conf = float("inf")
|
|
91
|
+
for idx, eval_res in enumerate(self.eval_results):
|
|
92
|
+
if idx == 0:
|
|
93
|
+
dataset_infos = api.dataset.get_list(eval_res.gt_project_id)
|
|
94
|
+
current_images_ids = []
|
|
95
|
+
for ds in dataset_infos:
|
|
96
|
+
image_infos = eval_res.api.image.get_list(ds.id)
|
|
97
|
+
current_images_ids.extend([image_info.id for image_info in image_infos])
|
|
98
|
+
images_ids.append(current_images_ids)
|
|
99
|
+
|
|
100
|
+
current_images_ids = []
|
|
101
|
+
dataset_infos = api.dataset.get_list(eval_res.dt_project_id)
|
|
102
|
+
for ds in dataset_infos:
|
|
103
|
+
image_infos = eval_res.api.image.get_list(ds.id)
|
|
104
|
+
current_images_ids.extend([image_info.id for image_info in image_infos])
|
|
105
|
+
images_ids.append(current_images_ids)
|
|
106
|
+
|
|
107
|
+
min_conf = min(min_conf, eval_res.f1_optimal_conf)
|
|
108
|
+
|
|
109
|
+
explore["imagesIds"] = list(i for x in zip(*images_ids) for i in x)
|
|
110
|
+
explore["filters"] = [{"type": "tag", "tagId": "confidence", "value": [min_conf, 1]}]
|
|
111
|
+
|
|
112
|
+
return res
|
supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
|
|
4
|
+
BaseVisMetric,
|
|
5
|
+
)
|
|
6
|
+
from supervisely.nn.benchmark.cv_tasks import CVTask
|
|
7
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
8
|
+
ChartWidget,
|
|
9
|
+
CollapseWidget,
|
|
10
|
+
MarkdownWidget,
|
|
11
|
+
TableWidget,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalizationAccuracyIoU(BaseVisMetric):
|
|
16
|
+
@property
|
|
17
|
+
def header_md(self) -> MarkdownWidget:
|
|
18
|
+
title = "Localization Accuracy (IoU)"
|
|
19
|
+
if self.eval_results[0].cv_task in [
|
|
20
|
+
CVTask.INSTANCE_SEGMENTATION,
|
|
21
|
+
CVTask.SEMANTIC_SEGMENTATION,
|
|
22
|
+
]:
|
|
23
|
+
title = "Mask Accuracy (IoU)"
|
|
24
|
+
text_template = self.vis_texts.markdown_localization_accuracy
|
|
25
|
+
text = text_template.format(self.vis_texts.definitions.iou_score)
|
|
26
|
+
return MarkdownWidget(
|
|
27
|
+
name="markdown_localization_accuracy",
|
|
28
|
+
title=title,
|
|
29
|
+
text=text,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def iou_distribution_md(self) -> MarkdownWidget:
|
|
34
|
+
text_template = self.vis_texts.markdown_iou_distribution
|
|
35
|
+
text = text_template.format(self.vis_texts.definitions.iou_score)
|
|
36
|
+
return MarkdownWidget(
|
|
37
|
+
name="markdown_iou_distribution",
|
|
38
|
+
title="IoU Distribution",
|
|
39
|
+
text=text,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def table_widget(self) -> TableWidget:
|
|
44
|
+
res = {}
|
|
45
|
+
|
|
46
|
+
columns = [" ", "Avg. IoU"]
|
|
47
|
+
res["content"] = []
|
|
48
|
+
for i, eval_result in enumerate(self.eval_results, 1):
|
|
49
|
+
value = round(eval_result.mp.base_metrics()["iou"], 2)
|
|
50
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
51
|
+
row = [model_name, value]
|
|
52
|
+
dct = {
|
|
53
|
+
"row": row,
|
|
54
|
+
"id": model_name,
|
|
55
|
+
"items": row,
|
|
56
|
+
}
|
|
57
|
+
res["content"].append(dct)
|
|
58
|
+
|
|
59
|
+
columns_options = [{"disableSort": True}, {"disableSort": True}]
|
|
60
|
+
|
|
61
|
+
res["columns"] = columns
|
|
62
|
+
res["columnsOptions"] = columns_options
|
|
63
|
+
|
|
64
|
+
return TableWidget(
|
|
65
|
+
name="localization_accuracy_table", data=res, show_header_controls=False, fix_columns=1
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def chart(self) -> ChartWidget:
|
|
70
|
+
return ChartWidget(name="chart_iou_distribution", figure=self.get_figure())
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def collapse_tip(self) -> CollapseWidget:
|
|
74
|
+
inner_md = MarkdownWidget(
|
|
75
|
+
name="markdown_iou_calculation",
|
|
76
|
+
title="How IoU is calculated?",
|
|
77
|
+
text=self.vis_texts.markdown_iou_calculation,
|
|
78
|
+
)
|
|
79
|
+
return CollapseWidget(widgets=[inner_md])
|
|
80
|
+
|
|
81
|
+
def get_figure(self):
|
|
82
|
+
import plotly.graph_objects as go # pylint: disable=import-error
|
|
83
|
+
from scipy.stats import gaussian_kde # pylint: disable=import-error
|
|
84
|
+
|
|
85
|
+
fig = go.Figure()
|
|
86
|
+
nbins = 40
|
|
87
|
+
# min_value = min([r.mp.ious[0] for r in self.eval_results])
|
|
88
|
+
x_range = np.linspace(0.5, 1, 500)
|
|
89
|
+
hist_data = [np.histogram(r.mp.ious, bins=nbins) for r in self.eval_results]
|
|
90
|
+
bin_width = min([bin_edges[1] - bin_edges[0] for _, bin_edges in hist_data])
|
|
91
|
+
|
|
92
|
+
for i, (eval_result, (hist, bin_edges)) in enumerate(zip(self.eval_results, hist_data)):
|
|
93
|
+
name = f"[{i+1}] {eval_result.name}"
|
|
94
|
+
kde = gaussian_kde(eval_result.mp.ious)
|
|
95
|
+
density = kde(x_range)
|
|
96
|
+
|
|
97
|
+
scaling_factor = len(eval_result.mp.ious) * bin_width
|
|
98
|
+
scaled_density = density * scaling_factor
|
|
99
|
+
|
|
100
|
+
fig.add_trace(
|
|
101
|
+
go.Bar(
|
|
102
|
+
x=bin_edges[:-1],
|
|
103
|
+
y=hist,
|
|
104
|
+
width=bin_width,
|
|
105
|
+
name=f"{name} (Bars)",
|
|
106
|
+
offset=0,
|
|
107
|
+
opacity=0.2,
|
|
108
|
+
hovertemplate=name + "<br>IoU: %{x:.2f}<br>Count: %{y}<extra></extra>",
|
|
109
|
+
marker=dict(color=eval_result.color, line=dict(width=0)),
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
fig.add_trace(
|
|
114
|
+
go.Scatter(
|
|
115
|
+
x=x_range,
|
|
116
|
+
y=scaled_density,
|
|
117
|
+
name=f"{name} (KDE)",
|
|
118
|
+
line=dict(color=eval_result.color, width=2),
|
|
119
|
+
hovertemplate=name + "<br>IoU: %{x:.2f}<br>Count: %{y:.1f}<extra></extra>",
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
fig.update_layout(
|
|
124
|
+
xaxis_title="IoU",
|
|
125
|
+
yaxis_title="Count",
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Add annotation for mean IoU as vertical line
|
|
129
|
+
for i, eval_result in enumerate(self.eval_results):
|
|
130
|
+
mean_iou = eval_result.mp.ious.mean()
|
|
131
|
+
y1 = len(eval_result.mp.ious) // nbins
|
|
132
|
+
fig.add_shape(
|
|
133
|
+
type="line",
|
|
134
|
+
x0=mean_iou,
|
|
135
|
+
x1=mean_iou,
|
|
136
|
+
y0=0,
|
|
137
|
+
y1=y1,
|
|
138
|
+
line=dict(color="orange", width=2, dash="dash"),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
fig.update_layout(
|
|
142
|
+
barmode="overlay",
|
|
143
|
+
bargap=0,
|
|
144
|
+
bargroupgap=0,
|
|
145
|
+
dragmode=False,
|
|
146
|
+
yaxis=dict(rangemode="tozero"),
|
|
147
|
+
xaxis=dict(range=[0.5, 1]),
|
|
148
|
+
modebar=dict(
|
|
149
|
+
remove=[
|
|
150
|
+
"zoom2d",
|
|
151
|
+
"pan2d",
|
|
152
|
+
"select2d",
|
|
153
|
+
"lasso2d",
|
|
154
|
+
"zoomIn2d",
|
|
155
|
+
"zoomOut2d",
|
|
156
|
+
"autoScale2d",
|
|
157
|
+
"resetScale2d",
|
|
158
|
+
]
|
|
159
|
+
),
|
|
160
|
+
)
|
|
161
|
+
return fig
|