supervisely 6.73.358__py3-none-any.whl → 6.73.360__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,13 +6,13 @@ training workflows in Supervisely.
6
6
  """
7
7
 
8
8
  from os import environ
9
- from typing import Union
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
10
 
11
11
  import supervisely.io.env as sly_env
12
12
  import supervisely.io.json as sly_json
13
13
  from supervisely import Api, ProjectMeta
14
14
  from supervisely._utils import is_production
15
- from supervisely.app.widgets import Stepper, Widget
15
+ from supervisely.app.widgets import Button, Card, Stepper, Widget
16
16
  from supervisely.geometry.bitmap import Bitmap
17
17
  from supervisely.geometry.graph import GraphNodes
18
18
  from supervisely.geometry.polygon import Polygon
@@ -22,6 +22,7 @@ from supervisely.nn.training.gui.classes_selector import ClassesSelector
22
22
  from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
23
23
  from supervisely.nn.training.gui.input_selector import InputSelector
24
24
  from supervisely.nn.training.gui.model_selector import ModelSelector
25
+ from supervisely.nn.training.gui.tags_selector import TagsSelector
25
26
  from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplitsSelector
26
27
  from supervisely.nn.training.gui.training_artifacts import TrainingArtifacts
27
28
  from supervisely.nn.training.gui.training_logs import TrainingLogs
@@ -30,6 +31,186 @@ from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_clic
30
31
  from supervisely.nn.utils import ModelSource, RuntimeType
31
32
 
32
33
 
34
+ class StepFlow:
35
+ """
36
+ Manages the flow of steps in the GUI, including wrappers and button handlers.
37
+
38
+ Allows flexible configuration of dependencies between steps and automatically
39
+ sets up proper handlers based on layout from app_options.
40
+ """
41
+
42
+ def __init__(self, stepper: Stepper, app_options: Dict[str, Any]):
43
+ """
44
+ Initializes the step manager.
45
+
46
+ :param stepper: Stepper object for step navigation
47
+ :param app_options: Application options
48
+ """
49
+ self.stepper = stepper
50
+ self.app_options = app_options
51
+ self.collapsable = app_options.get("collapsable", False)
52
+ self.steps = {} # Step configuration
53
+ self.step_sequence = [] # Step sequence
54
+
55
+ def register_step(
56
+ self,
57
+ name: str,
58
+ card: Card,
59
+ button: Optional[Button] = None,
60
+ widgets_to_disable: Optional[List[Widget]] = None,
61
+ validation_text: Optional[Widget] = None,
62
+ validation_func: Optional[Callable] = None,
63
+ position: Optional[int] = None,
64
+ ) -> "StepFlow":
65
+ """
66
+ Registers a step in the GUI.
67
+
68
+ :param name: Unique step name
69
+ :param card: Step card widget
70
+ :param button: Button for proceeding to the next step (optional)
71
+ :param widgets_to_disable: Widgets to disable during validation (optional)
72
+ :param validation_text: Widget for displaying validation text (optional)
73
+ :param validation_func: Validation function (optional)
74
+ :param position: Step position in the sequence (starting from 0)
75
+ :return: Current StepFlow object for method chaining
76
+ """
77
+ self.steps[name] = {
78
+ "card": card,
79
+ "button": button,
80
+ "widgets_to_disable": widgets_to_disable or [],
81
+ "validation_text": validation_text,
82
+ "validation_func": validation_func,
83
+ "position": position,
84
+ "next_steps": [],
85
+ "on_select_click": [],
86
+ "on_reselect_click": [],
87
+ "wrapper": None,
88
+ "has_button": button is not None,
89
+ }
90
+
91
+ if position is not None:
92
+ while len(self.step_sequence) <= position:
93
+ self.step_sequence.append(None)
94
+ self.step_sequence[position] = name
95
+
96
+ return self
97
+
98
+ def set_next_steps(self, step_name: str, next_steps: List[str]) -> "StepFlow":
99
+ """
100
+ Sets the list of next steps for the given step.
101
+
102
+ :param step_name: Current step name
103
+ :param next_steps: List of names of the next steps
104
+ :return: Current StepFlow object for method chaining
105
+ """
106
+ if step_name in self.steps:
107
+ self.steps[step_name]["next_steps"] = next_steps
108
+ return self
109
+
110
+ def add_on_select_actions(
111
+ self, step_name: str, actions: List[Callable], is_reselect: bool = False
112
+ ) -> "StepFlow":
113
+ """
114
+ Adds actions to be executed when a step is selected/reselected.
115
+
116
+ :param step_name: Step name
117
+ :param actions: List of functions to execute
118
+ :param is_reselect: True if these are actions for reselection, otherwise False
119
+ :return: Current StepFlow object for method chaining
120
+ """
121
+ if step_name in self.steps:
122
+ key = "on_reselect_click" if is_reselect else "on_select_click"
123
+ self.steps[step_name][key].extend(actions)
124
+ return self
125
+
126
+ def build_wrappers(self) -> Dict[str, Callable]:
127
+ """
128
+ Creates wrappers for all steps based on established dependencies.
129
+
130
+ :return: Dictionary with created wrappers by step name
131
+ """
132
+ valid_sequence = [s for s in self.step_sequence if s is not None and s in self.steps]
133
+
134
+ for step_name in reversed(valid_sequence):
135
+ step = self.steps[step_name]
136
+
137
+ cards_to_unlock = []
138
+ for next_step_name in step["next_steps"]:
139
+ if next_step_name in self.steps:
140
+ cards_to_unlock.append(self.steps[next_step_name]["card"])
141
+
142
+ callback = None
143
+ if step["next_steps"] and step["has_button"]:
144
+ for next_step_name in step["next_steps"]:
145
+ if (
146
+ next_step_name in self.steps
147
+ and self.steps[next_step_name].get("wrapper")
148
+ and self.steps[next_step_name]["has_button"]
149
+ ):
150
+ callback = self.steps[next_step_name]["wrapper"]
151
+ break
152
+
153
+ if step["has_button"]:
154
+ wrapper = wrap_button_click(
155
+ button=step["button"],
156
+ cards_to_unlock=cards_to_unlock,
157
+ widgets_to_disable=step["widgets_to_disable"],
158
+ callback=callback,
159
+ validation_text=step["validation_text"],
160
+ validation_func=step["validation_func"],
161
+ on_select_click=step["on_select_click"],
162
+ on_reselect_click=step["on_reselect_click"],
163
+ collapse_card=(step["card"], self.collapsable),
164
+ )
165
+
166
+ step["wrapper"] = wrapper
167
+
168
+ return {
169
+ name: self.steps[name]["wrapper"]
170
+ for name in self.steps
171
+ if self.steps[name].get("wrapper") and self.steps[name]["has_button"]
172
+ }
173
+
174
+ def setup_button_handlers(self) -> None:
175
+ """
176
+ Sets up handlers for buttons of all steps.
177
+ """
178
+ positions = {}
179
+ pos = 1
180
+
181
+ for i, step_name in enumerate(self.step_sequence):
182
+ if step_name is not None and step_name in self.steps:
183
+ positions[step_name] = pos
184
+ pos += 1
185
+
186
+ for step_name, step in self.steps.items():
187
+ if step_name in positions and step.get("wrapper") and step["has_button"]:
188
+
189
+ button = step["button"]
190
+ wrapper = step["wrapper"]
191
+ position = positions[step_name]
192
+ next_position = position + 1
193
+
194
+ def create_handler(btn, cb, next_pos):
195
+ def handler():
196
+ cb()
197
+ set_stepper_step(self.stepper, btn, next_pos=next_pos)
198
+
199
+ return handler
200
+
201
+ button.click(create_handler(button, wrapper, next_position))
202
+
203
+ def build(self) -> Dict[str, Callable]:
204
+ """
205
+ Performs the complete setup of the step system.
206
+
207
+ :return: Dictionary with created wrappers by step name
208
+ """
209
+ wrappers = self.build_wrappers()
210
+ self.setup_button_handlers()
211
+ return wrappers
212
+
213
+
33
214
  class TrainGUI:
34
215
  """
