supervisely 6.73.253__py3-none-any.whl → 6.73.255__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (60) hide show
  1. supervisely/api/file_api.py +16 -5
  2. supervisely/api/task_api.py +4 -2
  3. supervisely/app/widgets/field/field.py +10 -7
  4. supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
  5. supervisely/convert/image/sly/sly_image_converter.py +1 -1
  6. supervisely/nn/benchmark/base_benchmark.py +33 -35
  7. supervisely/nn/benchmark/base_evaluator.py +27 -1
  8. supervisely/nn/benchmark/base_visualizer.py +8 -11
  9. supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
  10. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
  11. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
  12. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
  13. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
  14. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
  15. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
  16. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
  17. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
  18. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
  19. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
  20. supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
  21. supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
  22. supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
  23. supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
  24. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
  25. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
  26. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
  27. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
  28. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
  29. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
  30. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
  31. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
  32. supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
  33. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
  34. supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
  35. supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
  36. supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
  37. supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
  38. supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
  39. supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
  40. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
  42. supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
  43. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
  44. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
  45. supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
  46. supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
  47. supervisely/nn/benchmark/visualization/renderer.py +25 -10
  48. supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
  49. supervisely/nn/inference/inference.py +1 -0
  50. supervisely/nn/training/gui/gui.py +32 -10
  51. supervisely/nn/training/gui/training_artifacts.py +145 -0
  52. supervisely/nn/training/gui/training_process.py +3 -19
  53. supervisely/nn/training/train_app.py +179 -70
  54. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/METADATA +1 -1
  55. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/RECORD +59 -47
  56. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
  57. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/LICENSE +0 -0
  58. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/WHEEL +0 -0
  59. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/entry_points.txt +0 -0
  60. {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,6 @@
1
1
  import numpy as np
2
2
 
3
- from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
4
- BaseVisMetric,
5
- )
3
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
6
4
  from supervisely.nn.benchmark.cv_tasks import CVTask
7
5
  from supervisely.nn.benchmark.visualization.widgets import (
8
6
  ChartWidget,
@@ -12,7 +10,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
12
10
  )
13
11
 
14
12
 
15
- class LocalizationAccuracyIoU(BaseVisMetric):
13
+ class LocalizationAccuracyIoU(BaseVisMetrics):
16
14
  @property
17
15
  def header_md(self) -> MarkdownWidget:
18
16
  title = "Localization Accuracy (IoU)"
@@ -90,7 +88,7 @@ class LocalizationAccuracyIoU(BaseVisMetric):
90
88
  bin_width = min([bin_edges[1] - bin_edges[0] for _, bin_edges in hist_data])
91
89
 
92
90
  for i, (eval_result, (hist, bin_edges)) in enumerate(zip(self.eval_results, hist_data)):
93
- name = f"[{i+1}] {eval_result.model_name}"
91
+ name = f"[{i+1}] {eval_result.name}"
94
92
  kde = gaussian_kde(eval_result.mp.ious)
95
93
  density = kde(x_range)
96
94
 
@@ -3,14 +3,12 @@ from typing import List
3
3
 
4
4
  import numpy as np
5
5
 
6
- from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
7
- BaseVisMetric,
8
- )
6
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
9
7
  from supervisely.nn.benchmark.visualization.widgets import ChartWidget
10
8
  from supervisely.nn.task_type import TaskType
11
9
 
12
10
 
13
- class OutcomeCounts(BaseVisMetric):
11
+ class OutcomeCounts(BaseVisMetrics):
14
12
  CHART_MAIN = "chart_outcome_counts"
15
13
  CHART_COMPARISON = "chart_outcome_counts_comparison"
16
14
 
@@ -97,7 +95,7 @@ class OutcomeCounts(BaseVisMetric):
97
95
  tp_counts = [eval_result.mp.TP_count for eval_result in self.eval_results][::-1]
98
96
  fn_counts = [eval_result.mp.FN_count for eval_result in self.eval_results][::-1]
99
97
  fp_counts = [eval_result.mp.FP_count for eval_result in self.eval_results][::-1]
