supervisely 6.73.254__py3-none-any.whl → 6.73.256__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/api.py +16 -8
- supervisely/api/file_api.py +16 -5
- supervisely/api/task_api.py +4 -2
- supervisely/app/widgets/field/field.py +10 -7
- supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
- supervisely/io/network_exceptions.py +14 -2
- supervisely/nn/benchmark/base_benchmark.py +33 -35
- supervisely/nn/benchmark/base_evaluator.py +27 -1
- supervisely/nn/benchmark/base_visualizer.py +8 -11
- supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
- supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
- supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
- supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
- supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
- supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
- supervisely/nn/benchmark/visualization/renderer.py +25 -10
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
- supervisely/nn/inference/inference.py +1 -0
- supervisely/nn/training/gui/gui.py +32 -10
- supervisely/nn/training/gui/training_artifacts.py +145 -0
- supervisely/nn/training/gui/training_process.py +3 -19
- supervisely/nn/training/train_app.py +179 -70
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/METADATA +1 -1
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/RECORD +60 -48
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/LICENSE +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/WHEEL +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.254.dist-info → supervisely-6.73.256.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
from supervisely.nn.benchmark.semantic_segmentation.base_vis_metric import (
|
|
2
2
|
SemanticSegmVisMetric,
|
|
3
3
|
)
|
|
4
|
-
from supervisely.nn.benchmark.visualization.widgets import
|
|
4
|
+
from supervisely.nn.benchmark.visualization.widgets import (
|
|
5
|
+
ChartWidget,
|
|
6
|
+
MarkdownWidget,
|
|
7
|
+
TableWidget,
|
|
8
|
+
)
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
class KeyMetrics(SemanticSegmVisMetric):
|
|
@@ -14,6 +18,32 @@ class KeyMetrics(SemanticSegmVisMetric):
|
|
|
14
18
|
text=self.vis_texts.markdown_key_metrics,
|
|
15
19
|
)
|
|
16
20
|
|
|
21
|
+
@property
|
|
22
|
+
def table(self) -> TableWidget:
|
|
23
|
+
columns = ["metrics", "values"]
|
|
24
|
+
content = []
|
|
25
|
+
|
|
26
|
+
metrics = self.eval_result.mp.key_metrics().copy()
|
|
27
|
+
metrics["mPixel accuracy"] = round(metrics["mPixel accuracy"] * 100, 2)
|
|
28
|
+
|
|
29
|
+
for metric, value in metrics.items():
|
|
30
|
+
row = [metric, round(value, 2)]
|
|
31
|
+
dct = {"row": row, "id": metric, "items": row}
|
|
32
|
+
content.append(dct)
|
|
33
|
+
|
|
34
|
+
columns_options = [{"disableSort": True}, {"disableSort": True}]
|
|
35
|
+
data = {"columns": columns, "columnsOptions": columns_options, "content": content}
|
|
36
|
+
|
|
37
|
+
table = TableWidget(
|
|
38
|
+
name="table_key_metrics",
|
|
39
|
+
data=data,
|
|
40
|
+
fix_columns=1,
|
|
41
|
+
width="60%",
|
|
42
|
+
show_header_controls=False,
|
|
43
|
+
main_column=columns[0],
|
|
44
|
+
)
|
|
45
|
+
return table
|
|
46
|
+
|
|
17
47
|
@property
|
|
18
48
|
def chart(self) -> ChartWidget:
|
|
19
49
|
return ChartWidget("base_metrics_chart", self.get_figure())
|
|
@@ -26,9 +26,7 @@ class Overview(SemanticSegmVisMetric):
|
|
|
26
26
|
link_text = link_text.replace("_", "\_")
|
|
27
27
|
|
|
28
28
|
model_name = self.eval_result.inference_info.get("model_name") or "Custom"
|
|
29
|
-
checkpoint_name = self.eval_result.
|
|
30
|
-
"checkpoint_name", ""
|
|
31
|
-
)
|
|
29
|
+
checkpoint_name = self.eval_result.checkpoint_name
|
|
32
30
|
|
|
33
31
|
# Note about validation dataset
|
|
34
32
|
classes_str, note_about_images, starter_app_info = self._get_overview_info()
|
|
@@ -93,6 +93,7 @@ class SemanticSegmentationVisualizer(BaseVisualizer):
|
|
|
93
93
|
# key metrics
|
|
94
94
|
key_metrics = KeyMetrics(self.vis_texts, self.eval_result)
|
|
95
95
|
self.key_metrics_md = key_metrics.md
|
|
96
|
+
self.key_metrics_table = key_metrics.table
|
|
96
97
|
self.key_metrics_chart = key_metrics.chart
|
|
97
98
|
|
|
98
99
|
# explore predictions
|
|
@@ -143,15 +144,14 @@ class SemanticSegmentationVisualizer(BaseVisualizer):
|
|
|
143
144
|
self.acknowledgement_md = acknowledgement.md
|
|
144
145
|
|
|
145
146
|
# SpeedTest
|
|
146
|
-
self.speedtest_present = False
|
|
147
|
-
self.speedtest_multiple_batch_sizes = False
|
|
148
147
|
speedtest = Speedtest(self.vis_texts, self.eval_result)
|
|
149
|
-
|
|
150
|
-
|
|
148
|
+
self.speedtest_present = not speedtest.is_empty()
|
|
149
|
+
self.speedtest_multiple_batch_sizes = False
|
|
150
|
+
if self.speedtest_present:
|
|
151
151
|
self.speedtest_md_intro = speedtest.intro_md
|
|
152
152
|
self.speedtest_intro_table = speedtest.intro_table
|
|
153
|
-
|
|
154
|
-
|
|
153
|
+
self.speedtest_multiple_batch_sizes = speedtest.multiple_batche_sizes()
|
|
154
|
+
if self.speedtest_multiple_batch_sizes:
|
|
155
155
|
self.speedtest_batch_inference_md = speedtest.batch_size_md
|
|
156
156
|
self.speedtest_chart = speedtest.chart
|
|
157
157
|
|
|
@@ -166,6 +166,7 @@ class SemanticSegmentationVisualizer(BaseVisualizer):
|
|
|
166
166
|
(0, self.header),
|
|
167
167
|
(1, self.overview_md),
|
|
168
168
|
(1, self.key_metrics_md),
|
|
169
|
+
(0, self.key_metrics_table),
|
|
169
170
|
(0, self.key_metrics_chart),
|
|
170
171
|
(1, self.explore_predictions_md),
|
|
171
172
|
(0, self.explore_predictions_gallery),
|
|
@@ -63,29 +63,11 @@ class Evaluator:
|
|
|
63
63
|
:param boundary_implementation: Choose "exact" for the euclidean pixel distance.
|
|
64
64
|
The Boundary IoU paper uses the L1 distance ("fast").
|
|
65
65
|
"""
|
|
66
|
-
global torch, np, GPU
|
|
66
|
+
global torch, np, GPU, numpy
|
|
67
67
|
import torch # pylint: disable=import-error
|
|
68
68
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
logger.info("Using GPU for evaluation.")
|
|
72
|
-
try:
|
|
73
|
-
# gpu-compatible numpy analogue
|
|
74
|
-
import cupy as np # pylint: disable=import-error
|
|
75
|
-
|
|
76
|
-
global numpy
|
|
77
|
-
import numpy as numpy
|
|
78
|
-
except:
|
|
79
|
-
logger.warning(
|
|
80
|
-
"Failed to import cupy. Use cupy official documentation to install this "
|
|
81
|
-
"module: https://docs.cupy.dev/en/stable/install.html"
|
|
82
|
-
)
|
|
83
|
-
else:
|
|
84
|
-
GPU = False
|
|
85
|
-
import numpy as np
|
|
86
|
-
|
|
87
|
-
global numpy
|
|
88
|
-
numpy = np
|
|
69
|
+
numpy = np
|
|
70
|
+
GPU = False
|
|
89
71
|
|
|
90
72
|
self.progress = progress or tqdm_sly
|
|
91
73
|
self.class_names = class_names
|
|
@@ -20,6 +20,7 @@ class Renderer:
|
|
|
20
20
|
layout: BaseWidget,
|
|
21
21
|
base_dir: str = "./output",
|
|
22
22
|
template: str = None,
|
|
23
|
+
report_name: str = "Model Evaluation Report.lnk",
|
|
23
24
|
) -> None:
|
|
24
25
|
if template is None:
|
|
25
26
|
template = (
|
|
@@ -28,6 +29,9 @@ class Renderer:
|
|
|
28
29
|
self.main_template = template
|
|
29
30
|
self.layout = layout
|
|
30
31
|
self.base_dir = base_dir
|
|
32
|
+
self.report_name = report_name
|
|
33
|
+
self._report = None
|
|
34
|
+
self._lnk = None
|
|
31
35
|
|
|
32
36
|
if Path(base_dir).exists():
|
|
33
37
|
if not dir_empty(base_dir):
|
|
@@ -81,20 +85,31 @@ class Renderer:
|
|
|
81
85
|
change_name_if_conflict=True,
|
|
82
86
|
progress_size_cb=pbar,
|
|
83
87
|
)
|
|
84
|
-
src = self.
|
|
85
|
-
|
|
88
|
+
src = self._save_report_link(api, team_id, remote_dir)
|
|
89
|
+
dst = Path(remote_dir).joinpath(self.report_name)
|
|
90
|
+
self._lnk = api.file.upload(team_id=team_id, src=src, dst=str(dst))
|
|
86
91
|
return remote_dir
|
|
87
92
|
|
|
88
|
-
def
|
|
89
|
-
report_link = self.
|
|
90
|
-
pth = Path(self.base_dir).joinpath(
|
|
93
|
+
def _save_report_link(self, api: Api, team_id: int, remote_dir: str):
|
|
94
|
+
report_link = self._get_report_path(api, team_id, remote_dir)
|
|
95
|
+
pth = Path(self.base_dir).joinpath(self.report_name)
|
|
91
96
|
with open(pth, "w") as f:
|
|
92
97
|
f.write(report_link)
|
|
93
98
|
return str(pth)
|
|
94
99
|
|
|
95
|
-
def
|
|
96
|
-
|
|
97
|
-
|
|
100
|
+
def _get_report_link(self, api: Api, team_id: int, remote_dir: str):
|
|
101
|
+
path = self._get_report_path(api, team_id, remote_dir)
|
|
102
|
+
return f"{api.server_address}{path}"
|
|
98
103
|
|
|
99
|
-
|
|
100
|
-
|
|
104
|
+
def _get_report_path(self, api: Api, team_id: int, remote_dir: str):
|
|
105
|
+
template_path = Path(remote_dir).joinpath("template.vue")
|
|
106
|
+
self._report = api.file.get_info_by_path(team_id, str(template_path))
|
|
107
|
+
return "/model-benchmark?id=" + str(self._report.id)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def report(self):
|
|
111
|
+
return self._report
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def lnk(self):
|
|
115
|
+
return self._lnk
|
|
@@ -79,6 +79,7 @@ class GalleryWidget(BaseWidget):
|
|
|
79
79
|
column_index=idx % self.columns_number,
|
|
80
80
|
project_meta=project_metas[idx % self.columns_number],
|
|
81
81
|
ignore_tags_filtering=skip_tags_filtering[idx % self.columns_number],
|
|
82
|
+
call_update=idx == len(image_infos) - 1,
|
|
82
83
|
)
|
|
83
84
|
|
|
84
85
|
def _get_init_data(self):
|
|
@@ -133,6 +133,7 @@ class Inference:
|
|
|
133
133
|
if self.INFERENCE_SETTINGS is not None:
|
|
134
134
|
custom_inference_settings = self.INFERENCE_SETTINGS
|
|
135
135
|
else:
|
|
136
|
+
logger.debug("Custom inference settings are not provided.")
|
|
136
137
|
custom_inference_settings = {}
|
|
137
138
|
if isinstance(custom_inference_settings, str):
|
|
138
139
|
if fs.file_exists(custom_inference_settings):
|
|
@@ -14,6 +14,7 @@ from supervisely.nn.training.gui.hyperparameters_selector import Hyperparameters
|
|
|
14
14
|
from supervisely.nn.training.gui.input_selector import InputSelector
|
|
15
15
|
from supervisely.nn.training.gui.model_selector import ModelSelector
|
|
16
16
|
from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplitsSelector
|
|
17
|
+
from supervisely.nn.training.gui.training_artifacts import TrainingArtifacts
|
|
17
18
|
from supervisely.nn.training.gui.training_logs import TrainingLogs
|
|
18
19
|
from supervisely.nn.training.gui.training_process import TrainingProcess
|
|
19
20
|
from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
|
|
@@ -50,7 +51,9 @@ class TrainGUI:
|
|
|
50
51
|
if is_production():
|
|
51
52
|
self.task_id = sly_env.task_id()
|
|
52
53
|
else:
|
|
53
|
-
self.task_id =
|
|
54
|
+
self.task_id = sly_env.task_id(raise_not_found=False)
|
|
55
|
+
if self.task_id is None:
|
|
56
|
+
self.task_id = "debug-session"
|
|
54
57
|
|
|
55
58
|
self.framework_name = framework_name
|
|
56
59
|
self.models = models
|
|
@@ -86,17 +89,22 @@ class TrainGUI:
|
|
|
86
89
|
# 7. Training logs
|
|
87
90
|
self.training_logs = TrainingLogs(self.app_options)
|
|
88
91
|
|
|
92
|
+
# 8. Training Artifacts
|
|
93
|
+
self.training_artifacts = TrainingArtifacts(self.app_options)
|
|
94
|
+
|
|
89
95
|
# Stepper layout
|
|
96
|
+
self.steps = [
|
|
97
|
+
self.input_selector.card,
|
|
98
|
+
self.train_val_splits_selector.card,
|
|
99
|
+
self.classes_selector.card,
|
|
100
|
+
self.model_selector.card,
|
|
101
|
+
self.hyperparameters_selector.card,
|
|
102
|
+
self.training_process.card,
|
|
103
|
+
self.training_logs.card,
|
|
104
|
+
self.training_artifacts.card,
|
|
105
|
+
]
|
|
90
106
|
self.stepper = Stepper(
|
|
91
|
-
widgets=
|
|
92
|
-
self.input_selector.card,
|
|
93
|
-
self.train_val_splits_selector.card,
|
|
94
|
-
self.classes_selector.card,
|
|
95
|
-
self.model_selector.card,
|
|
96
|
-
self.hyperparameters_selector.card,
|
|
97
|
-
self.training_process.card,
|
|
98
|
-
self.training_logs.card,
|
|
99
|
-
],
|
|
107
|
+
widgets=self.steps,
|
|
100
108
|
)
|
|
101
109
|
# ------------------------------------------------- #
|
|
102
110
|
|
|
@@ -265,6 +273,20 @@ class TrainGUI:
|
|
|
265
273
|
|
|
266
274
|
self.layout: Widget = self.stepper
|
|
267
275
|
|
|
276
|
+
def set_next_step(self):
|
|
277
|
+
current_step = self.stepper.get_active_step()
|
|
278
|
+
self.stepper.set_active_step(current_step + 1)
|
|
279
|
+
|
|
280
|
+
def set_previous_step(self):
|
|
281
|
+
current_step = self.stepper.get_active_step()
|
|
282
|
+
self.stepper.set_active_step(current_step - 1)
|
|
283
|
+
|
|
284
|
+
def set_first_step(self):
|
|
285
|
+
self.stepper.set_active_step(1)
|
|
286
|
+
|
|
287
|
+
def set_last_step(self):
|
|
288
|
+
self.stepper.set_active_step(len(self.steps))
|
|
289
|
+
|
|
268
290
|
def enable_select_buttons(self):
|
|
269
291
|
"""
|
|
270
292
|
Makes all select buttons in the GUI available for interaction.
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from supervisely import Api
|
|
4
|
+
from supervisely.app.widgets import (
|
|
5
|
+
Card,
|
|
6
|
+
Container,
|
|
7
|
+
Empty,
|
|
8
|
+
Field,
|
|
9
|
+
Flexbox,
|
|
10
|
+
FolderThumbnail,
|
|
11
|
+
ReportThumbnail,
|
|
12
|
+
Text,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
PYTORCH_ICON = "https://img.icons8.com/?size=100&id=jH4BpkMnRrU5&format=png&color=000000"
|
|
16
|
+
ONNX_ICON = "https://artwork.lfaidata.foundation/projects/onnx/icon/color/onnx-icon-color.png"
|
|
17
|
+
TRT_ICON = "https://img.icons8.com/?size=100&id=yqf95864UzeQ&format=png&color=000000"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TrainingArtifacts:
|
|
21
|
+
title = "Training Artifacts"
|
|
22
|
+
description = "All outputs of the training process will appear here"
|
|
23
|
+
lock_message = "Artifacts will be available after training is completed"
|
|
24
|
+
|
|
25
|
+
def __init__(self, app_options: Dict[str, Any]):
|
|
26
|
+
self.display_widgets = []
|
|
27
|
+
self.success_message_text = (
|
|
28
|
+
"Training completed. Training artifacts were uploaded to Team Files. "
|
|
29
|
+
"You can find and open tensorboard logs in the artifacts folder via the "
|
|
30
|
+
"<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
|
|
31
|
+
)
|
|
32
|
+
self.app_options = app_options
|
|
33
|
+
|
|
34
|
+
# GUI Components
|
|
35
|
+
self.validator_text = Text("")
|
|
36
|
+
self.validator_text.hide()
|
|
37
|
+
self.display_widgets.extend([self.validator_text])
|
|
38
|
+
|
|
39
|
+
# Outputs
|
|
40
|
+
self.artifacts_thumbnail = FolderThumbnail()
|
|
41
|
+
self.artifacts_thumbnail.hide()
|
|
42
|
+
|
|
43
|
+
self.artifacts_field = Field(
|
|
44
|
+
title="Artifacts",
|
|
45
|
+
description="Contains all outputs of the training process",
|
|
46
|
+
content=self.artifacts_thumbnail,
|
|
47
|
+
)
|
|
48
|
+
self.artifacts_field.hide()
|
|
49
|
+
self.display_widgets.extend([self.artifacts_field])
|
|
50
|
+
|
|
51
|
+
# Optional Model Benchmark
|
|
52
|
+
if app_options.get("model_benchmark", False):
|
|
53
|
+
self.model_benchmark_report_thumbnail = ReportThumbnail()
|
|
54
|
+
self.model_benchmark_report_thumbnail.hide()
|
|
55
|
+
|
|
56
|
+
self.model_benchmark_fail_text = Text(
|
|
57
|
+
text="Model evaluation did not finish successfully. Please check the app logs for details.",
|
|
58
|
+
status="error",
|
|
59
|
+
)
|
|
60
|
+
self.model_benchmark_fail_text.hide()
|
|
61
|
+
|
|
62
|
+
self.model_benchmark_widgets = Container(
|
|
63
|
+
[self.model_benchmark_report_thumbnail, self.model_benchmark_fail_text]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.model_benchmark_report_field = Field(
|
|
67
|
+
title="Model Benchmark",
|
|
68
|
+
description="Evaluation report of the trained model",
|
|
69
|
+
content=self.model_benchmark_widgets,
|
|
70
|
+
)
|
|
71
|
+
self.model_benchmark_report_field.hide()
|
|
72
|
+
self.display_widgets.extend([self.model_benchmark_report_field])
|
|
73
|
+
# -------------------------------- #
|
|
74
|
+
|
|
75
|
+
# PyTorch, ONNX, TensorRT demo
|
|
76
|
+
self.inference_demo_field = []
|
|
77
|
+
model_demo = self.app_options.get("demo", None)
|
|
78
|
+
if model_demo is not None:
|
|
79
|
+
pytorch_demo_link = model_demo.get("pytorch", None)
|
|
80
|
+
if pytorch_demo_link is not None:
|
|
81
|
+
pytorch_icon = Field.Icon(image_url=PYTORCH_ICON, bg_color_rgb=[255, 255, 255])
|
|
82
|
+
self.pytorch_instruction = Field(
|
|
83
|
+
title="PyTorch",
|
|
84
|
+
description="Open file",
|
|
85
|
+
description_url=pytorch_demo_link,
|
|
86
|
+
icon=pytorch_icon,
|
|
87
|
+
content=Empty(),
|
|
88
|
+
)
|
|
89
|
+
self.pytorch_instruction.hide()
|
|
90
|
+
self.inference_demo_field.extend([self.pytorch_instruction])
|
|
91
|
+
|
|
92
|
+
onnx_demo_link = model_demo.get("onnx", None)
|
|
93
|
+
if onnx_demo_link is not None:
|
|
94
|
+
if self.app_options.get("export_onnx_supported", False):
|
|
95
|
+
onnx_icon = Field.Icon(image_url=ONNX_ICON, bg_color_rgb=[255, 255, 255])
|
|
96
|
+
self.onnx_instruction = Field(
|
|
97
|
+
title="ONNX",
|
|
98
|
+
description="Open file",
|
|
99
|
+
description_url=onnx_demo_link,
|
|
100
|
+
icon=onnx_icon,
|
|
101
|
+
content=Empty(),
|
|
102
|
+
)
|
|
103
|
+
self.onnx_instruction.hide()
|
|
104
|
+
self.inference_demo_field.extend([self.onnx_instruction])
|
|
105
|
+
|
|
106
|
+
trt_demo_link = model_demo.get("tensorrt", None)
|
|
107
|
+
if trt_demo_link is not None:
|
|
108
|
+
if self.app_options.get("export_tensorrt_supported", False):
|
|
109
|
+
trt_icon = Field.Icon(image_url=TRT_ICON, bg_color_rgb=[255, 255, 255])
|
|
110
|
+
self.trt_instruction = Field(
|
|
111
|
+
title="TensorRT",
|
|
112
|
+
description="Open file",
|
|
113
|
+
description_url=trt_demo_link,
|
|
114
|
+
icon=trt_icon,
|
|
115
|
+
content=Empty(),
|
|
116
|
+
)
|
|
117
|
+
self.trt_instruction.hide()
|
|
118
|
+
self.inference_demo_field.extend([self.trt_instruction])
|
|
119
|
+
|
|
120
|
+
demo_overview_link = model_demo.get("overview", None)
|
|
121
|
+
self.inference_demo_field = Field(
|
|
122
|
+
title="How to run inference",
|
|
123
|
+
description="Instructions on how to use your checkpoints outside of Supervisely Platform",
|
|
124
|
+
content=Flexbox(self.inference_demo_field),
|
|
125
|
+
title_url=demo_overview_link,
|
|
126
|
+
)
|
|
127
|
+
self.inference_demo_field.hide()
|
|
128
|
+
self.display_widgets.extend([self.inference_demo_field])
|
|
129
|
+
# -------------------------------- #
|
|
130
|
+
|
|
131
|
+
self.container = Container(self.display_widgets)
|
|
132
|
+
self.card = Card(
|
|
133
|
+
title=self.title,
|
|
134
|
+
description=self.description,
|
|
135
|
+
content=self.container,
|
|
136
|
+
lock_message=self.lock_message,
|
|
137
|
+
)
|
|
138
|
+
self.card.lock()
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def widgets_to_disable(self) -> list:
|
|
142
|
+
return []
|
|
143
|
+
|
|
144
|
+
def validate_step(self) -> bool:
|
|
145
|
+
return True
|
|
@@ -23,23 +23,18 @@ class TrainingProcess:
|
|
|
23
23
|
|
|
24
24
|
def __init__(self, app_options: Dict[str, Any]):
|
|
25
25
|
self.display_widgets = []
|
|
26
|
-
self.success_message_text = (
|
|
27
|
-
"Training completed. Training artifacts were uploaded to Team Files. "
|
|
28
|
-
"You can find and open tensorboard logs in the artifacts folder via the "
|
|
29
|
-
"<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
|
|
30
|
-
)
|
|
31
26
|
self.app_options = app_options
|
|
32
27
|
|
|
33
28
|
# GUI Components
|
|
34
29
|
# Optional Select CUDA device
|
|
35
30
|
if self.app_options.get("device_selector", False):
|
|
36
31
|
self.select_device = SelectCudaDevice()
|
|
37
|
-
self.
|
|
32
|
+
self.select_device_field = Field(
|
|
38
33
|
title="Select CUDA device",
|
|
39
34
|
description="The device on which the model will be trained",
|
|
40
35
|
content=self.select_device,
|
|
41
36
|
)
|
|
42
|
-
self.display_widgets.extend([self.
|
|
37
|
+
self.display_widgets.extend([self.select_device_field])
|
|
43
38
|
# -------------------------------- #
|
|
44
39
|
|
|
45
40
|
self.experiment_name_input = Input("Enter experiment name")
|
|
@@ -63,26 +58,15 @@ class TrainingProcess:
|
|
|
63
58
|
self.validator_text = Text("")
|
|
64
59
|
self.validator_text.hide()
|
|
65
60
|
|
|
66
|
-
self.artifacts_thumbnail = FolderThumbnail()
|
|
67
|
-
self.artifacts_thumbnail.hide()
|
|
68
|
-
|
|
69
61
|
self.display_widgets.extend(
|
|
70
62
|
[
|
|
71
63
|
self.experiment_name_field,
|
|
72
64
|
button_container,
|
|
73
65
|
self.validator_text,
|
|
74
|
-
self.artifacts_thumbnail,
|
|
75
66
|
]
|
|
76
67
|
)
|
|
77
68
|
# -------------------------------- #
|
|
78
69
|
|
|
79
|
-
# Optional Model Benchmark
|
|
80
|
-
if app_options.get("model_benchmark", False):
|
|
81
|
-
self.model_benchmark_report_thumbnail = ReportThumbnail()
|
|
82
|
-
self.model_benchmark_report_thumbnail.hide()
|
|
83
|
-
self.display_widgets.extend([self.model_benchmark_report_thumbnail])
|
|
84
|
-
# -------------------------------- #
|
|
85
|
-
|
|
86
70
|
self.container = Container(self.display_widgets)
|
|
87
71
|
self.card = Card(
|
|
88
72
|
title=self.title,
|
|
@@ -96,7 +80,7 @@ class TrainingProcess:
|
|
|
96
80
|
def widgets_to_disable(self) -> list:
|
|
97
81
|
widgets = [self.experiment_name_input]
|
|
98
82
|
if self.app_options.get("device_selector", False):
|
|
99
|
-
widgets.
|
|
83
|
+
widgets.extend([self.select_device, self.select_device_field])
|
|
100
84
|
return widgets
|
|
101
85
|
|
|
102
86
|
def validate_step(self) -> bool:
|