supervisely 6.73.325__py3-none-any.whl → 6.73.327__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/annotation/annotation.py +1 -1
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +17 -14
- supervisely/app/widgets/pretrained_models_selector/template.html +2 -1
- supervisely/convert/image/yolo/yolo_helper.py +95 -25
- supervisely/convert/volume/nii/nii_planes_volume_converter.py +54 -6
- supervisely/convert/volume/nii/nii_volume_converter.py +7 -7
- supervisely/convert/volume/nii/nii_volume_helper.py +49 -0
- supervisely/nn/inference/gui/serving_gui_template.py +2 -3
- supervisely/nn/inference/inference.py +33 -25
- supervisely/nn/training/gui/classes_selector.py +24 -19
- supervisely/nn/training/gui/gui.py +90 -37
- supervisely/nn/training/gui/hyperparameters_selector.py +32 -15
- supervisely/nn/training/gui/input_selector.py +13 -2
- supervisely/nn/training/gui/model_selector.py +16 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +10 -1
- supervisely/nn/training/gui/training_artifacts.py +23 -4
- supervisely/nn/training/gui/training_logs.py +15 -3
- supervisely/nn/training/gui/training_process.py +14 -13
- supervisely/nn/training/train_app.py +59 -24
- supervisely/nn/utils.py +9 -0
- supervisely/project/project.py +16 -3
- supervisely/volume/volume.py +19 -21
- {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/METADATA +1 -1
- {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/RECORD +28 -28
- {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/LICENSE +0 -0
- {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/WHEEL +0 -0
- {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/top_level.txt +0 -0
|
@@ -75,6 +75,7 @@ from supervisely.nn.utils import (
|
|
|
75
75
|
ModelPrecision,
|
|
76
76
|
ModelSource,
|
|
77
77
|
RuntimeType,
|
|
78
|
+
_get_model_name,
|
|
78
79
|
)
|
|
79
80
|
from supervisely.project import ProjectType
|
|
80
81
|
from supervisely.project.download import download_to_cache, read_from_cached_project
|
|
@@ -173,9 +174,7 @@ class Inference:
|
|
|
173
174
|
self._use_gui = False
|
|
174
175
|
deploy_params, need_download = self._get_deploy_params_from_args()
|
|
175
176
|
if need_download:
|
|
176
|
-
local_model_files = self._download_model_files(
|
|
177
|
-
deploy_params["model_source"], deploy_params["model_files"], False
|
|
178
|
-
)
|
|
177
|
+
local_model_files = self._download_model_files(deploy_params, False)
|
|
179
178
|
deploy_params["model_files"] = local_model_files
|
|
180
179
|
self._load_model_headless(**deploy_params)
|
|
181
180
|
|
|
@@ -210,14 +209,12 @@ class Inference:
|
|
|
210
209
|
self.initialize_gui()
|
|
211
210
|
|
|
212
211
|
def on_serve_callback(
|
|
213
|
-
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
|
|
212
|
+
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
|
|
214
213
|
):
|
|
215
214
|
Progress("Deploying model ...", 1)
|
|
216
215
|
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
217
216
|
deploy_params = self.get_params_from_gui()
|
|
218
|
-
model_files = self._download_model_files(
|
|
219
|
-
deploy_params["model_source"], deploy_params["model_files"]
|
|
220
|
-
)
|
|
217
|
+
model_files = self._download_model_files(deploy_params)
|
|
221
218
|
deploy_params["model_files"] = model_files
|
|
222
219
|
self._load_model_headless(**deploy_params)
|
|
223
220
|
elif isinstance(self.gui, GUI.ServingGUI):
|
|
@@ -230,7 +227,7 @@ class Inference:
|
|
|
230
227
|
gui.show_deployed_model_info(self)
|
|
231
228
|
|
|
232
229
|
def on_change_model_callback(
|
|
233
|
-
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
|
|
230
|
+
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
|
|
234
231
|
):
|
|
235
232
|
self.shutdown_model()
|
|
236
233
|
if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
|
|
@@ -567,13 +564,23 @@ class Inference:
|
|
|
567
564
|
def _checkpoints_cache_dir(self):
|
|
568
565
|
return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
|
|
569
566
|
|
|
570
|
-
def _download_model_files(
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
567
|
+
def _download_model_files(self, deploy_params: dict, log_progress: bool = True) -> dict:
|
|
568
|
+
if deploy_params["runtime"] != RuntimeType.PYTORCH:
|
|
569
|
+
export = deploy_params["model_info"].get("export", {})
|
|
570
|
+
export_model = export.get(deploy_params["runtime"], None)
|
|
571
|
+
if export_model is not None:
|
|
572
|
+
if sly_fs.get_file_name(export_model) == sly_fs.get_file_name(
|
|
573
|
+
deploy_params["model_files"]["checkpoint"]
|
|
574
|
+
):
|
|
575
|
+
deploy_params["model_files"]["checkpoint"] = (
|
|
576
|
+
deploy_params["model_info"]["artifacts_dir"] + export_model
|
|
577
|
+
)
|
|
578
|
+
logger.info(f"Found model checkpoint for '{deploy_params['runtime']}'")
|
|
579
|
+
|
|
580
|
+
if deploy_params["model_source"] == ModelSource.PRETRAINED:
|
|
581
|
+
return self._download_pretrained_model(deploy_params["model_files"], log_progress)
|
|
582
|
+
elif deploy_params["model_source"] == ModelSource.CUSTOM:
|
|
583
|
+
return self._download_custom_model(deploy_params["model_files"], log_progress)
|
|
577
584
|
|
|
578
585
|
def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
|
|
579
586
|
"""
|
|
@@ -2929,9 +2936,7 @@ class Inference:
|
|
|
2929
2936
|
state = request.state.state
|
|
2930
2937
|
deploy_params = state["deploy_params"]
|
|
2931
2938
|
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
2932
|
-
model_files = self._download_model_files(
|
|
2933
|
-
deploy_params["model_source"], deploy_params["model_files"]
|
|
2934
|
-
)
|
|
2939
|
+
model_files = self._download_model_files(deploy_params)
|
|
2935
2940
|
deploy_params["model_files"] = model_files
|
|
2936
2941
|
self._load_model_headless(**deploy_params)
|
|
2937
2942
|
elif isinstance(self.gui, GUI.ServingGUI):
|
|
@@ -3061,7 +3066,7 @@ class Inference:
|
|
|
3061
3066
|
raise ValueError("No pretrained models found.")
|
|
3062
3067
|
|
|
3063
3068
|
model = self.pretrained_models[0]
|
|
3064
|
-
model_name = model
|
|
3069
|
+
model_name = _get_model_name(model)
|
|
3065
3070
|
if model_name is None:
|
|
3066
3071
|
raise ValueError("No model name found in the first pretrained model.")
|
|
3067
3072
|
|
|
@@ -3126,7 +3131,7 @@ class Inference:
|
|
|
3126
3131
|
meta = m.get("meta", None)
|
|
3127
3132
|
if meta is None:
|
|
3128
3133
|
continue
|
|
3129
|
-
model_name =
|
|
3134
|
+
model_name = _get_model_name(m)
|
|
3130
3135
|
if model_name is None:
|
|
3131
3136
|
continue
|
|
3132
3137
|
m_files = meta.get("model_files", None)
|
|
@@ -3135,7 +3140,7 @@ class Inference:
|
|
|
3135
3140
|
checkpoint = m_files.get("checkpoint", None)
|
|
3136
3141
|
if checkpoint is None:
|
|
3137
3142
|
continue
|
|
3138
|
-
if model ==
|
|
3143
|
+
if model.lower() == model_name.lower():
|
|
3139
3144
|
model_info = m
|
|
3140
3145
|
model_source = ModelSource.PRETRAINED
|
|
3141
3146
|
model_files = {"checkpoint": checkpoint}
|
|
@@ -3153,8 +3158,6 @@ class Inference:
|
|
|
3153
3158
|
model_meta_path = os.path.join(artifacts_dir, "model_meta.json")
|
|
3154
3159
|
model_info["model_meta"] = self._load_json_file(model_meta_path)
|
|
3155
3160
|
original_model_files = model_info.get("model_files")
|
|
3156
|
-
if not original_model_files:
|
|
3157
|
-
raise ValueError("Invalid 'experiment_info.json'. Missing 'model_files' key.")
|
|
3158
3161
|
return model_info, original_model_files
|
|
3159
3162
|
|
|
3160
3163
|
def _prepare_local_model_files(artifacts_dir, checkpoint_path, original_model_files):
|
|
@@ -3201,6 +3204,7 @@ class Inference:
|
|
|
3201
3204
|
model_files = _prepare_local_model_files(
|
|
3202
3205
|
artifacts_dir, checkpoint_path, original_model_files
|
|
3203
3206
|
)
|
|
3207
|
+
|
|
3204
3208
|
else:
|
|
3205
3209
|
local_artifacts_dir = os.path.join(
|
|
3206
3210
|
self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
|
|
@@ -3298,7 +3302,11 @@ class Inference:
|
|
|
3298
3302
|
if draw:
|
|
3299
3303
|
raise ValueError("Draw visualization is not supported for project inference")
|
|
3300
3304
|
|
|
3301
|
-
state = {
|
|
3305
|
+
state = {
|
|
3306
|
+
"projectId": project_id,
|
|
3307
|
+
"dataset_ids": dataset_ids,
|
|
3308
|
+
"settings": settings,
|
|
3309
|
+
}
|
|
3302
3310
|
if upload:
|
|
3303
3311
|
source_project = api.project.get_info_by_id(project_id)
|
|
3304
3312
|
workspace_id = source_project.workspace_id
|
|
@@ -3472,7 +3480,7 @@ class Inference:
|
|
|
3472
3480
|
def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
|
|
3473
3481
|
if model_source == ModelSource.PRETRAINED:
|
|
3474
3482
|
checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
|
|
3475
|
-
checkpoint_name = model_info
|
|
3483
|
+
checkpoint_name = _get_model_name(model_info)
|
|
3476
3484
|
else:
|
|
3477
3485
|
checkpoint_name = sly_fs.get_file_name_with_ext(model_files["checkpoint"])
|
|
3478
3486
|
checkpoint_url = os.path.join(
|
|
@@ -4,21 +4,28 @@ from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
|
|
|
4
4
|
|
|
5
5
|
class ClassesSelector:
|
|
6
6
|
title = "Classes Selector"
|
|
7
|
-
description =
|
|
8
|
-
"Select classes that will be used for training. "
|
|
9
|
-
"Supported shapes are Bitmap, Polygon, Rectangle."
|
|
10
|
-
)
|
|
7
|
+
description = "Select classes that will be used for training"
|
|
11
8
|
lock_message = "Select training and validation splits to unlock"
|
|
12
9
|
|
|
13
10
|
def __init__(self, project_id: int, classes: list, app_options: dict = {}):
|
|
11
|
+
# Init widgets
|
|
12
|
+
self.qa_stats_text = None
|
|
13
|
+
self.classes_table = None
|
|
14
|
+
self.validator_text = None
|
|
15
|
+
self.button = None
|
|
16
|
+
self.container = None
|
|
17
|
+
self.card = None
|
|
18
|
+
# -------------------------------- #
|
|
19
|
+
|
|
14
20
|
self.display_widgets = []
|
|
21
|
+
self.app_options = app_options
|
|
15
22
|
|
|
16
23
|
# GUI Components
|
|
17
24
|
if is_development() or is_debug_with_sly_net():
|
|
18
25
|
qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
|
|
19
26
|
else:
|
|
20
27
|
qa_stats_link = f"/projects/{project_id}/stats/datasets"
|
|
21
|
-
qa_stats_text = Text(
|
|
28
|
+
self.qa_stats_text = Text(
|
|
22
29
|
text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
|
|
23
30
|
)
|
|
24
31
|
|
|
@@ -32,7 +39,7 @@ class ClassesSelector:
|
|
|
32
39
|
self.validator_text.hide()
|
|
33
40
|
self.button = Button("Select")
|
|
34
41
|
self.display_widgets.extend(
|
|
35
|
-
[qa_stats_text, self.classes_table, self.validator_text, self.button]
|
|
42
|
+
[self.qa_stats_text, self.classes_table, self.validator_text, self.button]
|
|
36
43
|
)
|
|
37
44
|
# -------------------------------- #
|
|
38
45
|
|
|
@@ -42,7 +49,7 @@ class ClassesSelector:
|
|
|
42
49
|
description=self.description,
|
|
43
50
|
content=self.container,
|
|
44
51
|
lock_message=self.lock_message,
|
|
45
|
-
collapsable=app_options.get("collapsable", False),
|
|
52
|
+
collapsable=self.app_options.get("collapsable", False),
|
|
46
53
|
)
|
|
47
54
|
self.card.lock()
|
|
48
55
|
|
|
@@ -62,14 +69,14 @@ class ClassesSelector:
|
|
|
62
69
|
def validate_step(self) -> bool:
|
|
63
70
|
self.validator_text.hide()
|
|
64
71
|
|
|
65
|
-
|
|
72
|
+
project_classes = self.classes_table.project_meta.obj_classes
|
|
73
|
+
if len(project_classes) == 0:
|
|
66
74
|
self.validator_text.set(text="Project has no classes", status="error")
|
|
67
75
|
self.validator_text.show()
|
|
68
76
|
return False
|
|
69
77
|
|
|
70
78
|
selected_classes = self.classes_table.get_selected_classes()
|
|
71
79
|
table_data = self.classes_table._table_data
|
|
72
|
-
|
|
73
80
|
empty_classes = [
|
|
74
81
|
row[0]["data"]
|
|
75
82
|
for row in table_data
|
|
@@ -78,23 +85,21 @@ class ClassesSelector:
|
|
|
78
85
|
|
|
79
86
|
n_classes = len(selected_classes)
|
|
80
87
|
if n_classes == 0:
|
|
81
|
-
|
|
88
|
+
message = "Please select at least one class"
|
|
89
|
+
status = "error"
|
|
82
90
|
else:
|
|
83
|
-
|
|
91
|
+
class_text = "class" if n_classes == 1 else "classes"
|
|
92
|
+
message = f"Selected {n_classes} {class_text}"
|
|
84
93
|
status = "success"
|
|
85
94
|
if empty_classes:
|
|
86
95
|
intersections = set(selected_classes).intersection(empty_classes)
|
|
87
96
|
if intersections:
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
else f". Selected classes have no annotations: {', '.join(intersections)}"
|
|
97
|
+
class_text = "class" if len(intersections) == 1 else "classes"
|
|
98
|
+
message += (
|
|
99
|
+
f". Selected {class_text} have no annotations: {', '.join(intersections)}"
|
|
92
100
|
)
|
|
93
101
|
status = "warning"
|
|
94
102
|
|
|
95
|
-
|
|
96
|
-
self.validator_text.set(
|
|
97
|
-
text=f"Selected {n_classes} {class_text}{warning_text}", status=status
|
|
98
|
-
)
|
|
103
|
+
self.validator_text.set(text=message, status=status)
|
|
99
104
|
self.validator_text.show()
|
|
100
105
|
return n_classes > 0
|
|
@@ -6,11 +6,15 @@ training workflows in Supervisely.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from os import environ
|
|
9
|
+
from typing import Union
|
|
9
10
|
|
|
10
11
|
import supervisely.io.env as sly_env
|
|
12
|
+
import supervisely.io.json as sly_json
|
|
11
13
|
from supervisely import Api, ProjectMeta
|
|
12
14
|
from supervisely._utils import is_production
|
|
13
15
|
from supervisely.app.widgets import Stepper, Widget
|
|
16
|
+
from supervisely.geometry.bitmap import Bitmap
|
|
17
|
+
from supervisely.geometry.graph import GraphNodes
|
|
14
18
|
from supervisely.geometry.polygon import Polygon
|
|
15
19
|
from supervisely.geometry.rectangle import Rectangle
|
|
16
20
|
from supervisely.nn.task_type import TaskType
|
|
@@ -63,7 +67,7 @@ class TrainGUI:
|
|
|
63
67
|
self.models = models
|
|
64
68
|
self.hyperparameters = hyperparameters
|
|
65
69
|
self.app_options = app_options
|
|
66
|
-
self.collapsable = app_options.get("collapsable", False)
|
|
70
|
+
self.collapsable = self.app_options.get("collapsable", False)
|
|
67
71
|
self.need_convert_shapes_for_bm = False
|
|
68
72
|
|
|
69
73
|
self.team_id = sly_env.team_id(raise_not_found=False)
|
|
@@ -142,33 +146,73 @@ class TrainGUI:
|
|
|
142
146
|
self.training_process.set_experiment_name(experiment_name)
|
|
143
147
|
|
|
144
148
|
def need_convert_class_shapes() -> bool:
|
|
145
|
-
if
|
|
146
|
-
self.hyperparameters_selector.
|
|
147
|
-
self.need_convert_shapes_for_bm = False
|
|
148
|
-
else:
|
|
149
|
-
task_type = self.model_selector.get_selected_task_type()
|
|
150
|
-
|
|
151
|
-
def _need_convert(shape):
|
|
152
|
-
if task_type == TaskType.OBJECT_DETECTION:
|
|
153
|
-
return shape != Rectangle.geometry_name()
|
|
154
|
-
elif task_type in [
|
|
155
|
-
TaskType.INSTANCE_SEGMENTATION,
|
|
156
|
-
TaskType.SEMANTIC_SEGMENTATION,
|
|
157
|
-
]:
|
|
158
|
-
return shape == Polygon.geometry_name()
|
|
159
|
-
return
|
|
160
|
-
|
|
161
|
-
data = self.classes_selector.classes_table._table_data
|
|
162
|
-
selected_classes = set(self.classes_selector.classes_table.get_selected_classes())
|
|
163
|
-
empty = set(r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0)
|
|
164
|
-
need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
|
|
165
|
-
|
|
166
|
-
if need_convert.intersection(selected_classes - empty):
|
|
167
|
-
self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
|
|
168
|
-
self.need_convert_shapes_for_bm = True
|
|
169
|
-
else:
|
|
149
|
+
if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
|
|
150
|
+
if not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
|
|
170
151
|
self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
|
|
171
152
|
self.need_convert_shapes_for_bm = False
|
|
153
|
+
else:
|
|
154
|
+
task_type = self.model_selector.get_selected_task_type()
|
|
155
|
+
|
|
156
|
+
def _need_convert(shape):
|
|
157
|
+
if task_type == TaskType.OBJECT_DETECTION:
|
|
158
|
+
return shape != Rectangle.geometry_name()
|
|
159
|
+
elif task_type in [
|
|
160
|
+
TaskType.INSTANCE_SEGMENTATION,
|
|
161
|
+
TaskType.SEMANTIC_SEGMENTATION,
|
|
162
|
+
]:
|
|
163
|
+
return shape == Polygon.geometry_name()
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
data = self.classes_selector.classes_table._table_data
|
|
167
|
+
selected_classes = set(
|
|
168
|
+
self.classes_selector.classes_table.get_selected_classes()
|
|
169
|
+
)
|
|
170
|
+
empty = set(
|
|
171
|
+
r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
|
|
172
|
+
)
|
|
173
|
+
need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
|
|
174
|
+
|
|
175
|
+
if need_convert.intersection(selected_classes - empty):
|
|
176
|
+
self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
|
|
177
|
+
self.need_convert_shapes_for_bm = True
|
|
178
|
+
else:
|
|
179
|
+
self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
|
|
180
|
+
self.need_convert_shapes_for_bm = False
|
|
181
|
+
else:
|
|
182
|
+
self.need_convert_shapes_for_bm = False
|
|
183
|
+
|
|
184
|
+
def validate_class_shape_for_model_task():
|
|
185
|
+
task_type = self.model_selector.get_selected_task_type()
|
|
186
|
+
classes = self.classes_selector.get_selected_classes()
|
|
187
|
+
|
|
188
|
+
required_geometries = {
|
|
189
|
+
TaskType.INSTANCE_SEGMENTATION: {Polygon, Bitmap},
|
|
190
|
+
TaskType.SEMANTIC_SEGMENTATION: {Polygon, Bitmap},
|
|
191
|
+
TaskType.POSE_ESTIMATION: {GraphNodes},
|
|
192
|
+
}
|
|
193
|
+
task_specific_texts = {
|
|
194
|
+
TaskType.INSTANCE_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
|
|
195
|
+
TaskType.SEMANTIC_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
|
|
196
|
+
TaskType.POSE_ESTIMATION: "Only keypoint (graph) shape is supported for pose estimation task",
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
if task_type not in required_geometries:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
wrong_shape_classes = [
|
|
203
|
+
class_name
|
|
204
|
+
for class_name in classes
|
|
205
|
+
if self.project_meta.get_obj_class(class_name).geometry_type
|
|
206
|
+
not in required_geometries[task_type]
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
if wrong_shape_classes:
|
|
210
|
+
specific_text = task_specific_texts[task_type]
|
|
211
|
+
message_text = f"Model task type is {task_type}. {specific_text}. Selected classes have wrong shapes for the model task: {', '.join(wrong_shape_classes)}"
|
|
212
|
+
self.model_selector.validator_text.set(
|
|
213
|
+
text=message_text,
|
|
214
|
+
status="warning",
|
|
215
|
+
)
|
|
172
216
|
|
|
173
217
|
# ------------------------------------------------- #
|
|
174
218
|
|
|
@@ -201,7 +245,11 @@ class TrainGUI:
|
|
|
201
245
|
callback=self.hyperparameters_selector_cb,
|
|
202
246
|
validation_text=self.model_selector.validator_text,
|
|
203
247
|
validation_func=self.model_selector.validate_step,
|
|
204
|
-
on_select_click=[
|
|
248
|
+
on_select_click=[
|
|
249
|
+
set_experiment_name,
|
|
250
|
+
need_convert_class_shapes,
|
|
251
|
+
validate_class_shape_for_model_task,
|
|
252
|
+
],
|
|
205
253
|
collapse_card=(self.model_selector.card, self.collapsable),
|
|
206
254
|
)
|
|
207
255
|
|
|
@@ -299,17 +347,19 @@ class TrainGUI:
|
|
|
299
347
|
# ------------------------------------------------- #
|
|
300
348
|
|
|
301
349
|
# Other Buttons
|
|
302
|
-
if app_options.get("show_logs_in_gui", False):
|
|
350
|
+
if self.app_options.get("show_logs_in_gui", False):
|
|
303
351
|
|
|
304
352
|
@self.training_logs.logs_button.click
|
|
305
353
|
def show_logs():
|
|
306
354
|
self.training_logs.toggle_logs()
|
|
307
355
|
|
|
308
356
|
# Other handlers
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
self.hyperparameters_selector.
|
|
312
|
-
|
|
357
|
+
if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
|
|
358
|
+
|
|
359
|
+
@self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
|
|
360
|
+
def show_mb_speedtest(is_checked: bool):
|
|
361
|
+
self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
|
|
362
|
+
need_convert_class_shapes()
|
|
313
363
|
|
|
314
364
|
# ------------------------------------------------- #
|
|
315
365
|
|
|
@@ -361,7 +411,6 @@ class TrainGUI:
|
|
|
361
411
|
raise ValueError("app_state must be a dictionary")
|
|
362
412
|
|
|
363
413
|
required_keys = {
|
|
364
|
-
"input": ["project_id"],
|
|
365
414
|
"train_val_split": ["method"],
|
|
366
415
|
"classes": list,
|
|
367
416
|
"model": ["source"],
|
|
@@ -453,7 +502,7 @@ class TrainGUI:
|
|
|
453
502
|
raise ValueError("percent must be an integer in range 1 to 99")
|
|
454
503
|
return app_state
|
|
455
504
|
|
|
456
|
-
def load_from_app_state(self, app_state: dict) -> None:
|
|
505
|
+
def load_from_app_state(self, app_state: Union[str, dict]) -> None:
|
|
457
506
|
"""
|
|
458
507
|
Load the GUI state from app state dictionary.
|
|
459
508
|
|
|
@@ -463,7 +512,6 @@ class TrainGUI:
|
|
|
463
512
|
app_state example:
|
|
464
513
|
|
|
465
514
|
app_state = {
|
|
466
|
-
"input": {"project_id": 43192},
|
|
467
515
|
"train_val_split": {
|
|
468
516
|
"method": "random",
|
|
469
517
|
"split": "train",
|
|
@@ -489,10 +537,13 @@ class TrainGUI:
|
|
|
489
537
|
}
|
|
490
538
|
}
|
|
491
539
|
"""
|
|
540
|
+
if isinstance(app_state, str):
|
|
541
|
+
app_state = sly_json.load_json_file(app_state)
|
|
542
|
+
|
|
492
543
|
app_state = self.validate_app_state(app_state)
|
|
493
544
|
|
|
494
545
|
options = app_state.get("options", {})
|
|
495
|
-
input_settings = app_state
|
|
546
|
+
input_settings = app_state.get("input")
|
|
496
547
|
train_val_splits_settings = app_state["train_val_split"]
|
|
497
548
|
classes_settings = app_state["classes"]
|
|
498
549
|
model_settings = app_state["model"]
|
|
@@ -504,7 +555,7 @@ class TrainGUI:
|
|
|
504
555
|
self._init_model(model_settings)
|
|
505
556
|
self._init_hyperparameters(hyperparameters_settings, options)
|
|
506
557
|
|
|
507
|
-
def _init_input(self, input_settings: dict, options: dict) -> None:
|
|
558
|
+
def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
|
|
508
559
|
"""
|
|
509
560
|
Initialize the input selector with the given settings.
|
|
510
561
|
|
|
@@ -604,6 +655,8 @@ class TrainGUI:
|
|
|
604
655
|
)
|
|
605
656
|
self.hyperparameters_selector.set_speedtest_checkbox_value(
|
|
606
657
|
model_benchmark_settings["speed_test"]
|
|
658
|
+
if model_benchmark_settings["enable"]
|
|
659
|
+
else False
|
|
607
660
|
)
|
|
608
661
|
export_weights_settings = options.get("export", None)
|
|
609
662
|
if export_weights_settings is not None:
|
|
@@ -9,7 +9,6 @@ from supervisely.app.widgets import (
|
|
|
9
9
|
Field,
|
|
10
10
|
Text,
|
|
11
11
|
)
|
|
12
|
-
from supervisely.nn.utils import RuntimeType
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class HyperparametersSelector:
|
|
@@ -18,6 +17,22 @@ class HyperparametersSelector:
|
|
|
18
17
|
lock_message = "Select model to unlock"
|
|
19
18
|
|
|
20
19
|
def __init__(self, hyperparameters: dict, app_options: dict = {}):
|
|
20
|
+
# Init widgets
|
|
21
|
+
self.editor = None
|
|
22
|
+
self.run_model_benchmark_checkbox = None
|
|
23
|
+
self.run_speedtest_checkbox = None
|
|
24
|
+
self.model_benchmark_field = None
|
|
25
|
+
self.model_benchmark_learn_more = None
|
|
26
|
+
self.model_benchmark_auto_convert_warning = None
|
|
27
|
+
self.export_onnx_checkbox = None
|
|
28
|
+
self.export_tensorrt_checkbox = None
|
|
29
|
+
self.export_field = None
|
|
30
|
+
self.validator_text = None
|
|
31
|
+
self.button = None
|
|
32
|
+
self.container = None
|
|
33
|
+
self.card = None
|
|
34
|
+
# -------------------------------- #
|
|
35
|
+
|
|
21
36
|
self.display_widgets = []
|
|
22
37
|
self.app_options = app_options
|
|
23
38
|
|
|
@@ -28,7 +43,7 @@ class HyperparametersSelector:
|
|
|
28
43
|
self.display_widgets.extend([self.editor])
|
|
29
44
|
|
|
30
45
|
# Optional Model Benchmark
|
|
31
|
-
if app_options.get("model_benchmark", True):
|
|
46
|
+
if self.app_options.get("model_benchmark", True):
|
|
32
47
|
# Model Benchmark
|
|
33
48
|
self.run_model_benchmark_checkbox = Checkbox(
|
|
34
49
|
content="Run Model Benchmark evaluation", checked=True
|
|
@@ -37,7 +52,7 @@ class HyperparametersSelector:
|
|
|
37
52
|
|
|
38
53
|
self.model_benchmark_field = Field(
|
|
39
54
|
title="Model Evaluation Benchmark",
|
|
40
|
-
description=
|
|
55
|
+
description="Generate evaluation dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
|
|
41
56
|
content=Container([self.run_model_benchmark_checkbox, self.run_speedtest_checkbox]),
|
|
42
57
|
)
|
|
43
58
|
docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
|
|
@@ -60,8 +75,8 @@ class HyperparametersSelector:
|
|
|
60
75
|
# -------------------------------- #
|
|
61
76
|
|
|
62
77
|
# Optional Export Weights
|
|
63
|
-
export_onnx_supported = app_options.get("export_onnx_supported", False)
|
|
64
|
-
export_tensorrt_supported = app_options.get("export_tensorrt_supported", False)
|
|
78
|
+
export_onnx_supported = self.app_options.get("export_onnx_supported", False)
|
|
79
|
+
export_tensorrt_supported = self.app_options.get("export_tensorrt_supported", False)
|
|
65
80
|
|
|
66
81
|
onnx_name = "ONNX"
|
|
67
82
|
tensorrt_name = "TensorRT engine"
|
|
@@ -120,26 +135,28 @@ class HyperparametersSelector:
|
|
|
120
135
|
return self.editor.get_value()
|
|
121
136
|
|
|
122
137
|
def get_model_benchmark_checkbox_value(self) -> bool:
|
|
123
|
-
if self.
|
|
138
|
+
if self.run_model_benchmark_checkbox is not None:
|
|
124
139
|
return self.run_model_benchmark_checkbox.is_checked()
|
|
125
140
|
return False
|
|
126
141
|
|
|
127
142
|
def set_model_benchmark_checkbox_value(self, is_checked: bool) -> bool:
|
|
128
|
-
if
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
143
|
+
if self.run_model_benchmark_checkbox is not None:
|
|
144
|
+
if is_checked:
|
|
145
|
+
self.run_model_benchmark_checkbox.check()
|
|
146
|
+
else:
|
|
147
|
+
self.run_model_benchmark_checkbox.uncheck()
|
|
132
148
|
|
|
133
149
|
def get_speedtest_checkbox_value(self) -> bool:
|
|
134
|
-
if self.
|
|
150
|
+
if self.run_speedtest_checkbox is not None:
|
|
135
151
|
return self.run_speedtest_checkbox.is_checked()
|
|
136
152
|
return False
|
|
137
153
|
|
|
138
154
|
def set_speedtest_checkbox_value(self, is_checked: bool) -> bool:
|
|
139
|
-
if
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
155
|
+
if self.run_speedtest_checkbox is not None:
|
|
156
|
+
if is_checked:
|
|
157
|
+
self.run_speedtest_checkbox.check()
|
|
158
|
+
else:
|
|
159
|
+
self.run_speedtest_checkbox.uncheck()
|
|
143
160
|
|
|
144
161
|
def toggle_mb_speedtest(self, is_checked: bool) -> None:
|
|
145
162
|
if is_checked:
|
|
@@ -4,7 +4,6 @@ from supervisely.app.widgets import (
|
|
|
4
4
|
Card,
|
|
5
5
|
Checkbox,
|
|
6
6
|
Container,
|
|
7
|
-
Field,
|
|
8
7
|
ProjectThumbnail,
|
|
9
8
|
Text,
|
|
10
9
|
)
|
|
@@ -17,7 +16,19 @@ class InputSelector:
|
|
|
17
16
|
lock_message = None
|
|
18
17
|
|
|
19
18
|
def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
|
|
19
|
+
# Init widgets
|
|
20
|
+
self.project_thumbnail = None
|
|
21
|
+
self.use_cache_text = None
|
|
22
|
+
self.use_cache_checkbox = None
|
|
23
|
+
self.validator_text = None
|
|
24
|
+
self.button = None
|
|
25
|
+
self.container = None
|
|
26
|
+
self.card = None
|
|
27
|
+
# -------------------------------- #
|
|
28
|
+
|
|
20
29
|
self.display_widgets = []
|
|
30
|
+
self.app_options = app_options
|
|
31
|
+
|
|
21
32
|
self.project_id = project_info.id
|
|
22
33
|
self.project_info = project_info
|
|
23
34
|
|
|
@@ -49,7 +60,7 @@ class InputSelector:
|
|
|
49
60
|
title=self.title,
|
|
50
61
|
description=self.description,
|
|
51
62
|
content=self.container,
|
|
52
|
-
collapsable=app_options.get("collapsable", False),
|
|
63
|
+
collapsable=self.app_options.get("collapsable", False),
|
|
53
64
|
)
|
|
54
65
|
|
|
55
66
|
@property
|
|
@@ -14,7 +14,7 @@ from supervisely.app.widgets import (
|
|
|
14
14
|
)
|
|
15
15
|
from supervisely.nn.artifacts.utils import FrameworkMapper
|
|
16
16
|
from supervisely.nn.experiments import get_experiment_infos
|
|
17
|
-
from supervisely.nn.utils import ModelSource
|
|
17
|
+
from supervisely.nn.utils import ModelSource, _get_model_name
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class ModelSelector:
|
|
@@ -23,15 +23,26 @@ class ModelSelector:
|
|
|
23
23
|
lock_message = "Select classes to unlock"
|
|
24
24
|
|
|
25
25
|
def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
|
|
26
|
+
# Init widgets
|
|
27
|
+
self.pretrained_models_table = None
|
|
28
|
+
self.experiment_selector = None
|
|
29
|
+
self.model_source_tabs = None
|
|
30
|
+
self.validator_text = None
|
|
31
|
+
self.button = None
|
|
32
|
+
self.container = None
|
|
33
|
+
self.card = None
|
|
34
|
+
# -------------------------------- #
|
|
35
|
+
|
|
26
36
|
self.display_widgets = []
|
|
37
|
+
self.app_options = app_options
|
|
38
|
+
|
|
27
39
|
self.team_id = sly_env.team_id()
|
|
28
40
|
self.models = models
|
|
29
41
|
|
|
30
42
|
# GUI Components
|
|
31
43
|
self.pretrained_models_table = PretrainedModelsSelector(self.models)
|
|
32
|
-
|
|
33
44
|
experiment_infos = get_experiment_infos(api, self.team_id, framework)
|
|
34
|
-
if app_options.get("legacy_checkpoints", False):
|
|
45
|
+
if self.app_options.get("legacy_checkpoints", False):
|
|
35
46
|
try:
|
|
36
47
|
framework_cls = FrameworkMapper.get_framework_cls(framework, self.team_id)
|
|
37
48
|
legacy_experiment_infos = framework_cls.get_list_experiment_info()
|
|
@@ -61,7 +72,7 @@ class ModelSelector:
|
|
|
61
72
|
description=self.description,
|
|
62
73
|
content=self.container,
|
|
63
74
|
lock_message=self.lock_message,
|
|
64
|
-
collapsable=app_options.get("collapsable", False),
|
|
75
|
+
collapsable=self.app_options.get("collapsable", False),
|
|
65
76
|
)
|
|
66
77
|
self.card.lock()
|
|
67
78
|
|
|
@@ -82,8 +93,7 @@ class ModelSelector:
|
|
|
82
93
|
def get_model_name(self) -> str:
|
|
83
94
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
84
95
|
selected_row = self.pretrained_models_table.get_selected_row()
|
|
85
|
-
|
|
86
|
-
model_name = model_meta.get("model_name", None)
|
|
96
|
+
model_name = _get_model_name(selected_row)
|
|
87
97
|
else:
|
|
88
98
|
selected_row = self.experiment_selector.get_selected_experiment_info()
|
|
89
99
|
model_name = selected_row.get("model_name", None)
|