100
- model_names = [f"[{i}] {e.model_name}" for i, e in enumerate(self.eval_results, 1)][::-1]
98
+ model_names = [f"[{i}] {e.short_name}" for i, e in enumerate(self.eval_results, 1)][::-1]
101
99
  counts = [tp_counts, fn_counts, fp_counts]
102
100
  names = ["TP", "FN", "FP"]
103
101
  colors = ["#8ACAA1", "#dd3f3f", "#F7ADAA"]
@@ -123,7 +121,7 @@ class OutcomeCounts(BaseVisMetric):
123
121
  fig = go.Figure()
124
122
 
125
123
  colors = ["#8ACAA1", "#dd3f3f", "#F7ADAA"]
126
- model_names = [f"[{i}] {e.model_name}" for i, e in enumerate(self.eval_results, 1)][::-1]
124
+ model_names = [f"[{i}] {e.short_name}" for i, e in enumerate(self.eval_results, 1)][::-1]
127
125
  model_names.append("Common")
128
126
 
129
127
  diff_tps, common_tps = self.common_and_diff_tp
@@ -263,7 +261,7 @@ class OutcomeCounts(BaseVisMetric):
263
261
  res["layoutTemplate"] = [None, None, None]
264
262
  res["clickData"] = {}
265
263
  for i, eval_result in enumerate(self.eval_results, 1):
266
- model_name = f"[{i}] {eval_result.model_name}"
264
+ model_name = f"[{i}] {eval_result.name}"
267
265
  for outcome, matches_data in eval_result.click_data.outcome_counts.items():
268
266
  key = f"{model_name}_{outcome}"
269
267
  outcome_dict = res["clickData"].setdefault(key, {})
@@ -278,7 +276,7 @@ class OutcomeCounts(BaseVisMetric):
278
276
  title = f"{model_name}. {outcome}: {len(obj_ids)} object{'s' if len(obj_ids) > 1 else ''}"
279
277
  outcome_dict["title"] = title
280
278
  outcome_dict["imagesIds"] = list(img_ids)
281
- thr = eval_result.f1_optimal_conf
279
+ thr = eval_result.mp.f1_optimal_conf
282
280
  if outcome == "FN":