35
216
  A class representing the GUI for training workflows.
@@ -74,6 +255,9 @@ class TrainGUI:
74
255
  self.workspace_id = sly_env.workspace_id(raise_not_found=False)
75
256
  self.project_id = sly_env.project_id()
76
257
  self.project_info = self._api.project.get_info_by_id(self.project_id)
258
+ if self.project_info.type is None:
259
+ raise ValueError(f"Project with ID: '{self.project_id}' does not exist or was archived")
260
+
77
261
  self.project_meta = ProjectMeta.from_json(self._api.project.get_meta(self.project_id))
78
262
 
79
263
  if self.workspace_id is None:
@@ -83,51 +267,78 @@ class TrainGUI:
83
267
  self.team_id = self.project_info.team_id
84
268
  environ["TEAM_ID"] = str(self.team_id)
85
269
 
86
- # 1. Project selection + Train/val split
270
+ # ---------- Parse selector options ----------
271
+ self._classes_selector_opts = self.app_options.get("classes_selector", {})
272
+ self._tags_selector_opts = self.app_options.get("tags_selector", {})
273
+ self._train_val_splits_opts = self.app_options.get("train_val_splits_selector", {})
274
+ self._model_selector_opts = self.app_options.get("model_selector", {})
275
+
276
+ self.show_classes_selector = self._classes_selector_opts.get("enabled", True)
277
+ self.show_tags_selector = self._tags_selector_opts.get("enabled", False)
278
+ self.show_train_val_splits_selector = self._train_val_splits_opts.get("enabled", True)
279
+ self.show_model_selector = self._model_selector_opts.get("enabled", True)
280
+
281
+ # Ensure train_val_splits_methods compatibility
282
+ self._train_val_methods = self._train_val_splits_opts.get("methods", [])
283
+ # --------------------------------------------------------- #
284
+
285
+ # ------------------------------------------------- #
286
+ self.steps = []
287
+
288
+ # 1. Project selection
87
289
  self.input_selector = InputSelector(self.project_info, self.app_options)
