supervisely 6.73.250__py3-none-any.whl → 6.73.252__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (35) hide show
  1. supervisely/api/dataset_api.py +17 -1
  2. supervisely/api/project_api.py +4 -1
  3. supervisely/api/volume/volume_annotation_api.py +7 -4
  4. supervisely/app/widgets/experiment_selector/experiment_selector.py +16 -8
  5. supervisely/nn/benchmark/base_benchmark.py +17 -2
  6. supervisely/nn/benchmark/base_evaluator.py +28 -6
  7. supervisely/nn/benchmark/instance_segmentation/benchmark.py +1 -1
  8. supervisely/nn/benchmark/instance_segmentation/evaluator.py +14 -0
  9. supervisely/nn/benchmark/object_detection/benchmark.py +1 -1
  10. supervisely/nn/benchmark/object_detection/evaluator.py +43 -13
  11. supervisely/nn/benchmark/object_detection/metric_provider.py +7 -0
  12. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +33 -7
  13. supervisely/nn/benchmark/utils/detection/utlis.py +6 -4
  14. supervisely/nn/experiments.py +23 -16
  15. supervisely/nn/inference/gui/serving_gui_template.py +2 -35
  16. supervisely/nn/inference/inference.py +71 -8
  17. supervisely/nn/training/__init__.py +2 -0
  18. supervisely/nn/training/gui/classes_selector.py +14 -14
  19. supervisely/nn/training/gui/gui.py +28 -13
  20. supervisely/nn/training/gui/hyperparameters_selector.py +90 -41
  21. supervisely/nn/training/gui/input_selector.py +8 -6
  22. supervisely/nn/training/gui/model_selector.py +7 -5
  23. supervisely/nn/training/gui/train_val_splits_selector.py +8 -9
  24. supervisely/nn/training/gui/training_logs.py +17 -17
  25. supervisely/nn/training/gui/training_process.py +41 -36
  26. supervisely/nn/training/loggers/__init__.py +22 -0
  27. supervisely/nn/training/loggers/base_train_logger.py +8 -5
  28. supervisely/nn/training/loggers/tensorboard_logger.py +4 -11
  29. supervisely/nn/training/train_app.py +276 -90
  30. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/METADATA +8 -3
  31. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/RECORD +35 -35
  32. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/LICENSE +0 -0
  33. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/WHEEL +0 -0
  34. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/entry_points.txt +0 -0
  35. {supervisely-6.73.250.dist-info → supervisely-6.73.252.dist-info}/top_level.txt +0 -0
@@ -90,7 +90,10 @@ class Inference:
90
90
  """Path to file with list of models"""
91
91
  APP_OPTIONS: str = None
92
92
  """Path to file with app options"""
93
- DEFAULT_BATCH_SIZE = 16
93
+ DEFAULT_BATCH_SIZE: str = 16
94
+ """Default batch size for inference"""
95
+ INFERENCE_SETTINGS: str = None
96
+ """Path to file with custom inference settings"""
94
97
 
95
98
  def __init__(
96
99
  self,
@@ -125,8 +128,12 @@ class Inference:
125
128
  self._autostart_delay_time = 5 * 60 # 5 min
126
129
  self._tracker = None
127
130
  self._hardware: str = None
131
+ self.pretrained_models = self._load_models_json(self.MODELS) if self.MODELS else None
128
132
  if custom_inference_settings is None:
129
- custom_inference_settings = {}
133
+ if self.INFERENCE_SETTINGS is not None:
134
+ custom_inference_settings = self.INFERENCE_SETTINGS
135
+ else:
136
+ custom_inference_settings = {}
130
137
  if isinstance(custom_inference_settings, str):
131
138
  if fs.file_exists(custom_inference_settings):
132
139
  with open(custom_inference_settings, "r") as f:
@@ -153,7 +160,7 @@ class Inference:
153
160
  if self.FRAMEWORK_NAME is None:
154
161
  raise ValueError("FRAMEWORK_NAME is not defined")
155
162
  self._gui = GUI.ServingGUITemplate(
156
- self.FRAMEWORK_NAME, self.MODELS, self.APP_OPTIONS
163
+ self.FRAMEWORK_NAME, self.pretrained_models, self.APP_OPTIONS
157
164
  )
158
165
  self._user_layout = self._gui.widgets
159
166
  self._user_layout_card = self._gui.card
@@ -239,6 +246,38 @@ class Inference:
239
246
  )