283
281
  outcome_dict["filters"] = [
284
282
  {"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
@@ -327,7 +325,7 @@ class OutcomeCounts(BaseVisMetric):
327
325
  _update_outcome_dict("Common", outcome, outcome_dict, common_ids)
328
326
 
329
327
  for i, diff_ids in enumerate(diff_ids, 1):
330
- name = f"[{i}] {self.eval_results[i - 1].model_name}"
328
+ name = f"[{i}] {self.eval_results[i - 1].name}"
331
329
  key = f"{name}_{outcome}"
332
330
  outcome_dict = res["clickData"].setdefault(key, {})
333
331
 
@@ -1,9 +1,7 @@
1
1
  from typing import List
2
2
 
3
3
  from supervisely._utils import abs_url
4
- from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
5
- BaseVisMetric,
6
- )
4
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
7
5
  from supervisely.nn.benchmark.visualization.evaluation_result import EvalResult
8
6
  from supervisely.nn.benchmark.visualization.widgets import (
9
7
  ChartWidget,
@@ -12,7 +10,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
12
10
  )
13
11
 
14
12
 
15
- class Overview(BaseVisMetric):
13
+ class Overview(BaseVisMetrics):
16
14
 
17
15
  MARKDOWN_OVERVIEW = "markdown_overview"
18
16
  MARKDOWN_OVERVIEW_INFO = "markdown_overview_info"
@@ -62,7 +60,7 @@ class Overview(BaseVisMetric):
62
60
 
63
61
  @property
64
62
  def overview_widgets(self) -> List[MarkdownWidget]:
65
- self.formats = []
63
+ all_formats = []
66
64
  for eval_result in self.eval_results:
67
65
 
68
66
  url = eval_result.inference_info.get("checkpoint_url")
@@ -71,14 +69,10 @@ class Overview(BaseVisMetric):
71
69
  link_text = url
72
70
  link_text = link_text.replace("_", "\_")
73
71
 
74
- checkpoint_name = eval_result.inference_info.get("deploy_params", {}).get(
75
- "checkpoint_name", ""
76
- )
77
- model_name = eval_result.inference_info.get("model_name") or "Custom"
72
+ checkpoint_name = eval_result.checkpoint_name
73
+ model_name = eval_result.name or "Custom"
78
74
 
79
- report = eval_result.api.file.get_info_by_path(
80
- eval_result.team_id, eval_result.report_path
81
- )
75
+ report = eval_result.api.file.get_info_by_path(self.team_id, eval_result.report_path)
82
76
  report_link = abs_url(f"/model-benchmark?id={report.id}")
83
77
 
84
78
  formats = [
@@ -91,11 +85,11 @@ class Overview(BaseVisMetric):
91
85
  link_text,
92
86
  report_link,
93
87
  ]
94
- self.formats.append(formats)
88
+ all_formats.append(formats)
95
89
 
96
90
  text_template: str = getattr(self.vis_texts, self.MARKDOWN_OVERVIEW_INFO)
97
91
  widgets = []
98
- for formats in self.formats:
92
+ for formats in all_formats:
99
93
  md = MarkdownWidget(
100
94
  name=self.MARKDOWN_OVERVIEW_INFO,
101
95
  title="Overview",
@@ -204,7 +198,7 @@ class Overview(BaseVisMetric):
204
198
  # Overall Metrics
205
199
  fig = go.Figure()
206
200
  for i, eval_result in enumerate(self.eval_results):
207
- name = f"[{i + 1}] {eval_result.model_name}"
201
+ name = f"[{i + 1}] {eval_result.name}"
208
202
  base_metrics = eval_result.mp.base_metrics()
209
203
  r = list(base_metrics.values())
210
204
  theta = [eval_result.mp.metric_names[k] for k in base_metrics.keys()]
@@ -227,13 +221,8 @@ class Overview(BaseVisMetric):
227
221
  angularaxis=dict(rotation=90, direction="clockwise"),
228
222
  ),
229
223
  dragmode=False,
230
- # title="Overall Metrics",
231
- # width=700,
232
- # height=500,
233
- # autosize=False,
224
+ height=500,
234
225
  margin=dict(l=25, r=25, t=25, b=25),
235
- )
236
- fig.update_layout(
237
226
  modebar=dict(
238
227
  remove=[
239
228
  "zoom2d",
@@ -245,6 +234,6 @@ class Overview(BaseVisMetric):
245
234
  "autoScale2d",
246
235
  "resetScale2d",
247
236
  ]
248
- )
237
+ ),
249
238
  )
250
239
  return fig
@@ -1,9 +1,7 @@
1
1
  import numpy as np
2
2
 
3
3
  from supervisely.imaging.color import hex2rgb
4
- from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
5
- BaseVisMetric,
6
- )
4
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
7
5
  from supervisely.nn.benchmark.visualization.widgets import (
8
6
  ChartWidget,
9
7
  CollapseWidget,
@@ -13,7 +11,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
13
11
  )
14
12
 
15
13
 
16
- class PrCurve(BaseVisMetric):
14
+ class PrCurve(BaseVisMetrics):
17
15
  MARKDOWN_PR_CURVE = "markdown_pr_curve"
18
16
  MARKDOWN_PR_TRADE_OFFS = "markdown_trade_offs"
19
17
  MARKDOWN_WHAT_IS_PR_CURVE = "markdown_what_is_pr_curve"
@@ -98,7 +96,7 @@ class PrCurve(BaseVisMetric):
98
96
  pr_curve[pr_curve == -1] = np.nan
99
97
  pr_curve = np.nanmean(pr_curve, axis=-1)
100
98
 
101
- name = f"[{i}] {eval_result.model_name}"
99
+ name = f"[{i}] {eval_result.name}"
102
100
  color = ",".join(map(str, hex2rgb(eval_result.color))) + ",0.1"
103
101
  line = go.Scatter(
104
102
  x=eval_result.mp.recThrs,
@@ -1,6 +1,4 @@
1
- from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
2
- BaseVisMetric,
3
- )
1
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
4
2
  from supervisely.nn.benchmark.visualization.widgets import (
5
3
  ChartWidget,
6
4
  CollapseWidget,
@@ -10,7 +8,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
10
8
  )
