supervisely 6.73.254__py3-none-any.whl → 6.73.256__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (61) hide show
  1. supervisely/api/api.py +16 -8
  2. supervisely/api/file_api.py +16 -5
  3. supervisely/api/task_api.py +4 -2
  4. supervisely/app/widgets/field/field.py +10 -7
  5. supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
  6. supervisely/io/network_exceptions.py +14 -2
  7. supervisely/nn/benchmark/base_benchmark.py +33 -35
  8. supervisely/nn/benchmark/base_evaluator.py +27 -1
  9. supervisely/nn/benchmark/base_visualizer.py +8 -11
  10. supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
  11. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
  12. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
  13. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
  14. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
  15. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
  16. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
  17. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
  18. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
  19. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
  20. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
  21. supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
  22. supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
  23. supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
  24. supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
  25. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
  26. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
  27. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
  28. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
  29. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
  30. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
  31. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
  32. supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
  33. supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
  34. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
  35. supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
  36. supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
  37. supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
  38. supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
  39. supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
  40. supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
  41. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
  42. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
  43. supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
  44. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
  45. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
  46. supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
  47. supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
  48. supervisely/nn/benchmark/visualization/renderer.py +25 -10
  49. supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
  50. supervisely/nn/inference/inference.py +1 -0
  51. supervisely/nn/training/gui/gui.py +32 -10
  52. supervisely/nn/training/gui/training_artifacts.py +145 -0
  53. supervisely/nn/training/gui/training_process.py +3 -19
  54. supervisely/nn/training/train_app.py +179 -70
  55. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/METADATA +1 -1
  56. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/RECORD +60 -48
  57. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
  58. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/LICENSE +0 -0
  59. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/WHEEL +0 -0
  60. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/entry_points.txt +0 -0
  61. {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ from typing import List, Tuple
2
+
3
+ from supervisely.annotation.annotation import Annotation
4
+ from supervisely.api.image_api import ImageInfo
5
+ from supervisely.api.module_api import ApiField
6
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
7
+ from supervisely.nn.benchmark.visualization.widgets import GalleryWidget, MarkdownWidget
8
+ from supervisely.project.project_meta import ProjectMeta
9
+
10
+
11
+ class ExplorePredictions(BaseVisMetrics):
12
+
13
+ MARKDOWN = "markdown_explorer"
14
+ GALLERY_DIFFERENCE = "explore_difference_gallery"
15
+
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+ self.meta = None
19
+
20
+ def _merged_meta(self) -> ProjectMeta:
21
+ if self.meta is not None:
22
+ return self.meta
23
+ self.meta = self.eval_results[0].gt_project_meta
24
+ for eval_res in self.eval_results:
25
+ self.meta = self.meta.merge(eval_res.pred_project_meta)
26
+ return self.meta
27
+
28
+ @property
29
+ def difference_predictions_md(self) -> MarkdownWidget:
30
+ text = self.vis_texts.markdown_explorer
31
+ return MarkdownWidget(self.MARKDOWN, "Explore Predictions", text)
32
+
33
+ @property
34
+ def explore_gallery(self) -> GalleryWidget:
35
+ columns_number = len(self.eval_results) + 1
36
+ images, annotations = self._get_sample_data()
37
+ gallery = GalleryWidget(self.GALLERY_DIFFERENCE, columns_number=columns_number)
38
+ gallery.add_image_left_header("Click to explore more")
39
+ gallery.set_project_meta(self._merged_meta())
40
+ gallery.set_images(images, annotations)
41
+ click_data = self.get_click_data_explore_all()
42
+ gallery.set_click_data(self.explore_modal_table.id, click_data)
43
+ gallery.set_show_all_data(self.explore_modal_table.id, click_data)
44
+ gallery._gallery._update_filters()
45
+
46
+ return gallery
47
+
48
+ def _get_sample_data(self) -> Tuple[List[ImageInfo], List[Annotation], List[ProjectMeta]]:
49
+ images = []
50
+ annotations = []
51
+ api = self.eval_results[0].api
52
+ names = None
53
+ ds_name = None
54
+ for idx, eval_res in enumerate(self.eval_results):
55
+ if idx == 0:
56
+ dataset_info = eval_res.gt_dataset_infos[0]
57
+ infos = api.image.get_list(dataset_info.id, limit=5, force_metadata_for_links=False)
58
+ ds_name = dataset_info.name
59
+ images_ids = [image_info.id for image_info in infos]
60
+ names = [image_info.name for image_info in infos]
61
+ images.append(infos)
62
+ from supervisely.api.api import Api
63
+
64
+ api: Api
65
+ anns = api.annotation.download_batch(
66
+ dataset_info.id, images_ids, force_metadata_for_links=False
67
+ )
68
+ annotations.append(anns)
69
+ assert ds_name is not None, "Failed to get GT dataset name for gallery"
70
+
71
+ dataset_info = api.dataset.get_info_by_name(eval_res.pred_project_id, ds_name)
72
+
73
+ assert names is not None, "Failed to get GT image names for gallery"
74
+ infos = api.image.get_list(
75
+ dataset_info.id,
76
+ filters=[
77
+ {ApiField.FIELD: ApiField.NAME, ApiField.OPERATOR: "in", ApiField.VALUE: names}
78
+ ],
79
+ force_metadata_for_links=False,
80
+ )
81
+ images_ids = [image_info.id for image_info in infos]
82
+ images.append(infos)
83
+ anns = api.annotation.download_batch(
84
+ dataset_info.id, images_ids, force_metadata_for_links=False
85
+ )
86
+ annotations.append(anns)
87
+
88
+ images = list(i for x in zip(*images) for i in x)
89
+ annotations = list(i for x in zip(*annotations) for i in x)
90
+ return images, annotations
91
+
92
+ def get_click_data_explore_all(self) -> dict:
93
+ res = {}
94
+
95
+ res["projectMeta"] = self._merged_meta().to_json()
96
+ res["layoutTemplate"] = [{"columnTitle": "Ground Truth"}]
97
+ for idx, eval_res in enumerate(self.eval_results, 1):
98
+ res["layoutTemplate"].append({"columnTitle": f"[{idx}] {eval_res.short_name}"})
99
+
100
+ click_data = res.setdefault("clickData", {})
101
+ explore = click_data.setdefault("explore", {})
102
+ explore["title"] = "Explore all predictions"
103
+
104
+ image_names = set()
105
+ for eval_res in self.eval_results:
106
+ eval_res.mp.per_image_metrics["img_names"].apply(image_names.add)
107
+
108
+ filters = [{"field": "name", "operator": "in", "value": list(image_names)}]
109
+
110
+ images_ids = []
111
+ api = self.eval_results[0].api
112
+ names = None
113
+ ds_names = None
114
+ for idx, eval_res in enumerate(self.eval_results):
115
+ if idx == 0:
116
+ dataset_infos = eval_res.gt_dataset_infos
117
+ ds_names = [ds.name for ds in dataset_infos]
118
+ current_images_ids = []
119
+ current_images_names = []
120
+ for ds in dataset_infos:
121
+ image_infos = api.image.get_list(ds.id, filters, force_metadata_for_links=False)
122
+ image_infos = sorted(image_infos, key=lambda x: x.name)
123
+ current_images_names.extend([image_info.name for image_info in image_infos])
124
+ current_images_ids.extend([image_info.id for image_info in image_infos])
125
+ images_ids.append(current_images_ids)
126
+ names = current_images_names
127
+
128
+ dataset_infos = api.dataset.get_list(eval_res.pred_project_id)
129
+ dataset_infos = [ds for ds in dataset_infos if ds.name in ds_names]
130
+ dataset_infos = sorted(dataset_infos, key=lambda x: ds_names.index(x.name))
131
+ current_images_infos = []
132
+ for ds in dataset_infos:
133
+ image_infos = api.image.get_list(ds.id, filters, force_metadata_for_links=False)
134
+ image_infos = [image_info for image_info in image_infos if image_info.name in names]
135
+ current_images_infos.extend(image_infos)
136
+ current_images_infos = sorted(current_images_infos, key=lambda x: names.index(x.name))
137
+ images_ids.append([image_info.id for image_info in current_images_infos])
138
+
139
+ explore["imagesIds"] = list(i for x in zip(*images_ids) for i in x)
140
+
141
+ return res
@@ -0,0 +1,71 @@
1
+ from typing import List
2
+
3
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
4
+ from supervisely.nn.benchmark.semantic_segmentation.evaluator import (
5
+ SemanticSegmentationEvalResult,
6
+ )
7
+ from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
8
+
9
+
10
+ class FrequentlyConfused(BaseVisMetrics):
11
+
12
+ def __init__(self, *args, **kwargs) -> None:
13
+ super().__init__(*args, **kwargs)
14
+ self.eval_results: List[SemanticSegmentationEvalResult]
15
+ self.clickable = True
16
+ self._keypair_sep = "-"
17
+
18
+ @property
19
+ def md(self) -> MarkdownWidget:
20
+ if self.is_empty:
21
+ text = self.vis_texts.markdown_frequently_confused_empty
22
+ else:
23
+ text = self.vis_texts.markdown_frequently_confused
24
+ return MarkdownWidget("frequently_confused", "Frequently Confused Classes", text=text)
25
+
26
+ @property
27
+ def chart(self) -> ChartWidget:
28
+ return ChartWidget("frequently_confused", self.get_figure())
29
+
30
+ @property
31
+ def is_empty(self) -> bool:
32
+ return all(len(e.mp.frequently_confused[0]) == 0 for e in self.eval_results)
33
+
34
+ def get_figure(self):
35
+ import numpy as np
36
+ import plotly.graph_objects as go # pylint: disable=import-error
37
+
38
+ fig = go.Figure()
39
+
40
+ classes = self.eval_results[0].classes_whitelist
41
+
42
+ model_cnt = len(self.eval_results)
43
+ all_models_cmat = np.zeros((model_cnt, len(classes), len(classes)))
44
+ for model_idx, eval_result in enumerate(self.eval_results):
45
+ cmat, _ = eval_result.mp.confusion_matrix
46
+ all_models_cmat[model_idx] = cmat[::-1].copy()
47
+
48
+ sum_cmat = all_models_cmat.sum(axis=0)
49
+ np.fill_diagonal(sum_cmat, 0)
50
+ sum_cmat_flat = sum_cmat.flatten()
51
+ sorted_indices = np.argsort(sum_cmat_flat)[::-1]
52
+ n_pairs = min(10, len(classes) * (len(classes) - 1))
53
+ sorted_indices = sorted_indices[:n_pairs]
54
+ rows = sorted_indices // len(classes)
55
+ cols = sorted_indices % len(classes)
56
+ labels = [f"{classes[rows[i]]}-{classes[cols[i]]}" for i in range(n_pairs)]
57
+ for model_idx, eval_result in enumerate(self.eval_results):
58
+ cmat = all_models_cmat[model_idx]
59
+ probs = cmat[rows, cols]
60
+ probs = probs * 100
61
+ fig.add_trace(
62
+ go.Bar(
63
+ name=eval_result.name,
64
+ x=labels,
65
+ y=probs,
66
+ hovertemplate="%{x}: %{y:.2f}%<extra></extra>",
67
+ marker=dict(color=eval_result.color, line=dict(width=0.7)),
68
+ )
69
+ )
70
+
71
+ return fig
@@ -0,0 +1,68 @@
1
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
2
+ from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
3
+
4
+
5
+ class IntersectionErrorOverUnion(BaseVisMetrics):
6
+
7
+ @property
8
+ def md(self) -> MarkdownWidget:
9
+ return MarkdownWidget(
10
+ "intersection_error_over_union",
11
+ "Intersection & Error Over Union",
12
+ text=self.vis_texts.markdown_iou,
13
+ )
14
+
15
+ @property
16
+ def chart(self) -> ChartWidget:
17
+ return ChartWidget("intersection_error_over_union", self.get_figure())
18
+
19
+ def get_figure(self):
20
+ import plotly.graph_objects as go # pylint: disable=import-error
21
+ from plotly.subplots import make_subplots # pylint: disable=import-error
22
+
23
+ labels = ["mIoU", "mBoundaryEoU", "mExtentEoU", "mSegmentEoU"]
24
+
25
+ length = len(self.eval_results)
26
+ cols = 3 if length > 2 else 2
27
+ cols = 4 if length % 4 == 0 else cols
28
+ rows = length // cols + (1 if length % cols != 0 else 0)
29
+
30
+ fig = make_subplots(rows=rows, cols=cols, specs=[[{"type": "domain"}] * cols] * rows)
31
+
32
+ annotations = []
33
+ for idx, eval_result in enumerate(self.eval_results, start=1):
34
+ col = idx % cols + (cols if idx % cols == 0 else 0)
35
+ row = idx // cols + (1 if idx % cols != 0 else 0)
36
+
37
+ fig.add_trace(
38
+ go.Pie(
39
+ labels=labels,
40
+ values=[
41
+ eval_result.mp.iou,
42
+ eval_result.mp.boundary_eou,
43
+ eval_result.mp.extent_eou,
44
+ eval_result.mp.segment_eou,
45
+ ],
46
+ hole=0.5,
47
+ textposition="outside",
48
+ textinfo="percent+label",
49
+ marker=dict(colors=["#8ACAA1", "#FFE4B5", "#F7ADAA", "#dd3f3f"]),
50
+ ),
51
+ row=row,
52
+ col=col,
53
+ )
54
+
55
+ text = f"[{idx}] {eval_result.name[:7]}"
56
+ text += "..." if len(eval_result.name) > 7 else ""
57
+ annotations.append(
58
+ dict(
59
+ text=text,
60
+ x=sum(fig.get_subplot(row, col).x) / 2,
61
+ y=sum(fig.get_subplot(row, col).y) / 2,
62
+ showarrow=False,
63
+ xanchor="center",
64
+ )
65
+ )
66
+ fig.update_layout(annotations=annotations)
67
+
68
+ return fig
@@ -0,0 +1,223 @@
1
+ from typing import List
2
+
3
+ from supervisely._utils import abs_url
4
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
5
+ from supervisely.nn.benchmark.visualization.evaluation_result import EvalResult
6
+ from supervisely.nn.benchmark.visualization.widgets import (
7
+ ChartWidget,
8
+ MarkdownWidget,
9
+ TableWidget,
10
+ )
11
+
12
+
13
+ class Overview(BaseVisMetrics):
14
+
15
+ MARKDOWN_OVERVIEW = "markdown_overview"
16
+ MARKDOWN_OVERVIEW_INFO = "markdown_overview_info"
17
+ MARKDOWN_COMMON_OVERVIEW = "markdown_common_overview"
18
+ CHART = "chart_key_metrics"
19
+
20
+ def __init__(self, vis_texts, eval_results: List[EvalResult]) -> None:
21
+ super().__init__(vis_texts, eval_results)
22
+
23
+ @property
24
+ def overview_md(self) -> List[MarkdownWidget]:
25
+ info = []
26
+ model_names = []
27
+ for eval_result in self.eval_results:
28
+ model_name = eval_result.name or "Custom"
29
+ model_name = model_name.replace("_", "\_")
30
+ model_names.append(model_name)
31
+
32
+ info.append(
33
+ [
34
+ eval_result.gt_project_info.id,
35
+ eval_result.gt_project_info.name,
36
+ eval_result.inference_info.get("task_type"),
37
+ ]
38
+ )
39
+ if all([model_name == "Custom" for model_name in model_names]):
40
+ model_name = "Custom models"
41
+ elif all([model_name == model_names[0] for model_name in model_names]):
42
+ model_name = model_names[0]
43
+ else:
44
+ model_name = " vs. ".join(model_names)
45
+
46
+ info = [model_name] + info[0]
47
+
48
+ text_template: str = getattr(self.vis_texts, self.MARKDOWN_COMMON_OVERVIEW)
49
+ return MarkdownWidget(
50
+ name=self.MARKDOWN_COMMON_OVERVIEW,
51
+ title="Overview",
52
+ text=text_template.format(*info),
53
+ )
54
+
55
+ @property
56
+ def overview_widgets(self) -> List[MarkdownWidget]:
57
+ all_formats = []
58
+ for eval_result in self.eval_results:
59
+
60
+ url = eval_result.inference_info.get("checkpoint_url")
61
+ link_text = eval_result.inference_info.get("custom_checkpoint_path")
62
+ if link_text is None:
63
+ link_text = url
64
+ link_text = link_text.replace("_", "\_")
65
+
66
+ checkpoint_name = eval_result.checkpoint_name
67
+ model_name = eval_result.inference_info.get("model_name") or "Custom"
68
+
69
+ report = eval_result.api.file.get_info_by_path(self.team_id, eval_result.report_path)
70
+ report_link = abs_url(f"/model-benchmark?id={report.id}")
71
+
72
+ formats = [
73
+ checkpoint_name,
74
+ model_name.replace("_", "\_"),
75
+ checkpoint_name.replace("_", "\_"),
76
+ eval_result.inference_info.get("architecture"),
77
+ eval_result.inference_info.get("runtime"),
78
+ url,
79
+ link_text,
80
+ report_link,
81
+ ]
82
+ all_formats.append(formats)
83
+
84
+ text_template: str = getattr(self.vis_texts, self.MARKDOWN_OVERVIEW_INFO)
85
+ widgets = []
86
+ for formats in all_formats:
87
+ md = MarkdownWidget(
88
+ name=self.MARKDOWN_OVERVIEW_INFO,
89
+ title="Overview",
90
+ text=text_template.format(*formats),
91
+ )
92
+ md.is_info_block = True
93
+ widgets.append(md)
94
+ return widgets
95
+
96
+ def get_table_widget(self, latency, fps) -> TableWidget:
97
+ res = {}
98
+
99
+ columns = ["metrics"] + [f"[{i+1}] {r.name}" for i, r in enumerate(self.eval_results)]
100
+
101
+ all_metrics = [eval_result.mp.key_metrics() for eval_result in self.eval_results]
102
+ res["content"] = []
103
+
104
+ for metric in all_metrics[0].keys():
105
+ values = [m[metric] for m in all_metrics]
106
+ values = [v if v is not None else "―" for v in values]
107
+ values = [round(v, 2) if isinstance(v, float) else v for v in values]
108
+ row = [metric] + values
109
+ dct = {"row": row, "id": metric, "items": row}
110
+ res["content"].append(dct)
111
+
112
+ latency_row = ["Latency (ms)"] + latency
113
+ res["content"].append({"row": latency_row, "id": latency_row[0], "items": latency_row})
114
+
115
+ fps_row = ["FPS"] + fps
116
+ res["content"].append({"row": fps_row, "id": fps_row[0], "items": fps_row})
117
+
118
+ columns_options = [{"disableSort": True} for _ in columns]
119
+
120
+ res["columns"] = columns
121
+ res["columnsOptions"] = columns_options
122
+
123
+ return TableWidget(
124
+ name="table_key_metrics",
125
+ data=res,
126
+ show_header_controls=False,
127
+ fix_columns=1,
128
+ page_size=len(res["content"]),
129
+ )
130
+
131
+ @property
132
+ def chart_widget(self) -> ChartWidget:
133
+ return ChartWidget(name=self.CHART, figure=self.get_figure())
134
+
135
+ def get_overview_info(self, eval_result: EvalResult):
136
+ classes_cnt = len(eval_result.classes_whitelist)
137
+ classes_str = "classes" if classes_cnt > 1 else "class"
138
+ classes_str = f"{classes_cnt} {classes_str}"
139
+
140
+ train_session, images_str = "", ""
141
+ gt_project_id = eval_result.gt_project_info.id
142
+ gt_dataset_ids = eval_result.gt_dataset_ids
143
+ gt_images_cnt = eval_result.val_images_cnt
144
+ train_info = eval_result.train_info
145
+ total_imgs_cnt = eval_result.gt_project_info.items_count
146
+ if gt_images_cnt is not None:
147
+ val_imgs_cnt = gt_images_cnt
148
+ elif gt_dataset_ids is not None:
149
+ datasets = eval_result.gt_dataset_infos
150
+ val_imgs_cnt = sum(ds.items_count for ds in datasets)
151
+ else:
152
+ val_imgs_cnt = eval_result.gt_project_info.items_count
153
+
154
+ if train_info:
155
+ train_task_id = train_info.get("app_session_id")
156
+ if train_task_id:
157
+ task_info = eval_result.api.task.get_info_by_id(int(train_task_id))
158
+ app_id = task_info["meta"]["app"]["id"]
159
+ train_session = f'- **Training dashboard**: <a href="/apps/{app_id}/sessions/{train_task_id}" target="_blank">open</a>'
160
+
161
+ train_imgs_cnt = train_info.get("images_count")
162
+ images_str = f", {train_imgs_cnt} images in train, {val_imgs_cnt} images in validation"
163
+
164
+ if gt_images_cnt is not None:
165
+ images_str += (
166
+ f", total {total_imgs_cnt} images. Evaluated using subset - {val_imgs_cnt} images"
167
+ )
168
+ elif gt_dataset_ids is not None:
169
+ links = [
170
+ f'<a href="/projects/{gt_project_id}/datasets/{ds.id}" target="_blank">{ds.name}</a>'
171
+ for ds in datasets
172
+ ]
173
+ images_str += f", total {total_imgs_cnt} images. Evaluated on the dataset{'s' if len(links) > 1 else ''}: {', '.join(links)}"
174
+ else:
175
+ images_str += f", total {total_imgs_cnt} images. Evaluated on the whole project ({val_imgs_cnt} images)"
176
+
177
+ return classes_str, images_str, train_session
178
+
179
+ def get_figure(self): # -> Optional[go.Figure]
180
+ import plotly.graph_objects as go # pylint: disable=import-error
181
+
182
+ # Overall Metrics
183
+ fig = go.Figure()
184
+ for i, eval_result in enumerate(self.eval_results):
185
+ name = f"[{i + 1}] {eval_result.name}"
186
+ base_metrics = eval_result.mp.key_metrics().copy()
187
+ base_metrics["mPixel accuracy"] = round(base_metrics["mPixel accuracy"] * 100, 2)
188
+ r = list(base_metrics.values())
189
+ theta = list(base_metrics.keys())
190
+ fig.add_trace(
191
+ go.Scatterpolar(
192
+ r=r + [r[0]],
193
+ theta=theta + [theta[0]],
194
+ name=name,
195
+ marker=dict(color=eval_result.color),
196
+ hovertemplate=name + "<br>%{theta}: %{r:.2f}<extra></extra>",
197
+ )
198
+ )
199
+ fig.update_layout(
200
+ polar=dict(
201
+ radialaxis=dict(
202
+ range=[0, 105],
203
+ ticks="outside",
204
+ ),
205
+ angularaxis=dict(rotation=90, direction="clockwise"),
206
+ ),
207
+ dragmode=False,
208
+ height=500,
209
+ margin=dict(l=25, r=25, t=25, b=25),
210
+ modebar=dict(
211
+ remove=[
212
+ "zoom2d",
213
+ "pan2d",
214
+ "select2d",
215
+ "lasso2d",
216
+ "zoomIn2d",
217
+ "zoomOut2d",
218
+ "autoScale2d",
219
+ "resetScale2d",
220
+ ]
221
+ ),
222
+ )
223
+ return fig
@@ -0,0 +1,57 @@
1
+ from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
2
+ from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
3
+
4
+
5
+ class RenormalizedErrorOverUnion(BaseVisMetrics):
6
+
7
+ @property
8
+ def md(self) -> MarkdownWidget:
9
+ return MarkdownWidget(
10
+ "renormalized_error_over_union",
11
+ "Renormalized Error over Union",
12
+ text=self.vis_texts.markdown_renormalized_error_ou,
13
+ )
14
+
15
+ @property
16
+ def chart(self) -> ChartWidget:
17
+ return ChartWidget("intersection_error_over_union", self.get_figure())
18
+
19
+ def get_figure(self):
20
+ import plotly.graph_objects as go # pylint: disable=import-error
21
+
22
+ fig = go.Figure()
23
+
24
+ labels = ["Boundary EoU", "Extent EoU", "Segment EoU"]
25
+
26
+ for idx, eval_result in enumerate(self.eval_results, 1):
27
+ model_name = f"[{idx}] {eval_result.short_name}"
28
+
29
+ fig.add_trace(
30
+ go.Bar(
31
+ x=labels,
32
+ y=[
33
+ eval_result.mp.boundary_renormed_eou,
34
+ eval_result.mp.extent_renormed_eou,
35
+ eval_result.mp.segment_renormed_eou,
36
+ ],
37
+ name=model_name,
38
+ text=[
39
+ eval_result.mp.boundary_renormed_eou,
40
+ eval_result.mp.extent_renormed_eou,
41
+ eval_result.mp.segment_renormed_eou,
42
+ ],
43
+ textposition="outside",
44
+ marker=dict(color=eval_result.color, line=dict(width=0.7)),
45
+ width=0.4,
46
+ )
47
+ )
48
+
49
+ fig.update_traces(hovertemplate="%{x}: %{y:.2f}<extra></extra>")
50
+ fig.update_layout(
51
+ barmode="group",
52
+ bargap=0.15,
53
+ bargroupgap=0.05,
54
+ width=700 if len(self.eval_results) < 4 else 1000,
55
+ )
56
+
57
+ return fig