88
- # 2. Select train val splits
89
- self.train_val_splits_selector = TrainValSplitsSelector(
90
- self._api, self.project_id, self.app_options
91
- )
92
- # 3. Select classes
93
- self.classes_selector = ClassesSelector(self.project_id, [], self.app_options)
94
- # 4. Model selection
290
+ self.steps.append(self.input_selector.card)
291
+
292
+ # 2. Train/val split
293
+ self.train_val_splits_selector = None
294
+ if self.show_train_val_splits_selector:
295
+ self.train_val_splits_selector = TrainValSplitsSelector(
296
+ self._api, self.project_id, self.app_options
297
+ )
298
+ self.steps.append(self.train_val_splits_selector.card)
299
+
300
+ # 3. Select Classes
301
+ self.classes_selector = None
302
+ if self.show_classes_selector:
303
+ self.classes_selector = ClassesSelector(self.project_id, [], self.app_options)
304
+ self.steps.append(self.classes_selector.card)
305
+
306
+ # 4. Select Tags
307
+ self.tags_selector = None
308
+ if self.show_tags_selector:
309
+ self.tags_selector = TagsSelector(self.project_id, [], self.app_options)
310
+ self.steps.append(self.tags_selector.card)
311
+
312
+ # 5. Model selection
95
313
  self.model_selector = ModelSelector(
96
314
  self._api, self.framework_name, self.models, self.app_options
97
315
  )
98
- # 5. Training parameters (yaml), scheduler preview
316
+ if self.show_model_selector:
317
+ self.steps.append(self.model_selector.card)
318
+
319
+ # 6. Training parameters (yaml)
99
320
  self.hyperparameters_selector = HyperparametersSelector(
100
321
  self.hyperparameters, self.app_options
101
322
  )
102
- # 6. Start Train
323
+ self.steps.append(self.hyperparameters_selector.card)
324
+
325
+ # 7. Start Training
103
326
  self.training_process = TrainingProcess(self.app_options)
327
+ self.steps.append(self.training_process.card)
104
328
 
105
- # 7. Training logs
329
+ # 8. Training logs
106
330
  self.training_logs = TrainingLogs(self.app_options)
331
+ self.steps.append(self.training_logs.card)
107
332
 
108
- # 8. Training Artifacts
333
+ # 9. Training Artifacts
109
334
  self.training_artifacts = TrainingArtifacts(self._api, self.app_options)
335
+ self.steps.append(self.training_artifacts.card)
110
336
 
111
337
  # Stepper layout
112
- self.steps = [
113
- self.input_selector.card,
114
- self.train_val_splits_selector.card,
115
- self.classes_selector.card,
116
- self.model_selector.card,
117
- self.hyperparameters_selector.card,
118
- self.training_process.card,
119
- self.training_logs.card,
120
- self.training_artifacts.card,
121
- ]
122
- self.stepper = Stepper(
123
- widgets=self.steps,
124
- )
338
+ self.stepper = Stepper(widgets=self.steps)
125
339
  # ------------------------------------------------- #
126
340
 
127
341
  # Button utils
128
- def update_classes_table():
129
- pass
130
-
131
342
  def disable_hyperparams_editor():
132
343
  if self.hyperparameters_selector.editor.readonly:
133
344
  self.hyperparameters_selector.editor.readonly = False
@@ -161,16 +372,28 @@ class TrainGUI:
161
372
  TaskType.SEMANTIC_SEGMENTATION,
162
373
  ]:
163
374
  return shape == Polygon.geometry_name()
164
- return
165
-
166
- data = self.classes_selector.classes_table._table_data
167
- selected_classes = set(
168
- self.classes_selector.classes_table.get_selected_classes()
169
- )
170
- empty = set(
171
- r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
172
- )
173
- need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
375
+ return False
376
+
377
+ if self.classes_selector is not None:
378
+ data = self.classes_selector.classes_table._table_data
379
+ selected_classes = set(
380
+ self.classes_selector.classes_table.get_selected_classes()
381
+ )
382
+ empty = set(
383
+ r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
384
+ )
385
+ need_convert = set(
386
+ r[0]["data"] for r in data if _need_convert(r[1]["data"])
387
+ )
388
+ else:
389
+ # Set project meta classes when classes selector is disabled
390
+ selected_classes = set(cls.name for cls in self.project_meta.obj_classes)
391
+ need_convert = set(
392
+ obj_class.name
393
+ for obj_class in self.project_meta.obj_classes
394
+ if _need_convert(obj_class.geometry_type)
395
+ )
396
+ empty = set()
174
397
 