240
247
  device = "cpu"
241
248
 
249
+ def _load_models_json(self, models: str) -> List[Dict[str, Any]]:
250
+ """
251
+ Loads models from the provided file or list of model configurations.
252
+ """
253
+ if isinstance(models, str):
254
+ if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
255
+ models = sly_json.load_json_file(models)
256
+ else:
257
+ raise ValueError("File not found or invalid file format.")
258
+ else:
259
+ raise ValueError(
260
+ "Invalid models file. Please provide a valid '.json' file with list of model configurations."
261
+ )
262
+
263
+ if not isinstance(models, list):
264
+ raise ValueError("models parameters must be a list of dicts")
265
+ for item in models:
266
+ if not isinstance(item, dict):
267
+ raise ValueError(f"Each item in models must be a dict.")
268
+ model_meta = item.get("meta")
269
+ if model_meta is None:
270
+ raise ValueError(
271
+ "Model metadata not found. Please update provided models parameter to include key 'meta'."
272
+ )
273
+ model_files = model_meta.get("model_files")
274
+ if model_files is None:
275
+ raise ValueError(
276
+ "Model files not found in model metadata. "
277
+ "Please update provided models oarameter to include key 'model_files' in 'meta' key."
278
+ )
279
+ return models
280
+
242
281
  def get_ui(self) -> Widget:
243
282
  if not self._use_gui:
244
283
  return None
@@ -487,6 +526,9 @@ class Inference:
487
526
  def load_model_meta(self, model_tab: str, local_weights_path: str):
488
527
  raise NotImplementedError("Have to be implemented in child class after inheritance")
489
528
 
529
+ def _checkpoints_cache_dir(self):
530
+ return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
531
+
490
532
  def _download_model_files(self, model_source: str, model_files: List[str]) -> dict:
491
533
  if model_source == ModelSource.PRETRAINED:
492
534
  return self._download_pretrained_model(model_files)
@@ -498,17 +540,28 @@ class Inference:
498
540
  Downloads the pretrained model data.
