supervisely 6.73.254__py3-none-any.whl → 6.73.255__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/file_api.py +16 -5
- supervisely/api/task_api.py +4 -2
- supervisely/app/widgets/field/field.py +10 -7
- supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
- supervisely/nn/benchmark/base_benchmark.py +33 -35
- supervisely/nn/benchmark/base_evaluator.py +27 -1
- supervisely/nn/benchmark/base_visualizer.py +8 -11
- supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
- supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
- supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
- supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
- supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
- supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
- supervisely/nn/benchmark/visualization/renderer.py +25 -10
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
- supervisely/nn/inference/inference.py +1 -0
- supervisely/nn/training/gui/gui.py +32 -10
- supervisely/nn/training/gui/training_artifacts.py +145 -0
- supervisely/nn/training/gui/training_process.py +3 -19
- supervisely/nn/training/train_app.py +179 -70
- {supervisely-6.73.254.dist-info → supervisely-6.73.255.dist-info}/METADATA +1 -1
- {supervisely-6.73.254.dist-info → supervisely-6.73.255.dist-info}/RECORD +58 -46
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
- {supervisely-6.73.254.dist-info → supervisely-6.73.255.dist-info}/LICENSE +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.255.dist-info}/WHEEL +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.255.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.255.dist-info}/top_level.txt +0 -0
supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
|
|
3
|
-
from supervisely.nn.benchmark.
|
|
4
|
-
BaseVisMetric,
|
|
5
|
-
)
|
|
3
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
6
4
|
from supervisely.nn.benchmark.cv_tasks import CVTask
|
|
7
5
|
from supervisely.nn.benchmark.visualization.widgets import (
|
|
8
6
|
ChartWidget,
|
|
@@ -12,7 +10,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
|
|
|
12
10
|
)
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
class LocalizationAccuracyIoU(
|
|
13
|
+
class LocalizationAccuracyIoU(BaseVisMetrics):
|
|
16
14
|
@property
|
|
17
15
|
def header_md(self) -> MarkdownWidget:
|
|
18
16
|
title = "Localization Accuracy (IoU)"
|
|
@@ -90,7 +88,7 @@ class LocalizationAccuracyIoU(BaseVisMetric):
|
|
|
90
88
|
bin_width = min([bin_edges[1] - bin_edges[0] for _, bin_edges in hist_data])
|
|
91
89
|
|
|
92
90
|
for i, (eval_result, (hist, bin_edges)) in enumerate(zip(self.eval_results, hist_data)):
|
|
93
|
-
name = f"[{i+1}] {eval_result.
|
|
91
|
+
name = f"[{i+1}] {eval_result.name}"
|
|
94
92
|
kde = gaussian_kde(eval_result.mp.ious)
|
|
95
93
|
density = kde(x_range)
|
|
96
94
|
|
|
@@ -3,14 +3,12 @@ from typing import List
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
|
-
from supervisely.nn.benchmark.
|
|
7
|
-
BaseVisMetric,
|
|
8
|
-
)
|
|
6
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
9
7
|
from supervisely.nn.benchmark.visualization.widgets import ChartWidget
|
|
10
8
|
from supervisely.nn.task_type import TaskType
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
class OutcomeCounts(
|
|
11
|
+
class OutcomeCounts(BaseVisMetrics):
|
|
14
12
|
CHART_MAIN = "chart_outcome_counts"
|
|
15
13
|
CHART_COMPARISON = "chart_outcome_counts_comparison"
|
|
16
14
|
|
|
@@ -97,7 +95,7 @@ class OutcomeCounts(BaseVisMetric):
|
|
|
97
95
|
tp_counts = [eval_result.mp.TP_count for eval_result in self.eval_results][::-1]
|
|
98
96
|
fn_counts = [eval_result.mp.FN_count for eval_result in self.eval_results][::-1]
|
|
99
97
|
fp_counts = [eval_result.mp.FP_count for eval_result in self.eval_results][::-1]
|
|
100
|
-
model_names = [f"[{i}] {e.
|
|
98
|
+
model_names = [f"[{i}] {e.short_name}" for i, e in enumerate(self.eval_results, 1)][::-1]
|
|
101
99
|
counts = [tp_counts, fn_counts, fp_counts]
|
|
102
100
|
names = ["TP", "FN", "FP"]
|
|
103
101
|
colors = ["#8ACAA1", "#dd3f3f", "#F7ADAA"]
|
|
@@ -123,7 +121,7 @@ class OutcomeCounts(BaseVisMetric):
|
|
|
123
121
|
fig = go.Figure()
|
|
124
122
|
|
|
125
123
|
colors = ["#8ACAA1", "#dd3f3f", "#F7ADAA"]
|
|
126
|
-
model_names = [f"[{i}] {e.
|
|
124
|
+
model_names = [f"[{i}] {e.short_name}" for i, e in enumerate(self.eval_results, 1)][::-1]
|
|
127
125
|
model_names.append("Common")
|
|
128
126
|
|
|
129
127
|
diff_tps, common_tps = self.common_and_diff_tp
|
|
@@ -263,7 +261,7 @@ class OutcomeCounts(BaseVisMetric):
|
|
|
263
261
|
res["layoutTemplate"] = [None, None, None]
|
|
264
262
|
res["clickData"] = {}
|
|
265
263
|
for i, eval_result in enumerate(self.eval_results, 1):
|
|
266
|
-
model_name = f"[{i}] {eval_result.
|
|
264
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
267
265
|
for outcome, matches_data in eval_result.click_data.outcome_counts.items():
|
|
268
266
|
key = f"{model_name}_{outcome}"
|
|
269
267
|
outcome_dict = res["clickData"].setdefault(key, {})
|
|
@@ -278,7 +276,7 @@ class OutcomeCounts(BaseVisMetric):
|
|
|
278
276
|
title = f"{model_name}. {outcome}: {len(obj_ids)} object{'s' if len(obj_ids) > 1 else ''}"
|
|
279
277
|
outcome_dict["title"] = title
|
|
280
278
|
outcome_dict["imagesIds"] = list(img_ids)
|
|
281
|
-
thr = eval_result.f1_optimal_conf
|
|
279
|
+
thr = eval_result.mp.f1_optimal_conf
|
|
282
280
|
if outcome == "FN":
|
|
283
281
|
outcome_dict["filters"] = [
|
|
284
282
|
{"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
|
|
@@ -327,7 +325,7 @@ class OutcomeCounts(BaseVisMetric):
|
|
|
327
325
|
_update_outcome_dict("Common", outcome, outcome_dict, common_ids)
|
|
328
326
|
|
|
329
327
|
for i, diff_ids in enumerate(diff_ids, 1):
|
|
330
|
-
name = f"[{i}] {self.eval_results[i - 1].
|
|
328
|
+
name = f"[{i}] {self.eval_results[i - 1].name}"
|
|
331
329
|
key = f"{name}_{outcome}"
|
|
332
330
|
outcome_dict = res["clickData"].setdefault(key, {})
|
|
333
331
|
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
|
|
3
3
|
from supervisely._utils import abs_url
|
|
4
|
-
from supervisely.nn.benchmark.
|
|
5
|
-
BaseVisMetric,
|
|
6
|
-
)
|
|
4
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
7
5
|
from supervisely.nn.benchmark.visualization.evaluation_result import EvalResult
|
|
8
6
|
from supervisely.nn.benchmark.visualization.widgets import (
|
|
9
7
|
ChartWidget,
|
|
@@ -12,7 +10,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
|
|
|
12
10
|
)
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
class Overview(
|
|
13
|
+
class Overview(BaseVisMetrics):
|
|
16
14
|
|
|
17
15
|
MARKDOWN_OVERVIEW = "markdown_overview"
|
|
18
16
|
MARKDOWN_OVERVIEW_INFO = "markdown_overview_info"
|
|
@@ -62,7 +60,7 @@ class Overview(BaseVisMetric):
|
|
|
62
60
|
|
|
63
61
|
@property
|
|
64
62
|
def overview_widgets(self) -> List[MarkdownWidget]:
|
|
65
|
-
|
|
63
|
+
all_formats = []
|
|
66
64
|
for eval_result in self.eval_results:
|
|
67
65
|
|
|
68
66
|
url = eval_result.inference_info.get("checkpoint_url")
|
|
@@ -71,14 +69,10 @@ class Overview(BaseVisMetric):
|
|
|
71
69
|
link_text = url
|
|
72
70
|
link_text = link_text.replace("_", "\_")
|
|
73
71
|
|
|
74
|
-
checkpoint_name = eval_result.
|
|
75
|
-
|
|
76
|
-
)
|
|
77
|
-
model_name = eval_result.inference_info.get("model_name") or "Custom"
|
|
72
|
+
checkpoint_name = eval_result.checkpoint_name
|
|
73
|
+
model_name = eval_result.name or "Custom"
|
|
78
74
|
|
|
79
|
-
report = eval_result.api.file.get_info_by_path(
|
|
80
|
-
eval_result.team_id, eval_result.report_path
|
|
81
|
-
)
|
|
75
|
+
report = eval_result.api.file.get_info_by_path(self.team_id, eval_result.report_path)
|
|
82
76
|
report_link = abs_url(f"/model-benchmark?id={report.id}")
|
|
83
77
|
|
|
84
78
|
formats = [
|
|
@@ -91,11 +85,11 @@ class Overview(BaseVisMetric):
|
|
|
91
85
|
link_text,
|
|
92
86
|
report_link,
|
|
93
87
|
]
|
|
94
|
-
|
|
88
|
+
all_formats.append(formats)
|
|
95
89
|
|
|
96
90
|
text_template: str = getattr(self.vis_texts, self.MARKDOWN_OVERVIEW_INFO)
|
|
97
91
|
widgets = []
|
|
98
|
-
for formats in
|
|
92
|
+
for formats in all_formats:
|
|
99
93
|
md = MarkdownWidget(
|
|
100
94
|
name=self.MARKDOWN_OVERVIEW_INFO,
|
|
101
95
|
title="Overview",
|
|
@@ -204,7 +198,7 @@ class Overview(BaseVisMetric):
|
|
|
204
198
|
# Overall Metrics
|
|
205
199
|
fig = go.Figure()
|
|
206
200
|
for i, eval_result in enumerate(self.eval_results):
|
|
207
|
-
name = f"[{i + 1}] {eval_result.
|
|
201
|
+
name = f"[{i + 1}] {eval_result.name}"
|
|
208
202
|
base_metrics = eval_result.mp.base_metrics()
|
|
209
203
|
r = list(base_metrics.values())
|
|
210
204
|
theta = [eval_result.mp.metric_names[k] for k in base_metrics.keys()]
|
|
@@ -227,13 +221,8 @@ class Overview(BaseVisMetric):
|
|
|
227
221
|
angularaxis=dict(rotation=90, direction="clockwise"),
|
|
228
222
|
),
|
|
229
223
|
dragmode=False,
|
|
230
|
-
|
|
231
|
-
# width=700,
|
|
232
|
-
# height=500,
|
|
233
|
-
# autosize=False,
|
|
224
|
+
height=500,
|
|
234
225
|
margin=dict(l=25, r=25, t=25, b=25),
|
|
235
|
-
)
|
|
236
|
-
fig.update_layout(
|
|
237
226
|
modebar=dict(
|
|
238
227
|
remove=[
|
|
239
228
|
"zoom2d",
|
|
@@ -245,6 +234,6 @@ class Overview(BaseVisMetric):
|
|
|
245
234
|
"autoScale2d",
|
|
246
235
|
"resetScale2d",
|
|
247
236
|
]
|
|
248
|
-
)
|
|
237
|
+
),
|
|
249
238
|
)
|
|
250
239
|
return fig
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
|
|
3
3
|
from supervisely.imaging.color import hex2rgb
|
|
4
|
-
from supervisely.nn.benchmark.
|
|
5
|
-
BaseVisMetric,
|
|
6
|
-
)
|
|
4
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
7
5
|
from supervisely.nn.benchmark.visualization.widgets import (
|
|
8
6
|
ChartWidget,
|
|
9
7
|
CollapseWidget,
|
|
@@ -13,7 +11,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
|
|
|
13
11
|
)
|
|
14
12
|
|
|
15
13
|
|
|
16
|
-
class PrCurve(
|
|
14
|
+
class PrCurve(BaseVisMetrics):
|
|
17
15
|
MARKDOWN_PR_CURVE = "markdown_pr_curve"
|
|
18
16
|
MARKDOWN_PR_TRADE_OFFS = "markdown_trade_offs"
|
|
19
17
|
MARKDOWN_WHAT_IS_PR_CURVE = "markdown_what_is_pr_curve"
|
|
@@ -98,7 +96,7 @@ class PrCurve(BaseVisMetric):
|
|
|
98
96
|
pr_curve[pr_curve == -1] = np.nan
|
|
99
97
|
pr_curve = np.nanmean(pr_curve, axis=-1)
|
|
100
98
|
|
|
101
|
-
name = f"[{i}] {eval_result.
|
|
99
|
+
name = f"[{i}] {eval_result.name}"
|
|
102
100
|
color = ",".join(map(str, hex2rgb(eval_result.color))) + ",0.1"
|
|
103
101
|
line = go.Scatter(
|
|
104
102
|
x=eval_result.mp.recThrs,
|
supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
from supervisely.nn.benchmark.
|
|
2
|
-
BaseVisMetric,
|
|
3
|
-
)
|
|
1
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
4
2
|
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
3
|
ChartWidget,
|
|
6
4
|
CollapseWidget,
|
|
@@ -10,7 +8,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
|
|
|
10
8
|
)
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
class PrecisionRecallF1(
|
|
11
|
+
class PrecisionRecallF1(BaseVisMetrics):
|
|
14
12
|
MARKDOWN = "markdown_PRF1"
|
|
15
13
|
MARKDOWN_PRECISION_TITLE = "markdown_precision_per_class_title"
|
|
16
14
|
MARKDOWN_RECALL_TITLE = "markdown_recall_per_class_title"
|
|
@@ -136,14 +134,14 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
136
134
|
precision = eval_result.mp.json_metrics()["precision"]
|
|
137
135
|
recall = eval_result.mp.json_metrics()["recall"]
|
|
138
136
|
f1 = eval_result.mp.json_metrics()["f1"]
|
|
139
|
-
model_name = f"[{i}] {eval_result.
|
|
137
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
140
138
|
fig.add_trace(
|
|
141
139
|
go.Bar(
|
|
142
140
|
x=["Precision", "Recall", "F1-score"],
|
|
143
141
|
y=[precision, recall, f1],
|
|
144
142
|
name=model_name,
|
|
145
143
|
width=0.2 if classes_cnt >= 5 else None,
|
|
146
|
-
marker=dict(color=eval_result.color),
|
|
144
|
+
marker=dict(color=eval_result.color, line=dict(width=0.7)),
|
|
147
145
|
)
|
|
148
146
|
)
|
|
149
147
|
|
|
@@ -152,7 +150,7 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
152
150
|
xaxis_title="Metric",
|
|
153
151
|
yaxis_title="Value",
|
|
154
152
|
yaxis=dict(range=[0, 1.1]),
|
|
155
|
-
width=700
|
|
153
|
+
width=700,
|
|
156
154
|
)
|
|
157
155
|
|
|
158
156
|
return fig
|
|
@@ -163,7 +161,7 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
163
161
|
fig = go.Figure()
|
|
164
162
|
classes_cnt = len(self.eval_results[0].mp.cat_names)
|
|
165
163
|
for i, eval_result in enumerate(self.eval_results, 1):
|
|
166
|
-
model_name = f"[{i}] {eval_result.
|
|
164
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
167
165
|
sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
168
166
|
|
|
169
167
|
fig.add_trace(
|
|
@@ -171,8 +169,8 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
171
169
|
y=sorted_by_f1["recall"],
|
|
172
170
|
x=sorted_by_f1["category"],
|
|
173
171
|
name=f"{model_name} Recall",
|
|
174
|
-
width=0.2 if classes_cnt
|
|
175
|
-
marker=dict(color=eval_result.color),
|
|
172
|
+
width=0.2 if classes_cnt >= 5 else None,
|
|
173
|
+
marker=dict(color=eval_result.color, line=dict(width=0.7)),
|
|
176
174
|
)
|
|
177
175
|
)
|
|
178
176
|
|
|
@@ -191,7 +189,7 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
191
189
|
res["layoutTemplate"] = [None, None, None]
|
|
192
190
|
res["clickData"] = {}
|
|
193
191
|
for i, eval_result in enumerate(self.eval_results):
|
|
194
|
-
model_name = f"Model [{i + 1}] {eval_result.
|
|
192
|
+
model_name = f"Model [{i + 1}] {eval_result.name}"
|
|
195
193
|
for key, v in eval_result.click_data.objects_by_class.items():
|
|
196
194
|
click_data = res["clickData"].setdefault(f"{i}_{key}", {})
|
|
197
195
|
img_ids, obj_ids = set(), set()
|
|
@@ -207,7 +205,7 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
207
205
|
{
|
|
208
206
|
"type": "tag",
|
|
209
207
|
"tagId": "confidence",
|
|
210
|
-
"value": [eval_result.f1_optimal_conf, 1],
|
|
208
|
+
"value": [eval_result.mp.f1_optimal_conf, 1],
|
|
211
209
|
},
|
|
212
210
|
{"type": "tag", "tagId": "outcome", "value": "TP"},
|
|
213
211
|
{"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
|
|
@@ -220,7 +218,7 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
220
218
|
fig = go.Figure()
|
|
221
219
|
classes_cnt = len(self.eval_results[0].mp.cat_names)
|
|
222
220
|
for i, eval_result in enumerate(self.eval_results, 1):
|
|
223
|
-
model_name = f"[{i}] {eval_result.
|
|
221
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
224
222
|
sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
225
223
|
|
|
226
224
|
fig.add_trace(
|
|
@@ -228,8 +226,8 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
228
226
|
y=sorted_by_f1["precision"],
|
|
229
227
|
x=sorted_by_f1["category"],
|
|
230
228
|
name=f"{model_name} Precision",
|
|
231
|
-
width=0.2 if classes_cnt
|
|
232
|
-
marker=dict(color=eval_result.color),
|
|
229
|
+
width=0.2 if classes_cnt >= 5 else None,
|
|
230
|
+
marker=dict(color=eval_result.color, line=dict(width=0.7)),
|
|
233
231
|
)
|
|
234
232
|
)
|
|
235
233
|
|
|
@@ -249,7 +247,7 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
249
247
|
fig = go.Figure()
|
|
250
248
|
classes_cnt = len(self.eval_results[0].mp.cat_names)
|
|
251
249
|
for i, eval_result in enumerate(self.eval_results, 1):
|
|
252
|
-
model_name = f"[{i}] {eval_result.
|
|
250
|
+
model_name = f"[{i}] {eval_result.name}"
|
|
253
251
|
sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
|
|
254
252
|
|
|
255
253
|
fig.add_trace(
|
|
@@ -257,8 +255,8 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
257
255
|
y=sorted_by_f1["f1"],
|
|
258
256
|
x=sorted_by_f1["category"],
|
|
259
257
|
name=f"{model_name} F1-score",
|
|
260
|
-
width=0.2 if classes_cnt
|
|
261
|
-
marker=dict(color=eval_result.color),
|
|
258
|
+
width=0.2 if classes_cnt >= 5 else None,
|
|
259
|
+
marker=dict(color=eval_result.color, line=dict(width=0.7)),
|
|
262
260
|
)
|
|
263
261
|
)
|
|
264
262
|
|
|
@@ -278,7 +276,7 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
278
276
|
res["clickData"] = {}
|
|
279
277
|
|
|
280
278
|
for i, eval_result in enumerate(self.eval_results):
|
|
281
|
-
model_name = f"Model [{i + 1}] {eval_result.
|
|
279
|
+
model_name = f"Model [{i + 1}] {eval_result.name}"
|
|
282
280
|
click_data = res["clickData"].setdefault(i, {})
|
|
283
281
|
img_ids, obj_ids = set(), set()
|
|
284
282
|
objects_cnt = 0
|
|
@@ -292,7 +290,11 @@ class PrecisionRecallF1(BaseVisMetric):
|
|
|
292
290
|
click_data["title"] = f"{model_name}, {objects_cnt} objects"
|
|
293
291
|
click_data["imagesIds"] = list(img_ids)
|
|
294
292
|
click_data["filters"] = [
|
|
295
|
-
{
|
|
293
|
+
{
|
|
294
|
+
"type": "tag",
|
|
295
|
+
"tagId": "confidence",
|
|
296
|
+
"value": [eval_result.mp.f1_optimal_conf, 1],
|
|
297
|
+
},
|
|
296
298
|
{"type": "tag", "tagId": "outcome", "value": "TP"},
|
|
297
299
|
{"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
|
|
298
300
|
]
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
from typing import List, Union
|
|
2
2
|
|
|
3
3
|
from supervisely.imaging.color import hex2rgb
|
|
4
|
-
from supervisely.nn.benchmark.
|
|
5
|
-
BaseVisMetric,
|
|
6
|
-
)
|
|
4
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
7
5
|
from supervisely.nn.benchmark.visualization.widgets import (
|
|
8
6
|
ChartWidget,
|
|
9
7
|
MarkdownWidget,
|
|
@@ -11,11 +9,19 @@ from supervisely.nn.benchmark.visualization.widgets import (
|
|
|
11
9
|
)
|
|
12
10
|
|
|
13
11
|
|
|
14
|
-
class Speedtest(
|
|
12
|
+
class Speedtest(BaseVisMetrics):
|
|
15
13
|
|
|
16
14
|
def is_empty(self) -> bool:
|
|
17
15
|
return not any(eval_result.speedtest_info for eval_result in self.eval_results)
|
|
18
16
|
|
|
17
|
+
def multiple_batche_sizes(self) -> bool:
|
|
18
|
+
for eval_result in self.eval_results:
|
|
19
|
+
if eval_result.speedtest_info is None:
|
|
20
|
+
continue
|
|
21
|
+
if len(eval_result.speedtest_info["speedtest"]) > 1:
|
|
22
|
+
return True
|
|
23
|
+
return False
|
|
24
|
+
|
|
19
25
|
@property
|
|
20
26
|
def latency(self) -> List[Union[int, str]]:
|
|
21
27
|
latency = []
|
|
@@ -272,7 +278,7 @@ class Speedtest(BaseVisMetric):
|
|
|
272
278
|
go.Scatter(
|
|
273
279
|
x=list(temp_res["ms"].keys()),
|
|
274
280
|
y=list(temp_res["ms"].values()),
|
|
275
|
-
name=f"[{idx}] {eval_result.
|
|
281
|
+
name=f"[{idx}] {eval_result.name} (ms)",
|
|
276
282
|
line=dict(color=eval_result.color),
|
|
277
283
|
customdata=list(temp_res["ms_std"].values()),
|
|
278
284
|
error_y=dict(
|
|
@@ -290,7 +296,7 @@ class Speedtest(BaseVisMetric):
|
|
|
290
296
|
go.Scatter(
|
|
291
297
|
x=list(temp_res["fps"].keys()),
|
|
292
298
|
y=list(temp_res["fps"].values()),
|
|
293
|
-
name=f"[{idx}] {eval_result.
|
|
299
|
+
name=f"[{idx}] {eval_result.name} (fps)",
|
|
294
300
|
line=dict(color=eval_result.color),
|
|
295
301
|
hovertemplate="Batch Size: %{x}<br>FPS: %{y:.2f}<extra></extra>", # <br> Standard deviation: %{customdata:.2f}<extra></extra>",
|
|
296
302
|
),
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import
|
|
2
|
-
from pathlib import Path
|
|
1
|
+
from typing import List
|
|
3
2
|
|
|
4
|
-
import supervisely.nn.benchmark.comparison.detection_visualization.text_templates as
|
|
3
|
+
import supervisely.nn.benchmark.comparison.detection_visualization.text_templates as texts
|
|
4
|
+
from supervisely.nn.benchmark.comparison.base_visualizer import BaseComparisonVisualizer
|
|
5
5
|
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics import (
|
|
6
6
|
AveragePrecisionByClass,
|
|
7
7
|
CalibrationScore,
|
|
@@ -13,7 +13,9 @@ from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics imp
|
|
|
13
13
|
PrecisionRecallF1,
|
|
14
14
|
Speedtest,
|
|
15
15
|
)
|
|
16
|
-
from supervisely.nn.benchmark.
|
|
16
|
+
from supervisely.nn.benchmark.object_detection.evaluator import (
|
|
17
|
+
ObjectDetectionEvalResult,
|
|
18
|
+
)
|
|
17
19
|
from supervisely.nn.benchmark.visualization.widgets import (
|
|
18
20
|
ContainerWidget,
|
|
19
21
|
GalleryWidget,
|
|
@@ -22,22 +24,13 @@ from supervisely.nn.benchmark.visualization.widgets import (
|
|
|
22
24
|
)
|
|
23
25
|
|
|
24
26
|
|
|
25
|
-
class DetectionComparisonVisualizer:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
self.api = comparison.api
|
|
29
|
-
self.vis_texts = vis_texts
|
|
30
|
-
|
|
31
|
-
self._create_widgets()
|
|
32
|
-
layout = self._create_layout()
|
|
33
|
-
|
|
34
|
-
self.renderer = Renderer(layout, str(Path(self.comparison.workdir, "visualizations")))
|
|
35
|
-
|
|
36
|
-
def visualize(self):
|
|
37
|
-
return self.renderer.visualize()
|
|
27
|
+
class DetectionComparisonVisualizer(BaseComparisonVisualizer):
|
|
28
|
+
vis_texts = texts
|
|
29
|
+
ann_opacity = 0.5
|
|
38
30
|
|
|
39
|
-
def
|
|
40
|
-
|
|
31
|
+
def __init__(self, *args, **kwargs):
|
|
32
|
+
super().__init__(*args, **kwargs)
|
|
33
|
+
self.eval_results: List[ObjectDetectionEvalResult]
|
|
41
34
|
|
|
42
35
|
def _create_widgets(self):
|
|
43
36
|
# Modal Gellery
|
|
@@ -48,10 +41,11 @@ class DetectionComparisonVisualizer:
|
|
|
48
41
|
self.clickable_label = self._create_clickable_label()
|
|
49
42
|
|
|
50
43
|
# Speedtest init here for overview
|
|
51
|
-
speedtest = Speedtest(self.vis_texts, self.comparison.
|
|
44
|
+
speedtest = Speedtest(self.vis_texts, self.comparison.eval_results)
|
|
52
45
|
|
|
53
46
|
# Overview
|
|
54
|
-
overview = Overview(self.vis_texts, self.comparison.
|
|
47
|
+
overview = Overview(self.vis_texts, self.comparison.eval_results)
|
|
48
|
+
overview.team_id = self.comparison.team_id
|
|
55
49
|
self.header = self._create_header()
|
|
56
50
|
self.overviews = self._create_overviews(overview)
|
|
57
51
|
self.overview_md = overview.overview_md
|
|
@@ -61,11 +55,11 @@ class DetectionComparisonVisualizer:
|
|
|
61
55
|
)
|
|
62
56
|
self.overview_chart = overview.chart_widget
|
|
63
57
|
|
|
64
|
-
columns_number = len(self.comparison.
|
|
58
|
+
columns_number = len(self.comparison.eval_results) + 1 # +1 for GT
|
|
65
59
|
self.explore_predictions_modal_gallery = self._create_explore_modal_table(columns_number)
|
|
66
60
|
explore_predictions = ExplorePredictions(
|
|
67
61
|
self.vis_texts,
|
|
68
|
-
self.comparison.
|
|
62
|
+
self.comparison.eval_results,
|
|
69
63
|
explore_modal_table=self.explore_predictions_modal_gallery,
|
|
70
64
|
)
|
|
71
65
|
self.explore_predictions_md = explore_predictions.difference_predictions_md
|
|
@@ -74,7 +68,7 @@ class DetectionComparisonVisualizer:
|
|
|
74
68
|
# Outcome Counts
|
|
75
69
|
outcome_counts = OutcomeCounts(
|
|
76
70
|
self.vis_texts,
|
|
77
|
-
self.comparison.
|
|
71
|
+
self.comparison.eval_results,
|
|
78
72
|
explore_modal_table=self.explore_modal_table,
|
|
79
73
|
)
|
|
80
74
|
self.outcome_counts_md = self._create_outcome_counts_md()
|
|
@@ -83,7 +77,7 @@ class DetectionComparisonVisualizer:
|
|
|
83
77
|
self.outcome_counts_comparison = outcome_counts.chart_widget_comparison
|
|
84
78
|
|
|
85
79
|
# Precision-Recall Curve
|
|
86
|
-
pr_curve = PrCurve(self.vis_texts, self.comparison.
|
|
80
|
+
pr_curve = PrCurve(self.vis_texts, self.comparison.eval_results)
|
|
87
81
|
self.pr_curve_md = pr_curve.markdown_widget
|
|
88
82
|
self.pr_curve_collapsed_widgets = pr_curve.collapsed_widget
|
|
89
83
|
self.pr_curve_table = pr_curve.table_widget
|
|
@@ -92,7 +86,7 @@ class DetectionComparisonVisualizer:
|
|
|
92
86
|
# Average Precision by Class
|
|
93
87
|
avg_prec_by_class = AveragePrecisionByClass(
|
|
94
88
|
self.vis_texts,
|
|
95
|
-
self.comparison.
|
|
89
|
+
self.comparison.eval_results,
|
|
96
90
|
explore_modal_table=self.explore_modal_table,
|
|
97
91
|
)
|
|
98
92
|
self.avg_prec_by_class_md = avg_prec_by_class.markdown_widget
|
|
@@ -101,7 +95,7 @@ class DetectionComparisonVisualizer:
|
|
|
101
95
|
# Precision, Recall, F1
|
|
102
96
|
precision_recall_f1 = PrecisionRecallF1(
|
|
103
97
|
self.vis_texts,
|
|
104
|
-
self.comparison.
|
|
98
|
+
self.comparison.eval_results,
|
|
105
99
|
explore_modal_table=self.explore_modal_table,
|
|
106
100
|
)
|
|
107
101
|
self.precision_recall_f1_md = precision_recall_f1.markdown_widget
|
|
@@ -118,14 +112,14 @@ class DetectionComparisonVisualizer:
|
|
|
118
112
|
# TODO: ???
|
|
119
113
|
|
|
120
114
|
# Localization Accuracy (IoU)
|
|
121
|
-
loc_acc = LocalizationAccuracyIoU(self.vis_texts, self.comparison.
|
|
115
|
+
loc_acc = LocalizationAccuracyIoU(self.vis_texts, self.comparison.eval_results)
|
|
122
116
|
self.loc_acc_header_md = loc_acc.header_md
|
|
123
117
|
self.loc_acc_iou_distribution_md = loc_acc.iou_distribution_md
|
|
124
118
|
self.loc_acc_chart = loc_acc.chart
|
|
125
119
|
self.loc_acc_table = loc_acc.table_widget
|
|
126
120
|
|
|
127
121
|
# Calibration Score
|
|
128
|
-
cal_score = CalibrationScore(self.vis_texts, self.comparison.
|
|
122
|
+
cal_score = CalibrationScore(self.vis_texts, self.comparison.eval_results)
|
|
129
123
|
self.cal_score_md = cal_score.header_md
|
|
130
124
|
self.cal_score_md_2 = cal_score.header_md_2
|
|
131
125
|
self.cal_score_collapse_tip = cal_score.collapse_tip
|
|
@@ -140,6 +134,7 @@ class DetectionComparisonVisualizer:
|
|
|
140
134
|
|
|
141
135
|
# SpeedTest
|
|
142
136
|
self.speedtest_present = False
|
|
137
|
+
self.speedtest_multiple_batch_sizes = False
|
|
143
138
|
if not speedtest.is_empty():
|
|
144
139
|
self.speedtest_present = True
|
|
145
140
|
self.speedtest_md_intro = speedtest.md_intro
|
|
@@ -148,8 +143,10 @@ class DetectionComparisonVisualizer:
|
|
|
148
143
|
self.speed_inference_time_table = speedtest.inference_time_table
|
|
149
144
|
self.speed_fps_md = speedtest.fps_md
|
|
150
145
|
self.speed_fps_table = speedtest.fps_table
|
|
151
|
-
self.
|
|
152
|
-
self.
|
|
146
|
+
self.speedtest_multiple_batch_sizes = speedtest.multiple_batche_sizes()
|
|
147
|
+
if self.speedtest_multiple_batch_sizes:
|
|
148
|
+
self.speed_batch_inference_md = speedtest.batch_inference_md
|
|
149
|
+
self.speed_chart = speedtest.chart
|
|
153
150
|
|
|
154
151
|
def _create_layout(self):
|
|
155
152
|
is_anchors_widgets = [
|
|
@@ -216,10 +213,11 @@ class DetectionComparisonVisualizer:
|
|
|
216
213
|
(0, self.speed_inference_time_table),
|
|
217
214
|
(0, self.speed_fps_md),
|
|
218
215
|
(0, self.speed_fps_table),
|
|
219
|
-
(0, self.speed_batch_inference_md),
|
|
220
|
-
(0, self.speed_chart),
|
|
221
216
|
]
|
|
222
217
|
)
|
|
218
|
+
if self.speedtest_multiple_batch_sizes:
|
|
219
|
+
is_anchors_widgets.append((0, self.speed_batch_inference_md))
|
|
220
|
+
is_anchors_widgets.append((0, self.speed_chart))
|
|
223
221
|
anchors = []
|
|
224
222
|
for is_anchor, widget in is_anchors_widgets:
|
|
225
223
|
if is_anchor:
|
|
@@ -232,30 +230,6 @@ class DetectionComparisonVisualizer:
|
|
|
232
230
|
)
|
|
233
231
|
return layout
|
|
234
232
|
|
|
235
|
-
def _create_header(self) -> MarkdownWidget:
|
|
236
|
-
me = self.api.user.get_my_info().login
|
|
237
|
-
current_date = datetime.datetime.now().strftime("%d %B %Y, %H:%M")
|
|
238
|
-
header_main_text = " ∣ ".join( # vs. or | or ∣
|
|
239
|
-
eval_res.name for eval_res in self.comparison.evaluation_results
|
|
240
|
-
)
|
|
241
|
-
header_text = self.vis_texts.markdown_header.format(header_main_text, me, current_date)
|
|
242
|
-
header = MarkdownWidget("markdown_header", "Header", text=header_text)
|
|
243
|
-
return header
|
|
244
|
-
|
|
245
|
-
def _create_overviews(self, vm: Overview) -> ContainerWidget:
|
|
246
|
-
grid_cols = 2
|
|
247
|
-
if len(vm.overview_widgets) > 2:
|
|
248
|
-
grid_cols = 3
|
|
249
|
-
if len(vm.overview_widgets) % 4 == 0:
|
|
250
|
-
grid_cols = 4
|
|
251
|
-
return ContainerWidget(
|
|
252
|
-
vm.overview_widgets,
|
|
253
|
-
name="overview_container",
|
|
254
|
-
title="Overview",
|
|
255
|
-
grid=True,
|
|
256
|
-
grid_cols=grid_cols,
|
|
257
|
-
)
|
|
258
|
-
|
|
259
233
|
def _create_key_metrics(self) -> MarkdownWidget:
|
|
260
234
|
key_metrics_text = self.vis_texts.markdown_key_metrics.format(
|
|
261
235
|
self.vis_texts.definitions.average_precision,
|
|
@@ -277,22 +251,3 @@ class DetectionComparisonVisualizer:
|
|
|
277
251
|
return MarkdownWidget(
|
|
278
252
|
"markdown_outcome_counts_diff", "Outcome Counts Differences", text=outcome_counts_text
|
|
279
253
|
)
|
|
280
|
-
|
|
281
|
-
def _create_explore_modal_table(self, columns_number=3):
|
|
282
|
-
# TODO: table for each evaluation?
|
|
283
|
-
all_predictions_modal_gallery = GalleryWidget(
|
|
284
|
-
"all_predictions_modal_gallery", is_modal=True, columns_number=columns_number
|
|
285
|
-
)
|
|
286
|
-
all_predictions_modal_gallery.set_project_meta(
|
|
287
|
-
self.comparison.evaluation_results[0].dt_project_meta
|
|
288
|
-
)
|
|
289
|
-
return all_predictions_modal_gallery
|
|
290
|
-
|
|
291
|
-
def _create_diff_modal_table(self, columns_number=3) -> GalleryWidget:
|
|
292
|
-
diff_modal_gallery = GalleryWidget(
|
|
293
|
-
"diff_predictions_modal_gallery", is_modal=True, columns_number=columns_number
|
|
294
|
-
)
|
|
295
|
-
return diff_modal_gallery
|
|
296
|
-
|
|
297
|
-
def _create_clickable_label(self):
|
|
298
|
-
return MarkdownWidget("clickable_label", "", text=self.vis_texts.clickable_label)
|