175
398
  if need_convert.intersection(selected_classes - empty):
176
399
  self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
@@ -183,7 +406,10 @@ class TrainGUI:
183
406
 
184
407
  def validate_class_shape_for_model_task():
185
408
  task_type = self.model_selector.get_selected_task_type()
186
- classes = self.classes_selector.get_selected_classes()
409
+ if self.classes_selector is not None:
410
+ classes = self.classes_selector.get_selected_classes()
411
+ else:
412
+ classes = list(self.project_meta.obj_classes.keys())
187
413
 
188
414
  required_geometries = {
189
415
  TaskType.INSTANCE_SEGMENTATION: {Polygon, Bitmap},
@@ -216,144 +442,189 @@ class TrainGUI:
216
442
 
217
443
  # ------------------------------------------------- #
218
444
 
219
- # Wrappers
220
- self.training_process_cb = wrap_button_click(
221
- button=self.hyperparameters_selector.button,
222
- cards_to_unlock=[self.training_logs.card],
223
- widgets_to_disable=self.training_process.widgets_to_disable,
224
- callback=None,
225
- validation_text=self.training_process.validator_text,
226
- validation_func=self.training_process.validate_step,
227
- )
445
+ self.step_flow = StepFlow(self.stepper, self.app_options)
446
+ position = 0
228
447
 
229
- self.hyperparameters_selector_cb = wrap_button_click(
230
- button=self.hyperparameters_selector.button,
231
- cards_to_unlock=[self.training_process.card],
232
- widgets_to_disable=self.hyperparameters_selector.widgets_to_disable,
233
- callback=self.training_process_cb,
234
- validation_text=self.hyperparameters_selector.validator_text,
235
- validation_func=self.hyperparameters_selector.validate_step,
236
- on_select_click=[disable_hyperparams_editor],
237
- on_reselect_click=[disable_hyperparams_editor],
238
- collapse_card=(self.hyperparameters_selector.card, self.collapsable),
448
+ # 1. Input selector
449
+ self.step_flow.register_step(
450
+ "input_selector",
451
+ self.input_selector.card,
452
+ self.input_selector.button,
453
+ self.input_selector.widgets_to_disable,
454
+ self.input_selector.validator_text,
455
+ self.input_selector.validate_step,
456
+ position=position,
239
457
  )
458
+ position += 1
459
+
460
+ # 2. Train/Val splits selector
461
+ if self.show_train_val_splits_selector and self.train_val_splits_selector is not None:
462
+ self.step_flow.register_step(
463
+ "train_val_splits",
464
+ self.train_val_splits_selector.card,
465
+ self.train_val_splits_selector.button,
466
+ self.train_val_splits_selector.widgets_to_disable,
467
+ self.train_val_splits_selector.validator_text,
468
+ self.train_val_splits_selector.validate_step,
469
+ position=position,
470
+ )
471
+ position += 1
472
+
473
+ # 3. Classes selector
474
+ if self.show_classes_selector and self.classes_selector is not None:
475
+ self.step_flow.register_step(
476
+ "classes_selector",
477
+ self.classes_selector.card,
478
+ self.classes_selector.button,
479
+ self.classes_selector.widgets_to_disable,
480
+ self.classes_selector.validator_text,
481
+ self.classes_selector.validate_step,
482
+ position=position,
483
+ )
484
+ position += 1
485
+
486
+ # 4. Tags selector
487
+ if self.show_tags_selector and self.tags_selector is not None:
488
+ self.step_flow.register_step(
489
+ "tags_selector",
490
+ self.tags_selector.card,
491
+ self.tags_selector.button,
492
+ self.tags_selector.widgets_to_disable,
493
+ self.tags_selector.validator_text,
494
+ self.tags_selector.validate_step,
495
+ position=position,
496
+ )
497
+ position += 1
498
+
499
+ # 5. Model selector
500
+ if self.show_model_selector:
501
+ self.step_flow.register_step(
502
+ "model_selector",
503
+ self.model_selector.card,
504
+ self.model_selector.button,
505
+ self.model_selector.widgets_to_disable,
506
+ self.model_selector.validator_text,
507
+ self.model_selector.validate_step,
508
+ position=position,
509
+ ).add_on_select_actions(
510
+ "model_selector",
511
+ [
512
+ set_experiment_name,
513
+ need_convert_class_shapes,
514
+ validate_class_shape_for_model_task,
515
+ ],
516
+ )
517
+ position += 1
240
518
 
241
- self.model_selector_cb = wrap_button_click(
242
- button=self.model_selector.button,
243
- cards_to_unlock=[self.hyperparameters_selector.card],
244
- widgets_to_disable=self.model_selector.widgets_to_disable,
245
- callback=self.hyperparameters_selector_cb,
246
- validation_text=self.model_selector.validator_text,
247
- validation_func=self.model_selector.validate_step,
248
- on_select_click=[
249
- set_experiment_name,
250
- need_convert_class_shapes,
251
- validate_class_shape_for_model_task,
252
- ],
253
- collapse_card=(self.model_selector.card, self.collapsable),
519
+ # 6. Hyperparameters selector
520
+ self.step_flow.register_step(
521
+ "hyperparameters_selector",
522
+ self.hyperparameters_selector.card,
523
+ self.hyperparameters_selector.button,
524
+ self.hyperparameters_selector.widgets_to_disable,
525
+ self.hyperparameters_selector.validator_text,
526
+ self.hyperparameters_selector.validate_step,
527
+ position=position,
528
+ ).add_on_select_actions("hyperparameters_selector", [disable_hyperparams_editor])
529
+ self.step_flow.add_on_select_actions(
530
+ "hyperparameters_selector", [disable_hyperparams_editor], is_reselect=True
254
531
  )
532
+ position += 1
255
533
 
256
- self.classes_selector_cb = wrap_button_click(
257
- button=self.classes_selector.button,
258
- cards_to_unlock=[self.model_selector.card],
259
- widgets_to_disable=self.classes_selector.widgets_to_disable,
260
- callback=self.model_selector_cb,
261
- validation_text=self.classes_selector.validator_text,
262
- validation_func=self.classes_selector.validate_step,
263
- collapse_card=(self.classes_selector.card, self.collapsable),
534
+ # 7. Training process
535
+ self.step_flow.register_step(
536
+ "training_process",
537
+ self.training_process.card,
538
+ None,
539
+ self.training_process.widgets_to_disable,
540
+ self.training_process.validator_text,
541
+ self.training_process.validate_step,
542
+ position=position,
264
543
  )
544
+ position += 1
265
545
 
266
- self.train_val_splits_selector_cb = wrap_button_click(
267
- button=self.train_val_splits_selector.button,
268
- cards_to_unlock=[self.classes_selector.card],
269
- widgets_to_disable=self.train_val_splits_selector.widgets_to_disable,
270
- callback=self.classes_selector_cb,
271
- validation_text=self.train_val_splits_selector.validator_text,
272
- validation_func=self.train_val_splits_selector.validate_step,
273
- collapse_card=(self.train_val_splits_selector.card, self.collapsable),
546
+ # 8. Training logs
547
+ self.step_flow.register_step(
548
+ "training_logs",
549
+ self.training_logs.card,
550
+ None,
551
+ self.training_logs.widgets_to_disable,
552
+ self.training_logs.validator_text,
553
+ self.training_logs.validate_step,
554
+ position=position,
274
555
  )
556
+ position += 1
275
557
 
276
- self.input_selector_cb = wrap_button_click(
277
- button=self.input_selector.button,
278
- cards_to_unlock=[self.train_val_splits_selector.card],
279
- widgets_to_disable=self.input_selector.widgets_to_disable,
280
- callback=self.train_val_splits_selector_cb,
281
- validation_text=self.input_selector.validator_text,
282
- validation_func=self.input_selector.validate_step,
283
- on_select_click=[update_classes_table],
284
- collapse_card=(self.input_selector.card, self.collapsable),
558
+ # 9. Training artifacts
559
+ self.step_flow.register_step(
560
+ "training_artifacts",
561
+ self.training_artifacts.card,
562
+ None,
563
+ self.training_artifacts.widgets_to_disable,
564
+ self.training_artifacts.validator_text,
565
+ self.training_artifacts.validate_step,
566
+ position=position,
285
567
  )
286
- # ------------------------------------------------- #
287
568
 
288
- # Main Buttons
569
+ # Set dependencies between steps
570
+ has_train_val_splits = (
571
+ self.show_train_val_splits_selector and self.train_val_splits_selector is not None
572
+ )
573
+ has_classes_selector = self.show_classes_selector and self.classes_selector is not None
574
+ has_tags_selector = self.show_tags_selector and self.tags_selector is not None
575
+
576
+ # Set step dependency chain
577
+ # 1. Input selector
578
+ prev_step = "input_selector"
579
+ if has_train_val_splits:
580
+ self.step_flow.set_next_steps(prev_step, ["train_val_splits"])
581
+ prev_step = "train_val_splits"
582
+ if has_classes_selector:
583
+ self.step_flow.set_next_steps(prev_step, ["classes_selector"])
584
+ prev_step = "classes_selector"
585
+ if has_tags_selector:
586
+ self.step_flow.set_next_steps(prev_step, ["tags_selector"])
587
+ prev_step = "tags_selector"
588
+
589
+ if self.show_model_selector and self.model_selector is not None:
590
+ self.step_flow.set_next_steps(prev_step, ["model_selector"])
591
+ # Model selector -> hyperparameters
592
+ self.step_flow.set_next_steps("model_selector", ["hyperparameters_selector"])
593
+ prev_step = "model_selector"
594
+ else:
595
+ self.step_flow.set_next_steps(prev_step, ["hyperparameters_selector"])
289
596
 
290
- # Define outside. Used by user in app
291
- # @self.training_process.start_button.click
292
- # def start_training():
293
- # pass
597
+ # 6. Hyperparameters selector -> 7. Training process
598
+ self.step_flow.set_next_steps("hyperparameters_selector", ["training_process"])
294
599
 
295
- # @self.training_process.stop_button.click
296
- # def stop_training():
297
- # pass
600
+ # 7. Training process -> 8. Training logs
601
+ self.step_flow.set_next_steps("training_process", ["training_logs"])
298
602
 
603
+ # 8. Training logs -> 9. Training artifacts
604
+ self.step_flow.set_next_steps("training_logs", ["training_artifacts"])
299
605
  # ------------------------------------------------- #
300
606
 
301
- # Select Buttons
302
- @self.hyperparameters_selector.button.click
303
- def select_hyperparameters():
304
- self.hyperparameters_selector_cb()
305
- set_stepper_step(
306
- self.stepper,
307
- self.hyperparameters_selector.button,
308
- next_pos=6,
309
- )
310
-
311
- @self.model_selector.button.click
312
- def select_model():
313
- self.model_selector_cb()
314
- set_stepper_step(
315
- self.stepper,
316
- self.model_selector.button,
317
- next_pos=5,
318
- )
319
-
320
- @self.classes_selector.button.click
321
- def select_classes():
322
- self.classes_selector_cb()
323
- set_stepper_step(
324
- self.stepper,
325
- self.classes_selector.button,
326
- next_pos=4,
327
- )
328
-
329
- @self.train_val_splits_selector.button.click
330
- def select_train_val_splits():
331
- self.train_val_splits_selector_cb()
332
- set_stepper_step(
333
- self.stepper,
334
- self.train_val_splits_selector.button,
335
- next_pos=3,
336
- )
337
-
338
- @self.input_selector.button.click
339
- def select_input():
340
- self.input_selector_cb()
341
- set_stepper_step(
342
- self.stepper,
343
- self.input_selector.button,
344
- next_pos=2,
345
- )
346
-
607
+ # Create all wrappers and set button handlers
608
+ wrappers = self.step_flow.build()
609
+
610
+ self.input_selector_cb = wrappers.get("input_selector")
611
+ self.train_val_splits_selector_cb = wrappers.get("train_val_splits")
612
+ self.classes_selector_cb = wrappers.get("classes_selector")
613
+ self.tags_selector_cb = wrappers.get("tags_selector")
614
+ self.model_selector_cb = wrappers.get("model_selector")
615
+ self.hyperparameters_selector_cb = wrappers.get("hyperparameters_selector")
616
+ self.training_process_cb = wrappers.get("training_process")
617
+ self.training_logs_cb = wrappers.get("training_logs")
618
+ self.training_artifacts_cb = wrappers.get("training_artifacts")
347
619
  # ------------------------------------------------- #
348
620
 
349
- # Other Buttons
621
+ # Other handlers
350
622
  if self.app_options.get("show_logs_in_gui", False):
351
623
 
352
624
  @self.training_logs.logs_button.click
353
625
  def show_logs():
354
626
  self.training_logs.toggle_logs()
355
627
 
356
- # Other handlers
357
628
  if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
358
629
 
359
630
  @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
@@ -365,6 +636,8 @@ class TrainGUI:
365
636
 
366
637
  self.layout: Widget = self.stepper
367
638
 
639
+ # (дублирующийся блок был перемещён выше и здесь удалён)
640
+
368
641
  def set_next_step(self):
369
642
  current_step = self.stepper.get_active_step()
370
643
  self.stepper.set_active_step(current_step + 1)
@@ -383,21 +656,35 @@ class TrainGUI:
383
656
  """
384
657
  Makes all select buttons in the GUI available for interaction.
385
658
  """
386
- self.input_selector.button.enable()
387
- self.train_val_splits_selector.button.enable()
388
- self.classes_selector.button.enable()
389
- self.model_selector.button.enable()
390
- self.hyperparameters_selector.button.enable()
659
+ if self.input_selector is not None:
660
+ self.input_selector.button.enable()
661
+ if self.train_val_splits_selector is not None:
662
+ self.train_val_splits_selector.button.enable()
663
+ if self.classes_selector is not None:
664
+ self.classes_selector.button.enable()
665
+ if self.tags_selector is not None:
666
+ self.tags_selector.button.enable()
667
+ if self.model_selector is not None:
668
+ self.model_selector.button.enable()
669
+ if self.hyperparameters_selector is not None:
670
+ self.hyperparameters_selector.button.enable()
391
671
 
392
672
  def disable_select_buttons(self):
393
673
  """
394
674
  Makes all select buttons in the GUI unavailable for interaction.
395
675
  """
396
- self.input_selector.button.disable()
397
- self.train_val_splits_selector.button.disable()
398
- self.classes_selector.button.disable()
399
- self.model_selector.button.disable()
400
- self.hyperparameters_selector.button.disable()
676
+ if self.input_selector is not None:
677
+ self.input_selector.button.disable()
678
+ if self.train_val_splits_selector is not None:
679
+ self.train_val_splits_selector.button.disable()
680
+ if self.classes_selector is not None:
681
+ self.classes_selector.button.disable()
682
+ if self.tags_selector is not None:
683
+ self.tags_selector.button.disable()
684
+ if self.model_selector is not None:
685
+ self.model_selector.button.disable()
686
+ if self.hyperparameters_selector is not None:
687
+ self.hyperparameters_selector.button.disable()
401
688
 
402
689
  # Set GUI from config
403
690
  def validate_app_state(self, app_state: dict) -> dict:
@@ -410,29 +697,54 @@ class TrainGUI:
410
697
  if not isinstance(app_state, dict):
411
698
  raise ValueError("app_state must be a dictionary")
412
699
 
413
- required_keys = {
414
- "train_val_split": ["method"],
415
- "classes": list,
700
+ show_train_val = self.show_train_val_splits_selector
701
+ show_classes = self.show_classes_selector
702
+ show_tags = self.show_tags_selector
703
+
704
+ # Basic required keys always needed
705
+ base_required = {
416
706
  "model": ["source"],
417
- "hyperparameters": (dict, str), # Allowing dict or str for hyperparameters
707
+ "hyperparameters": (dict, str),
418
708
  }
419
-
420
- for key, subkeys_or_type in required_keys.items():
709
+ if show_train_val:
710
+ base_required["train_val_split"] = ["method"]
711
+ if show_classes:
712
+ base_required["classes"] = list
713
+ if show_tags:
714
+ base_required["tags"] = list
715
+
716
+ for key, subkeys_or_type in base_required.items():
421
717
  if key not in app_state:
422
718
  raise KeyError(f"Missing required key in app_state: {key}")
423
-
424
719
  if isinstance(subkeys_or_type, list):
425
720
  for subkey in subkeys_or_type:
426
721
  if subkey not in app_state[key]:
427
722
  raise KeyError(f"Missing required key in app_state['{key}']: {subkey}")
428
723
  elif not isinstance(app_state[key], subkeys_or_type):
429
- valid_types = (
430
- " or ".join([t.__name__ for t in subkeys_or_type])
431
- if isinstance(subkeys_or_type, tuple)
432
- else subkeys_or_type.__name__
433
- )
724
+ valid_types = ""
725
+ if isinstance(subkeys_or_type, tuple):
726
+ type_names = []
727
+ for t in subkeys_or_type:
728
+ if hasattr(t, "__name__"):
729
+ type_names.append(t.__name__)
730
+ else:
731
+ type_names.append(type(t).__name__)
732
+ valid_types = " or ".join(type_names)
733
+ else:
734
+ if hasattr(subkeys_or_type, "__name__"):
735
+ valid_types = subkeys_or_type.__name__
736
+ else:
737
+ valid_types = type(subkeys_or_type).__name__
434
738
  raise ValueError(f"app_state['{key}'] must be of type {valid_types}")
435
739
 
740
+ # Provide defaults for optional sections when selectors are disabled
741
+ if not show_train_val:
742
+ app_state.setdefault("train_val_split", {"method": "random"})
743
+ if not show_classes:
744
+ app_state.setdefault("classes", [])
745
+ if not show_tags:
746
+ app_state.setdefault("tags", [])
747
+
436
748
  model = app_state["model"]
437
749
  if model["source"] == "Pretrained models":
438
750
  if "model_name" not in model:
@@ -539,20 +851,21 @@ class TrainGUI:
539
851
  """
540
852
  if isinstance(app_state, str):
541
853
  app_state = sly_json.load_json_file(app_state)
542
-
543
854
  app_state = self.validate_app_state(app_state)
544
855
 
545
856
  options = app_state.get("options", {})
546
857
  input_settings = app_state.get("input")
547
- train_val_splits_settings = app_state["train_val_split"]
548
- classes_settings = app_state["classes"]
858
+ train_val_splits_settings = app_state.get("train_val_split", {})
859
+ classes_settings = app_state.get("classes", [])
860
+ tags_settings = app_state.get("tags", [])
549
861
  model_settings = app_state["model"]
550
862
  hyperparameters_settings = app_state["hyperparameters"]
551
863
 
552
864
  self._init_input(input_settings, options)
553
- self._init_classes(classes_settings)
554
- self._init_train_val_splits(train_val_splits_settings)
555
- self._init_model(model_settings)
865
+ self._init_train_val_splits(train_val_splits_settings, options)
866
+ self._init_classes(classes_settings, options)
867
+ self._init_tags(tags_settings, options)
868
+ self._init_model(model_settings, options)
556
869
  self._init_hyperparameters(hyperparameters_settings, options)
557
870
 
558
871
  def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
@@ -569,13 +882,41 @@ class TrainGUI:
569
882
  self.input_selector_cb()
570
883
  # ----------------------------------------- #
571
884
 
572
- def _init_train_val_splits(self, train_val_splits_settings: dict) -> None:
885
+ def _init_train_val_splits(self, train_val_splits_settings: dict, options: dict) -> None:
573
886
  """
574
887
  Initialize the train/val splits selector with the given settings.
575
888
 
576
889
  :param train_val_splits_settings: The train/val splits settings.
577
890
  :type train_val_splits_settings: dict
891
+ :param options: The application options.
892
+ :type options: dict
578
893
  """
894
+ if self.train_val_splits_selector is None:
895
+ return # Selector disabled by app options
896
+
897
+ if train_val_splits_settings == {}:
898
+ available_methods = self.app_options.get("train_val_splits_methods", [])
899
+ if available_methods == []:
900
+ method = "random"
901
+ train_val_splits_settings = {"method": method, "split": "train", "percent": 80}
902
+ else:
903
+ method = available_methods[0]
904
+ if method == "random":
905
+ train_val_splits_settings = {"method": method, "split": "train", "percent": 80}
906
+ elif method == "tags":
907
+ train_val_splits_settings = {
908
+ "method": method,
909
+ "train_tag": "train",
910
+ "val_tag": "val",
911
+ "untagged_action": "ignore",
912
+ }
913
+ elif method == "datasets":
914
+ train_val_splits_settings = {
915
+ "method": method,
916
+ "train_datasets": [],
917
+ "val_datasets": [],
918
+ }
919
+
579
920
  split_method = train_val_splits_settings["method"]
580
921
  if split_method == "random":
581
922
  split = train_val_splits_settings["split"]
@@ -596,24 +937,48 @@ class TrainGUI:
596
937
  )
597
938
  self.train_val_splits_selector_cb()
598
939
 
599
- def _init_classes(self, classes_settings: list) -> None:
940
+ def _init_classes(self, classes_settings: list, options: dict) -> None:
600
941
  """
601
942
  Initialize the classes selector with the given settings.
602
943
 
603
944
  :param classes_settings: The classes settings.
604
945
  :type classes_settings: list
946
+ :param options: The application options.
947
+ :type options: dict
605
948
  """
949
+ if self.classes_selector is None:
950
+ return # Selector disabled by app options
951
+
606
952
  # Set Classes
607
953
  self.classes_selector.set_classes(classes_settings)
608
954
  self.classes_selector_cb()
609
955
  # ----------------------------------------- #
610
956
 
611
- def _init_model(self, model_settings: dict) -> None:
957
+ def _init_tags(self, tags_settings: list, options: dict) -> None:
958
+ """
959
+ Initialize the tags selector with the given settings.
960
+
961
+ :param tags_settings: The tags settings.
962
+ :type tags_settings: list
963
+ :param options: The application options.
964
+ :type options: dict
965
+ """
966
+ if self.tags_selector is None:
967
+ return # Selector disabled by app options
968
+
969
+ # Set Tags
970
+ self.tags_selector.set_tags(tags_settings)
971
+ self.tags_selector_cb()
972
+ # ----------------------------------------- #
973
+
974
+ def _init_model(self, model_settings: dict, options: dict) -> None:
612
975
  """
613
976
  Initialize the model selector with the given settings.
614
977
 
615
978
  :param model_settings: The model settings.
616
979
  :type model_settings: dict
980
+ :param options: The application options.
981
+ :type options: dict
617
982
  """
618
983
 
619
984
  # Pretrained