499
541
  """
500
542
  local_model_files = {}
543
+ cache_dir = self._checkpoints_cache_dir()
501
544
 
502
545
  for file in model_files:
503
546
  file_url = model_files[file]
504
- file_path = os.path.join(self.model_dir, file)
547
+ file_name = sly_fs.get_file_name_with_ext(file_url)
505
548
  if file_url.startswith("http"):
506
549
  with urlopen(file_url) as f:
507
550
  file_size = f.length
508
551
  file_name = get_filename_from_headers(file_url)
509
- file_path = os.path.join(self.model_dir, file_name)
510
552
  if file_name is None:
511
553
  file_name = file
554
+ file_path = os.path.join(self.model_dir, file_name)
555
+ cached_path = os.path.join(cache_dir, file_name)
556
+ if os.path.exists(cached_path):
557
+ local_model_files[file] = cached_path
558
+ logger.debug(f"Model: '{file_name}' was found in checkpoint cache")
559
+ continue
560
+ if os.path.exists(file_path):
561
+ local_model_files[file] = file_path
562
+ logger.debug(f"Model: '{file_name}' was found in model dir")
563
+ continue
564
+
512
565
  with self.gui.download_progress(
513
566
  message=f"Downloading: '{file_name}'",
514
567
  total=file_size,
@@ -614,13 +667,23 @@ class Inference:
614
667
  model_files = deploy_params.get("model_files", {})
615
668
  if model_info:
616
669
  checkpoint_name = os.path.basename(model_files.get("checkpoint"))
670
+ checkpoint_file_path = os.path.join(
671
+ model_info.get("artifacts_dir"), "checkpoints", checkpoint_name
672
+ )
673
+ checkpoint_file_info = self.api.file.get_info_by_path(
674
+ env.team_id(), checkpoint_file_path
675
+ )
676
+ if checkpoint_file_info is None:
677
+ checkpoint_url = None
678
+ else:
679
+ checkpoint_url = self.api.file.get_url(checkpoint_file_info.id)
680
+
617
681
  self.checkpoint_info = CheckpointInfo(
618
682
  checkpoint_name=checkpoint_name,
619
683
  model_name=model_info.get("model_name"),
620
684
  architecture=model_info.get("framework_name"),
621
- custom_checkpoint_path=os.path.join(
622
- model_info.get("artifacts_dir"), checkpoint_name
623
- ),
685
+ checkpoint_url=checkpoint_url,
686
+ custom_checkpoint_path=checkpoint_file_path,
624
687
  model_source=ModelSource.CUSTOM,
625
688
  )
626
689
 
@@ -0,0 +1,2 @@
1
+ from supervisely.nn.training.train_app import TrainApp
2
+ from supervisely.nn.training.loggers import train_logger
@@ -11,36 +11,36 @@ class ClassesSelector:
11
11
  lock_message = "Select training and validation splits to unlock"
12
12
 
13
13
  def __init__(self, project_id: int, classes: list, app_options: dict = {}):
14
- self.classes_table = ClassesTable(project_id=project_id) # use dataset_ids
15
- if len(classes) > 0:
16
- self.classes_table.select_classes(classes) # from app options
17
- else:
18
- self.classes_table.select_all()
14
+ self.display_widgets = []
19
15
 
16
+ # GUI Components
20
17
  if is_development() or is_debug_with_sly_net():
21
18
  qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
22
19
  else:
23
20
  qa_stats_link = f"/projects/{project_id}/stats/datasets"
24
-
25
21
  qa_stats_text = Text(
26
22
  text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
27
23
  )
28
24
 
25
+ self.classes_table = ClassesTable(project_id=project_id)
26
+ if len(classes) > 0:
27
+ self.classes_table.select_classes(classes)
28
+ else:
29
+ self.classes_table.select_all()
30
+
29
31
  self.validator_text = Text("")
30
32
  self.validator_text.hide()
31
33
  self.button = Button("Select")
32
- container = Container(
33
- [
34
- qa_stats_text,
35
- self.classes_table,
36
- self.validator_text,
37
- self.button,
38
- ]
34
+ self.display_widgets.extend(
35
+ [qa_stats_text, self.classes_table, self.validator_text, self.button]
39
36
  )
37
+ # -------------------------------- #
38
+
39
+ self.container = Container(self.display_widgets)
40
40
  self.card = Card(
41
41
  title=self.title,
42
42
  description=self.description,
43
- content=container,
43
+ content=self.container,
44
44
  lock_message=self.lock_message,
45
45
  collapsable=app_options.get("collapsable", False),
46
46
  )
@@ -6,7 +6,7 @@ training workflows in Supervisely.
6
6
  """
7
7
 
8
8
  import supervisely.io.env as sly_env
9
- from supervisely import Api
9
+ from supervisely import Api, ProjectMeta
10
10
  from supervisely._utils import is_production
11
11
  from supervisely.app.widgets import Stepper, Widget
12
12
  from supervisely.nn.training.gui.classes_selector import ClassesSelector
@@ -17,7 +17,7 @@ from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplits
17
17
  from supervisely.nn.training.gui.training_logs import TrainingLogs
18
18
  from supervisely.nn.training.gui.training_process import TrainingProcess
19
19
  from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
20
- from supervisely.nn.utils import ModelSource
20
+ from supervisely.nn.utils import ModelSource, RuntimeType
21
21
 
22
22
 
23
23
  class TrainGUI:
@@ -62,6 +62,7 @@ class TrainGUI:
62
62
  self.workspace_id = sly_env.workspace_id()
63
63
  self.project_id = sly_env.project_id() # from app options?
64
64
  self.project_info = self._api.project.get_info_by_id(self.project_id)
65
+ self.project_meta = ProjectMeta.from_json(self._api.project.get_meta(self.project_id))
65
66
 
66
67
  # 1. Project selection + Train/val split
67
68
  self.input_selector = InputSelector(self.project_info, self.app_options)
@@ -399,7 +400,7 @@ class TrainGUI:
399
400
 
