supervisely 6.73.254__py3-none-any.whl → 6.73.256__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/api.py +16 -8
- supervisely/api/file_api.py +16 -5
- supervisely/api/task_api.py +4 -2
- supervisely/app/widgets/field/field.py +10 -7
- supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
- supervisely/io/network_exceptions.py +14 -2
- supervisely/nn/benchmark/base_benchmark.py +33 -35
- supervisely/nn/benchmark/base_evaluator.py +27 -1
- supervisely/nn/benchmark/base_visualizer.py +8 -11
- supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
- supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
- supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
- supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
- supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
- supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
- supervisely/nn/benchmark/visualization/renderer.py +25 -10
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
- supervisely/nn/inference/inference.py +1 -0
- supervisely/nn/training/gui/gui.py +32 -10
- supervisely/nn/training/gui/training_artifacts.py +145 -0
- supervisely/nn/training/gui/training_process.py +3 -19
- supervisely/nn/training/train_app.py +179 -70
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/METADATA +1 -1
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/RECORD +60 -48
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/LICENSE +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/WHEEL +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/top_level.txt +0 -0
supervisely/api/api.py
CHANGED
|
@@ -69,6 +69,7 @@ from supervisely.io.network_exceptions import (
|
|
|
69
69
|
process_requests_exception,
|
|
70
70
|
process_requests_exception_async,
|
|
71
71
|
process_unhandled_request,
|
|
72
|
+
RetryableRequestException,
|
|
72
73
|
)
|
|
73
74
|
from supervisely.project.project_meta import ProjectMeta
|
|
74
75
|
from supervisely.sly_logger import logger
|
|
@@ -1283,16 +1284,20 @@ class Api:
|
|
|
1283
1284
|
Api._raise_for_status_httpx(resp)
|
|
1284
1285
|
|
|
1285
1286
|
hhash = resp.headers.get("x-content-checksum-sha256", None)
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1287
|
+
try:
|
|
1288
|
+
for chunk in resp.iter_raw(chunk_size):
|
|
1289
|
+
yield chunk, hhash
|
|
1290
|
+
total_streamed += len(chunk)
|
|
1291
|
+
except Exception as e:
|
|
1292
|
+
raise RetryableRequestException(repr(e))
|
|
1293
|
+
|
|
1289
1294
|
if expected_size != 0 and total_streamed != expected_size:
|
|
1290
1295
|
raise ValueError(
|
|
1291
1296
|
f"Streamed size does not match the expected: {total_streamed} != {expected_size}"
|
|
1292
1297
|
)
|
|
1293
1298
|
logger.trace(f"Streamed size: {total_streamed}, expected size: {expected_size}")
|
|
1294
1299
|
return
|
|
1295
|
-
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
1300
|
+
except (httpx.RequestError, httpx.HTTPStatusError, RetryableRequestException) as e:
|
|
1296
1301
|
if (
|
|
1297
1302
|
isinstance(e, httpx.HTTPStatusError)
|
|
1298
1303
|
and resp.status_code == 400
|
|
@@ -1531,9 +1536,12 @@ class Api:
|
|
|
1531
1536
|
|
|
1532
1537
|
# received hash of the content to check integrity of the data stream
|
|
1533
1538
|
hhash = resp.headers.get("x-content-checksum-sha256", None)
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1539
|
+
try:
|
|
1540
|
+
async for chunk in resp.aiter_raw(chunk_size):
|
|
1541
|
+
yield chunk, hhash
|
|
1542
|
+
total_streamed += len(chunk)
|
|
1543
|
+
except Exception as e:
|
|
1544
|
+
raise RetryableRequestException(repr(e))
|
|
1537
1545
|
|
|
1538
1546
|
if expected_size != 0 and total_streamed != expected_size:
|
|
1539
1547
|
raise ValueError(
|
|
@@ -1541,7 +1549,7 @@ class Api:
|
|
|
1541
1549
|
)
|
|
1542
1550
|
logger.trace(f"Streamed size: {total_streamed}, expected size: {expected_size}")
|
|
1543
1551
|
return
|
|
1544
|
-
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
1552
|
+
except (httpx.RequestError, httpx.HTTPStatusError, RetryableRequestException) as e:
|
|
1545
1553
|
if (
|
|
1546
1554
|
isinstance(e, httpx.HTTPStatusError)
|
|
1547
1555
|
and resp.status_code == 400
|
supervisely/api/file_api.py
CHANGED
|
@@ -1361,11 +1361,22 @@ class FileApi(ModuleApiBase):
|
|
|
1361
1361
|
# Output: /My_App_Test_001
|
|
1362
1362
|
"""
|
|
1363
1363
|
res_dir = dir_path.rstrip("/")
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1364
|
+
if not self.dir_exists(team_id, res_dir + "/"):
|
|
1365
|
+
return res_dir
|
|
1366
|
+
|
|
1367
|
+
low, high = 0, 1
|
|
1368
|
+
while self.dir_exists(team_id, f"{res_dir}_{high:03d}/"):
|
|
1369
|
+
low = high
|
|
1370
|
+
high *= 2
|
|
1371
|
+
|
|
1372
|
+
while low < high:
|
|
1373
|
+
mid = (low + high) // 2
|
|
1374
|
+
if self.dir_exists(team_id, f"{res_dir}_{mid:03d}/"):
|
|
1375
|
+
low = mid + 1
|
|
1376
|
+
else:
|
|
1377
|
+
high = mid
|
|
1378
|
+
|
|
1379
|
+
return f"{res_dir}_{low:03d}"
|
|
1369
1380
|
|
|
1370
1381
|
def upload_directory(
|
|
1371
1382
|
self,
|
supervisely/api/task_api.py
CHANGED
|
@@ -821,10 +821,12 @@ class TaskApi(ModuleApiBase, ModuleWithStatus):
|
|
|
821
821
|
)
|
|
822
822
|
return resp.json()
|
|
823
823
|
|
|
824
|
-
def set_output_report(
|
|
824
|
+
def set_output_report(
|
|
825
|
+
self, task_id: int, file_id: int, file_name: str, description: Optional[str] = "Report"
|
|
826
|
+
) -> Dict:
|
|
825
827
|
"""set_output_report"""
|
|
826
828
|
return self._set_custom_output(
|
|
827
|
-
task_id, file_id, file_name, description=
|
|
829
|
+
task_id, file_id, file_name, description=description, icon="zmdi zmdi-receipt"
|
|
828
830
|
)
|
|
829
831
|
|
|
830
832
|
def _set_custom_output(
|
|
@@ -58,14 +58,12 @@ class Field(Widget):
|
|
|
58
58
|
image_url: Optional[str] = None,
|
|
59
59
|
) -> Field.Icon:
|
|
60
60
|
if zmdi_class is None and image_url is None:
|
|
61
|
-
raise ValueError("One of the arguments has to be defined: zmdi_class or image_url")
|
|
62
|
-
if zmdi_class is not None and image_url is not None:
|
|
63
61
|
raise ValueError(
|
|
64
|
-
"
|
|
62
|
+
"One of the arguments has to be defined: zmdi_class or image_url"
|
|
65
63
|
)
|
|
66
|
-
if
|
|
64
|
+
if zmdi_class is not None and image_url is not None:
|
|
67
65
|
raise ValueError(
|
|
68
|
-
"
|
|
66
|
+
"Only one of the arguments has to be defined: zmdi_class or image_url"
|
|
69
67
|
)
|
|
70
68
|
|
|
71
69
|
if image_url is None and color_rgb is None:
|
|
@@ -104,6 +102,7 @@ class Field(Widget):
|
|
|
104
102
|
res["bgColor"] = sly_color.rgb2hex(self._bg_color)
|
|
105
103
|
if self._image_url is not None:
|
|
106
104
|
res["imageUrl"] = self._image_url
|
|
105
|
+
res["bgColor"] = sly_color.rgb2hex(self._bg_color)
|
|
107
106
|
return res
|
|
108
107
|
|
|
109
108
|
def __init__(
|
|
@@ -123,9 +122,13 @@ class Field(Widget):
|
|
|
123
122
|
self._icon = icon
|
|
124
123
|
self._content = content
|
|
125
124
|
if self._title_url is not None and self._title is None:
|
|
126
|
-
raise ValueError(
|
|
125
|
+
raise ValueError(
|
|
126
|
+
"Title can not be specified only as url without text value"
|
|
127
|
+
)
|
|
127
128
|
if self._description_url is not None and self._description is None:
|
|
128
|
-
raise ValueError(
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"Description can not be specified only as url without text value"
|
|
131
|
+
)
|
|
129
132
|
|
|
130
133
|
super().__init__(widget_id=widget_id, file_path=__file__)
|
|
131
134
|
|
|
@@ -144,6 +144,7 @@ class GridGalleryV2(Widget):
|
|
|
144
144
|
title: str = "",
|
|
145
145
|
column_index: int = None,
|
|
146
146
|
ignore_tags_filtering: Union[bool, List[str]] = False,
|
|
147
|
+
call_update: bool = True,
|
|
147
148
|
):
|
|
148
149
|
column_index = self.get_column_index(incoming_value=column_index)
|
|
149
150
|
cell_uuid = str(
|
|
@@ -168,7 +169,8 @@ class GridGalleryV2(Widget):
|
|
|
168
169
|
}
|
|
169
170
|
)
|
|
170
171
|
|
|
171
|
-
|
|
172
|
+
if call_update:
|
|
173
|
+
self._update()
|
|
172
174
|
return cell_uuid
|
|
173
175
|
|
|
174
176
|
def clean_up(self):
|
|
@@ -28,6 +28,14 @@ RETRY_STATUS_CODES = {
|
|
|
28
28
|
}
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
class RetryableRequestException(Exception):
|
|
32
|
+
"""Exception that indicates that the request should be retried."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, message, response=None):
|
|
35
|
+
super().__init__(message)
|
|
36
|
+
self.response = response
|
|
37
|
+
|
|
38
|
+
|
|
31
39
|
async def process_requests_exception_async(
|
|
32
40
|
external_logger,
|
|
33
41
|
exc,
|
|
@@ -50,6 +58,8 @@ async def process_requests_exception_async(
|
|
|
50
58
|
except Exception:
|
|
51
59
|
pass
|
|
52
60
|
|
|
61
|
+
is_retryable_exception = isinstance(exc, RetryableRequestException)
|
|
62
|
+
|
|
53
63
|
is_connection_error = isinstance(
|
|
54
64
|
exc,
|
|
55
65
|
(
|
|
@@ -82,7 +92,7 @@ async def process_requests_exception_async(
|
|
|
82
92
|
except (AttributeError, ValueError):
|
|
83
93
|
pass
|
|
84
94
|
|
|
85
|
-
if is_connection_error
|
|
95
|
+
if any([is_connection_error, is_server_retryable_error, is_retryable_exception]):
|
|
86
96
|
await process_retryable_request_async(
|
|
87
97
|
external_logger,
|
|
88
98
|
exc,
|
|
@@ -136,6 +146,8 @@ def process_requests_exception(
|
|
|
136
146
|
except Exception:
|
|
137
147
|
pass
|
|
138
148
|
|
|
149
|
+
is_retryable_exception = isinstance(exc, RetryableRequestException)
|
|
150
|
+
|
|
139
151
|
is_connection_error = isinstance(
|
|
140
152
|
exc,
|
|
141
153
|
(
|
|
@@ -168,7 +180,7 @@ def process_requests_exception(
|
|
|
168
180
|
except (AttributeError, ValueError):
|
|
169
181
|
pass
|
|
170
182
|
|
|
171
|
-
if is_connection_error
|
|
183
|
+
if any([is_connection_error, is_server_retryable_error, is_retryable_exception]):
|
|
172
184
|
process_retryable_request(
|
|
173
185
|
external_logger,
|
|
174
186
|
exc,
|
|
@@ -74,6 +74,8 @@ class BaseBenchmark:
|
|
|
74
74
|
self.train_info = None
|
|
75
75
|
self.evaluator_app_info = None
|
|
76
76
|
self.evaluation_params = evaluation_params
|
|
77
|
+
self.visualizer = None
|
|
78
|
+
self.remote_vis_dir = None
|
|
77
79
|
self._eval_results = None
|
|
78
80
|
self.report_id = None
|
|
79
81
|
self._validate_evaluation_params()
|
|
@@ -379,7 +381,7 @@ class BaseBenchmark:
|
|
|
379
381
|
logger.info(f"Found GT annotations in {gt_path}")
|
|
380
382
|
if not os.path.exists(dt_path):
|
|
381
383
|
with self.pbar(
|
|
382
|
-
message="Evaluation: Downloading
|
|
384
|
+
message="Evaluation: Downloading prediction annotations", total=self.num_items
|
|
383
385
|
) as p:
|
|
384
386
|
download_project(
|
|
385
387
|
self.api,
|
|
@@ -490,11 +492,12 @@ class BaseBenchmark:
|
|
|
490
492
|
"It should be defined in the subclass of BaseBenchmark (e.g. ObjectDetectionBenchmark)."
|
|
491
493
|
)
|
|
492
494
|
eval_result = self.get_eval_result()
|
|
493
|
-
|
|
494
|
-
|
|
495
|
+
layout_dir = self.get_layout_results_dir()
|
|
496
|
+
self.visualizer = self.visualizer_cls( # pylint: disable=not-callable
|
|
497
|
+
self.api, [eval_result], layout_dir, self.pbar
|
|
495
498
|
)
|
|
496
499
|
with self.pbar(message="Visualizations: Rendering layout", total=1) as p:
|
|
497
|
-
|
|
500
|
+
self.visualizer.visualize()
|
|
498
501
|
p.update(1)
|
|
499
502
|
|
|
500
503
|
def _get_or_create_diff_project(self) -> Tuple[ProjectInfo, bool]:
|
|
@@ -540,37 +543,16 @@ class BaseBenchmark:
|
|
|
540
543
|
return diff_project_info, is_existed
|
|
541
544
|
|
|
542
545
|
def upload_visualizations(self, dest_dir: str):
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
layout_dir
|
|
546
|
-
), f"The layout dir {layout_dir!r} is empty. You should run visualizations before uploading results."
|
|
547
|
-
|
|
548
|
-
# self.api.file.remove_dir(self.team_id, dest_dir, silent=True)
|
|
549
|
-
|
|
550
|
-
remote_dir = dest_dir
|
|
551
|
-
with self.pbar(
|
|
552
|
-
message="Visualizations: Uploading layout",
|
|
553
|
-
total=get_directory_size(layout_dir),
|
|
554
|
-
unit="B",
|
|
555
|
-
unit_scale=True,
|
|
556
|
-
) as p:
|
|
557
|
-
remote_dir = self.api.file.upload_directory(
|
|
558
|
-
self.team_id,
|
|
559
|
-
layout_dir,
|
|
560
|
-
dest_dir,
|
|
561
|
-
replace_if_conflict=True,
|
|
562
|
-
change_name_if_conflict=False,
|
|
563
|
-
progress_size_cb=p,
|
|
564
|
-
)
|
|
565
|
-
|
|
566
|
-
logger.info(f"Uploaded to: {remote_dir!r}")
|
|
546
|
+
self.remote_vis_dir = self.visualizer.upload_results(self.team_id, dest_dir, self.pbar)
|
|
547
|
+
return self.remote_vis_dir
|
|
567
548
|
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
logger.info(f"Open url: {report_link}")
|
|
549
|
+
@property
|
|
550
|
+
def report(self):
|
|
551
|
+
return self.visualizer.renderer.report
|
|
572
552
|
|
|
573
|
-
|
|
553
|
+
@property
|
|
554
|
+
def lnk(self):
|
|
555
|
+
return self.visualizer.renderer.lnk
|
|
574
556
|
|
|
575
557
|
def upload_report_link(self, remote_dir: str):
|
|
576
558
|
template_path = os.path.join(remote_dir, "template.vue")
|
|
@@ -578,16 +560,25 @@ class BaseBenchmark:
|
|
|
578
560
|
self.report_id = vue_template_info.id
|
|
579
561
|
|
|
580
562
|
report_link = "/model-benchmark?id=" + str(vue_template_info.id)
|
|
581
|
-
|
|
563
|
+
|
|
564
|
+
lnk_name = "Model Evaluation Report.lnk"
|
|
565
|
+
local_path = os.path.join(self.get_layout_results_dir(), lnk_name)
|
|
582
566
|
with open(local_path, "w") as file:
|
|
583
567
|
file.write(report_link)
|
|
584
568
|
|
|
585
|
-
remote_path = os.path.join(remote_dir,
|
|
569
|
+
remote_path = os.path.join(remote_dir, lnk_name)
|
|
586
570
|
file_info = self.api.file.upload(self.team_id, local_path, remote_path)
|
|
587
571
|
|
|
588
572
|
logger.info(f"Report link: {report_link}")
|
|
589
573
|
return file_info
|
|
590
574
|
|
|
575
|
+
def get_report_link(self) -> str:
|
|
576
|
+
if self.remote_vis_dir is None:
|
|
577
|
+
raise ValueError("Visualizations are not uploaded yet.")
|
|
578
|
+
return self.visualizer.renderer._get_report_link(
|
|
579
|
+
self.api, self.team_id, self.remote_vis_dir
|
|
580
|
+
)
|
|
581
|
+
|
|
591
582
|
def _merge_metas(self, gt_project_id, pred_project_id):
|
|
592
583
|
gt_meta = self.api.project.get_meta(gt_project_id)
|
|
593
584
|
gt_meta = ProjectMeta.from_json(gt_meta)
|
|
@@ -623,3 +614,10 @@ class BaseBenchmark:
|
|
|
623
614
|
if self._eval_results is None:
|
|
624
615
|
self._eval_results = self.evaluator.get_eval_result()
|
|
625
616
|
return self._eval_results
|
|
617
|
+
|
|
618
|
+
def get_diff_project_info(self):
|
|
619
|
+
eval_result = self.get_eval_result()
|
|
620
|
+
if hasattr(eval_result, "diff_project_info"):
|
|
621
|
+
self.diff_project_info = eval_result.diff_project_info
|
|
622
|
+
return self.diff_project_info
|
|
623
|
+
return None
|
|
@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Union
|
|
|
7
7
|
import yaml
|
|
8
8
|
|
|
9
9
|
from supervisely.app.widgets import SlyTqdm
|
|
10
|
+
from supervisely.io.fs import get_file_name_with_ext
|
|
10
11
|
from supervisely.task.progress import tqdm_sly
|
|
11
12
|
|
|
12
13
|
|
|
@@ -40,6 +41,14 @@ class BaseEvalResult:
|
|
|
40
41
|
or self.inference_info.get("model_name")
|
|
41
42
|
)
|
|
42
43
|
|
|
44
|
+
@property
|
|
45
|
+
def short_name(self) -> str:
|
|
46
|
+
if not self.name:
|
|
47
|
+
return
|
|
48
|
+
if len(self.name) > 20:
|
|
49
|
+
return self.name[:9] + "..." + self.name[-7:]
|
|
50
|
+
return self.name
|
|
51
|
+
|
|
43
52
|
@property
|
|
44
53
|
def gt_project_id(self) -> int:
|
|
45
54
|
return self.inference_info.get("gt_project_id")
|
|
@@ -79,11 +88,28 @@ class BaseEvalResult:
|
|
|
79
88
|
def _prepare_data(self) -> None:
|
|
80
89
|
"""Prepare data to allow easy access to the data"""
|
|
81
90
|
raise NotImplementedError()
|
|
82
|
-
|
|
91
|
+
|
|
83
92
|
@property
|
|
84
93
|
def key_metrics(self):
|
|
85
94
|
raise NotImplementedError()
|
|
86
95
|
|
|
96
|
+
@property
|
|
97
|
+
def checkpoint_name(self):
|
|
98
|
+
if self.inference_info is None:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
deploy_params = self.inference_info.get("deploy_params", {})
|
|
102
|
+
name = None
|
|
103
|
+
if deploy_params:
|
|
104
|
+
name = deploy_params.get("checkpoint_name") # not TrainApp
|
|
105
|
+
if name is None:
|
|
106
|
+
name = deploy_params.get("model_files", {}).get("checkpoint")
|
|
107
|
+
if name is not None:
|
|
108
|
+
name = get_file_name_with_ext(name)
|
|
109
|
+
if name is None:
|
|
110
|
+
name = self.inference_info.get("checkpoint_name", "")
|
|
111
|
+
return name
|
|
112
|
+
|
|
87
113
|
|
|
88
114
|
class BaseEvaluator:
|
|
89
115
|
EVALUATION_PARAMS_YAML_PATH: str = None
|
|
@@ -61,6 +61,7 @@ class BaseVisMetric(BaseVisMetrics):
|
|
|
61
61
|
|
|
62
62
|
class BaseVisualizer:
|
|
63
63
|
cv_task = None
|
|
64
|
+
report_name = "Model Evaluation Report.lnk"
|
|
64
65
|
|
|
65
66
|
def __init__(
|
|
66
67
|
self,
|
|
@@ -141,7 +142,7 @@ class BaseVisualizer:
|
|
|
141
142
|
def visualize(self):
|
|
142
143
|
if self.renderer is None:
|
|
143
144
|
layout = self._create_layout()
|
|
144
|
-
self.renderer = Renderer(layout, self.workdir)
|
|
145
|
+
self.renderer = Renderer(layout, self.workdir, report_name=self.report_name)
|
|
145
146
|
return self.renderer.visualize()
|
|
146
147
|
|
|
147
148
|
def upload_results(self, team_id: int, remote_dir: str, progress=None):
|
|
@@ -178,19 +179,15 @@ class BaseVisualizer:
|
|
|
178
179
|
diff_ds_infos.append(diff_dataset)
|
|
179
180
|
return diff_dataset
|
|
180
181
|
|
|
182
|
+
is_existed = False
|
|
181
183
|
project_name = self._generate_diff_project_name(self.eval_result.pred_project_info.name)
|
|
182
184
|
workspace_id = self.eval_result.pred_project_info.workspace_id
|
|
183
|
-
project_info = self.api.project.
|
|
184
|
-
workspace_id, project_name,
|
|
185
|
+
project_info = self.api.project.create(
|
|
186
|
+
workspace_id, project_name, change_name_if_conflict=True
|
|
185
187
|
)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
workspace_id, project_name, change_name_if_conflict=True
|
|
190
|
-
)
|
|
191
|
-
pred_datasets = {ds.id: ds for ds in self.eval_result.pred_dataset_infos}
|
|
192
|
-
for dataset in pred_datasets:
|
|
193
|
-
_get_or_create_diff_dataset(dataset, pred_datasets)
|
|
188
|
+
pred_datasets = {ds.id: ds for ds in self.eval_result.pred_dataset_infos}
|
|
189
|
+
for dataset in pred_datasets:
|
|
190
|
+
_get_or_create_diff_dataset(dataset, pred_datasets)
|
|
194
191
|
return project_info, diff_ds_infos, is_existed
|
|
195
192
|
|
|
196
193
|
def _generate_diff_project_name(self, pred_project_name):
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from supervisely.api.module_api import ApiField
|
|
6
|
+
from supervisely.nn.benchmark.comparison.semantic_segmentation.vis_metrics import (
|
|
7
|
+
Overview,
|
|
8
|
+
)
|
|
9
|
+
from supervisely.nn.benchmark.visualization.renderer import Renderer
|
|
10
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
11
|
+
ContainerWidget,
|
|
12
|
+
GalleryWidget,
|
|
13
|
+
MarkdownWidget,
|
|
14
|
+
)
|
|
15
|
+
from supervisely.project.project_meta import ProjectMeta
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseComparisonVisualizer:
|
|
19
|
+
vis_texts = None
|
|
20
|
+
ann_opacity = None
|
|
21
|
+
report_name = "Model Comparison Report.lnk"
|
|
22
|
+
|
|
23
|
+
def __init__(self, comparison):
|
|
24
|
+
self.comparison = comparison
|
|
25
|
+
self.api = comparison.api
|
|
26
|
+
self.eval_results = comparison.eval_results
|
|
27
|
+
self.gt_project_info = None
|
|
28
|
+
self.gt_project_meta = None
|
|
29
|
+
# self._widgets_created = False
|
|
30
|
+
|
|
31
|
+
for eval_result in self.eval_results:
|
|
32
|
+
eval_result.api = self.api # add api to eval_result for overview widget
|
|
33
|
+
self._get_eval_project_infos(eval_result)
|
|
34
|
+
|
|
35
|
+
self._create_widgets()
|
|
36
|
+
layout = self._create_layout()
|
|
37
|
+
|
|
38
|
+
self.renderer = Renderer(
|
|
39
|
+
layout,
|
|
40
|
+
str(Path(self.comparison.workdir, "visualizations")),
|
|
41
|
+
report_name=self.report_name,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def visualize(self):
|
|
45
|
+
return self.renderer.visualize()
|
|
46
|
+
|
|
47
|
+
def upload_results(self, team_id: int, remote_dir: str, progress=None):
|
|
48
|
+
return self.renderer.upload_results(self.api, team_id, remote_dir, progress)
|
|
49
|
+
|
|
50
|
+
def _create_widgets(self):
|
|
51
|
+
raise NotImplementedError("Have to implement in subclasses")
|
|
52
|
+
|
|
53
|
+
def _create_layout(self):
|
|
54
|
+
raise NotImplementedError("Have to implement in subclasses")
|
|
55
|
+
|
|
56
|
+
def _create_header(self) -> MarkdownWidget:
|
|
57
|
+
"""Creates header widget"""
|
|
58
|
+
me = self.api.user.get_my_info().login
|
|
59
|
+
current_date = datetime.datetime.now().strftime("%d %B %Y, %H:%M")
|
|
60
|
+
header_main_text = " ∣ ".join( # vs. or | or ∣
|
|
61
|
+
eval_res.name for eval_res in self.comparison.eval_results
|
|
62
|
+
)
|
|
63
|
+
header_text = self.vis_texts.markdown_header.format(header_main_text, me, current_date)
|
|
64
|
+
header = MarkdownWidget("markdown_header", "Header", text=header_text)
|
|
65
|
+
return header
|
|
66
|
+
|
|
67
|
+
def _create_overviews(self, vm: Overview, grid_cols: Optional[int] = None) -> ContainerWidget:
|
|
68
|
+
"""Creates overview widgets"""
|
|
69
|
+
overview_widgets = vm.overview_widgets
|
|
70
|
+
if grid_cols is None:
|
|
71
|
+
grid_cols = 2
|
|
72
|
+
if len(overview_widgets) > 2:
|
|
73
|
+
grid_cols = 3
|
|
74
|
+
if len(overview_widgets) % 4 == 0:
|
|
75
|
+
grid_cols = 4
|
|
76
|
+
return ContainerWidget(
|
|
77
|
+
overview_widgets,
|
|
78
|
+
name="overview_container",
|
|
79
|
+
title="Overview",
|
|
80
|
+
grid=True,
|
|
81
|
+
grid_cols=grid_cols,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _create_explore_modal_table(
|
|
85
|
+
self, columns_number=3, click_gallery_id=None, hover_text=None
|
|
86
|
+
) -> GalleryWidget:
|
|
87
|
+
gallery = GalleryWidget(
|
|
88
|
+
"all_predictions_modal_gallery",
|
|
89
|
+
is_modal=True,
|
|
90
|
+
columns_number=columns_number,
|
|
91
|
+
click_gallery_id=click_gallery_id,
|
|
92
|
+
opacity=self.ann_opacity,
|
|
93
|
+
)
|
|
94
|
+
gallery.set_project_meta(self.eval_results[0].pred_project_meta)
|
|
95
|
+
if hover_text:
|
|
96
|
+
gallery.add_image_left_header(hover_text)
|
|
97
|
+
return gallery
|
|
98
|
+
|
|
99
|
+
def _create_diff_modal_table(self, columns_number=3) -> GalleryWidget:
|
|
100
|
+
gallery = GalleryWidget(
|
|
101
|
+
"diff_predictions_modal_gallery",
|
|
102
|
+
is_modal=True,
|
|
103
|
+
columns_number=columns_number,
|
|
104
|
+
opacity=self.ann_opacity,
|
|
105
|
+
)
|
|
106
|
+
gallery.set_project_meta(self.eval_results[0].pred_project_meta)
|
|
107
|
+
return gallery
|
|
108
|
+
|
|
109
|
+
def _create_clickable_label(self):
|
|
110
|
+
return MarkdownWidget("clickable_label", "", text=self.vis_texts.clickable_label)
|
|
111
|
+
|
|
112
|
+
def _get_eval_project_infos(self, eval_result):
|
|
113
|
+
# get project infos
|
|
114
|
+
if self.gt_project_info is None:
|
|
115
|
+
self.gt_project_info = self.api.project.get_info_by_id(eval_result.gt_project_id)
|
|
116
|
+
eval_result.gt_project_info = self.gt_project_info
|
|
117
|
+
eval_result.pred_project_info = self.api.project.get_info_by_id(eval_result.pred_project_id)
|
|
118
|
+
|
|
119
|
+
# get project metas
|
|
120
|
+
if self.gt_project_meta is None:
|
|
121
|
+
self.gt_project_meta = ProjectMeta.from_json(
|
|
122
|
+
self.api.project.get_meta(eval_result.gt_project_id)
|
|
123
|
+
)
|
|
124
|
+
eval_result.gt_project_meta = self.gt_project_meta
|
|
125
|
+
eval_result.pred_project_meta = ProjectMeta.from_json(
|
|
126
|
+
self.api.project.get_meta(eval_result.pred_project_id)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# get dataset infos
|
|
130
|
+
filters = None
|
|
131
|
+
if eval_result.gt_dataset_ids is not None:
|
|
132
|
+
filters = [
|
|
133
|
+
{
|
|
134
|
+
ApiField.FIELD: ApiField.ID,
|
|
135
|
+
ApiField.OPERATOR: "in",
|
|
136
|
+
ApiField.VALUE: eval_result.gt_dataset_ids,
|
|
137
|
+
}
|
|
138
|
+
]
|
|
139
|
+
eval_result.gt_dataset_infos = self.api.dataset.get_list(
|
|
140
|
+
eval_result.gt_project_id,
|
|
141
|
+
filters=filters,
|
|
142
|
+
recursive=True,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# eval_result.pred_dataset_infos = self.api.dataset.get_list(
|
|
146
|
+
# eval_result.pred_project_id, recursive=True
|
|
147
|
+
# )
|
|
@@ -4,7 +4,7 @@ from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.avg
|
|
|
4
4
|
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.calibration_score import (
|
|
5
5
|
CalibrationScore,
|
|
6
6
|
)
|
|
7
|
-
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.
|
|
7
|
+
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.explore_predictions import (
|
|
8
8
|
ExplorePredictions,
|
|
9
9
|
)
|
|
10
10
|
from supervisely.nn.benchmark.comparison.detection_visualization.vis_metrics.localization_accuracy import (
|
supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
from supervisely.nn.benchmark.
|
|
2
|
-
BaseVisMetric,
|
|
3
|
-
)
|
|
1
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
4
2
|
from supervisely.nn.benchmark.visualization.widgets import ChartWidget, MarkdownWidget
|
|
5
3
|
|
|
6
4
|
|
|
7
|
-
class AveragePrecisionByClass(
|
|
5
|
+
class AveragePrecisionByClass(BaseVisMetrics):
|
|
8
6
|
MARKDOWN_CLASS_AP = "markdown_class_ap_polar"
|
|
9
7
|
MARKDOWN_CLASS_AP_BAR = "markdown_class_ap_bar"
|
|
10
8
|
|
|
@@ -43,8 +41,8 @@ class AveragePrecisionByClass(BaseVisMetric):
|
|
|
43
41
|
x=eval_result.mp.cat_names,
|
|
44
42
|
y=ap_per_class,
|
|
45
43
|
name=trace_name,
|
|
46
|
-
width=0.2,
|
|
47
|
-
marker=dict(color=eval_result.color),
|
|
44
|
+
width=0.2 if cls_cnt >= 5 else None,
|
|
45
|
+
marker=dict(color=eval_result.color, line=dict(width=0.7)),
|
|
48
46
|
)
|
|
49
47
|
)
|
|
50
48
|
|
|
@@ -116,7 +114,7 @@ class AveragePrecisionByClass(BaseVisMetric):
|
|
|
116
114
|
{
|
|
117
115
|
"type": "tag",
|
|
118
116
|
"tagId": "confidence",
|
|
119
|
-
"value": [eval_result.f1_optimal_conf, 1],
|
|
117
|
+
"value": [eval_result.mp.f1_optimal_conf, 1],
|
|
120
118
|
},
|
|
121
119
|
{"type": "tag", "tagId": "outcome", "value": "TP"},
|
|
122
120
|
{"type": "specific_objects", "tagId": None, "value": list(obj_ids)},
|
supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
from supervisely.nn.benchmark.
|
|
2
|
-
BaseVisMetric,
|
|
3
|
-
)
|
|
1
|
+
from supervisely.nn.benchmark.base_visualizer import BaseVisMetrics
|
|
4
2
|
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
3
|
ChartWidget,
|
|
6
4
|
CollapseWidget,
|
|
@@ -10,7 +8,7 @@ from supervisely.nn.benchmark.visualization.widgets import (
|
|
|
10
8
|
)
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
class CalibrationScore(
|
|
11
|
+
class CalibrationScore(BaseVisMetrics):
|
|
14
12
|
@property
|
|
15
13
|
def header_md(self) -> MarkdownWidget:
|
|
16
14
|
text_template = self.vis_texts.markdown_calibration_score_1
|
|
@@ -139,7 +137,7 @@ class CalibrationScore(BaseVisMetric):
|
|
|
139
137
|
x=eval_result.dfsp_down["scores"],
|
|
140
138
|
y=eval_result.dfsp_down["f1"],
|
|
141
139
|
mode="lines",
|
|
142
|
-
name=f"[{i+1}] {eval_result.
|
|
140
|
+
name=f"[{i+1}] {eval_result.name}",
|
|
143
141
|
line=dict(color=eval_result.color),
|
|
144
142
|
hovertemplate="Confidence Score: %{x:.2f}<br>Value: %{y:.2f}<extra></extra>",
|
|
145
143
|
)
|
|
@@ -194,7 +192,7 @@ class CalibrationScore(BaseVisMetric):
|
|
|
194
192
|
x=pred_probs,
|
|
195
193
|
y=true_probs,
|
|
196
194
|
mode="lines+markers",
|
|
197
|
-
name=f"[{i+1}] {eval_result.
|
|
195
|
+
name=f"[{i+1}] {eval_result.name}",
|
|
198
196
|
line=dict(color=eval_result.color),
|
|
199
197
|
hovertemplate=f"{eval_result.name}<br>"
|
|
200
198
|
+ "Confidence Score: %{x:.2f}<br>Fraction of True Positives: %{y:.2f}<extra></extra>",
|