supervisely 6.73.390__py3-none-any.whl → 6.73.391__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. supervisely/app/widgets/experiment_selector/experiment_selector.py +20 -3
  2. supervisely/app/widgets/experiment_selector/template.html +49 -70
  3. supervisely/app/widgets/report_thumbnail/report_thumbnail.py +19 -4
  4. supervisely/decorators/profile.py +20 -0
  5. supervisely/nn/benchmark/utils/detection/utlis.py +7 -0
  6. supervisely/nn/experiments.py +4 -0
  7. supervisely/nn/inference/gui/serving_gui_template.py +71 -11
  8. supervisely/nn/inference/inference.py +108 -6
  9. supervisely/nn/training/gui/classes_selector.py +246 -27
  10. supervisely/nn/training/gui/gui.py +318 -234
  11. supervisely/nn/training/gui/hyperparameters_selector.py +2 -2
  12. supervisely/nn/training/gui/model_selector.py +42 -1
  13. supervisely/nn/training/gui/tags_selector.py +1 -1
  14. supervisely/nn/training/gui/train_val_splits_selector.py +8 -7
  15. supervisely/nn/training/gui/training_artifacts.py +10 -1
  16. supervisely/nn/training/gui/training_process.py +17 -1
  17. supervisely/nn/training/train_app.py +227 -72
  18. supervisely/template/__init__.py +2 -0
  19. supervisely/template/base_generator.py +90 -0
  20. supervisely/template/experiment/__init__.py +0 -0
  21. supervisely/template/experiment/experiment.html.jinja +537 -0
  22. supervisely/template/experiment/experiment_generator.py +996 -0
  23. supervisely/template/experiment/header.html.jinja +154 -0
  24. supervisely/template/experiment/sidebar.html.jinja +240 -0
  25. supervisely/template/experiment/sly-style.css +397 -0
  26. supervisely/template/experiment/template.html.jinja +18 -0
  27. supervisely/template/extensions.py +172 -0
  28. supervisely/template/template_renderer.py +253 -0
  29. {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/METADATA +3 -1
  30. {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/RECORD +34 -23
  31. {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/LICENSE +0 -0
  32. {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/WHEEL +0 -0
  33. {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/entry_points.txt +0 -0
  34. {supervisely-6.73.390.dist-info → supervisely-6.73.391.dist-info}/top_level.txt +0 -0
@@ -7,10 +7,10 @@ training workflows in a Supervisely application.
7
7
 
8
8
  import shutil
9
9
  import subprocess
10
+ import time
10
11
  from datetime import datetime
11
12
  from os import getcwd, listdir, walk
12
13
  from os.path import basename, dirname, exists, expanduser, isdir, isfile, join
13
- from time import sleep
14
14
  from typing import Any, Dict, List, Literal, Optional, Union
15
15
  from urllib.request import urlopen
16
16
 
@@ -46,6 +46,7 @@ from supervisely._utils import abs_url, get_filename_from_headers
46
46
  from supervisely.api.file_api import FileInfo
47
47
  from supervisely.app import get_synced_data_dir, show_dialog
48
48
  from supervisely.app.widgets import Progress
49
+ from supervisely.decorators.profile import timeit_with_result
49
50
  from supervisely.nn.benchmark import (
50
51
  InstanceSegmentationBenchmark,
51
52
  InstanceSegmentationEvaluator,
@@ -70,6 +71,7 @@ from supervisely.project.download import (
70
71
  get_dataset_path,
71
72
  is_cached,
72
73
  )
74
+ from supervisely.template.experiment.experiment_generator import ExperimentGenerator
73
75
 
74
76
 
75
77
  class TrainApp:
@@ -184,6 +186,7 @@ class TrainApp:
184
186
  self.app = Application(layout=self.gui.layout)
185
187
  self._server = self.app.get_server()
186
188
  self._train_func = None
189
+ self._training_duration = None
187
190
 
188
191
  self._onnx_supported = self._app_options.get("export_onnx_supported", False)
189
192
  self._tensorrt_supported = self._app_options.get("export_tensorrt_supported", False)
@@ -396,6 +399,20 @@ class TrainApp:
396
399
  :rtype: str
397
400
  """
398
401
  return self.gui.training_process.get_device()
402
+
403
+ @property
404
+ def base_checkpoint(self) -> str:
405
+ """
406
+ Returns the name of the base checkpoint.
407
+ """
408
+ return self.gui.model_selector.get_checkpoint_name()
409
+
410
+ @property
411
+ def base_checkpoint_link(self) -> str:
412
+ """
413
+ Returns the link to the base checkpoint.
414
+ """
415
+ return self.gui.model_selector.get_checkpoint_link()
399
416
 
400
417
  # Classes
401
418
  @property
@@ -528,7 +545,7 @@ class TrainApp:
528
545
  """
529
546
 
530
547
  def decorator(func):
531
- self._train_func = func
548
+ self._train_func = timeit_with_result(func)
532
549
  self.gui.training_process.start_button.click(self._wrapped_start_training)
533
550
  return func
534
551
 
@@ -574,9 +591,17 @@ class TrainApp:
574
591
  self._workflow_input()
575
592
  # Step 2. Download Project
576
593
  self._download_project()
577
- # Step 3. Split Project
594
+ # Step 3. Convert Project to Task
595
+ if self.gui.need_convert_shapes:
596
+ if self.gui.classes_selector is not None:
597
+ if self.gui.classes_selector.is_convert_class_shapes_enabled():
598
+ self._convert_project_to_model_task()
599
+ # Step 4. Split Project
578
600
  self._split_project()
579
- # Step 4. Download Model files
601
+ # Step 5. Remove classes except selected
602
+ if self.sly_project.type == ProjectType.IMAGES.value:
603
+ self.sly_project.remove_classes_except(self.project_dir, self.classes, True)
604
+ # Step 6. Download Model files
580
605
  self._download_model()
581
606
 
582
607
  def _finalize(self, experiment_info: dict) -> None:
@@ -618,8 +643,8 @@ class TrainApp:
618
643
  # Convert GT project
619
644
  gt_project_id, bm_splits_data = None, train_splits_data
620
645
  # @TODO: check with anyshape classes
621
- if self._app_options.get("auto_convert_classes", True):
622
- if self.gui.need_convert_shapes_for_bm:
646
+ if self.gui.need_convert_shapes:
647
+ if self.gui.hyperparameters_selector.get_model_benchmark_checkbox_value():
623
648
  self._set_text_status("convert_gt_project")
624
649
  gt_project_id, bm_splits_data = self._convert_and_split_gt_project(
625
650
  experiment_info["task_type"]
@@ -675,18 +700,31 @@ class TrainApp:
675
700
  self._generate_model_meta(remote_dir, model_meta)
676
701
  self._upload_demo_files(remote_dir)
677
702
 
678
- # Step 10. Set output widgets
703
+ # Step 10. Generate training output
704
+ need_generate_report = self._app_options.get("generate_report", True)
705
+ if need_generate_report:
706
+ output_file_info = self._generate_experiment_report(experiment_info, model_meta)
707
+ else: # output artifacts directory
708
+ output_file_info = session_link_file_info
709
+
710
+ # Step 11. Set output widgets
679
711
  self._set_text_status("reset")
680
712
  self._set_training_output(
681
- experiment_info, remote_dir, session_link_file_info, mb_eval_report
713
+ experiment_info,
714
+ remote_dir,
715
+ output_file_info,
716
+ mb_eval_report,
682
717
  )
683
718
  self._set_ws_progress_status("completed")
684
719
 
685
- # Step 11. Workflow output
720
+ # Step 12. Workflow output
686
721
  if is_production():
687
722
  best_checkpoint_file_info = self._get_best_checkpoint_info(experiment_info, remote_dir)
688
723
  self._workflow_output(
689
- remote_dir, best_checkpoint_file_info, mb_eval_lnk_file_info, mb_eval_report_id
724
+ remote_dir,
725
+ best_checkpoint_file_info,
726
+ mb_eval_lnk_file_info,
727
+ mb_eval_report_id,
690
728
  )
691
729
 
692
730
  def _get_best_checkpoint_info(self, experiment_info: dict, remote_dir: str) -> FileInfo:
@@ -718,11 +756,6 @@ class TrainApp:
718
756
  :param inference_settings: Settings for the inference class.
719
757
  :type inference_settings: dict
720
758
  """
721
- # if not self.is_model_benchmark_enabled:
722
- # raise ValueError(
723
- # "Enable 'model_benchmark' in app_options.yaml to register an inference class."
724
- # )
725
-
726
759
  self._is_inference_class_regirested = True
727
760
  self._inference_class = inference_class
728
761
  self._inference_settings = None
@@ -746,15 +779,25 @@ class TrainApp:
746
779
  classes = self.classes
747
780
  tags = self.tags
748
781
 
749
- model = self._get_model_config_for_app_state(experiment_info)
782
+ convert_class_shapes = False
783
+ if self.gui.classes_selector is not None:
784
+ convert_class_shapes = self.gui.classes_selector.is_convert_class_shapes_enabled()
785
+
750
786
  options = {
787
+ "cache_project": self.gui.input_selector.get_cache_value(),
788
+ "convert_class_shapes": convert_class_shapes,
751
789
  "model_benchmark": {
752
790
  "enable": self.gui.hyperparameters_selector.get_model_benchmark_checkbox_value(),
753
791
  "speed_test": self.gui.hyperparameters_selector.get_speedtest_checkbox_value(),
754
792
  },
755
- "cache_project": self.gui.input_selector.get_cache_value(),
793
+ "export": {
794
+ "enable": self.gui.hyperparameters_selector.is_export_required(),
795
+ "ONNXRuntime": self.gui.hyperparameters_selector.get_export_onnx_checkbox_value(),
796
+ "TensorRT": self.gui.hyperparameters_selector.get_export_tensorrt_checkbox_value(),
797
+ },
756
798
  }
757
799
 
800
+ model = self._get_model_config_for_app_state(experiment_info)
758
801
  app_state = {
759
802
  "model": model,
760
803
  "hyperparameters": self.hyperparameters_yaml,
@@ -766,36 +809,48 @@ class TrainApp:
766
809
  app_state["tags"] = tags
767
810
  return app_state
768
811
 
769
- def load_app_state(self, app_state: dict) -> None:
812
+ def load_app_state(self, app_state: Union[str, dict]) -> None:
770
813
  """
771
814
  Load the GUI state from app state dictionary.
772
815
 
773
- :param app_state: The state dictionary.
774
- :type app_state: dict
816
+ :param app_state: The state dictionary or path to the state file.
817
+ :type app_state: Union[str, dict]
775
818
 
776
819
  app_state example:
777
820
 
778
821
  app_state = {
779
- "input": {"project_id": 55555},
780
822
  "train_val_split": {
781
823
  "method": "random",
782
824
  "split": "train",
783
825
  "percent": 90
784
826
  },
785
827
  "classes": ["apple"],
786
- "tags": ["green", "red"],
828
+ # Pretrained model
787
829
  "model": {
788
830
  "source": "Pretrained models",
789
831
  "model_name": "rtdetr_r50vd_coco_objects365"
790
832
  },
833
+ # Custom model
834
+ # "model": {
835
+ # "source": "Custom models",
836
+ # "task_id": 555,
837
+ # "checkpoint": "checkpoint_10.pth"
838
+ # },
791
839
  "hyperparameters": hyperparameters, # yaml string
792
840
  "options": {
841
+ "convert_class_shapes": True,
793
842
  "model_benchmark": {
794
843
  "enable": True,
795
844
  "speed_test": True
796
845
  },
797
- "cache_project": True
798
- }
846
+ "cache_project": True,
847
+ "export": {
848
+ "enable": True,
849
+ "ONNXRuntime": True,
850
+ "TensorRT": True
851
+ },
852
+ },
853
+ "experiment_name": "my_experiment",
799
854
  }
800
855
  """
801
856
  self.gui.load_from_app_state(app_state)
@@ -910,17 +965,12 @@ class TrainApp:
910
965
 
911
966
  # Preprocess
912
967
  # Download Project
913
- def _read_project(self, remove_unselected_classes: bool = True) -> None:
968
+ def _read_project(self) -> None:
914
969
  """
915
970
  Reads the project data from Supervisely.
916
-
917
- :param remove_unselected_classes: Whether to remove unselected classes from the project.
918
- :type remove_unselected_classes: bool
919
971
  """
920
972
  if self.project_info.type == ProjectType.IMAGES.value:
921
973
  self.sly_project = Project(self.project_dir, OpenMode.READ)
922
- if remove_unselected_classes:
923
- self.sly_project.remove_classes_except(self.project_dir, self.classes, True)
924
974
  elif self.project_info.type == ProjectType.VIDEOS.value:
925
975
  self.sly_project = VideoProject(self.project_dir, OpenMode.READ)
926
976
  else:
@@ -928,6 +978,48 @@ class TrainApp:
928
978
  f"Unsupported project type: {self.project_info.type}. Only images and videos are supported."
929
979
  )
930
980
 
981
+ def _convert_project_to_model_task(self) -> None:
982
+ """
983
+ Converts the project to the appropriate type.
984
+ """
985
+ if not self.project_info.type == ProjectType.IMAGES.value:
986
+ logger.info("Class shape conversion supported only for images projects")
987
+ return
988
+
989
+ task_type = self.gui.model_selector.get_selected_task_type()
990
+ if task_type not in [
991
+ TaskType.OBJECT_DETECTION,
992
+ TaskType.INSTANCE_SEGMENTATION,
993
+ TaskType.SEMANTIC_SEGMENTATION,
994
+ ]:
995
+ logger.info(f"Class shape conversion for {task_type} task is not supported")
996
+ return
997
+
998
+ logger.info(f"Converting project for {task_type} task")
999
+ with self.progress_bar_main(
1000
+ message=f"Converting project to {task_type} task",
1001
+ total=len(self.sly_project.datasets),
1002
+ ) as pbar:
1003
+ if task_type == TaskType.OBJECT_DETECTION:
1004
+ self.sly_project.to_detection_task(
1005
+ self.project_dir, inplace=True, progress_cb=pbar.update
1006
+ )
1007
+ elif task_type == TaskType.INSTANCE_SEGMENTATION:
1008
+ self.sly_project.to_segmentation_task(
1009
+ self.project_dir,
1010
+ inplace=True,
1011
+ segmentation_type="instance",
1012
+ progress_cb=pbar.update,
1013
+ )
1014
+ elif task_type == TaskType.SEMANTIC_SEGMENTATION:
1015
+ self.sly_project.to_segmentation_task(
1016
+ self.project_dir,
1017
+ inplace=True,
1018
+ segmentation_type="semantic",
1019
+ progress_cb=pbar.update,
1020
+ )
1021
+ self.sly_project = Project(self.project_dir, OpenMode.READ)
1022
+
931
1023
  def _download_project(self) -> None:
932
1024
  """
933
1025
  Downloads the project data from Supervisely.
@@ -1161,7 +1253,7 @@ class TrainApp:
1161
1253
 
1162
1254
  # Clean up temporary directory
1163
1255
  sly_fs.remove_dir(project_split_path)
1164
- self._read_project(False)
1256
+ self._read_project()
1165
1257
 
1166
1258
  # ----------------------------------------- #
1167
1259
 
@@ -1559,21 +1651,23 @@ class TrainApp:
1559
1651
  new_checkpoint_paths = []
1560
1652
  best_checkpoints_name = experiment_info["best_checkpoint"]
1561
1653
 
1562
- # Prepare model files
1654
+ # Prepare checkpoint files
1563
1655
  try:
1564
- model_files = {}
1656
+ # need to save original key names
1657
+ ckpt_files = {}
1565
1658
  for file in experiment_info["model_files"]:
1659
+ file_name = sly_fs.get_file_name_with_ext(experiment_info["model_files"][file])
1566
1660
  with open(experiment_info["model_files"][file], "r") as f:
1567
- model_files[file] = f.read()
1661
+ ckpt_files[file] = {"name": file_name, "content": f.read()}
1568
1662
  except Exception as e:
1569
1663
  logger.warning(f"Error loading model files: {e}")
1570
- model_files = {}
1664
+ ckpt_files = {}
1571
1665
 
1572
1666
  for checkpoint_path in checkpoint_paths:
1573
1667
  checkpoint_name = sly_fs.get_file_name_with_ext(checkpoint_path)
1574
1668
  new_checkpoint_path = join(self._output_checkpoints_dir, checkpoint_name)
1575
1669
  shutil.move(checkpoint_path, new_checkpoint_path)
1576
- if len(model_files) > 0:
1670
+ if len(ckpt_files) > 0:
1577
1671
  try:
1578
1672
  # pylint: disable=import-error
1579
1673
  import torch
@@ -1586,7 +1680,7 @@ class TrainApp:
1586
1680
  "experiment": self.gui.training_process.get_experiment_name(),
1587
1681
  }
1588
1682
  state_dict["model_meta"] = model_meta.to_json()
1589
- state_dict["model_files"] = model_files
1683
+ state_dict["model_files"] = ckpt_files
1590
1684
  torch.save(state_dict, new_checkpoint_path)
1591
1685
  except Exception as e:
1592
1686
  logger.warning(
@@ -1622,7 +1716,11 @@ class TrainApp:
1622
1716
  logger.debug(f"Uploading '{local_path}' to Supervisely")
1623
1717
  total_size = sly_fs.get_file_size(local_path)
1624
1718
  with self.progress_bar_main(
1625
- message=message, total=total_size, unit="B", unit_scale=True, unit_divisor=1024
1719
+ message=message,
1720
+ total=total_size,
1721
+ unit="B",
1722
+ unit_scale=True,
1723
+ unit_divisor=1024,
1626
1724
  ) as upload_artifacts_pbar:
1627
1725
  self.progress_bar_main.show()
1628
1726
  file_info = self._api.file.upload(
@@ -1728,6 +1826,8 @@ class TrainApp:
1728
1826
  "experiment_name": self.gui.training_process.get_experiment_name(),
1729
1827
  "framework_name": self.framework_name,
1730
1828
  "model_name": experiment_info["model_name"],
1829
+ "base_checkpoint": self.base_checkpoint,
1830
+ "base_checkpoint_link": self.base_checkpoint_link,
1731
1831
  "task_type": experiment_info["task_type"],
1732
1832
  "project_id": self.project_info.id,
1733
1833
  "task_id": self.task_id,
@@ -1743,7 +1843,10 @@ class TrainApp:
1743
1843
  "evaluation_report_id": evaluation_report_id,
1744
1844
  "evaluation_report_link": evaluation_report_link,
1745
1845
  "evaluation_metrics": eval_metrics,
1846
+ "primary_metric": primary_metric_name,
1746
1847
  "logs": {"type": "tensorboard", "link": f"{remote_dir}logs/"},
1848
+ "device": self.gui.training_process.get_device_name(),
1849
+ "training_duration": self._training_duration,
1747
1850
  }
1748
1851
 
1749
1852
  if self._has_splits_selector:
@@ -1778,10 +1881,32 @@ class TrainApp:
1778
1881
  )
1779
1882
 
1780
1883
  # Do not include this fields to uploaded file:
1781
- experiment_info["primary_metric"] = primary_metric_name
1782
1884
  experiment_info["project_preview"] = self.project_info.image_preview_url
1783
1885
  return experiment_info
1784
1886
 
1887
+ def _generate_experiment_report(
1888
+ self, experiment_info: dict, model_meta: ProjectMeta
1889
+ ) -> FileInfo:
1890
+ """
1891
+ Generates and uploads the experiment report to the output directory.
1892
+ """
1893
+ # @TODO: add report to workflow output
1894
+ experiment = ExperimentGenerator(
1895
+ api=self._api,
1896
+ experiment_info=experiment_info,
1897
+ hyperparameters=self.hyperparameters_yaml,
1898
+ model_meta=model_meta,
1899
+ serving_class=self._inference_class,
1900
+ team_id=self.team_id,
1901
+ output_dir=join(self.work_dir, "experiment_report"),
1902
+ app_options=self._app_options,
1903
+ )
1904
+ experiment.generate()
1905
+
1906
+ experiment.upload_to_artifacts()
1907
+ file_info = experiment.get_report()
1908
+ return file_info
1909
+
1785
1910
  def _generate_hyperparameters(self, remote_dir: str, experiment_info: Dict) -> None:
1786
1911
  """
1787
1912
  Generates and uploads the hyperparameters.yaml file to the output directory.
@@ -2001,7 +2126,7 @@ class TrainApp:
2001
2126
  experiment_info: dict,
2002
2127
  remote_dir: str,
2003
2128
  file_info: FileInfo,
2004
- mb_eval_report=None,
2129
+ mb_eval_report: Optional[FileInfo] = None,
2005
2130
  ) -> None:
2006
2131
  """
2007
2132
  Sets the training output in the GUI.
@@ -2016,8 +2141,10 @@ class TrainApp:
2016
2141
  if is_production():
2017
2142
  self._api.task.set_output_experiment(self.task_id, experiment_info)
2018
2143
  set_directory(remote_dir)
2144
+ # Set artifacts thumbnail to GUI
2019
2145
  self.gui.training_artifacts.artifacts_thumbnail.set(file_info)
2020
2146
  self.gui.training_artifacts.artifacts_thumbnail.show()
2147
+ # Set experiment report thumbnail to GUI
2021
2148
  self.gui.training_artifacts.artifacts_field.show()
2022
2149
  # ---------------------------- #
2023
2150
 
@@ -2343,7 +2470,13 @@ class TrainApp:
2343
2470
  if diff_project_info:
2344
2471
  self._api.project.remove(diff_project_info.id)
2345
2472
  except Exception as e2:
2346
- return lnk_file_info, report, report_id, eval_metrics, primary_metric_name
2473
+ return (
2474
+ lnk_file_info,
2475
+ report,
2476
+ report_id,
2477
+ eval_metrics,
2478
+ primary_metric_name,
2479
+ )
2347
2480
  return lnk_file_info, report, report_id, eval_metrics, primary_metric_name
2348
2481
 
2349
2482
  # ----------------------------------------- #
@@ -2445,7 +2578,7 @@ class TrainApp:
2445
2578
  self._api.app.workflow.add_output_folder(remote_checkpoint_dir, meta=meta)
2446
2579
  else:
2447
2580
  logger.debug(
2448
- f"File with checkpoints not found in Team Files. Cannot set workflow output."
2581
+ "File with checkpoints not found in Team Files. Cannot set workflow output."
2449
2582
  )
2450
2583
 
2451
2584
  if self.is_model_benchmark_enabled:
@@ -2466,7 +2599,7 @@ class TrainApp:
2466
2599
  self._api.app.workflow.add_output_file(model_benchmark_report, meta=meta)
2467
2600
  else:
2468
2601
  logger.debug(
2469
- f"File with model benchmark report not found in Team Files. Cannot set workflow output."
2602
+ "File with model benchmark report not found in Team Files. Cannot set workflow output."
2470
2603
  )
2471
2604
  except Exception as e:
2472
2605
  logger.debug(f"Failed to add output to the workflow: {repr(e)}")
@@ -2603,6 +2736,7 @@ class TrainApp:
2603
2736
  message = f"Error occurred during training initialization. {check_logs_text}"
2604
2737
  self._show_error(message, e)
2605
2738
  self._set_ws_progress_status("reset")
2739
+ self.app.shutdown()
2606
2740
  raise e
2607
2741
 
2608
2742
  try:
@@ -2619,7 +2753,9 @@ class TrainApp:
2619
2753
  self._set_text_status("training")
2620
2754
  if self._app_options.get("train_logger", None) is None:
2621
2755
  self._set_ws_progress_status("training")
2756
+
2622
2757
  experiment_info = self._train_func()
2758
+ self._training_duration = self._train_func.elapsed
2623
2759
  except ZeroDivisionError as e:
2624
2760
  message = (
2625
2761
  "'ZeroDivisionError' occurred during training. "
@@ -2645,13 +2781,14 @@ class TrainApp:
2645
2781
  self.gui.training_logs.tensorboard_button.hide()
2646
2782
  self.gui.training_logs.tensorboard_offline_button.show()
2647
2783
 
2648
- sleep(1)
2649
- self.app.shutdown()
2784
+ time.sleep(1)
2650
2785
  except Exception as e:
2651
2786
  message = f"Error occurred during finalizing and uploading training artifacts. {check_logs_text}"
2652
2787
  self._show_error(message, e)
2653
2788
  self._set_ws_progress_status("reset")
2654
2789
  raise e
2790
+ finally:
2791
+ self.app.shutdown()
2655
2792
 
2656
2793
  def _show_error(self, message: str, e=None):
2657
2794
  if e is not None:
@@ -2718,41 +2855,53 @@ class TrainApp:
2718
2855
  "export_onnx",
2719
2856
  "export_trt",
2720
2857
  "convert_gt_project",
2858
+ "experiment_report",
2721
2859
  ],
2722
2860
  ):
2723
-
2724
2861
  if status == "reset":
2725
- self.gui.training_process.validator_text.set("", "text")
2862
+ message = ""
2863
+ status = "text"
2864
+ self.gui.training_process.validator_text.set(message, status)
2865
+ return
2726
2866
  elif status == "completed":
2727
- self.gui.training_process.validator_text.set("Training completed", "success")
2867
+ message = "Training completed"
2868
+ status = "success"
2728
2869
  elif status == "training":
2729
- self.gui.training_process.validator_text.set("Training is in progress...", "info")
2870
+ message = "Training is in progress..."
2871
+ status = "info"
2730
2872
  elif status == "finalizing":
2731
- self.gui.training_process.validator_text.set(
2732
- "Finalizing and preparing training artifacts...", "info"
2733
- )
2873
+ message = "Finalizing and preparing training artifacts..."
2874
+ status = "info"
2734
2875
  elif status == "preparing":
2735
- self.gui.training_process.validator_text.set("Preparing data for training...", "info")
2876
+ message = "Preparing data for training..."
2877
+ status = "info"
2736
2878
  elif status == "export_onnx":
2737
- self.gui.training_process.validator_text.set(
2738
- f"Converting to {RuntimeType.ONNXRUNTIME}", "info"
2739
- )
2879
+ message = f"Converting to {RuntimeType.ONNXRUNTIME}"
2880
+ status = "info"
2740
2881
  elif status == "export_trt":
2741
- self.gui.training_process.validator_text.set(
2742
- f"Converting to {RuntimeType.TENSORRT}", "info"
2743
- )
2882
+ message = f"Converting to {RuntimeType.TENSORRT}"
2883
+ status = "info"
2744
2884
  elif status == "uploading":
2745
- self.gui.training_process.validator_text.set("Uploading training artifacts...", "info")
2885
+ message = "Uploading training artifacts..."
2886
+ status = "info"
2746
2887
  elif status == "benchmark":
2747
- self.gui.training_process.validator_text.set(
2748
- "Running Model Benchmark evaluation...", "info"
2749
- )
2888
+ message = "Running Model Benchmark evaluation..."
2889
+ status = "info"
2750
2890
  elif status == "validating":
2751
- self.gui.training_process.validator_text.set("Validating experiment...", "info")
2891
+ message = "Validating experiment..."
2892
+ status = "info"
2752
2893
  elif status == "metadata":
2753
- self.gui.training_process.validator_text.set("Generating training metadata...", "info")
2894
+ message = "Generating training metadata..."
2895
+ status = "info"
2754
2896
  elif status == "convert_gt_project":
2755
- self.gui.training_process.validator_text.set("Converting GT project...", "info")
2897
+ message = "Converting GT project..."
2898
+ status = "info"
2899
+ elif status == "experiment_report":
2900
+ message = "Generating experiment report..."
2901
+ status = "info"
2902
+
2903
+ message = f"Status: {message}"
2904
+ self.gui.training_process.validator_text.set(message, status)
2756
2905
 
2757
2906
  def _set_ws_progress_status(
2758
2907
  self,
@@ -2861,17 +3010,22 @@ class TrainApp:
2861
3010
  save_image_info=True,
2862
3011
  )
2863
3012
  project = Project("tmp_project", OpenMode.READ)
3013
+ project.remove_classes_except(project.directory, self.classes, True)
2864
3014
 
2865
3015
  pr_prefix = ""
2866
3016
  if task_type == TaskType.OBJECT_DETECTION:
2867
3017
  Project.to_detection_task(project.directory, inplace=True)
2868
3018
  pr_prefix = "[detection]: "
2869
- elif (
2870
- task_type == TaskType.INSTANCE_SEGMENTATION
2871
- or task_type == TaskType.SEMANTIC_SEGMENTATION
2872
- ):
2873
- Project.to_segmentation_task(project.directory, inplace=True)
2874
- pr_prefix = "[segmentation]: "
3019
+ elif task_type == TaskType.INSTANCE_SEGMENTATION:
3020
+ Project.to_segmentation_task(
3021
+ project.directory, segmentation_type="instance", inplace=True
3022
+ )
3023
+ pr_prefix = "[instance segmentation]: "
3024
+ elif task_type == TaskType.SEMANTIC_SEGMENTATION:
3025
+ Project.to_segmentation_task(
3026
+ project.directory, segmentation_type="semantic", inplace=True
3027
+ )
3028
+ pr_prefix = "[semantic segmentation]: "
2875
3029
 
2876
3030
  gt_project_info = self._api.project.create(
2877
3031
  self.workspace_id,
@@ -2905,3 +3059,4 @@ class TrainApp:
2905
3059
  # 4. Match splits with original project
2906
3060
  gt_split_data = self._postprocess_splits(gt_project_info.id)
2907
3061
  return gt_project_info.id, gt_split_data
3062
+ return gt_project_info.id, gt_split_data
@@ -0,0 +1,2 @@
1
+ from supervisely.template.experiment.experiment_generator import ExperimentGenerator
2
+ from supervisely.template.template_renderer import TemplateRenderer