400
401
  app_state = {
401
402
  "input": {"project_id": 43192},
402
- "train_val_splits": {
403
+ "train_val_split": {
403
404
  "method": "random",
404
405
  "split": "train",
405
406
  "percent": 90
@@ -415,13 +416,18 @@ class TrainGUI:
415
416
  "enable": True,
416
417
  "speed_test": True
417
418
  },
418
- "cache_project": True
419
+ "cache_project": True,
420
+ "export": {
421
+ "enable": True,
422
+ "ONNXRuntime": True,
423
+ "TensorRT": True
424
+ },
419
425
  }
420
426
  }
421
427
  """
422
428
  app_state = self.validate_app_state(app_state)
423
429
 
424
- options = app_state["options"]
430
+ options = app_state.get("options", {})
425
431
  input_settings = app_state["input"]
426
432
  train_val_splits_settings = app_state["train_val_split"]
427
433
  classes_settings = app_state["classes"]
@@ -444,7 +450,7 @@ class TrainGUI:
444
450
  :type options: dict
445
451
  """
446
452
  # Set Input
447
- self.input_selector.set_cache(options["cache_project"])
453
+ self.input_selector.set_cache(options.get("cache_project", True))
448
454
  self.input_selector_cb()
449
455
  # ----------------------------------------- #
450
456
 