11
9
 
12
10
 
13
- class PrecisionRecallF1(BaseVisMetric):
11
+ class PrecisionRecallF1(BaseVisMetrics):
14
12
  MARKDOWN = "markdown_PRF1"
15
13
  MARKDOWN_PRECISION_TITLE = "markdown_precision_per_class_title"
16
14
  MARKDOWN_RECALL_TITLE = "markdown_recall_per_class_title"
@@ -136,14 +134,14 @@ class PrecisionRecallF1(BaseVisMetric):
136
134
  precision = eval_result.mp.json_metrics()["precision"]
137
135
  recall = eval_result.mp.json_metrics()["recall"]
138
136
  f1 = eval_result.mp.json_metrics()["f1"]
139
- model_name = f"[{i}] {eval_result.model_name}"
137
+ model_name = f"[{i}] {eval_result.name}"
140
138
  fig.add_trace(
141
139
  go.Bar(
142
140
  x=["Precision", "Recall", "F1-score"],
143
141
  y=[precision, recall, f1],
144
142
  name=model_name,
145
143
  width=0.2 if classes_cnt >= 5 else None,
146
- marker=dict(color=eval_result.color),
144
+ marker=dict(color=eval_result.color, line=dict(width=0.7)),
147
145
  )
148
146
  )
149
147
 
@@ -152,7 +150,7 @@ class PrecisionRecallF1(BaseVisMetric):
152
150
  xaxis_title="Metric",
153
151
  yaxis_title="Value",
154
152
  yaxis=dict(range=[0, 1.1]),
155
- width=700 if classes_cnt < 5 else None,
153
+ width=700,
156
154
  )
157
155
 
158
156
  return fig
@@ -163,7 +161,7 @@ class PrecisionRecallF1(BaseVisMetric):
163
161
  fig = go.Figure()
164
162
  classes_cnt = len(self.eval_results[0].mp.cat_names)
165
163
  for i, eval_result in enumerate(self.eval_results, 1):
166
- model_name = f"[{i}] {eval_result.model_name}"
164
+ model_name = f"[{i}] {eval_result.name}"
167
165
  sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
168
166
 
169
167
  fig.add_trace(
@@ -171,8 +169,8 @@ class PrecisionRecallF1(BaseVisMetric):
171
169
  y=sorted_by_f1["recall"],
172
170
  x=sorted_by_f1["category"],
173
171
  name=f"{model_name} Recall",
174
- width=0.2 if classes_cnt < 5 else None,
175
- marker=dict(color=eval_result.color),
172
+ width=0.2 if classes_cnt >= 5 else None,
173
+ marker=dict(color=eval_result.color, line=dict(width=0.7)),
176
174
  )
177
175
  )
178
176
 
@@ -191,7 +189,7 @@ class PrecisionRecallF1(BaseVisMetric):
191
189
  res["layoutTemplate"] = [None, None, None]
192
190
  res["clickData"] = {}
193
191
  for i, eval_result in enumerate(self.eval_results):
194
- model_name = f"Model [{i + 1}] {eval_result.model_name}"
192
+ model_name = f"Model [{i + 1}] {eval_result.name}"
195
193
  for key, v in eval_result.click_data.objects_by_class.items():
196
194
  click_data = res["clickData"].setdefault(f"{i}_{key}", {})
197
195
  img_ids, obj_ids = set(), set()
@@ -207,7 +205,7 @@ class PrecisionRecallF1(BaseVisMetric):
207
205
  {
208
206
  "type": "tag",
209
207
  "tagId": "confidence",
210
- "value": [eval_result.f1_optimal_conf, 1],
208
+ "value": [eval_result.mp.f1_optimal_conf, 1],
211
209
  },
212
210
  {"type": "tag", "tagId": "outcome", "value": "TP"},
