supervisely 6.73.389__py3-none-any.whl → 6.73.391__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. supervisely/app/widgets/experiment_selector/experiment_selector.py +20 -3
  2. supervisely/app/widgets/experiment_selector/template.html +49 -70
  3. supervisely/app/widgets/report_thumbnail/report_thumbnail.py +19 -4
  4. supervisely/decorators/profile.py +20 -0
  5. supervisely/nn/benchmark/utils/detection/utlis.py +7 -0
  6. supervisely/nn/experiments.py +4 -0
  7. supervisely/nn/inference/gui/serving_gui_template.py +71 -11
  8. supervisely/nn/inference/inference.py +108 -6
  9. supervisely/nn/training/gui/classes_selector.py +246 -27
  10. supervisely/nn/training/gui/gui.py +318 -234
  11. supervisely/nn/training/gui/hyperparameters_selector.py +2 -2
  12. supervisely/nn/training/gui/model_selector.py +42 -1
  13. supervisely/nn/training/gui/tags_selector.py +1 -1
  14. supervisely/nn/training/gui/train_val_splits_selector.py +8 -7
  15. supervisely/nn/training/gui/training_artifacts.py +10 -1
  16. supervisely/nn/training/gui/training_process.py +17 -1
  17. supervisely/nn/training/train_app.py +227 -72
  18. supervisely/template/__init__.py +2 -0
  19. supervisely/template/base_generator.py +90 -0
  20. supervisely/template/experiment/__init__.py +0 -0
  21. supervisely/template/experiment/experiment.html.jinja +537 -0
  22. supervisely/template/experiment/experiment_generator.py +996 -0
  23. supervisely/template/experiment/header.html.jinja +154 -0
  24. supervisely/template/experiment/sidebar.html.jinja +240 -0
  25. supervisely/template/experiment/sly-style.css +397 -0
  26. supervisely/template/experiment/template.html.jinja +18 -0
  27. supervisely/template/extensions.py +172 -0
  28. supervisely/template/template_renderer.py +253 -0
  29. {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/METADATA +3 -1
  30. {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/RECORD +34 -23
  31. {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/LICENSE +0 -0
  32. {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/WHEEL +0 -0
  33. {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/entry_points.txt +0 -0
  34. {supervisely-6.73.389.dist-info → supervisely-6.73.391.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,12 @@ This module provides the `TrainGUI` class that handles the graphical user interf
5
5
  training workflows in Supervisely.
6
6
  """
7
7
 
8
- from os import environ
8
+ import os
9
+ from os import environ, getenv
9
10
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
11
 
12
+ from supervisely import logger
13
+ import supervisely.io.fs as sly_fs
11
14
  import supervisely.io.env as sly_env
12
15
  import supervisely.io.json as sly_json
13
16
  from supervisely import Api, ProjectMeta
@@ -29,6 +32,7 @@ from supervisely.nn.training.gui.training_logs import TrainingLogs
29
32
  from supervisely.nn.training.gui.training_process import TrainingProcess
30
33
  from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
31
34
  from supervisely.nn.utils import ModelSource, RuntimeType
35
+ from supervisely.nn.experiments import ExperimentInfo
32
36
 
33
37
 
34
38
  class StepFlow:
@@ -249,7 +253,7 @@ class TrainGUI:
249
253
  self.hyperparameters = hyperparameters
250
254
  self.app_options = app_options
251
255
  self.collapsable = self.app_options.get("collapsable", False)
252
- self.need_convert_shapes_for_bm = False
256
+ self.need_convert_shapes = False
253
257
 
254
258
  self.team_id = sly_env.team_id(raise_not_found=False)
255
259
  self.workspace_id = sly_env.workspace_id(raise_not_found=False)
@@ -289,32 +293,32 @@ class TrainGUI:
289
293
  self.input_selector = InputSelector(self.project_info, self.app_options)
290
294
  self.steps.append(self.input_selector.card)
291
295
 
292
- # 2. Train/val split
293
- self.train_val_splits_selector = None
294
- if self.show_train_val_splits_selector:
295
- self.train_val_splits_selector = TrainValSplitsSelector(
296
- self._api, self.project_id, self.app_options
297
- )
298
- self.steps.append(self.train_val_splits_selector.card)
296
+ # 2. Model selector
297
+ self.model_selector = ModelSelector(
298
+ self._api, self.framework_name, self.models, self.app_options
299
+ )
300
+ if self.show_model_selector:
301
+ self.steps.append(self.model_selector.card)
299
302
 
300
- # 3. Select Classes
303
+ # 3. Classes selector
301
304
  self.classes_selector = None
302
305
  if self.show_classes_selector:
303
- self.classes_selector = ClassesSelector(self.project_id, [], self.app_options)
306
+ self.classes_selector = ClassesSelector(self.project_id, [], self.model_selector, self.app_options)
304
307
  self.steps.append(self.classes_selector.card)
305
308
 
306
- # 4. Select Tags
309
+ # 4. Tags selector
307
310
  self.tags_selector = None
308
311
  if self.show_tags_selector:
309
312
  self.tags_selector = TagsSelector(self.project_id, [], self.app_options)
310
313
  self.steps.append(self.tags_selector.card)
311
314
 
312
- # 5. Model selection
313
- self.model_selector = ModelSelector(
314
- self._api, self.framework_name, self.models, self.app_options
315
- )
316
- if self.show_model_selector:
317
- self.steps.append(self.model_selector.card)
315
+ # 5. Train/Val splits selector
316
+ self.train_val_splits_selector = None
317
+ if self.show_train_val_splits_selector:
318
+ self.train_val_splits_selector = TrainValSplitsSelector(
319
+ self._api, self.project_id, self.app_options
320
+ )
321
+ self.steps.append(self.train_val_splits_selector.card)
318
322
 
319
323
  # 6. Training parameters (yaml)
320
324
  self.hyperparameters_selector = HyperparametersSelector(
@@ -360,89 +364,36 @@ class TrainGUI:
360
364
  self.training_process.set_experiment_name(experiment_name)
361
365
 
362
366
  def need_convert_class_shapes() -> bool:
363
- if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
364
- if not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
365
- self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
366
- self.need_convert_shapes_for_bm = False
367
- else:
368
- task_type = self.model_selector.get_selected_task_type()
369
-
370
- def _need_convert(shape):
371
- if task_type == TaskType.OBJECT_DETECTION:
372
- return shape != Rectangle.geometry_name()
373
- elif task_type in [
374
- TaskType.INSTANCE_SEGMENTATION,
375
- TaskType.SEMANTIC_SEGMENTATION,
376
- ]:
377
- return shape == Polygon.geometry_name()
378
- return False
379
-
380
- if self.classes_selector is not None:
381
- data = self.classes_selector.classes_table._table_data
382
- selected_classes = set(
383
- self.classes_selector.classes_table.get_selected_classes()
384
- )
385
- empty = set(
386
- r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
387
- )
388
- need_convert = set(
389
- r[0]["data"] for r in data if _need_convert(r[1]["data"])
390
- )
391
- else:
392
- # Set project meta classes when classes selector is disabled
393
- selected_classes = set(cls.name for cls in self.project_meta.obj_classes)
394
- need_convert = set(
395
- obj_class.name
396
- for obj_class in self.project_meta.obj_classes
397
- if _need_convert(obj_class.geometry_type)
398
- )
399
- empty = set()
400
-
401
- if need_convert.intersection(selected_classes - empty):
402
- self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
403
- self.need_convert_shapes_for_bm = True
404
- else:
405
- self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
406
- self.need_convert_shapes_for_bm = False
407
- else:
408
- self.need_convert_shapes_for_bm = False
367
+ if self.hyperparameters_selector.run_model_benchmark_checkbox is None or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
368
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
369
+ self.need_convert_shapes = False
370
+ return False
409
371
 
410
- def validate_class_shape_for_model_task():
411
372
  task_type = self.model_selector.get_selected_task_type()
412
- if self.classes_selector is not None:
413
- classes = self.classes_selector.get_selected_classes()
414
- else:
415
- classes = list(self.project_meta.obj_classes.keys())
416
373
 
417
- required_geometries = {
418
- TaskType.INSTANCE_SEGMENTATION: {Polygon, Bitmap},
419
- TaskType.SEMANTIC_SEGMENTATION: {Polygon, Bitmap},
420
- TaskType.POSE_ESTIMATION: {GraphNodes},
421
- }
422
- task_specific_texts = {
423
- TaskType.INSTANCE_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
424
- TaskType.SEMANTIC_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
425
- TaskType.POSE_ESTIMATION: "Only keypoint (graph) shape is supported for pose estimation task",
426
- }
374
+ if self.classes_selector is not None:
375
+ wrong_shapes = set(self.classes_selector.get_wrong_shape_classes(task_type))
427
376
 
428
- if task_type not in required_geometries:
429
- return
377
+ # Exclude classes with no annotations to avoid unnecessary conversion
378
+ data = self.classes_selector.classes_table._table_data
379
+ empty_classes = {r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0}
380
+ need_conversion = bool(wrong_shapes - empty_classes)
381
+ else:
382
+ # Classes selector disabled – check entire project meta
383
+ if task_type == TaskType.OBJECT_DETECTION:
384
+ need_conversion = any(obj_cls.geometry_type != Rectangle for obj_cls in self.project_meta.obj_classes)
385
+ elif task_type in [TaskType.INSTANCE_SEGMENTATION, TaskType.SEMANTIC_SEGMENTATION]:
386
+ need_conversion = any(obj_cls.geometry_type == Polygon for obj_cls in self.project_meta.obj_classes)
387
+ else:
388
+ need_conversion = False
430
389
 
431
- wrong_shape_classes = [
432
- class_name
433
- for class_name in classes
434
- if self.project_meta.get_obj_class(class_name).geometry_type
435
- not in required_geometries[task_type]
436
- ]
437
-
438
- if wrong_shape_classes:
439
- specific_text = task_specific_texts[task_type]
440
- message_text = f"Model task type is {task_type}. {specific_text}. Selected classes have wrong shapes for the model task: {', '.join(wrong_shape_classes)}"
441
- self.model_selector.validator_text.set(
442
- text=message_text,
443
- status="warning",
444
- )
390
+ if need_conversion:
391
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
392
+ else:
393
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
445
394
 
395
+ self.need_convert_shapes = need_conversion
396
+ return need_conversion
446
397
  # ------------------------------------------------- #
447
398
 
448
399
  self.step_flow = StepFlow(self.stepper, self.app_options)
@@ -460,17 +411,17 @@ class TrainGUI:
460
411
  )
461
412
  position += 1
462
413
 
463
- # 2. Train/Val splits selector
464
- if self.show_train_val_splits_selector and self.train_val_splits_selector is not None:
414
+ # 2. Model selector
415
+ if self.show_model_selector:
465
416
  self.step_flow.register_step(
466
- "train_val_splits",
467
- self.train_val_splits_selector.card,
468
- self.train_val_splits_selector.button,
469
- self.train_val_splits_selector.widgets_to_disable,
470
- self.train_val_splits_selector.validator_text,
471
- self.train_val_splits_selector.validate_step,
472
- position=position,
473
- )
417
+ "model_selector",
418
+ self.model_selector.card,
419
+ self.model_selector.button,
420
+ self.model_selector.widgets_to_disable,
421
+ self.model_selector.validator_text,
422
+ self.model_selector.validate_step,
423
+ position=position
424
+ ).add_on_select_actions("model_selector", [set_experiment_name])
474
425
  position += 1
475
426
 
476
427
  # 3. Classes selector
@@ -483,7 +434,7 @@ class TrainGUI:
483
434
  self.classes_selector.validator_text,
484
435
  self.classes_selector.validate_step,
485
436
  position=position,
486
- )
437
+ ).add_on_select_actions("classes_selector", [need_convert_class_shapes])
487
438
  position += 1
488
439
 
489
440
  # 4. Tags selector
@@ -499,23 +450,16 @@ class TrainGUI:
499
450
  )
500
451
  position += 1
501
452
 
502
- # 5. Model selector
503
- if self.show_model_selector:
453
+ # 5. Train/Val splits selector
454
+ if self.show_train_val_splits_selector and self.train_val_splits_selector is not None:
504
455
  self.step_flow.register_step(
505
- "model_selector",
506
- self.model_selector.card,
507
- self.model_selector.button,
508
- self.model_selector.widgets_to_disable,
509
- self.model_selector.validator_text,
510
- self.model_selector.validate_step,
456
+ "train_val_splits",
457
+ self.train_val_splits_selector.card,
458
+ self.train_val_splits_selector.button,
459
+ self.train_val_splits_selector.widgets_to_disable,
460
+ self.train_val_splits_selector.validator_text,
461
+ self.train_val_splits_selector.validate_step,
511
462
  position=position,
512
- ).add_on_select_actions(
513
- "model_selector",
514
- [
515
- set_experiment_name,
516
- need_convert_class_shapes,
517
- validate_class_shape_for_model_task,
518
- ],
519
463
  )
520
464
  position += 1
521
465
 
@@ -570,34 +514,28 @@ class TrainGUI:
570
514
  )
571
515
 
572
516
  # Set dependencies between steps
573
- has_train_val_splits = (
574
- self.show_train_val_splits_selector and self.train_val_splits_selector is not None
575
- )
517
+ has_model_selector = self.show_model_selector and self.model_selector is not None
576
518
  has_classes_selector = self.show_classes_selector and self.classes_selector is not None
577
519
  has_tags_selector = self.show_tags_selector and self.tags_selector is not None
520
+ has_train_val_splits = self.show_train_val_splits_selector and self.train_val_splits_selector is not None
578
521
 
579
522
  # Set step dependency chain
580
- # 1. Input selector
581
523
  prev_step = "input_selector"
582
- if has_train_val_splits:
583
- self.step_flow.set_next_steps(prev_step, ["train_val_splits"])
584
- prev_step = "train_val_splits"
524
+ if has_model_selector:
525
+ self.step_flow.set_next_steps(prev_step, ["model_selector"])
526
+ prev_step = "model_selector"
585
527
  if has_classes_selector:
586
528
  self.step_flow.set_next_steps(prev_step, ["classes_selector"])
587
529
  prev_step = "classes_selector"
588
530
  if has_tags_selector:
589
531
  self.step_flow.set_next_steps(prev_step, ["tags_selector"])
590
532
  prev_step = "tags_selector"
591
-
592
- if self.show_model_selector and self.model_selector is not None:
593
- self.step_flow.set_next_steps(prev_step, ["model_selector"])
594
- # Model selector -> hyperparameters
595
- self.step_flow.set_next_steps("model_selector", ["hyperparameters_selector"])
596
- prev_step = "model_selector"
597
- else:
598
- self.step_flow.set_next_steps(prev_step, ["hyperparameters_selector"])
533
+ if has_train_val_splits:
534
+ self.step_flow.set_next_steps(prev_step, ["train_val_splits"])
535
+ prev_step = "train_val_splits"
599
536
 
600
537
  # 6. Hyperparameters selector -> 7. Training process
538
+ self.step_flow.set_next_steps(prev_step, ["hyperparameters_selector"])
601
539
  self.step_flow.set_next_steps("hyperparameters_selector", ["training_process"])
602
540
 
603
541
  # 7. Training process -> 8. Training logs
@@ -633,13 +571,17 @@ class TrainGUI:
633
571
  @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
634
572
  def show_mb_speedtest(is_checked: bool):
635
573
  self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
636
- need_convert_class_shapes()
637
-
638
574
  # ------------------------------------------------- #
639
575
 
640
576
  self.layout: Widget = self.stepper
641
577
 
642
- # (дублирующийся блок был перемещён выше и здесь удалён)
578
+ # Run from experiment page
579
+ train_task_id = getenv("modal.state.trainTaskId", None)
580
+ train_mode = getenv("modal.state.trainMode", None)
581
+ if train_task_id is not None and train_mode is not None:
582
+ self._run_from_experiment(train_task_id, train_mode)
583
+ # ----------------------------------------- #
584
+
643
585
 
644
586
  def set_next_step(self):
645
587
  current_step = self.stepper.get_active_step()
@@ -831,12 +773,16 @@ class TrainGUI:
831
773
  )
832
774
  return app_state
833
775
 
834
- def load_from_app_state(self, app_state: Union[str, dict]) -> None:
776
+ def load_from_app_state(self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True) -> None:
835
777
  """
836
- Load the GUI state from app state dictionary.
778
+ Load the GUI state from app state dictionary or path to the state file.
837
779
 
838
- :param app_state: The state dictionary.
839
- :type app_state: dict
780
+ :param app_state: The state dictionary or path to the state file.
781
+ :type app_state: Union[str, dict]
782
+ :param click_cb: Automatically click the callback functions to set the GUI state.
783
+ :type click_cb: bool
784
+ :param validate_steps: Validate the steps. If False, the steps will not be validated.
785
+ :type validate_steps: bool
840
786
 
841
787
  app_state example:
842
788
 
@@ -847,21 +793,29 @@ class TrainGUI:
847
793
  "percent": 90
848
794
  },
849
795
  "classes": ["apple"],
796
+ # Pretrained model
850
797
  "model": {
851
798
  "source": "Pretrained models",
852
799
  "model_name": "rtdetr_r50vd_coco_objects365"
853
800
  },
801
+ # Custom model
802
+ # "model": {
803
+ # "source": "Custom models",
804
+ # "task_id": 555,
805
+ # "checkpoint": "checkpoint_10.pth"
806
+ # },
854
807
  "hyperparameters": hyperparameters, # yaml string
855
808
  "options": {
809
+ "convert_class_shapes": True,
856
810
  "model_benchmark": {
857
811
  "enable": True,
858
812
  "speed_test": True
859
813
  },
860
814
  "cache_project": True,
861
- "export": {
862
- "enable": True,
863
- "ONNXRuntime": True,
864
- "TensorRT": True
815
+ "export": {
816
+ "enable": True,
817
+ "ONNXRuntime": True,
818
+ "TensorRT": True
865
819
  },
866
820
  },
867
821
  "experiment_name": "my_experiment",
@@ -869,27 +823,43 @@ class TrainGUI:
869
823
  """
870
824
  if isinstance(app_state, str):
871
825
  app_state = sly_json.load_json_file(app_state)
826
+
872
827
  app_state = self.validate_app_state(app_state)
873
-
874
828
  options = app_state.get("options", {})
875
- input_settings = app_state.get("input")
876
- train_val_splits_settings = app_state.get("train_val_split", {})
877
- classes_settings = app_state.get("classes", [])
878
- tags_settings = app_state.get("tags", [])
879
- model_settings = app_state["model"]
880
- hyperparameters_settings = app_state["hyperparameters"]
881
- experiment_name = app_state.get("experiment_name", None)
882
-
883
- self._init_input(input_settings, options)
884
- self._init_train_val_splits(train_val_splits_settings, options)
885
- self._init_classes(classes_settings, options)
886
- self._init_tags(tags_settings, options)
887
- self._init_model(model_settings, options)
888
- self._init_hyperparameters(hyperparameters_settings, options)
829
+
830
+ # Set experiment name
831
+ experiment_name = app_state.get("experiment_name")
889
832
  if experiment_name is not None:
890
833
  self.training_process.set_experiment_name(experiment_name)
891
834
 
892
- def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
835
+ # Run init-steps and stop on validation failure
836
+ def _run_step(init_fn, settings) -> bool:
837
+ if not init_fn(settings, options, click_cb, validate_steps):
838
+ return False
839
+ return True
840
+
841
+ # GUI init steps
842
+ _steps = [
843
+ (self._init_input, app_state.get("input"), "Input project"),
844
+ (self._init_model, app_state["model"], "Select Model"),
845
+ (self._init_classes, app_state.get("classes", []), "Classes Selector"),
846
+ (self._init_tags, app_state.get("tags", []), "Tags Selector"),
847
+ (self._init_train_val_splits, app_state.get("train_val_split", {}), "Train/Val Splits"),
848
+ (self._init_hyperparameters, app_state["hyperparameters"], "Hyperparameters"),
849
+ ]
850
+
851
+ for idx, (init_fn, settings, step_name) in enumerate(_steps, start=1):
852
+ if not _run_step(init_fn, settings):
853
+ if validate_steps:
854
+ logger.warning(f"Step '{step_name}' {idx}/{len(_steps)} failed to validate")
855
+ return
856
+ if validate_steps:
857
+ logger.info(f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully")
858
+ if validate_steps:
859
+ logger.info(f"All steps have been validated successfully")
860
+ # ------------------------------------------------------------------ #
861
+
862
+ def _init_input(self, input_settings: Union[dict, None], options: dict, click_cb: bool = True, validate: bool = True) -> bool:
893
863
  """
894
864
  Initialize the input selector with the given settings.
895
865
 
@@ -897,13 +867,123 @@ class TrainGUI:
897
867
  :type input_settings: dict
898
868
  :param options: The application options.
899
869
  :type options: dict
870
+ :param click_cb: Click the callback function.
871
+ :type click_cb: bool
872
+ :param validate: Validate the step.
873
+ :type validate: bool
900
874
  """
901
875
  # Set Input
902
876
  self.input_selector.set_cache(options.get("cache_project", True))
903
- self.input_selector_cb()
877
+ is_valid = True
878
+ if validate:
879
+ is_valid = self.input_selector.validate_step()
880
+ if is_valid and click_cb:
881
+ self.input_selector_cb()
882
+ self.set_next_step()
883
+ return is_valid
904
884
  # ----------------------------------------- #
905
885
 
906
- def _init_train_val_splits(self, train_val_splits_settings: dict, options: dict) -> None:
886
+ def _init_model(self, model_settings: dict, options: dict = None, click_cb: bool = True, validate: bool = True) -> bool:
887
+ """
888
+ Initialize the model selector with the given settings.
889
+
890
+ :param model_settings: The model settings.
891
+ :type model_settings: dict
892
+ :param options: The application options.
893
+ :type options: dict
894
+ :param click_cb: Click the callback function.
895
+ :type click_cb: bool
896
+ :param validate: Validate the step.
897
+ :type validate: bool
898
+ """
899
+
900
+ # Pretrained
901
+ if model_settings["source"] == ModelSource.PRETRAINED:
902
+ self.model_selector.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
903
+ self.model_selector.pretrained_models_table.set_by_model_name(
904
+ model_settings["model_name"]
905
+ )
906
+
907
+ # Custom
908
+ elif model_settings["source"] == ModelSource.CUSTOM:
909
+ self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
910
+ self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
911
+ active_row = self.model_selector.experiment_selector.get_selected_row()
912
+ if model_settings["checkpoint"] not in active_row.checkpoints_names:
913
+ raise ValueError(
914
+ f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
915
+ )
916
+
917
+ active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
918
+
919
+ is_valid = True
920
+ if validate:
921
+ is_valid = self.model_selector.validate_step()
922
+ if is_valid and click_cb:
923
+ self.model_selector_cb()
924
+ self.set_next_step()
925
+ return is_valid
926
+ # ----------------------------------------- #
927
+
928
+ def _init_classes(self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
929
+ """
930
+ Initialize the classes selector with the given settings.
931
+
932
+ :param classes_settings: The classes settings.
933
+ :type classes_settings: list
934
+ :param options: The application options.
935
+ :type options: dict
936
+ :param click_cb: Click the callback function.
937
+ :type click_cb: bool
938
+ :param validate: Validate the step.
939
+ :type validate: bool
940
+ """
941
+ if self.classes_selector is None:
942
+ return True # Selector disabled by app options
943
+
944
+ convert_class_shapes = options.get("convert_class_shapes", True)
945
+ if convert_class_shapes:
946
+ self.classes_selector.convert_class_shapes_checkbox.check()
947
+
948
+ # Set Classes
949
+ self.classes_selector.set_classes(classes_settings)
950
+ is_valid = True
951
+ if validate:
952
+ is_valid = self.classes_selector.validate_step()
953
+ if is_valid and click_cb:
954
+ self.classes_selector_cb()
955
+ self.set_next_step()
956
+ return is_valid
957
+ # ----------------------------------------- #
958
+
959
+ def _init_tags(self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
960
+ """
961
+ Initialize the tags selector with the given settings.
962
+
963
+ :param tags_settings: The tags settings.
964
+ :type tags_settings: list
965
+ :param options: The application options.
966
+ :type options: dict
967
+ :param click_cb: Click the callback function.
968
+ :type click_cb: bool
969
+ :param validate: Validate the step.
970
+ :type validate: bool
971
+ """
972
+ if self.tags_selector is None:
973
+ return True # Selector disabled by app options
974
+
975
+ # Set Tags
976
+ self.tags_selector.set_tags(tags_settings)
977
+ is_valid = True
978
+ if validate:
979
+ is_valid = self.tags_selector.validate_step()
980
+ if is_valid and click_cb:
981
+ self.tags_selector_cb()
982
+ self.set_next_step()
983
+ return is_valid
984
+ # ----------------------------------------- #
985
+
986
+ def _init_train_val_splits(self, train_val_splits_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
907
987
  """
908
988
  Initialize the train/val splits selector with the given settings.
909
989
 
@@ -911,9 +991,13 @@ class TrainGUI:
911
991
  :type train_val_splits_settings: dict
912
992
  :param options: The application options.
913
993
  :type options: dict
994
+ :param click_cb: Click the callback function.
995
+ :type click_cb: bool
996
+ :param validate: Validate the step.
997
+ :type validate: bool
914
998
  """
915
999
  if self.train_val_splits_selector is None:
916
- return # Selector disabled by app options
1000
+ return True # Selector disabled by app options
917
1001
 
918
1002
  if train_val_splits_settings == {}:
919
1003
  available_methods = self.app_options.get("train_val_splits_methods", [])
@@ -965,74 +1049,16 @@ class TrainGUI:
965
1049
  self.train_val_splits_selector.train_val_splits.set_collections_splits(
966
1050
  train_collections, val_collections
967
1051
  )
968
- self.train_val_splits_selector_cb()
969
1052
 
970
- def _init_classes(self, classes_settings: list, options: dict) -> None:
971
- """
972
- Initialize the classes selector with the given settings.
973
-
974
- :param classes_settings: The classes settings.
975
- :type classes_settings: list
976
- :param options: The application options.
977
- :type options: dict
978
- """
979
- if self.classes_selector is None:
980
- return # Selector disabled by app options
981
-
982
- # Set Classes
983
- self.classes_selector.set_classes(classes_settings)
984
- self.classes_selector_cb()
985
- # ----------------------------------------- #
986
-
987
- def _init_tags(self, tags_settings: list, options: dict) -> None:
988
- """
989
- Initialize the tags selector with the given settings.
990
-
991
- :param tags_settings: The tags settings.
992
- :type tags_settings: list
993
- :param options: The application options.
994
- :type options: dict
995
- """
996
- if self.tags_selector is None:
997
- return # Selector disabled by app options
998
-
999
- # Set Tags
1000
- self.tags_selector.set_tags(tags_settings)
1001
- self.tags_selector_cb()
1002
- # ----------------------------------------- #
1003
-
1004
- def _init_model(self, model_settings: dict, options: dict) -> None:
1005
- """
1006
- Initialize the model selector with the given settings.
1007
-
1008
- :param model_settings: The model settings.
1009
- :type model_settings: dict
1010
- :param options: The application options.
1011
- :type options: dict
1012
- """
1013
-
1014
- # Pretrained
1015
- if model_settings["source"] == ModelSource.PRETRAINED:
1016
- self.model_selector.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
1017
- self.model_selector.pretrained_models_table.set_by_model_name(
1018
- model_settings["model_name"]
1019
- )
1020
-
1021
- # Custom
1022
- elif model_settings["source"] == ModelSource.CUSTOM:
1023
- self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
1024
- self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
1025
- active_row = self.model_selector.experiment_selector.get_selected_row()
1026
- if model_settings["checkpoint"] not in active_row.checkpoints_names:
1027
- raise ValueError(
1028
- f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
1029
- )
1030
-
1031
- active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
1032
- self.model_selector_cb()
1033
- # ----------------------------------------- #
1034
-
1035
- def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict) -> None:
1053
+ is_valid = True
1054
+ if validate:
1055
+ is_valid = self.train_val_splits_selector.validate_step()
1056
+ if is_valid and click_cb:
1057
+ self.train_val_splits_selector_cb()
1058
+ self.set_next_step()
1059
+ return is_valid
1060
+
1061
+ def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict, click_cb: bool = True, validate: bool = True) -> bool:
1036
1062
  """
1037
1063
  Initialize the hyperparameters selector with the given settings.
1038
1064
 
@@ -1040,6 +1066,10 @@ class TrainGUI:
1040
1066
  :type hyperparameters_settings: dict
1041
1067
  :param options: The application options.
1042
1068
  :type options: dict
1069
+ :param click_cb: Click the callback function.
1070
+ :type click_cb: bool
1071
+ :param validate: Validate the step.
1072
+ :type validate: bool
1043
1073
  """
1044
1074
  self.hyperparameters_selector.set_hyperparameters(hyperparameters_settings)
1045
1075
 
@@ -1061,6 +1091,60 @@ class TrainGUI:
1061
1091
  self.hyperparameters_selector.set_export_tensorrt_checkbox_value(
1062
1092
  export_weights_settings.get(RuntimeType.TENSORRT, False)
1063
1093
  )
1064
- self.hyperparameters_selector_cb()
1065
1094
 
1095
+ is_valid = True
1096
+ if validate:
1097
+ is_valid = self.hyperparameters_selector.validate_step()
1098
+ if is_valid and click_cb:
1099
+ self.hyperparameters_selector_cb()
1100
+ self.set_next_step()
1101
+ return is_valid
1102
+ # ----------------------------------------- #
1103
+
1104
+ # Run from experiment page
1105
+ def _download_experiment_state(self, experiment_info: ExperimentInfo) -> dict:
1106
+ local_app_state_path = f"./app_state.json"
1107
+ remote_app_state_path = os.path.join(experiment_info.artifacts_dir, "app_state.json")
1108
+ self._api.file.download(self.team_id, remote_app_state_path, local_app_state_path)
1109
+ app_state = sly_json.load_json_file(local_app_state_path)
1110
+ sly_fs.silent_remove(local_app_state_path)
1111
+ return app_state
1112
+
1113
+ def _download_experiment_hparams(self, experiment_info: ExperimentInfo) -> dict:
1114
+ local_hparams_path = f"./{experiment_info.hyperparameters}"
1115
+ remote_hparams_path = os.path.join(experiment_info.artifacts_dir, experiment_info.hyperparameters)
1116
+ self._api.file.download(self.team_id, remote_hparams_path, local_hparams_path)
1117
+ with open(local_hparams_path, "r") as f:
1118
+ hparams = f.read()
1119
+ sly_fs.silent_remove(local_hparams_path)
1120
+ return hparams
1121
+
1122
+ def _run_from_experiment(self, train_task_id: int, train_mode: str):
1123
+ experiment_info = self._api.nn.get_experiment_info(train_task_id)
1124
+ experiment_state = experiment_info.app_state
1125
+
1126
+ if train_mode == "continue":
1127
+ model_settings = {
1128
+ "source": ModelSource.CUSTOM,
1129
+ "task_id": train_task_id,
1130
+ "checkpoint": experiment_info.best_checkpoint
1131
+ }
1132
+
1133
+ if experiment_state is not None:
1134
+ self.input_selector.validator_text.set(f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.", "success")
1135
+ self.input_selector.validator_text.show()
1136
+ experiment_state = self._download_experiment_state(experiment_info)
1137
+ if train_mode == "continue":
1138
+ experiment_state["model"] = model_settings
1139
+ self.load_from_app_state(experiment_state, click_cb=False, validate_steps=False)
1140
+ else:
1141
+ self.input_selector.validator_text.set(
1142
+ f"Couldn't load full training configuration from the experiment: {experiment_info.experiment_name}. Only model and hyperparameters are loaded.",
1143
+ "warning"
1144
+ )
1145
+ self.input_selector.validator_text.show()
1146
+ hparams = self._download_experiment_hparams(experiment_info)
1147
+ self.hyperparameters_selector.set_hyperparameters(hparams)
1148
+ if train_mode == "continue":
1149
+ self._init_model(model_settings, {}, click_cb=False, validate=False)
1066
1150
  # ----------------------------------------- #