@@ -527,13 +533,22 @@ class TrainGUI:
527
533
  """
528
534
  self.hyperparameters_selector.set_hyperparameters(hyperparameters_settings)
529
535
 
530
- model_benchmark_settings = options["model_benchmark"]
531
- self.hyperparameters_selector.set_model_benchmark_checkbox_value(
532
- model_benchmark_settings["enable"]
533
- )
534
- self.hyperparameters_selector.set_speedtest_checkbox_value(
535
- model_benchmark_settings["speed_test"]
536
- )
536
+ model_benchmark_settings = options.get("model_benchmark", None)
537
+ if model_benchmark_settings is not None:
538
+ self.hyperparameters_selector.set_model_benchmark_checkbox_value(
539
+ model_benchmark_settings["enable"]
540
+ )
541
+ self.hyperparameters_selector.set_speedtest_checkbox_value(
542
+ model_benchmark_settings["speed_test"]
543
+ )
544
+ export_weights_settings = options.get("export", None)
545
+ if export_weights_settings is not None:
546
+ self.hyperparameters_selector.set_export_onnx_checkbox_value(
547
+ export_weights_settings.get(RuntimeType.ONNXRUNTIME, False)
548
+ )
549
+ self.hyperparameters_selector.set_export_tensorrt_checkbox_value(
550
+ export_weights_settings.get(RuntimeType.TENSORRT, False)
551
+ )
537
552
  self.hyperparameters_selector_cb()
538
553
 
539
554
  # ----------------------------------------- #
@@ -9,6 +9,7 @@ from supervisely.app.widgets import (
9
9
  Field,
10
10
  Text,
11
11
  )
12
+ from supervisely.nn.utils import RuntimeType
12
13
 
13
14
 
14
15
  class HyperparametersSelector:
@@ -17,55 +18,75 @@ class HyperparametersSelector:
17
18
  lock_message = "Select model to unlock"
18
19
 
19
20
  def __init__(self, hyperparameters: dict, app_options: dict = {}):
21
+ self.display_widgets = []
20
22
  self.app_options = app_options
23
+
24
+ # GUI Components
21
25
  self.editor = Editor(
22
26
  hyperparameters, height_lines=50, language_mode="yaml", auto_format=True
23
27
  )
28
+ self.display_widgets.extend([self.editor])
24
29
 
25
- # Model Benchmark
26
- self.run_model_benchmark_checkbox = Checkbox(
27
- content="Run Model Benchmark evaluation", checked=True
28
- )
29
- self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=True)
30
-
31
- self.model_benchmark_field = Field(
32
- Container(
33
- widgets=[
34
- self.run_model_benchmark_checkbox,
35
- self.run_speedtest_checkbox,
36
- ]
37
- ),
38
- title="Model Evaluation Benchmark",
39
- description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
40
- )
41
- docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
42
- self.model_benchmark_learn_more = Text(
43
- f"Learn more about Model Benchmark in the {docs_link}.", status="info"
44
- )
45
-
30
+ # Optional Model Benchmark
46
31
  if app_options.get("model_benchmark", True):
47
- self.model_benchmark_field.show()
48
- self.model_benchmark_learn_more.show()
49
- else:
50
- self.model_benchmark_field.hide()
51
- self.model_benchmark_learn_more.hide()
32
+ # Model Benchmark
33
+ self.run_model_benchmark_checkbox = Checkbox(
34
+ content="Run Model Benchmark evaluation", checked=True
35
+ )
36
+ self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=True)
37
+
38
+ self.model_benchmark_field = Field(
39
+ title="Model Evaluation Benchmark",
40
+ description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
41
+ content=Container([self.run_model_benchmark_checkbox, self.run_speedtest_checkbox]),
42
+ )
43
+ docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
44
+ self.model_benchmark_learn_more = Text(
45
+ f"Learn more about Model Benchmark in the {docs_link}.", status="info"
46
+ )
47
+ self.display_widgets.extend(
48
+ [self.model_benchmark_field, self.model_benchmark_learn_more]
49
+ )
50
+ # -------------------------------- #
51
+
52
+ # Optional Export Weights
53
+ export_onnx_supported = app_options.get("export_onnx_supported", False)
54
+ export_tensorrt_supported = app_options.get("export_tensorrt_supported", False)
55
+
56
+ onnx_name = "ONNX"
57
+ tensorrt_name = "TensorRT engine"
58
+ export_runtimes = []
59
+ export_runtime_names = []
60
+ if export_onnx_supported:
61
+ self.export_onnx_checkbox = Checkbox(content=f"Export to {onnx_name}")
62
+ export_runtimes.append(self.export_onnx_checkbox)
63
+ export_runtime_names.append(onnx_name)
64
+ if export_tensorrt_supported:
65
+ self.export_tensorrt_checkbox = Checkbox(content=f"Export to {tensorrt_name}")
66
+ export_runtimes.append(self.export_tensorrt_checkbox)
67
+ export_runtime_names.append(tensorrt_name)
68
+ if export_onnx_supported or export_tensorrt_supported:
69
+ export_field_description = ", ".join(export_runtime_names)
70
+ runtime_container = Container(export_runtimes)
71
+ self.export_field = Field(
72
+ title="Export model",
73
+ description=f"Export best checkpoint to the following formats: {export_field_description}.",
74
+ content=runtime_container,
75
+ )
76
+ self.display_widgets.extend([self.export_field])
77
+ # -------------------------------- #
52
78
 
53
79
  self.validator_text = Text("")
54
80
  self.validator_text.hide()
55
81
  self.button = Button("Select")
56
- container = Container(
57
- [
58
- self.editor,
59
- self.model_benchmark_field,
60
- self.model_benchmark_learn_more,
61
- self.validator_text,
62
- self.button,
63
- ]
64
- )
82
+ self.display_widgets.extend([self.validator_text, self.button])
83
+ # -------------------------------- #
84
+
85
+ self.container = Container(self.display_widgets)
65
86
  self.card = Card(
66
87
  title=self.title,
67
88
  description=self.description,
68
- content=container,
89
+ content=self.container,
69
90
  lock_message=self.lock_message,
70
91
  collapsable=app_options.get("collapsable", False),
71
92
  )
@@ -73,11 +94,14 @@ class HyperparametersSelector:
73
94
 
74
95
  @property
75
96
  def widgets_to_disable(self) -> list:
76
- return [
77
- self.editor,
78
- self.run_model_benchmark_checkbox,
79
- self.run_speedtest_checkbox,
80
- ]
97
+ widgets = [self.editor]
98
+ if self.app_options.get("model_benchmark", True):
99
+ widgets.extend([self.run_model_benchmark_checkbox, self.run_speedtest_checkbox])
100
+ if self.app_options.get("export_onnx_supported", False):
101
+ widgets.append(self.export_onnx_checkbox)
102
+ if self.app_options.get("export_tensorrt_supported", False):
103
+ widgets.append(self.export_tensorrt_checkbox)
104
+ return widgets
81
105
 
82
106
  def set_hyperparameters(self, hyperparameters: Union[str, dict]) -> None:
83
107
  self.editor.set_text(hyperparameters)
@@ -113,5 +137,30 @@ class HyperparametersSelector:
113
137
  else:
114
138
  self.run_speedtest_checkbox.hide()
115
139
 
140
+ def get_export_onnx_checkbox_value(self) -> bool:
141
+ if self.app_options.get("export_onnx_supported", False):
142
+ return self.export_onnx_checkbox.is_checked()
143
+ return False
144
+
145
+ def set_export_onnx_checkbox_value(self, value: bool) -> None:
146
+ if value:
147
+ self.export_onnx_checkbox.check()
148
+ else:
149
+ self.export_onnx_checkbox.uncheck()
150
+
151
+ def get_export_tensorrt_checkbox_value(self) -> bool:
152
+ if self.app_options.get("export_tensorrt_supported", False):
153
+ return self.export_tensorrt_checkbox.is_checked()
154
+ return False
155
+
156
+ def set_export_tensorrt_checkbox_value(self, value: bool) -> None:
157
+ if value:
158
+ self.export_tensorrt_checkbox.check()
159
+ else:
160
+ self.export_tensorrt_checkbox.uncheck()
161
+
162
+ def is_export_required(self) -> bool:
163
+ return self.get_export_onnx_checkbox_value() or self.get_export_tensorrt_checkbox_value()
164
+
116
165
  def validate_step(self) -> bool:
117
166
  return True
@@ -17,9 +17,11 @@ class InputSelector:
17
17
  lock_message = None
18
18
 
19
19
  def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
20
+ self.display_widgets = []
20
21
  self.project_id = project_info.id
21
22
  self.project_info = project_info
22
23
 
24
+ # GUI Components
23
25
  self.project_thumbnail = ProjectThumbnail(self.project_info)
24
26
 
25
27
  if is_cached(self.project_id):
@@ -32,27 +34,27 @@ class InputSelector:
32
34
  self.validator_text = Text("")
33
35
  self.validator_text.hide()
34
36
  self.button = Button("Select")
35
- container = Container(
36
- widgets=[
37
+ self.display_widgets.extend(
38
+ [
37
39
  self.project_thumbnail,
38
40
  self.use_cache_checkbox,
39
41
  self.validator_text,
40
42
  self.button,
41
43
  ]
42
44
  )
45
+ # -------------------------------- #
43
46
 
47
+ self.container = Container(self.display_widgets)
44
48
  self.card = Card(
45
49
  title=self.title,
46
50
  description=self.description,
47
- content=container,
51
+ content=self.container,
48
52
  collapsable=app_options.get("collapsable", False),
49
53
  )
50
54
 
51
55
  @property
52
56
  def widgets_to_disable(self) -> list:
53
- return [
54
- self.use_cache_checkbox,
55
- ]
57
+ return [self.use_cache_checkbox]
56
58
 
57
59
  def get_project_id(self) -> int:
58
60
  return self.project_id
@@ -21,15 +21,14 @@ class ModelSelector:
21
21
  lock_message = "Select classes to unlock"
22
22
 
23
23
  def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
24
+ self.display_widgets = []
24
25
  self.team_id = sly_env.team_id() # get from project id
25
26
  self.models = models
26
27
 
27
- # Pretrained models
28
+ # GUI Components
28
29
  self.pretrained_models_table = PretrainedModelsSelector(self.models)
29
-
30
30
  experiment_infos = get_experiment_infos(api, self.team_id, framework)
31
31
  self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
32
- # Model source tabs
33
32
  self.model_source_tabs = RadioTabs(
34
33
  titles=[ModelSource.PRETRAINED, ModelSource.CUSTOM],
35
34
  descriptions=[
@@ -42,11 +41,14 @@ class ModelSelector:
42
41
  self.validator_text = Text("")
43
42
  self.validator_text.hide()
44
43
  self.button = Button("Select")
45
- container = Container([self.model_source_tabs, self.validator_text, self.button])
44
+ self.display_widgets.extend([self.model_source_tabs, self.validator_text, self.button])
45
+ # -------------------------------- #
46
+
47
+ self.container = Container(self.display_widgets)
46
48
  self.card = Card(
47
49
  title=self.title,
48
50
  description=self.description,
49
- content=container,
51
+ content=self.container,
50
52
  lock_message=self.lock_message,
51
53
  collapsable=app_options.get("collapsable", False),
52
54
  )
@@ -10,10 +10,12 @@ class TrainValSplitsSelector:
10
10
  lock_message = "Select input options to unlock"
11
11
 
12
12
  def __init__(self, api: Api, project_id: int, app_options: dict = {}):
13
+ self.display_widgets = []
13
14
  self.api = api
14
15
  self.project_id = project_id
15
- self.train_val_splits = TrainValSplits(project_id)
16
16
 
17
+ # GUI Components
18
+ self.train_val_splits = TrainValSplits(project_id)
17
19
  train_val_dataset_ids = {"train": [], "val": []}
18
20
  for _, dataset in api.dataset.tree(project_id):
19
21
  if dataset.name.lower() == "train" or dataset.name.lower() == "training":
@@ -39,17 +41,14 @@ class TrainValSplitsSelector:
39
41
  self.validator_text.hide()
40
42
 
41
43
  self.button = Button("Select")
42
- container = Container(
43
- [
44
- self.train_val_splits,
45
- self.validator_text,
46
- self.button,
47
- ]
48
- )
44
+ self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
45
+ # -------------------------------- #
46
+
47
+ self.container = Container(self.display_widgets)
49
48
  self.card = Card(
50
49
  title=self.title,
51
50
  description=self.description,
52
- content=container,
51
+ content=self.container,
53
52
  lock_message=self.lock_message,
54
53
  collapsable=app_options.get("collapsable", False),
55
54
  )
@@ -12,14 +12,13 @@ class TrainingLogs:
12
12
  lock_message = "Start training to unlock"
13
13
 
14
14
  def __init__(self, app_options: Dict[str, Any]):
15
+ self.display_widgets = []
15
16
  api = Api.from_env()
16
17
  self.app_options = app_options
17
18
 
18
- self.progress_bar_main = Progress(hide_on_finish=False)
19
- self.progress_bar_main.hide()
20
-
21
- self.progress_bar_secondary = Progress(hide_on_finish=False)
22
- self.progress_bar_secondary.hide()
19
+ # GUI Components
20
+ self.validator_text = Text("")
21
+ self.validator_text.hide()
23
22
 
24
23
  if is_production():
25
24
  task_id = get_task_id(raise_not_found=False)
@@ -43,16 +42,9 @@ class TrainingLogs:
43
42
  )
44
43
  self.tensorboard_button.disable()
45
44
 
46
- self.validator_text = Text("")
47
- self.validator_text.hide()
48
-
49
- container_widgets = [
50
- self.validator_text,
51
- self.tensorboard_button,
52
- self.progress_bar_main,
53
- self.progress_bar_secondary,
54
- ]
45
+ self.display_widgets.extend([self.validator_text, self.tensorboard_button])
55
46
 
47
+ # Optional Show logs button
56
48
  if app_options.get("show_logs_in_gui", False):
57
49
  self.logs_button = Button(
58
50
  text="Show logs",
@@ -63,14 +55,22 @@ class TrainingLogs:
63
55
  self.task_logs = TaskLogs(task_id)
64
56
  self.task_logs.hide()
65
57
  logs_container = Container([self.logs_button, self.task_logs])
66
- container_widgets.insert(2, logs_container)
58
+ self.display_widgets.extend([logs_container])
59
+ # -------------------------------- #
67
60
 
68
- container = Container(container_widgets)
61
+ # Progress bars
62
+ self.progress_bar_main = Progress(hide_on_finish=False)
63
+ self.progress_bar_main.hide()
64
+ self.progress_bar_secondary = Progress(hide_on_finish=False)
65
+ self.progress_bar_secondary.hide()
66
+ self.display_widgets.extend([self.progress_bar_main, self.progress_bar_secondary])
67
+ # -------------------------------- #
69
68
 
69
+ self.container = Container(self.display_widgets)
70
70
  self.card = Card(
71
71
  title=self.title,
72
72
  description=self.description,
73
- content=container,
73
+ content=self.container,
74
74
  lock_message=self.lock_message,
75
75
  )
76
76
  self.card.lock()