213
211
  {"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
@@ -220,7 +218,7 @@ class PrecisionRecallF1(BaseVisMetric):
220
218
  fig = go.Figure()
221
219
  classes_cnt = len(self.eval_results[0].mp.cat_names)
222
220
  for i, eval_result in enumerate(self.eval_results, 1):
223
- model_name = f"[{i}] {eval_result.model_name}"
221
+ model_name = f"[{i}] {eval_result.name}"
224
222
  sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
225
223
 
226
224
  fig.add_trace(
@@ -228,8 +226,8 @@ class PrecisionRecallF1(BaseVisMetric):
228
226
  y=sorted_by_f1["precision"],
229
227
  x=sorted_by_f1["category"],
230
228
  name=f"{model_name} Precision",
231
- width=0.2 if classes_cnt < 5 else None,
232
- marker=dict(color=eval_result.color),
229
+ width=0.2 if classes_cnt >= 5 else None,
230
+ marker=dict(color=eval_result.color, line=dict(width=0.7)),
233
231
  )
234
232
  )
235
233
 
@@ -249,7 +247,7 @@ class PrecisionRecallF1(BaseVisMetric):
249
247
  fig = go.Figure()
250
248
  classes_cnt = len(self.eval_results[0].mp.cat_names)
251
249
  for i, eval_result in enumerate(self.eval_results, 1):
252
- model_name = f"[{i}] {eval_result.model_name}"
250
+ model_name = f"[{i}] {eval_result.name}"
253
251
  sorted_by_f1 = eval_result.mp.per_class_metrics().sort_values(by="f1")
254
252
 
255
253
  fig.add_trace(
@@ -257,8 +255,8 @@ class PrecisionRecallF1(BaseVisMetric):
257
255
  y=sorted_by_f1["f1"],
258
256
  x=sorted_by_f1["category"],
259
257
  name=f"{model_name} F1-score",
260
- width=0.2 if classes_cnt < 5 else None,
261
- marker=dict(color=eval_result.color),
258
+ width=0.2 if classes_cnt >= 5 else None,
259
+ marker=dict(color=eval_result.color, line=dict(width=0.7)),
262
260
  )
263
261
  )
264
262
 
@@ -278,7 +276,7 @@ class PrecisionRecallF1(BaseVisMetric):
278
276
  res["clickData"] = {}
279
277
 
280
278
  for i, eval_result in enumerate(self.eval_results):
281
- model_name = f"Model [{i + 1}] {eval_result.model_name}"
279
+ model_name = f"Model [{i + 1}] {eval_result.name}"
282
280
  click_data = res["clickData"].setdefault(i, {})
283
281
  img_ids, obj_ids = set(), set()
284
282
  objects_cnt = 0
@@ -292,7 +290,11 @@ class PrecisionRecallF1(BaseVisMetric):
292
290
  click_data["title"] = f"{model_name}, {objects_cnt} objects"
293
291
  click_data["imagesIds"] = list(img_ids)
