supervisely 6.73.359__py3-none-any.whl → 6.73.360__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/app/widgets/project_thumbnail/project_thumbnail.py +3 -2
- supervisely/app/widgets/random_splits_table/random_splits_table.py +13 -2
- supervisely/app/widgets/random_splits_table/template.html +2 -2
- supervisely/app/widgets/select_app_session/select_app_session.py +3 -0
- supervisely/app/widgets/train_val_splits/train_val_splits.py +36 -24
- supervisely/nn/training/gui/gui.py +551 -186
- supervisely/nn/training/gui/input_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +26 -6
- supervisely/nn/training/gui/tags_selector.py +105 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +80 -18
- supervisely/nn/training/train_app.py +139 -43
- {supervisely-6.73.359.dist-info → supervisely-6.73.360.dist-info}/METADATA +80 -59
- {supervisely-6.73.359.dist-info → supervisely-6.73.360.dist-info}/RECORD +17 -16
- {supervisely-6.73.359.dist-info → supervisely-6.73.360.dist-info}/LICENSE +0 -0
- {supervisely-6.73.359.dist-info → supervisely-6.73.360.dist-info}/WHEEL +0 -0
- {supervisely-6.73.359.dist-info → supervisely-6.73.360.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.359.dist-info → supervisely-6.73.360.dist-info}/top_level.txt +0 -0
|
@@ -6,13 +6,13 @@ training workflows in Supervisely.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from os import environ
|
|
9
|
-
from typing import Union
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
10
10
|
|
|
11
11
|
import supervisely.io.env as sly_env
|
|
12
12
|
import supervisely.io.json as sly_json
|
|
13
13
|
from supervisely import Api, ProjectMeta
|
|
14
14
|
from supervisely._utils import is_production
|
|
15
|
-
from supervisely.app.widgets import Stepper, Widget
|
|
15
|
+
from supervisely.app.widgets import Button, Card, Stepper, Widget
|
|
16
16
|
from supervisely.geometry.bitmap import Bitmap
|
|
17
17
|
from supervisely.geometry.graph import GraphNodes
|
|
18
18
|
from supervisely.geometry.polygon import Polygon
|
|
@@ -22,6 +22,7 @@ from supervisely.nn.training.gui.classes_selector import ClassesSelector
|
|
|
22
22
|
from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
|
|
23
23
|
from supervisely.nn.training.gui.input_selector import InputSelector
|
|
24
24
|
from supervisely.nn.training.gui.model_selector import ModelSelector
|
|
25
|
+
from supervisely.nn.training.gui.tags_selector import TagsSelector
|
|
25
26
|
from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplitsSelector
|
|
26
27
|
from supervisely.nn.training.gui.training_artifacts import TrainingArtifacts
|
|
27
28
|
from supervisely.nn.training.gui.training_logs import TrainingLogs
|
|
@@ -30,6 +31,186 @@ from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_clic
|
|
|
30
31
|
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
class StepFlow:
|
|
35
|
+
"""
|
|
36
|
+
Manages the flow of steps in the GUI, including wrappers and button handlers.
|
|
37
|
+
|
|
38
|
+
Allows flexible configuration of dependencies between steps and automatically
|
|
39
|
+
sets up proper handlers based on layout from app_options.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, stepper: Stepper, app_options: Dict[str, Any]):
|
|
43
|
+
"""
|
|
44
|
+
Initializes the step manager.
|
|
45
|
+
|
|
46
|
+
:param stepper: Stepper object for step navigation
|
|
47
|
+
:param app_options: Application options
|
|
48
|
+
"""
|
|
49
|
+
self.stepper = stepper
|
|
50
|
+
self.app_options = app_options
|
|
51
|
+
self.collapsable = app_options.get("collapsable", False)
|
|
52
|
+
self.steps = {} # Step configuration
|
|
53
|
+
self.step_sequence = [] # Step sequence
|
|
54
|
+
|
|
55
|
+
def register_step(
|
|
56
|
+
self,
|
|
57
|
+
name: str,
|
|
58
|
+
card: Card,
|
|
59
|
+
button: Optional[Button] = None,
|
|
60
|
+
widgets_to_disable: Optional[List[Widget]] = None,
|
|
61
|
+
validation_text: Optional[Widget] = None,
|
|
62
|
+
validation_func: Optional[Callable] = None,
|
|
63
|
+
position: Optional[int] = None,
|
|
64
|
+
) -> "StepFlow":
|
|
65
|
+
"""
|
|
66
|
+
Registers a step in the GUI.
|
|
67
|
+
|
|
68
|
+
:param name: Unique step name
|
|
69
|
+
:param card: Step card widget
|
|
70
|
+
:param button: Button for proceeding to the next step (optional)
|
|
71
|
+
:param widgets_to_disable: Widgets to disable during validation (optional)
|
|
72
|
+
:param validation_text: Widget for displaying validation text (optional)
|
|
73
|
+
:param validation_func: Validation function (optional)
|
|
74
|
+
:param position: Step position in the sequence (starting from 0)
|
|
75
|
+
:return: Current StepFlow object for method chaining
|
|
76
|
+
"""
|
|
77
|
+
self.steps[name] = {
|
|
78
|
+
"card": card,
|
|
79
|
+
"button": button,
|
|
80
|
+
"widgets_to_disable": widgets_to_disable or [],
|
|
81
|
+
"validation_text": validation_text,
|
|
82
|
+
"validation_func": validation_func,
|
|
83
|
+
"position": position,
|
|
84
|
+
"next_steps": [],
|
|
85
|
+
"on_select_click": [],
|
|
86
|
+
"on_reselect_click": [],
|
|
87
|
+
"wrapper": None,
|
|
88
|
+
"has_button": button is not None,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if position is not None:
|
|
92
|
+
while len(self.step_sequence) <= position:
|
|
93
|
+
self.step_sequence.append(None)
|
|
94
|
+
self.step_sequence[position] = name
|
|
95
|
+
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
def set_next_steps(self, step_name: str, next_steps: List[str]) -> "StepFlow":
|
|
99
|
+
"""
|
|
100
|
+
Sets the list of next steps for the given step.
|
|
101
|
+
|
|
102
|
+
:param step_name: Current step name
|
|
103
|
+
:param next_steps: List of names of the next steps
|
|
104
|
+
:return: Current StepFlow object for method chaining
|
|
105
|
+
"""
|
|
106
|
+
if step_name in self.steps:
|
|
107
|
+
self.steps[step_name]["next_steps"] = next_steps
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
def add_on_select_actions(
|
|
111
|
+
self, step_name: str, actions: List[Callable], is_reselect: bool = False
|
|
112
|
+
) -> "StepFlow":
|
|
113
|
+
"""
|
|
114
|
+
Adds actions to be executed when a step is selected/reselected.
|
|
115
|
+
|
|
116
|
+
:param step_name: Step name
|
|
117
|
+
:param actions: List of functions to execute
|
|
118
|
+
:param is_reselect: True if these are actions for reselection, otherwise False
|
|
119
|
+
:return: Current StepFlow object for method chaining
|
|
120
|
+
"""
|
|
121
|
+
if step_name in self.steps:
|
|
122
|
+
key = "on_reselect_click" if is_reselect else "on_select_click"
|
|
123
|
+
self.steps[step_name][key].extend(actions)
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
def build_wrappers(self) -> Dict[str, Callable]:
|
|
127
|
+
"""
|
|
128
|
+
Creates wrappers for all steps based on established dependencies.
|
|
129
|
+
|
|
130
|
+
:return: Dictionary with created wrappers by step name
|
|
131
|
+
"""
|
|
132
|
+
valid_sequence = [s for s in self.step_sequence if s is not None and s in self.steps]
|
|
133
|
+
|
|
134
|
+
for step_name in reversed(valid_sequence):
|
|
135
|
+
step = self.steps[step_name]
|
|
136
|
+
|
|
137
|
+
cards_to_unlock = []
|
|
138
|
+
for next_step_name in step["next_steps"]:
|
|
139
|
+
if next_step_name in self.steps:
|
|
140
|
+
cards_to_unlock.append(self.steps[next_step_name]["card"])
|
|
141
|
+
|
|
142
|
+
callback = None
|
|
143
|
+
if step["next_steps"] and step["has_button"]:
|
|
144
|
+
for next_step_name in step["next_steps"]:
|
|
145
|
+
if (
|
|
146
|
+
next_step_name in self.steps
|
|
147
|
+
and self.steps[next_step_name].get("wrapper")
|
|
148
|
+
and self.steps[next_step_name]["has_button"]
|
|
149
|
+
):
|
|
150
|
+
callback = self.steps[next_step_name]["wrapper"]
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
if step["has_button"]:
|
|
154
|
+
wrapper = wrap_button_click(
|
|
155
|
+
button=step["button"],
|
|
156
|
+
cards_to_unlock=cards_to_unlock,
|
|
157
|
+
widgets_to_disable=step["widgets_to_disable"],
|
|
158
|
+
callback=callback,
|
|
159
|
+
validation_text=step["validation_text"],
|
|
160
|
+
validation_func=step["validation_func"],
|
|
161
|
+
on_select_click=step["on_select_click"],
|
|
162
|
+
on_reselect_click=step["on_reselect_click"],
|
|
163
|
+
collapse_card=(step["card"], self.collapsable),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
step["wrapper"] = wrapper
|
|
167
|
+
|
|
168
|
+
return {
|
|
169
|
+
name: self.steps[name]["wrapper"]
|
|
170
|
+
for name in self.steps
|
|
171
|
+
if self.steps[name].get("wrapper") and self.steps[name]["has_button"]
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
def setup_button_handlers(self) -> None:
|
|
175
|
+
"""
|
|
176
|
+
Sets up handlers for buttons of all steps.
|
|
177
|
+
"""
|
|
178
|
+
positions = {}
|
|
179
|
+
pos = 1
|
|
180
|
+
|
|
181
|
+
for i, step_name in enumerate(self.step_sequence):
|
|
182
|
+
if step_name is not None and step_name in self.steps:
|
|
183
|
+
positions[step_name] = pos
|
|
184
|
+
pos += 1
|
|
185
|
+
|
|
186
|
+
for step_name, step in self.steps.items():
|
|
187
|
+
if step_name in positions and step.get("wrapper") and step["has_button"]:
|
|
188
|
+
|
|
189
|
+
button = step["button"]
|
|
190
|
+
wrapper = step["wrapper"]
|
|
191
|
+
position = positions[step_name]
|
|
192
|
+
next_position = position + 1
|
|
193
|
+
|
|
194
|
+
def create_handler(btn, cb, next_pos):
|
|
195
|
+
def handler():
|
|
196
|
+
cb()
|
|
197
|
+
set_stepper_step(self.stepper, btn, next_pos=next_pos)
|
|
198
|
+
|
|
199
|
+
return handler
|
|
200
|
+
|
|
201
|
+
button.click(create_handler(button, wrapper, next_position))
|
|
202
|
+
|
|
203
|
+
def build(self) -> Dict[str, Callable]:
|
|
204
|
+
"""
|
|
205
|
+
Performs the complete setup of the step system.
|
|
206
|
+
|
|
207
|
+
:return: Dictionary with created wrappers by step name
|
|
208
|
+
"""
|
|
209
|
+
wrappers = self.build_wrappers()
|
|
210
|
+
self.setup_button_handlers()
|
|
211
|
+
return wrappers
|
|
212
|
+
|
|
213
|
+
|
|
33
214
|
class TrainGUI:
|
|
34
215
|
"""
|
|
35
216
|
A class representing the GUI for training workflows.
|
|
@@ -74,6 +255,9 @@ class TrainGUI:
|
|
|
74
255
|
self.workspace_id = sly_env.workspace_id(raise_not_found=False)
|
|
75
256
|
self.project_id = sly_env.project_id()
|
|
76
257
|
self.project_info = self._api.project.get_info_by_id(self.project_id)
|
|
258
|
+
if self.project_info.type is None:
|
|
259
|
+
raise ValueError(f"Project with ID: '{self.project_id}' does not exist or was archived")
|
|
260
|
+
|
|
77
261
|
self.project_meta = ProjectMeta.from_json(self._api.project.get_meta(self.project_id))
|
|
78
262
|
|
|
79
263
|
if self.workspace_id is None:
|
|
@@ -83,51 +267,78 @@ class TrainGUI:
|
|
|
83
267
|
self.team_id = self.project_info.team_id
|
|
84
268
|
environ["TEAM_ID"] = str(self.team_id)
|
|
85
269
|
|
|
86
|
-
#
|
|
270
|
+
# ---------- Parse selector options ----------
|
|
271
|
+
self._classes_selector_opts = self.app_options.get("classes_selector", {})
|
|
272
|
+
self._tags_selector_opts = self.app_options.get("tags_selector", {})
|
|
273
|
+
self._train_val_splits_opts = self.app_options.get("train_val_splits_selector", {})
|
|
274
|
+
self._model_selector_opts = self.app_options.get("model_selector", {})
|
|
275
|
+
|
|
276
|
+
self.show_classes_selector = self._classes_selector_opts.get("enabled", True)
|
|
277
|
+
self.show_tags_selector = self._tags_selector_opts.get("enabled", False)
|
|
278
|
+
self.show_train_val_splits_selector = self._train_val_splits_opts.get("enabled", True)
|
|
279
|
+
self.show_model_selector = self._model_selector_opts.get("enabled", True)
|
|
280
|
+
|
|
281
|
+
# Ensure train_val_splits_methods compatibility
|
|
282
|
+
self._train_val_methods = self._train_val_splits_opts.get("methods", [])
|
|
283
|
+
# --------------------------------------------------------- #
|
|
284
|
+
|
|
285
|
+
# ------------------------------------------------- #
|
|
286
|
+
self.steps = []
|
|
287
|
+
|
|
288
|
+
# 1. Project selection
|
|
87
289
|
self.input_selector = InputSelector(self.project_info, self.app_options)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
290
|
+
self.steps.append(self.input_selector.card)
|
|
291
|
+
|
|
292
|
+
# 2. Train/val split
|
|
293
|
+
self.train_val_splits_selector = None
|
|
294
|
+
if self.show_train_val_splits_selector:
|
|
295
|
+
self.train_val_splits_selector = TrainValSplitsSelector(
|
|
296
|
+
self._api, self.project_id, self.app_options
|
|
297
|
+
)
|
|
298
|
+
self.steps.append(self.train_val_splits_selector.card)
|
|
299
|
+
|
|
300
|
+
# 3. Select Classes
|
|
301
|
+
self.classes_selector = None
|
|
302
|
+
if self.show_classes_selector:
|
|
303
|
+
self.classes_selector = ClassesSelector(self.project_id, [], self.app_options)
|
|
304
|
+
self.steps.append(self.classes_selector.card)
|
|
305
|
+
|
|
306
|
+
# 4. Select Tags
|
|
307
|
+
self.tags_selector = None
|
|
308
|
+
if self.show_tags_selector:
|
|
309
|
+
self.tags_selector = TagsSelector(self.project_id, [], self.app_options)
|
|
310
|
+
self.steps.append(self.tags_selector.card)
|
|
311
|
+
|
|
312
|
+
# 5. Model selection
|
|
95
313
|
self.model_selector = ModelSelector(
|
|
96
314
|
self._api, self.framework_name, self.models, self.app_options
|
|
97
315
|
)
|
|
98
|
-
|
|
316
|
+
if self.show_model_selector:
|
|
317
|
+
self.steps.append(self.model_selector.card)
|
|
318
|
+
|
|
319
|
+
# 6. Training parameters (yaml)
|
|
99
320
|
self.hyperparameters_selector = HyperparametersSelector(
|
|
100
321
|
self.hyperparameters, self.app_options
|
|
101
322
|
)
|
|
102
|
-
|
|
323
|
+
self.steps.append(self.hyperparameters_selector.card)
|
|
324
|
+
|
|
325
|
+
# 7. Start Training
|
|
103
326
|
self.training_process = TrainingProcess(self.app_options)
|
|
327
|
+
self.steps.append(self.training_process.card)
|
|
104
328
|
|
|
105
|
-
#
|
|
329
|
+
# 8. Training logs
|
|
106
330
|
self.training_logs = TrainingLogs(self.app_options)
|
|
331
|
+
self.steps.append(self.training_logs.card)
|
|
107
332
|
|
|
108
|
-
#
|
|
333
|
+
# 9. Training Artifacts
|
|
109
334
|
self.training_artifacts = TrainingArtifacts(self._api, self.app_options)
|
|
335
|
+
self.steps.append(self.training_artifacts.card)
|
|
110
336
|
|
|
111
337
|
# Stepper layout
|
|
112
|
-
self.
|
|
113
|
-
self.input_selector.card,
|
|
114
|
-
self.train_val_splits_selector.card,
|
|
115
|
-
self.classes_selector.card,
|
|
116
|
-
self.model_selector.card,
|
|
117
|
-
self.hyperparameters_selector.card,
|
|
118
|
-
self.training_process.card,
|
|
119
|
-
self.training_logs.card,
|
|
120
|
-
self.training_artifacts.card,
|
|
121
|
-
]
|
|
122
|
-
self.stepper = Stepper(
|
|
123
|
-
widgets=self.steps,
|
|
124
|
-
)
|
|
338
|
+
self.stepper = Stepper(widgets=self.steps)
|
|
125
339
|
# ------------------------------------------------- #
|
|
126
340
|
|
|
127
341
|
# Button utils
|
|
128
|
-
def update_classes_table():
|
|
129
|
-
pass
|
|
130
|
-
|
|
131
342
|
def disable_hyperparams_editor():
|
|
132
343
|
if self.hyperparameters_selector.editor.readonly:
|
|
133
344
|
self.hyperparameters_selector.editor.readonly = False
|
|
@@ -161,16 +372,28 @@ class TrainGUI:
|
|
|
161
372
|
TaskType.SEMANTIC_SEGMENTATION,
|
|
162
373
|
]:
|
|
163
374
|
return shape == Polygon.geometry_name()
|
|
164
|
-
return
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
375
|
+
return False
|
|
376
|
+
|
|
377
|
+
if self.classes_selector is not None:
|
|
378
|
+
data = self.classes_selector.classes_table._table_data
|
|
379
|
+
selected_classes = set(
|
|
380
|
+
self.classes_selector.classes_table.get_selected_classes()
|
|
381
|
+
)
|
|
382
|
+
empty = set(
|
|
383
|
+
r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
|
|
384
|
+
)
|
|
385
|
+
need_convert = set(
|
|
386
|
+
r[0]["data"] for r in data if _need_convert(r[1]["data"])
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
# Set project meta classes when classes selector is disabled
|
|
390
|
+
selected_classes = set(cls.name for cls in self.project_meta.obj_classes)
|
|
391
|
+
need_convert = set(
|
|
392
|
+
obj_class.name
|
|
393
|
+
for obj_class in self.project_meta.obj_classes
|
|
394
|
+
if _need_convert(obj_class.geometry_type)
|
|
395
|
+
)
|
|
396
|
+
empty = set()
|
|
174
397
|
|
|
175
398
|
if need_convert.intersection(selected_classes - empty):
|
|
176
399
|
self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
|
|
@@ -183,7 +406,10 @@ class TrainGUI:
|
|
|
183
406
|
|
|
184
407
|
def validate_class_shape_for_model_task():
|
|
185
408
|
task_type = self.model_selector.get_selected_task_type()
|
|
186
|
-
|
|
409
|
+
if self.classes_selector is not None:
|
|
410
|
+
classes = self.classes_selector.get_selected_classes()
|
|
411
|
+
else:
|
|
412
|
+
classes = list(self.project_meta.obj_classes.keys())
|
|
187
413
|
|
|
188
414
|
required_geometries = {
|
|
189
415
|
TaskType.INSTANCE_SEGMENTATION: {Polygon, Bitmap},
|
|
@@ -216,144 +442,189 @@ class TrainGUI:
|
|
|
216
442
|
|
|
217
443
|
# ------------------------------------------------- #
|
|
218
444
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
button=self.hyperparameters_selector.button,
|
|
222
|
-
cards_to_unlock=[self.training_logs.card],
|
|
223
|
-
widgets_to_disable=self.training_process.widgets_to_disable,
|
|
224
|
-
callback=None,
|
|
225
|
-
validation_text=self.training_process.validator_text,
|
|
226
|
-
validation_func=self.training_process.validate_step,
|
|
227
|
-
)
|
|
445
|
+
self.step_flow = StepFlow(self.stepper, self.app_options)
|
|
446
|
+
position = 0
|
|
228
447
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
collapse_card=(self.hyperparameters_selector.card, self.collapsable),
|
|
448
|
+
# 1. Input selector
|
|
449
|
+
self.step_flow.register_step(
|
|
450
|
+
"input_selector",
|
|
451
|
+
self.input_selector.card,
|
|
452
|
+
self.input_selector.button,
|
|
453
|
+
self.input_selector.widgets_to_disable,
|
|
454
|
+
self.input_selector.validator_text,
|
|
455
|
+
self.input_selector.validate_step,
|
|
456
|
+
position=position,
|
|
239
457
|
)
|
|
458
|
+
position += 1
|
|
459
|
+
|
|
460
|
+
# 2. Train/Val splits selector
|
|
461
|
+
if self.show_train_val_splits_selector and self.train_val_splits_selector is not None:
|
|
462
|
+
self.step_flow.register_step(
|
|
463
|
+
"train_val_splits",
|
|
464
|
+
self.train_val_splits_selector.card,
|
|
465
|
+
self.train_val_splits_selector.button,
|
|
466
|
+
self.train_val_splits_selector.widgets_to_disable,
|
|
467
|
+
self.train_val_splits_selector.validator_text,
|
|
468
|
+
self.train_val_splits_selector.validate_step,
|
|
469
|
+
position=position,
|
|
470
|
+
)
|
|
471
|
+
position += 1
|
|
472
|
+
|
|
473
|
+
# 3. Classes selector
|
|
474
|
+
if self.show_classes_selector and self.classes_selector is not None:
|
|
475
|
+
self.step_flow.register_step(
|
|
476
|
+
"classes_selector",
|
|
477
|
+
self.classes_selector.card,
|
|
478
|
+
self.classes_selector.button,
|
|
479
|
+
self.classes_selector.widgets_to_disable,
|
|
480
|
+
self.classes_selector.validator_text,
|
|
481
|
+
self.classes_selector.validate_step,
|
|
482
|
+
position=position,
|
|
483
|
+
)
|
|
484
|
+
position += 1
|
|
485
|
+
|
|
486
|
+
# 4. Tags selector
|
|
487
|
+
if self.show_tags_selector and self.tags_selector is not None:
|
|
488
|
+
self.step_flow.register_step(
|
|
489
|
+
"tags_selector",
|
|
490
|
+
self.tags_selector.card,
|
|
491
|
+
self.tags_selector.button,
|
|
492
|
+
self.tags_selector.widgets_to_disable,
|
|
493
|
+
self.tags_selector.validator_text,
|
|
494
|
+
self.tags_selector.validate_step,
|
|
495
|
+
position=position,
|
|
496
|
+
)
|
|
497
|
+
position += 1
|
|
498
|
+
|
|
499
|
+
# 5. Model selector
|
|
500
|
+
if self.show_model_selector:
|
|
501
|
+
self.step_flow.register_step(
|
|
502
|
+
"model_selector",
|
|
503
|
+
self.model_selector.card,
|
|
504
|
+
self.model_selector.button,
|
|
505
|
+
self.model_selector.widgets_to_disable,
|
|
506
|
+
self.model_selector.validator_text,
|
|
507
|
+
self.model_selector.validate_step,
|
|
508
|
+
position=position,
|
|
509
|
+
).add_on_select_actions(
|
|
510
|
+
"model_selector",
|
|
511
|
+
[
|
|
512
|
+
set_experiment_name,
|
|
513
|
+
need_convert_class_shapes,
|
|
514
|
+
validate_class_shape_for_model_task,
|
|
515
|
+
],
|
|
516
|
+
)
|
|
517
|
+
position += 1
|
|
240
518
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
],
|
|
253
|
-
collapse_card=(self.model_selector.card, self.collapsable),
|
|
519
|
+
# 6. Hyperparameters selector
|
|
520
|
+
self.step_flow.register_step(
|
|
521
|
+
"hyperparameters_selector",
|
|
522
|
+
self.hyperparameters_selector.card,
|
|
523
|
+
self.hyperparameters_selector.button,
|
|
524
|
+
self.hyperparameters_selector.widgets_to_disable,
|
|
525
|
+
self.hyperparameters_selector.validator_text,
|
|
526
|
+
self.hyperparameters_selector.validate_step,
|
|
527
|
+
position=position,
|
|
528
|
+
).add_on_select_actions("hyperparameters_selector", [disable_hyperparams_editor])
|
|
529
|
+
self.step_flow.add_on_select_actions(
|
|
530
|
+
"hyperparameters_selector", [disable_hyperparams_editor], is_reselect=True
|
|
254
531
|
)
|
|
532
|
+
position += 1
|
|
255
533
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
534
|
+
# 7. Training process
|
|
535
|
+
self.step_flow.register_step(
|
|
536
|
+
"training_process",
|
|
537
|
+
self.training_process.card,
|
|
538
|
+
None,
|
|
539
|
+
self.training_process.widgets_to_disable,
|
|
540
|
+
self.training_process.validator_text,
|
|
541
|
+
self.training_process.validate_step,
|
|
542
|
+
position=position,
|
|
264
543
|
)
|
|
544
|
+
position += 1
|
|
265
545
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
546
|
+
# 8. Training logs
|
|
547
|
+
self.step_flow.register_step(
|
|
548
|
+
"training_logs",
|
|
549
|
+
self.training_logs.card,
|
|
550
|
+
None,
|
|
551
|
+
self.training_logs.widgets_to_disable,
|
|
552
|
+
self.training_logs.validator_text,
|
|
553
|
+
self.training_logs.validate_step,
|
|
554
|
+
position=position,
|
|
274
555
|
)
|
|
556
|
+
position += 1
|
|
275
557
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
558
|
+
# 9. Training artifacts
|
|
559
|
+
self.step_flow.register_step(
|
|
560
|
+
"training_artifacts",
|
|
561
|
+
self.training_artifacts.card,
|
|
562
|
+
None,
|
|
563
|
+
self.training_artifacts.widgets_to_disable,
|
|
564
|
+
self.training_artifacts.validator_text,
|
|
565
|
+
self.training_artifacts.validate_step,
|
|
566
|
+
position=position,
|
|
285
567
|
)
|
|
286
|
-
# ------------------------------------------------- #
|
|
287
568
|
|
|
288
|
-
#
|
|
569
|
+
# Set dependencies between steps
|
|
570
|
+
has_train_val_splits = (
|
|
571
|
+
self.show_train_val_splits_selector and self.train_val_splits_selector is not None
|
|
572
|
+
)
|
|
573
|
+
has_classes_selector = self.show_classes_selector and self.classes_selector is not None
|
|
574
|
+
has_tags_selector = self.show_tags_selector and self.tags_selector is not None
|
|
575
|
+
|
|
576
|
+
# Set step dependency chain
|
|
577
|
+
# 1. Input selector
|
|
578
|
+
prev_step = "input_selector"
|
|
579
|
+
if has_train_val_splits:
|
|
580
|
+
self.step_flow.set_next_steps(prev_step, ["train_val_splits"])
|
|
581
|
+
prev_step = "train_val_splits"
|
|
582
|
+
if has_classes_selector:
|
|
583
|
+
self.step_flow.set_next_steps(prev_step, ["classes_selector"])
|
|
584
|
+
prev_step = "classes_selector"
|
|
585
|
+
if has_tags_selector:
|
|
586
|
+
self.step_flow.set_next_steps(prev_step, ["tags_selector"])
|
|
587
|
+
prev_step = "tags_selector"
|
|
588
|
+
|
|
589
|
+
if self.show_model_selector and self.model_selector is not None:
|
|
590
|
+
self.step_flow.set_next_steps(prev_step, ["model_selector"])
|
|
591
|
+
# Model selector -> hyperparameters
|
|
592
|
+
self.step_flow.set_next_steps("model_selector", ["hyperparameters_selector"])
|
|
593
|
+
prev_step = "model_selector"
|
|
594
|
+
else:
|
|
595
|
+
self.step_flow.set_next_steps(prev_step, ["hyperparameters_selector"])
|
|
289
596
|
|
|
290
|
-
#
|
|
291
|
-
|
|
292
|
-
# def start_training():
|
|
293
|
-
# pass
|
|
597
|
+
# 6. Hyperparameters selector -> 7. Training process
|
|
598
|
+
self.step_flow.set_next_steps("hyperparameters_selector", ["training_process"])
|
|
294
599
|
|
|
295
|
-
#
|
|
296
|
-
|
|
297
|
-
# pass
|
|
600
|
+
# 7. Training process -> 8. Training logs
|
|
601
|
+
self.step_flow.set_next_steps("training_process", ["training_logs"])
|
|
298
602
|
|
|
603
|
+
# 8. Training logs -> 9. Training artifacts
|
|
604
|
+
self.step_flow.set_next_steps("training_logs", ["training_artifacts"])
|
|
299
605
|
# ------------------------------------------------- #
|
|
300
606
|
|
|
301
|
-
#
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
self.model_selector_cb()
|
|
314
|
-
set_stepper_step(
|
|
315
|
-
self.stepper,
|
|
316
|
-
self.model_selector.button,
|
|
317
|
-
next_pos=5,
|
|
318
|
-
)
|
|
319
|
-
|
|
320
|
-
@self.classes_selector.button.click
|
|
321
|
-
def select_classes():
|
|
322
|
-
self.classes_selector_cb()
|
|
323
|
-
set_stepper_step(
|
|
324
|
-
self.stepper,
|
|
325
|
-
self.classes_selector.button,
|
|
326
|
-
next_pos=4,
|
|
327
|
-
)
|
|
328
|
-
|
|
329
|
-
@self.train_val_splits_selector.button.click
|
|
330
|
-
def select_train_val_splits():
|
|
331
|
-
self.train_val_splits_selector_cb()
|
|
332
|
-
set_stepper_step(
|
|
333
|
-
self.stepper,
|
|
334
|
-
self.train_val_splits_selector.button,
|
|
335
|
-
next_pos=3,
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
@self.input_selector.button.click
|
|
339
|
-
def select_input():
|
|
340
|
-
self.input_selector_cb()
|
|
341
|
-
set_stepper_step(
|
|
342
|
-
self.stepper,
|
|
343
|
-
self.input_selector.button,
|
|
344
|
-
next_pos=2,
|
|
345
|
-
)
|
|
346
|
-
|
|
607
|
+
# Create all wrappers and set button handlers
|
|
608
|
+
wrappers = self.step_flow.build()
|
|
609
|
+
|
|
610
|
+
self.input_selector_cb = wrappers.get("input_selector")
|
|
611
|
+
self.train_val_splits_selector_cb = wrappers.get("train_val_splits")
|
|
612
|
+
self.classes_selector_cb = wrappers.get("classes_selector")
|
|
613
|
+
self.tags_selector_cb = wrappers.get("tags_selector")
|
|
614
|
+
self.model_selector_cb = wrappers.get("model_selector")
|
|
615
|
+
self.hyperparameters_selector_cb = wrappers.get("hyperparameters_selector")
|
|
616
|
+
self.training_process_cb = wrappers.get("training_process")
|
|
617
|
+
self.training_logs_cb = wrappers.get("training_logs")
|
|
618
|
+
self.training_artifacts_cb = wrappers.get("training_artifacts")
|
|
347
619
|
# ------------------------------------------------- #
|
|
348
620
|
|
|
349
|
-
# Other
|
|
621
|
+
# Other handlers
|
|
350
622
|
if self.app_options.get("show_logs_in_gui", False):
|
|
351
623
|
|
|
352
624
|
@self.training_logs.logs_button.click
|
|
353
625
|
def show_logs():
|
|
354
626
|
self.training_logs.toggle_logs()
|
|
355
627
|
|
|
356
|
-
# Other handlers
|
|
357
628
|
if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
|
|
358
629
|
|
|
359
630
|
@self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
|
|
@@ -365,6 +636,8 @@ class TrainGUI:
|
|
|
365
636
|
|
|
366
637
|
self.layout: Widget = self.stepper
|
|
367
638
|
|
|
639
|
+
# (дублирующийся блок был перемещён выше и здесь удалён)
|
|
640
|
+
|
|
368
641
|
def set_next_step(self):
|
|
369
642
|
current_step = self.stepper.get_active_step()
|
|
370
643
|
self.stepper.set_active_step(current_step + 1)
|
|
@@ -383,21 +656,35 @@ class TrainGUI:
|
|
|
383
656
|
"""
|
|
384
657
|
Makes all select buttons in the GUI available for interaction.
|
|
385
658
|
"""
|
|
386
|
-
self.input_selector
|
|
387
|
-
|
|
388
|
-
self.
|
|
389
|
-
|
|
390
|
-
self.
|
|
659
|
+
if self.input_selector is not None:
|
|
660
|
+
self.input_selector.button.enable()
|
|
661
|
+
if self.train_val_splits_selector is not None:
|
|
662
|
+
self.train_val_splits_selector.button.enable()
|
|
663
|
+
if self.classes_selector is not None:
|
|
664
|
+
self.classes_selector.button.enable()
|
|
665
|
+
if self.tags_selector is not None:
|
|
666
|
+
self.tags_selector.button.enable()
|
|
667
|
+
if self.model_selector is not None:
|
|
668
|
+
self.model_selector.button.enable()
|
|
669
|
+
if self.hyperparameters_selector is not None:
|
|
670
|
+
self.hyperparameters_selector.button.enable()
|
|
391
671
|
|
|
392
672
|
def disable_select_buttons(self):
|
|
393
673
|
"""
|
|
394
674
|
Makes all select buttons in the GUI unavailable for interaction.
|
|
395
675
|
"""
|
|
396
|
-
self.input_selector
|
|
397
|
-
|
|
398
|
-
self.
|
|
399
|
-
|
|
400
|
-
self.
|
|
676
|
+
if self.input_selector is not None:
|
|
677
|
+
self.input_selector.button.disable()
|
|
678
|
+
if self.train_val_splits_selector is not None:
|
|
679
|
+
self.train_val_splits_selector.button.disable()
|
|
680
|
+
if self.classes_selector is not None:
|
|
681
|
+
self.classes_selector.button.disable()
|
|
682
|
+
if self.tags_selector is not None:
|
|
683
|
+
self.tags_selector.button.disable()
|
|
684
|
+
if self.model_selector is not None:
|
|
685
|
+
self.model_selector.button.disable()
|
|
686
|
+
if self.hyperparameters_selector is not None:
|
|
687
|
+
self.hyperparameters_selector.button.disable()
|
|
401
688
|
|
|
402
689
|
# Set GUI from config
|
|
403
690
|
def validate_app_state(self, app_state: dict) -> dict:
|
|
@@ -410,29 +697,54 @@ class TrainGUI:
|
|
|
410
697
|
if not isinstance(app_state, dict):
|
|
411
698
|
raise ValueError("app_state must be a dictionary")
|
|
412
699
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
700
|
+
show_train_val = self.show_train_val_splits_selector
|
|
701
|
+
show_classes = self.show_classes_selector
|
|
702
|
+
show_tags = self.show_tags_selector
|
|
703
|
+
|
|
704
|
+
# Basic required keys always needed
|
|
705
|
+
base_required = {
|
|
416
706
|
"model": ["source"],
|
|
417
|
-
"hyperparameters": (dict, str),
|
|
707
|
+
"hyperparameters": (dict, str),
|
|
418
708
|
}
|
|
419
|
-
|
|
420
|
-
|
|
709
|
+
if show_train_val:
|
|
710
|
+
base_required["train_val_split"] = ["method"]
|
|
711
|
+
if show_classes:
|
|
712
|
+
base_required["classes"] = list
|
|
713
|
+
if show_tags:
|
|
714
|
+
base_required["tags"] = list
|
|
715
|
+
|
|
716
|
+
for key, subkeys_or_type in base_required.items():
|
|
421
717
|
if key not in app_state:
|
|
422
718
|
raise KeyError(f"Missing required key in app_state: {key}")
|
|
423
|
-
|
|
424
719
|
if isinstance(subkeys_or_type, list):
|
|
425
720
|
for subkey in subkeys_or_type:
|
|
426
721
|
if subkey not in app_state[key]:
|
|
427
722
|
raise KeyError(f"Missing required key in app_state['{key}']: {subkey}")
|
|
428
723
|
elif not isinstance(app_state[key], subkeys_or_type):
|
|
429
|
-
valid_types =
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
724
|
+
valid_types = ""
|
|
725
|
+
if isinstance(subkeys_or_type, tuple):
|
|
726
|
+
type_names = []
|
|
727
|
+
for t in subkeys_or_type:
|
|
728
|
+
if hasattr(t, "__name__"):
|
|
729
|
+
type_names.append(t.__name__)
|
|
730
|
+
else:
|
|
731
|
+
type_names.append(type(t).__name__)
|
|
732
|
+
valid_types = " or ".join(type_names)
|
|
733
|
+
else:
|
|
734
|
+
if hasattr(subkeys_or_type, "__name__"):
|
|
735
|
+
valid_types = subkeys_or_type.__name__
|
|
736
|
+
else:
|
|
737
|
+
valid_types = type(subkeys_or_type).__name__
|
|
434
738
|
raise ValueError(f"app_state['{key}'] must be of type {valid_types}")
|
|
435
739
|
|
|
740
|
+
# Provide defaults for optional sections when selectors are disabled
|
|
741
|
+
if not show_train_val:
|
|
742
|
+
app_state.setdefault("train_val_split", {"method": "random"})
|
|
743
|
+
if not show_classes:
|
|
744
|
+
app_state.setdefault("classes", [])
|
|
745
|
+
if not show_tags:
|
|
746
|
+
app_state.setdefault("tags", [])
|
|
747
|
+
|
|
436
748
|
model = app_state["model"]
|
|
437
749
|
if model["source"] == "Pretrained models":
|
|
438
750
|
if "model_name" not in model:
|
|
@@ -539,20 +851,21 @@ class TrainGUI:
|
|
|
539
851
|
"""
|
|
540
852
|
if isinstance(app_state, str):
|
|
541
853
|
app_state = sly_json.load_json_file(app_state)
|
|
542
|
-
|
|
543
854
|
app_state = self.validate_app_state(app_state)
|
|
544
855
|
|
|
545
856
|
options = app_state.get("options", {})
|
|
546
857
|
input_settings = app_state.get("input")
|
|
547
|
-
train_val_splits_settings = app_state
|
|
548
|
-
classes_settings = app_state
|
|
858
|
+
train_val_splits_settings = app_state.get("train_val_split", {})
|
|
859
|
+
classes_settings = app_state.get("classes", [])
|
|
860
|
+
tags_settings = app_state.get("tags", [])
|
|
549
861
|
model_settings = app_state["model"]
|
|
550
862
|
hyperparameters_settings = app_state["hyperparameters"]
|
|
551
863
|
|
|
552
864
|
self._init_input(input_settings, options)
|
|
553
|
-
self.
|
|
554
|
-
self.
|
|
555
|
-
self.
|
|
865
|
+
self._init_train_val_splits(train_val_splits_settings, options)
|
|
866
|
+
self._init_classes(classes_settings, options)
|
|
867
|
+
self._init_tags(tags_settings, options)
|
|
868
|
+
self._init_model(model_settings, options)
|
|
556
869
|
self._init_hyperparameters(hyperparameters_settings, options)
|
|
557
870
|
|
|
558
871
|
def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
|
|
@@ -569,13 +882,41 @@ class TrainGUI:
|
|
|
569
882
|
self.input_selector_cb()
|
|
570
883
|
# ----------------------------------------- #
|
|
571
884
|
|
|
572
|
-
def _init_train_val_splits(self, train_val_splits_settings: dict) -> None:
|
|
885
|
+
def _init_train_val_splits(self, train_val_splits_settings: dict, options: dict) -> None:
|
|
573
886
|
"""
|
|
574
887
|
Initialize the train/val splits selector with the given settings.
|
|
575
888
|
|
|
576
889
|
:param train_val_splits_settings: The train/val splits settings.
|
|
577
890
|
:type train_val_splits_settings: dict
|
|
891
|
+
:param options: The application options.
|
|
892
|
+
:type options: dict
|
|
578
893
|
"""
|
|
894
|
+
if self.train_val_splits_selector is None:
|
|
895
|
+
return # Selector disabled by app options
|
|
896
|
+
|
|
897
|
+
if train_val_splits_settings == {}:
|
|
898
|
+
available_methods = self.app_options.get("train_val_splits_methods", [])
|
|
899
|
+
if available_methods == []:
|
|
900
|
+
method = "random"
|
|
901
|
+
train_val_splits_settings = {"method": method, "split": "train", "percent": 80}
|
|
902
|
+
else:
|
|
903
|
+
method = available_methods[0]
|
|
904
|
+
if method == "random":
|
|
905
|
+
train_val_splits_settings = {"method": method, "split": "train", "percent": 80}
|
|
906
|
+
elif method == "tags":
|
|
907
|
+
train_val_splits_settings = {
|
|
908
|
+
"method": method,
|
|
909
|
+
"train_tag": "train",
|
|
910
|
+
"val_tag": "val",
|
|
911
|
+
"untagged_action": "ignore",
|
|
912
|
+
}
|
|
913
|
+
elif method == "datasets":
|
|
914
|
+
train_val_splits_settings = {
|
|
915
|
+
"method": method,
|
|
916
|
+
"train_datasets": [],
|
|
917
|
+
"val_datasets": [],
|
|
918
|
+
}
|
|
919
|
+
|
|
579
920
|
split_method = train_val_splits_settings["method"]
|
|
580
921
|
if split_method == "random":
|
|
581
922
|
split = train_val_splits_settings["split"]
|
|
@@ -596,24 +937,48 @@ class TrainGUI:
|
|
|
596
937
|
)
|
|
597
938
|
self.train_val_splits_selector_cb()
|
|
598
939
|
|
|
599
|
-
def _init_classes(self, classes_settings: list) -> None:
|
|
940
|
+
def _init_classes(self, classes_settings: list, options: dict) -> None:
|
|
600
941
|
"""
|
|
601
942
|
Initialize the classes selector with the given settings.
|
|
602
943
|
|
|
603
944
|
:param classes_settings: The classes settings.
|
|
604
945
|
:type classes_settings: list
|
|
946
|
+
:param options: The application options.
|
|
947
|
+
:type options: dict
|
|
605
948
|
"""
|
|
949
|
+
if self.classes_selector is None:
|
|
950
|
+
return # Selector disabled by app options
|
|
951
|
+
|
|
606
952
|
# Set Classes
|
|
607
953
|
self.classes_selector.set_classes(classes_settings)
|
|
608
954
|
self.classes_selector_cb()
|
|
609
955
|
# ----------------------------------------- #
|
|
610
956
|
|
|
611
|
-
def
|
|
957
|
+
def _init_tags(self, tags_settings: list, options: dict) -> None:
|
|
958
|
+
"""
|
|
959
|
+
Initialize the tags selector with the given settings.
|
|
960
|
+
|
|
961
|
+
:param tags_settings: The tags settings.
|
|
962
|
+
:type tags_settings: list
|
|
963
|
+
:param options: The application options.
|
|
964
|
+
:type options: dict
|
|
965
|
+
"""
|
|
966
|
+
if self.tags_selector is None:
|
|
967
|
+
return # Selector disabled by app options
|
|
968
|
+
|
|
969
|
+
# Set Tags
|
|
970
|
+
self.tags_selector.set_tags(tags_settings)
|
|
971
|
+
self.tags_selector_cb()
|
|
972
|
+
# ----------------------------------------- #
|
|
973
|
+
|
|
974
|
+
def _init_model(self, model_settings: dict, options: dict) -> None:
|
|
612
975
|
"""
|
|
613
976
|
Initialize the model selector with the given settings.
|
|
614
977
|
|
|
615
978
|
:param model_settings: The model settings.
|
|
616
979
|
:type model_settings: dict
|
|
980
|
+
:param options: The application options.
|
|
981
|
+
:type options: dict
|
|
617
982
|
"""
|
|
618
983
|
|
|
619
984
|
# Pretrained
|