294
292
  click_data["filters"] = [
295
- {"type": "tag", "tagId": "confidence", "value": [eval_result.f1_optimal_conf, 1]},
293
+ {
294
+ "type": "tag",
295
+ "tagId": "confidence",
296
+ "value": [eval_result.mp.f1_optimal_conf, 1],
297
+ },
296
298
  {"type": "tag", "tagId": "outcome", "value": "TP"},
297
299
  {"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
298
300
  ]
@@ -1,9 +1,7 @@
1
1
  from typing import List, Union
2
2
 
3
3
  from supervisely.imaging.color import hex2rgb
4
- from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.vis_metric import (
5
- BaseVisMetric,
6
- )
4
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
7
5
  from supervisely.nn.benchmark.visualization.widgets import (
8
6
  ChartWidget,
9
7
  MarkdownWidget,
@@ -11,11 +9,19 @@ from supervisely.nn.benchmark.visualization.widgets import (
11
9
  )
12
10
 
13
11
 
14
- class Speedtest(BaseVisMetric):
12
+ class Speedtest(BaseVisMetrics):
15
13
 
16
14
  def is_empty(self) -> bool:
17
15
  return not any(eval_result.speedtest_info for eval_result in self.eval_results)
18
16
 
17
+ def multiple_batche_sizes(self) -> bool:
18
+ for eval_result in self.eval_results:
19
+ if eval_result.speedtest_info is None:
20
+ continue
21
+ if len(eval_result.speedtest_info["speedtest"]) > 1:
22
+ return True
23
+ return False
24
+
19
25
  @property
20
26
  def latency(self) -> List[Union[int, str]]:
21
27
  latency = []
@@ -272,7 +278,7 @@ class Speedtest(BaseVisMetric):
272
278
  go.Scatter(
273
279
  x=list(temp_res["ms"].keys()),
274
280
  y=list(temp_res["ms"].values()),
275
- name=f"[{idx}] {eval_result.model_name} (ms)",
281
+ name=f"[{idx}] {eval_result.name} (ms)",
276
282
  line=dict(color=eval_result.color),
277
283
  customdata=list(temp_res["ms_std"].values()),
278
284
  error_y=dict(
@@ -290,7 +296,7 @@ class Speedtest(BaseVisMetric):
290
296
  go.Scatter(
291
297
  x=list(temp_res["fps"].keys()),
292
298
  y=list(temp_res["fps"].values()),
293
- name=f"[{idx}] {eval_result.model_name} (fps)",
299
+ name=f"[{idx}] {eval_result.name} (fps)",
294
300
  line=dict(color=eval_result.color),
295
301
  hovertemplate="Batch Size: %{x}<br>FPS: %{y:.2f}<extra></extra>", # <br> Standard deviation: %{customdata:.2f}<extra></extra>",
296
302
  ),
@@ -1,7 +1,7 @@
1
- import datetime
2
- from pathlib import Path
1
+ from typing import List
3
2
 
4
- import supervisely.nn.benchmark.comparison.detection_visualization.text_templates as vis_texts
3
+ import supervisely.nn.benchmark.comparison.detection_visualization.text_templates as texts
4
+ from supervisely.nn.benchmark.comparison.base_visualizer import BaseComparisonVisualizer
5
5
  from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics import (
6
6
  AveragePrecisionByClass,
7
7
  CalibrationScore,
@@ -13,7 +13,9 @@ from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics imp
13
13
  PrecisionRecallF1,
14
14
  Speedtest,
15
15
  )
16
- from supervisely.nn.benchmark.visualization.renderer import Renderer
16
+ from supervisely.nn.benchmark.object_detection.evaluator import (
17
+ ObjectDetectionEvalResult,
18
+ )
17
19
  from supervisely.nn.benchmark.visualization.widgets import (
18
20
  ContainerWidget,
19
21
  GalleryWidget,
@@ -22,22 +24,13 @@ from supervisely.nn.benchmark.visualization.widgets import (
22
24
  )
23
25
 
24
26
 
25
- class DetectionComparisonVisualizer:
26
- def __init__(self, comparison):
27
- self.comparison = comparison
28
- self.api = comparison.api
29
- self.vis_texts = vis_texts
30
-
31
- self._create_widgets()
32
- layout = self._create_layout()
33
-
34
- self.renderer = Renderer(layout, str(Path(self.comparison.workdir, "visualizations")))
35
-
36
- def visualize(self):
37
- return self.renderer.visualize()
27
+ class DetectionComparisonVisualizer(BaseComparisonVisualizer):
28
+ vis_texts = texts
29
+ ann_opacity = 0.5
38
30
 
39
- def upload_results(self, team_id: int, remote_dir: str, progress=None):
40
- return self.renderer.upload_results(self.api, team_id, remote_dir, progress)
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.eval_results: List[ObjectDetectionEvalResult]
41
34
 
42
35
  def _create_widgets(self):
43
36
  # Modal Gellery
@@ -48,10 +41,11 @@ class DetectionComparisonVisualizer:
48
41
  self.clickable_label = self._create_clickable_label()
49
42
 
50
43
  # Speedtest init here for overview
51
- speedtest = Speedtest(self.vis_texts, self.comparison.evaluation_results)
44
+ speedtest = Speedtest(self.vis_texts, self.comparison.eval_results)
52
45
 
53
46
  # Overview
54
- overview = Overview(self.vis_texts, self.comparison.evaluation_results)
47
+ overview = Overview(self.vis_texts, self.comparison.eval_results)
48
+ overview.team_id = self.comparison.team_id
55
49
  self.header = self._create_header()
56
50
  self.overviews = self._create_overviews(overview)
57
51
  self.overview_md = overview.overview_md
@@ -61,11 +55,11 @@ class DetectionComparisonVisualizer:
61
55
  )
62
56
  self.overview_chart = overview.chart_widget
63
57
 
64
- columns_number = len(self.comparison.evaluation_results) + 1 # +1 for GT
58
+ columns_number = len(self.comparison.eval_results) + 1 # +1 for GT
65
59
  self.explore_predictions_modal_gallery = self._create_explore_modal_table(columns_number)
66
60
  explore_predictions = ExplorePredictions(
67
61
  self.vis_texts,
68
- self.comparison.evaluation_results,
62
+ self.comparison.eval_results,
69
63
  explore_modal_table=self.explore_predictions_modal_gallery,
70
64
  )
71
65
  self.explore_predictions_md = explore_predictions.difference_predictions_md
@@ -74,7 +68,7 @@ class DetectionComparisonVisualizer:
74
68
  # Outcome Counts
75
69
  outcome_counts = OutcomeCounts(
76
70
  self.vis_texts,
77
- self.comparison.evaluation_results,
71
+ self.comparison.eval_results,
78
72
  explore_modal_table=self.explore_modal_table,
79
73
  )
80
74
  self.outcome_counts_md = self._create_outcome_counts_md()
@@ -83,7 +77,7 @@ class DetectionComparisonVisualizer:
83
77
  self.outcome_counts_comparison = outcome_counts.chart_widget_comparison
84
78
 
85
79
  # Precision-Recall Curve
86
- pr_curve = PrCurve(self.vis_texts, self.comparison.evaluation_results)
80
+ pr_curve = PrCurve(self.vis_texts, self.comparison.eval_results)
87
81
  self.pr_curve_md = pr_curve.markdown_widget
88
82
  self.pr_curve_collapsed_widgets = pr_curve.collapsed_widget
89
83
  self.pr_curve_table = pr_curve.table_widget
@@ -92,7 +86,7 @@ class DetectionComparisonVisualizer:
92
86
  # Average Precision by Class
93
87
  avg_prec_by_class = AveragePrecisionByClass(
94
88
  self.vis_texts,
95
- self.comparison.evaluation_results,
89
+ self.comparison.eval_results,
96
90
  explore_modal_table=self.explore_modal_table,
97
91
  )
98
92
  self.avg_prec_by_class_md = avg_prec_by_class.markdown_widget
@@ -101,7 +95,7 @@ class DetectionComparisonVisualizer:
101
95
  # Precision, Recall, F1
102
96
  precision_recall_f1 = PrecisionRecallF1(
103
97
  self.vis_texts,
104
- self.comparison.evaluation_results,
98
+ self.comparison.eval_results,
105
99
  explore_modal_table=self.explore_modal_table,
106
100
  )
107
101
  self.precision_recall_f1_md = precision_recall_f1.markdown_widget
@@ -118,14 +112,14 @@ class DetectionComparisonVisualizer:
118
112
  # TODO: ???
119
113
 
120
114
  # Localization Accuracy (IoU)
121
- loc_acc = LocalizationAccuracyIoU(self.vis_texts, self.comparison.evaluation_results)
115
+ loc_acc = LocalizationAccuracyIoU(self.vis_texts, self.comparison.eval_results)
122
116
  self.loc_acc_header_md = loc_acc.header_md
123
117
  self.loc_acc_iou_distribution_md = loc_acc.iou_distribution_md
124
118
  self.loc_acc_chart = loc_acc.chart
125
119
  self.loc_acc_table = loc_acc.table_widget
126
120
 
127
121
  # Calibration Score
128
- cal_score = CalibrationScore(self.vis_texts, self.comparison.evaluation_results)
122
+ cal_score = CalibrationScore(self.vis_texts, self.comparison.eval_results)
129
123
  self.cal_score_md = cal_score.header_md
130
124
  self.cal_score_md_2 = cal_score.header_md_2
131
125
  self.cal_score_collapse_tip = cal_score.collapse_tip
@@ -140,6 +134,7 @@ class DetectionComparisonVisualizer:
140
134
 
141
135
  # SpeedTest
142
136
  self.speedtest_present = False
137
+ self.speedtest_multiple_batch_sizes = False
143
138
  if not speedtest.is_empty():
144
139
  self.speedtest_present = True
145
140
  self.speedtest_md_intro = speedtest.md_intro
@@ -148,8 +143,10 @@ class DetectionComparisonVisualizer:
148
143
  self.speed_inference_time_table = speedtest.inference_time_table
149
144
  self.speed_fps_md = speedtest.fps_md
150
145
  self.speed_fps_table = speedtest.fps_table
151
- self.speed_batch_inference_md = speedtest.batch_inference_md
152
- self.speed_chart = speedtest.chart
146
+ self.speedtest_multiple_batch_sizes = speedtest.multiple_batche_sizes()
147
+ if self.speedtest_multiple_batch_sizes:
148
+ self.speed_batch_inference_md = speedtest.batch_inference_md
149
+ self.speed_chart = speedtest.chart
153
150
 
154
151
  def _create_layout(self):
155
152
  is_anchors_widgets = [
@@ -216,10 +213,11 @@ class DetectionComparisonVisualizer:
216
213
  (0, self.speed_inference_time_table),
217
214
  (0, self.speed_fps_md),
218
215
  (0, self.speed_fps_table),
219
- (0, self.speed_batch_inference_md),
220
- (0, self.speed_chart),
221
216
  ]
222
217
  )
218
+ if self.speedtest_multiple_batch_sizes:
219
+ is_anchors_widgets.append((0, self.speed_batch_inference_md))
220
+ is_anchors_widgets.append((0, self.speed_chart))
223
221
  anchors = []
224
222
  for is_anchor, widget in is_anchors_widgets:
225
223
  if is_anchor:
@@ -232,30 +230,6 @@ class DetectionComparisonVisualizer:
232
230
  )
233
231
  return layout
234
232
 
235
- def _create_header(self) -> MarkdownWidget:
236
- me = self.api.user.get_my_info().login
237
- current_date = datetime.datetime.now().strftime("%d %B %Y, %H:%M")
238
- header_main_text = " ∣ ".join( # vs. or | or ∣
239
- eval_res.name for eval_res in self.comparison.evaluation_results
240
- )
241
- header_text = self.vis_texts.markdown_header.format(header_main_text, me, current_date)
242
- header = MarkdownWidget("markdown_header", "Header", text=header_text)
243
- return header
244
-
245
- def _create_overviews(self, vm: Overview) -> ContainerWidget:
246
- grid_cols = 2
247
- if len(vm.overview_widgets) > 2:
248
- grid_cols = 3
249
- if len(vm.overview_widgets) % 4 == 0:
250
- grid_cols = 4
251
- return ContainerWidget(
252
- vm.overview_widgets,
253
- name="overview_container",
254
- title="Overview",
255
- grid=True,
256
- grid_cols=grid_cols,
257
- )
258
-
259
233
  def _create_key_metrics(self) -> MarkdownWidget:
260
234
  key_metrics_text = self.vis_texts.markdown_key_metrics.format(
261
235
  self.vis_texts.definitions.average_precision,
@@ -277,22 +251,3 @@ class DetectionComparisonVisualizer:
277
251
  return MarkdownWidget(
278
252
  "markdown_outcome_counts_diff", "Outcome Counts Differences", text=outcome_counts_text
279
253
  )
280
-
281
- def _create_explore_modal_table(self, columns_number=3):
282
- # TODO: table for each evaluation?
283
- all_predictions_modal_gallery = GalleryWidget(
284
- "all_predictions_modal_gallery", is_modal=True, columns_number=columns_number
285
- )
286
- all_predictions_modal_gallery.set_project_meta(
287
- self.comparison.evaluation_results[0].dt_project_meta
288
- )
289
- return all_predictions_modal_gallery
290
-
291
- def _create_diff_modal_table(self, columns_number=3) -> GalleryWidget:
292
- diff_modal_gallery = GalleryWidget(
293
- "diff_predictions_modal_gallery", is_modal=True, columns_number=columns_number
294
- )
295
- return diff_modal_gallery
296
-
297
- def _create_clickable_label(self):
298
- return MarkdownWidget("clickable_label", "", text=self.vis_texts.clickable_label)