supervisely 6.73.326__py3-none-any.whl → 6.73.328__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. supervisely/annotation/annotation.py +1 -1
  2. supervisely/annotation/tag_meta.py +1 -1
  3. supervisely/api/api.py +8 -5
  4. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +17 -14
  5. supervisely/app/widgets/pretrained_models_selector/template.html +2 -1
  6. supervisely/convert/image/yolo/yolo_helper.py +95 -25
  7. supervisely/io/env.py +15 -0
  8. supervisely/nn/inference/gui/serving_gui_template.py +2 -3
  9. supervisely/nn/inference/inference.py +33 -25
  10. supervisely/nn/training/gui/classes_selector.py +24 -19
  11. supervisely/nn/training/gui/gui.py +90 -37
  12. supervisely/nn/training/gui/hyperparameters_selector.py +32 -15
  13. supervisely/nn/training/gui/input_selector.py +13 -2
  14. supervisely/nn/training/gui/model_selector.py +16 -6
  15. supervisely/nn/training/gui/train_val_splits_selector.py +10 -1
  16. supervisely/nn/training/gui/training_artifacts.py +23 -4
  17. supervisely/nn/training/gui/training_logs.py +15 -3
  18. supervisely/nn/training/gui/training_process.py +14 -13
  19. supervisely/nn/training/train_app.py +59 -24
  20. supervisely/nn/utils.py +9 -0
  21. supervisely/project/project.py +16 -3
  22. {supervisely-6.73.326.dist-info → supervisely-6.73.328.dist-info}/METADATA +1 -1
  23. {supervisely-6.73.326.dist-info → supervisely-6.73.328.dist-info}/RECORD +27 -27
  24. {supervisely-6.73.326.dist-info → supervisely-6.73.328.dist-info}/WHEEL +1 -1
  25. {supervisely-6.73.326.dist-info → supervisely-6.73.328.dist-info}/LICENSE +0 -0
  26. {supervisely-6.73.326.dist-info → supervisely-6.73.328.dist-info}/entry_points.txt +0 -0
  27. {supervisely-6.73.326.dist-info → supervisely-6.73.328.dist-info}/top_level.txt +0 -0
@@ -4,21 +4,28 @@ from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
4
4
 
5
5
  class ClassesSelector:
6
6
  title = "Classes Selector"
7
- description = (
8
- "Select classes that will be used for training. "
9
- "Supported shapes are Bitmap, Polygon, Rectangle."
10
- )
7
+ description = "Select classes that will be used for training"
11
8
  lock_message = "Select training and validation splits to unlock"
12
9
 
13
10
  def __init__(self, project_id: int, classes: list, app_options: dict = {}):
11
+ # Init widgets
12
+ self.qa_stats_text = None
13
+ self.classes_table = None
14
+ self.validator_text = None
15
+ self.button = None
16
+ self.container = None
17
+ self.card = None
18
+ # -------------------------------- #
19
+
14
20
  self.display_widgets = []
21
+ self.app_options = app_options
15
22
 
16
23
  # GUI Components
17
24
  if is_development() or is_debug_with_sly_net():
18
25
  qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
19
26
  else:
20
27
  qa_stats_link = f"/projects/{project_id}/stats/datasets"
21
- qa_stats_text = Text(
28
+ self.qa_stats_text = Text(
22
29
  text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
23
30
  )
24
31
 
@@ -32,7 +39,7 @@ class ClassesSelector:
32
39
  self.validator_text.hide()
33
40
  self.button = Button("Select")
34
41
  self.display_widgets.extend(
35
- [qa_stats_text, self.classes_table, self.validator_text, self.button]
42
+ [self.qa_stats_text, self.classes_table, self.validator_text, self.button]
36
43
  )
37
44
  # -------------------------------- #
38
45
 
@@ -42,7 +49,7 @@ class ClassesSelector:
42
49
  description=self.description,
43
50
  content=self.container,
44
51
  lock_message=self.lock_message,
45
- collapsable=app_options.get("collapsable", False),
52
+ collapsable=self.app_options.get("collapsable", False),
46
53
  )
47
54
  self.card.lock()
48
55
 
@@ -62,14 +69,14 @@ class ClassesSelector:
62
69
  def validate_step(self) -> bool:
63
70
  self.validator_text.hide()
64
71
 
65
- if len(self.classes_table.project_meta.obj_classes) == 0:
72
+ project_classes = self.classes_table.project_meta.obj_classes
73
+ if len(project_classes) == 0:
66
74
  self.validator_text.set(text="Project has no classes", status="error")
67
75
  self.validator_text.show()
68
76
  return False
69
77
 
70
78
  selected_classes = self.classes_table.get_selected_classes()
71
79
  table_data = self.classes_table._table_data
72
-
73
80
  empty_classes = [
74
81
  row[0]["data"]
75
82
  for row in table_data
@@ -78,23 +85,21 @@ class ClassesSelector:
78
85
 
79
86
  n_classes = len(selected_classes)
80
87
  if n_classes == 0:
81
- self.validator_text.set(text="Please select at least one class", status="error")
88
+ message = "Please select at least one class"
89
+ status = "error"
82
90
  else:
83
- warning_text = ""
91
+ class_text = "class" if n_classes == 1 else "classes"
92
+ message = f"Selected {n_classes} {class_text}"
84
93
  status = "success"
85
94
  if empty_classes:
86
95
  intersections = set(selected_classes).intersection(empty_classes)
87
96
  if intersections:
88
- warning_text = (
89
- f". Selected class has no annotations: {', '.join(intersections)}"
90
- if len(intersections) == 1
91
- else f". Selected classes have no annotations: {', '.join(intersections)}"
97
+ class_text = "class" if len(intersections) == 1 else "classes"
98
+ message += (
99
+ f". Selected {class_text} have no annotations: {', '.join(intersections)}"
92
100
  )
93
101
  status = "warning"
94
102
 
95
- class_text = "class" if n_classes == 1 else "classes"
96
- self.validator_text.set(
97
- text=f"Selected {n_classes} {class_text}{warning_text}", status=status
98
- )
103
+ self.validator_text.set(text=message, status=status)
99
104
  self.validator_text.show()
100
105
  return n_classes > 0
@@ -6,11 +6,15 @@ training workflows in Supervisely.
6
6
  """
7
7
 
8
8
  from os import environ
9
+ from typing import Union
9
10
 
10
11
  import supervisely.io.env as sly_env
12
+ import supervisely.io.json as sly_json
11
13
  from supervisely import Api, ProjectMeta
12
14
  from supervisely._utils import is_production
13
15
  from supervisely.app.widgets import Stepper, Widget
16
+ from supervisely.geometry.bitmap import Bitmap
17
+ from supervisely.geometry.graph import GraphNodes
14
18
  from supervisely.geometry.polygon import Polygon
15
19
  from supervisely.geometry.rectangle import Rectangle
16
20
  from supervisely.nn.task_type import TaskType
@@ -63,7 +67,7 @@ class TrainGUI:
63
67
  self.models = models
64
68
  self.hyperparameters = hyperparameters
65
69
  self.app_options = app_options
66
- self.collapsable = app_options.get("collapsable", False)
70
+ self.collapsable = self.app_options.get("collapsable", False)
67
71
  self.need_convert_shapes_for_bm = False
68
72
 
69
73
  self.team_id = sly_env.team_id(raise_not_found=False)
@@ -142,33 +146,73 @@ class TrainGUI:
142
146
  self.training_process.set_experiment_name(experiment_name)
143
147
 
144
148
  def need_convert_class_shapes() -> bool:
145
- if not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
146
- self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
147
- self.need_convert_shapes_for_bm = False
148
- else:
149
- task_type = self.model_selector.get_selected_task_type()
150
-
151
- def _need_convert(shape):
152
- if task_type == TaskType.OBJECT_DETECTION:
153
- return shape != Rectangle.geometry_name()
154
- elif task_type in [
155
- TaskType.INSTANCE_SEGMENTATION,
156
- TaskType.SEMANTIC_SEGMENTATION,
157
- ]:
158
- return shape == Polygon.geometry_name()
159
- return
160
-
161
- data = self.classes_selector.classes_table._table_data
162
- selected_classes = set(self.classes_selector.classes_table.get_selected_classes())
163
- empty = set(r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0)
164
- need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
165
-
166
- if need_convert.intersection(selected_classes - empty):
167
- self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
168
- self.need_convert_shapes_for_bm = True
169
- else:
149
+ if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
150
+ if not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
170
151
  self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
171
152
  self.need_convert_shapes_for_bm = False
153
+ else:
154
+ task_type = self.model_selector.get_selected_task_type()
155
+
156
+ def _need_convert(shape):
157
+ if task_type == TaskType.OBJECT_DETECTION:
158
+ return shape != Rectangle.geometry_name()
159
+ elif task_type in [
160
+ TaskType.INSTANCE_SEGMENTATION,
161
+ TaskType.SEMANTIC_SEGMENTATION,
162
+ ]:
163
+ return shape == Polygon.geometry_name()
164
+ return
165
+
166
+ data = self.classes_selector.classes_table._table_data
167
+ selected_classes = set(
168
+ self.classes_selector.classes_table.get_selected_classes()
169
+ )
170
+ empty = set(
171
+ r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
172
+ )
173
+ need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
174
+
175
+ if need_convert.intersection(selected_classes - empty):
176
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
177
+ self.need_convert_shapes_for_bm = True
178
+ else:
179
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
180
+ self.need_convert_shapes_for_bm = False
181
+ else:
182
+ self.need_convert_shapes_for_bm = False
183
+
184
+ def validate_class_shape_for_model_task():
185
+ task_type = self.model_selector.get_selected_task_type()
186
+ classes = self.classes_selector.get_selected_classes()
187
+
188
+ required_geometries = {
189
+ TaskType.INSTANCE_SEGMENTATION: {Polygon, Bitmap},
190
+ TaskType.SEMANTIC_SEGMENTATION: {Polygon, Bitmap},
191
+ TaskType.POSE_ESTIMATION: {GraphNodes},
192
+ }
193
+ task_specific_texts = {
194
+ TaskType.INSTANCE_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
195
+ TaskType.SEMANTIC_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
196
+ TaskType.POSE_ESTIMATION: "Only keypoint (graph) shape is supported for pose estimation task",
197
+ }
198
+
199
+ if task_type not in required_geometries:
200
+ return
201
+
202
+ wrong_shape_classes = [
203
+ class_name
204
+ for class_name in classes
205
+ if self.project_meta.get_obj_class(class_name).geometry_type
206
+ not in required_geometries[task_type]
207
+ ]
208
+
209
+ if wrong_shape_classes:
210
+ specific_text = task_specific_texts[task_type]
211
+ message_text = f"Model task type is {task_type}. {specific_text}. Selected classes have wrong shapes for the model task: {', '.join(wrong_shape_classes)}"
212
+ self.model_selector.validator_text.set(
213
+ text=message_text,
214
+ status="warning",
215
+ )
172
216
 
173
217
  # ------------------------------------------------- #
174
218
 
@@ -201,7 +245,11 @@ class TrainGUI:
201
245
  callback=self.hyperparameters_selector_cb,
202
246
  validation_text=self.model_selector.validator_text,
203
247
  validation_func=self.model_selector.validate_step,
204
- on_select_click=[set_experiment_name, need_convert_class_shapes],
248
+ on_select_click=[
249
+ set_experiment_name,
250
+ need_convert_class_shapes,
251
+ validate_class_shape_for_model_task,
252
+ ],
205
253
  collapse_card=(self.model_selector.card, self.collapsable),
206
254
  )
207
255
 
@@ -299,17 +347,19 @@ class TrainGUI:
299
347
  # ------------------------------------------------- #
300
348
 
301
349
  # Other Buttons
302
- if app_options.get("show_logs_in_gui", False):
350
+ if self.app_options.get("show_logs_in_gui", False):
303
351
 
304
352
  @self.training_logs.logs_button.click
305
353
  def show_logs():
306
354
  self.training_logs.toggle_logs()
307
355
 
308
356
  # Other handlers
309
- @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
310
- def show_mb_speedtest(is_checked: bool):
311
- self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
312
- need_convert_class_shapes()
357
+ if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
358
+
359
+ @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
360
+ def show_mb_speedtest(is_checked: bool):
361
+ self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
362
+ need_convert_class_shapes()
313
363
 
314
364
  # ------------------------------------------------- #
315
365
 
@@ -361,7 +411,6 @@ class TrainGUI:
361
411
  raise ValueError("app_state must be a dictionary")
362
412
 
363
413
  required_keys = {
364
- "input": ["project_id"],
365
414
  "train_val_split": ["method"],
366
415
  "classes": list,
367
416
  "model": ["source"],
@@ -453,7 +502,7 @@ class TrainGUI:
453
502
  raise ValueError("percent must be an integer in range 1 to 99")
454
503
  return app_state
455
504
 
456
- def load_from_app_state(self, app_state: dict) -> None:
505
+ def load_from_app_state(self, app_state: Union[str, dict]) -> None:
457
506
  """
458
507
  Load the GUI state from app state dictionary.
459
508
 
@@ -463,7 +512,6 @@ class TrainGUI:
463
512
  app_state example:
464
513
 
465
514
  app_state = {
466
- "input": {"project_id": 43192},
467
515
  "train_val_split": {
468
516
  "method": "random",
469
517
  "split": "train",
@@ -489,10 +537,13 @@ class TrainGUI:
489
537
  }
490
538
  }
491
539
  """
540
+ if isinstance(app_state, str):
541
+ app_state = sly_json.load_json_file(app_state)
542
+
492
543
  app_state = self.validate_app_state(app_state)
493
544
 
494
545
  options = app_state.get("options", {})
495
- input_settings = app_state["input"]
546
+ input_settings = app_state.get("input")
496
547
  train_val_splits_settings = app_state["train_val_split"]
497
548
  classes_settings = app_state["classes"]
498
549
  model_settings = app_state["model"]
@@ -504,7 +555,7 @@ class TrainGUI:
504
555
  self._init_model(model_settings)
505
556
  self._init_hyperparameters(hyperparameters_settings, options)
506
557
 
507
- def _init_input(self, input_settings: dict, options: dict) -> None:
558
+ def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
508
559
  """
509
560
  Initialize the input selector with the given settings.
510
561
 
@@ -604,6 +655,8 @@ class TrainGUI:
604
655
  )
605
656
  self.hyperparameters_selector.set_speedtest_checkbox_value(
606
657
  model_benchmark_settings["speed_test"]
658
+ if model_benchmark_settings["enable"]
659
+ else False
607
660
  )
608
661
  export_weights_settings = options.get("export", None)
609
662
  if export_weights_settings is not None:
@@ -9,7 +9,6 @@ from supervisely.app.widgets import (
9
9
  Field,
10
10
  Text,
11
11
  )
12
- from supervisely.nn.utils import RuntimeType
13
12
 
14
13
 
15
14
  class HyperparametersSelector:
@@ -18,6 +17,22 @@ class HyperparametersSelector:
18
17
  lock_message = "Select model to unlock"
19
18
 
20
19
  def __init__(self, hyperparameters: dict, app_options: dict = {}):
20
+ # Init widgets
21
+ self.editor = None
22
+ self.run_model_benchmark_checkbox = None
23
+ self.run_speedtest_checkbox = None
24
+ self.model_benchmark_field = None
25
+ self.model_benchmark_learn_more = None
26
+ self.model_benchmark_auto_convert_warning = None
27
+ self.export_onnx_checkbox = None
28
+ self.export_tensorrt_checkbox = None
29
+ self.export_field = None
30
+ self.validator_text = None
31
+ self.button = None
32
+ self.container = None
33
+ self.card = None
34
+ # -------------------------------- #
35
+
21
36
  self.display_widgets = []
22
37
  self.app_options = app_options
23
38
 
@@ -28,7 +43,7 @@ class HyperparametersSelector:
28
43
  self.display_widgets.extend([self.editor])
29
44
 
30
45
  # Optional Model Benchmark
31
- if app_options.get("model_benchmark", True):
46
+ if self.app_options.get("model_benchmark", True):
32
47
  # Model Benchmark
33
48
  self.run_model_benchmark_checkbox = Checkbox(
34
49
  content="Run Model Benchmark evaluation", checked=True
@@ -37,7 +52,7 @@ class HyperparametersSelector:
37
52
 
38
53
  self.model_benchmark_field = Field(
39
54
  title="Model Evaluation Benchmark",
40
- description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
55
+ description="Generate evaluation dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
41
56
  content=Container([self.run_model_benchmark_checkbox, self.run_speedtest_checkbox]),
42
57
  )
43
58
  docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
@@ -60,8 +75,8 @@ class HyperparametersSelector:
60
75
  # -------------------------------- #
61
76
 
62
77
  # Optional Export Weights
63
- export_onnx_supported = app_options.get("export_onnx_supported", False)
64
- export_tensorrt_supported = app_options.get("export_tensorrt_supported", False)
78
+ export_onnx_supported = self.app_options.get("export_onnx_supported", False)
79
+ export_tensorrt_supported = self.app_options.get("export_tensorrt_supported", False)
65
80
 
66
81
  onnx_name = "ONNX"
67
82
  tensorrt_name = "TensorRT engine"
@@ -120,26 +135,28 @@ class HyperparametersSelector:
120
135
  return self.editor.get_value()
121
136
 
122
137
  def get_model_benchmark_checkbox_value(self) -> bool:
123
- if self.app_options.get("model_benchmark", True):
138
+ if self.run_model_benchmark_checkbox is not None:
124
139
  return self.run_model_benchmark_checkbox.is_checked()
125
140
  return False
126
141
 
127
142
  def set_model_benchmark_checkbox_value(self, is_checked: bool) -> bool:
128
- if is_checked:
129
- self.run_model_benchmark_checkbox.check()
130
- else:
131
- self.run_model_benchmark_checkbox.uncheck()
143
+ if self.run_model_benchmark_checkbox is not None:
144
+ if is_checked:
145
+ self.run_model_benchmark_checkbox.check()
146
+ else:
147
+ self.run_model_benchmark_checkbox.uncheck()
132
148
 
133
149
  def get_speedtest_checkbox_value(self) -> bool:
134
- if self.app_options.get("model_benchmark", True):
150
+ if self.run_speedtest_checkbox is not None:
135
151
  return self.run_speedtest_checkbox.is_checked()
136
152
  return False
137
153
 
138
154
  def set_speedtest_checkbox_value(self, is_checked: bool) -> bool:
139
- if is_checked:
140
- self.run_speedtest_checkbox.check()
141
- else:
142
- self.run_speedtest_checkbox.uncheck()
155
+ if self.run_speedtest_checkbox is not None:
156
+ if is_checked:
157
+ self.run_speedtest_checkbox.check()
158
+ else:
159
+ self.run_speedtest_checkbox.uncheck()
143
160
 
144
161
  def toggle_mb_speedtest(self, is_checked: bool) -> None:
145
162
  if is_checked:
@@ -4,7 +4,6 @@ from supervisely.app.widgets import (
4
4
  Card,
5
5
  Checkbox,
6
6
  Container,
7
- Field,
8
7
  ProjectThumbnail,
9
8
  Text,
10
9
  )
@@ -17,7 +16,19 @@ class InputSelector:
17
16
  lock_message = None
18
17
 
19
18
  def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
19
+ # Init widgets
20
+ self.project_thumbnail = None
21
+ self.use_cache_text = None
22
+ self.use_cache_checkbox = None
23
+ self.validator_text = None
24
+ self.button = None
25
+ self.container = None
26
+ self.card = None
27
+ # -------------------------------- #
28
+
20
29
  self.display_widgets = []
30
+ self.app_options = app_options
31
+
21
32
  self.project_id = project_info.id
22
33
  self.project_info = project_info
23
34
 
@@ -49,7 +60,7 @@ class InputSelector:
49
60
  title=self.title,
50
61
  description=self.description,
51
62
  content=self.container,
52
- collapsable=app_options.get("collapsable", False),
63
+ collapsable=self.app_options.get("collapsable", False),
53
64
  )
54
65
 
55
66
  @property
@@ -14,7 +14,7 @@ from supervisely.app.widgets import (
14
14
  )
15
15
  from supervisely.nn.artifacts.utils import FrameworkMapper
16
16
  from supervisely.nn.experiments import get_experiment_infos
17
- from supervisely.nn.utils import ModelSource
17
+ from supervisely.nn.utils import ModelSource, _get_model_name
18
18
 
19
19
 
20
20
  class ModelSelector:
@@ -23,15 +23,26 @@ class ModelSelector:
23
23
  lock_message = "Select classes to unlock"
24
24
 
25
25
  def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
26
+ # Init widgets
27
+ self.pretrained_models_table = None
28
+ self.experiment_selector = None
29
+ self.model_source_tabs = None
30
+ self.validator_text = None
31
+ self.button = None
32
+ self.container = None
33
+ self.card = None
34
+ # -------------------------------- #
35
+
26
36
  self.display_widgets = []
37
+ self.app_options = app_options
38
+
27
39
  self.team_id = sly_env.team_id()
28
40
  self.models = models
29
41
 
30
42
  # GUI Components
31
43
  self.pretrained_models_table = PretrainedModelsSelector(self.models)
32
-
33
44
  experiment_infos = get_experiment_infos(api, self.team_id, framework)
34
- if app_options.get("legacy_checkpoints", False):
45
+ if self.app_options.get("legacy_checkpoints", False):
35
46
  try:
36
47
  framework_cls = FrameworkMapper.get_framework_cls(framework, self.team_id)
37
48
  legacy_experiment_infos = framework_cls.get_list_experiment_info()
@@ -61,7 +72,7 @@ class ModelSelector:
61
72
  description=self.description,
62
73
  content=self.container,
63
74
  lock_message=self.lock_message,
64
- collapsable=app_options.get("collapsable", False),
75
+ collapsable=self.app_options.get("collapsable", False),
65
76
  )
66
77
  self.card.lock()
67
78
 
@@ -82,8 +93,7 @@ class ModelSelector:
82
93
  def get_model_name(self) -> str:
83
94
  if self.get_model_source() == ModelSource.PRETRAINED:
84
95
  selected_row = self.pretrained_models_table.get_selected_row()
85
- model_meta = selected_row.get("meta", {})
86
- model_name = model_meta.get("model_name", None)
96
+ model_name = _get_model_name(selected_row)
87
97
  else:
88
98
  selected_row = self.experiment_selector.get_selected_experiment_info()
89
99
  model_name = selected_row.get("model_name", None)
@@ -10,7 +10,16 @@ class TrainValSplitsSelector:
10
10
  lock_message = "Select input options to unlock"
11
11
 
12
12
  def __init__(self, api: Api, project_id: int, app_options: dict = {}):
13
+ # Init widgets
14
+ self.train_val_splits = None
15
+ self.validator_text = None
16
+ self.button = None
17
+ self.container = None
18
+ self.card = None
19
+ # -------------------------------- #
20
+
13
21
  self.display_widgets = []
22
+ self.app_options = app_options
14
23
  self.api = api
15
24
  self.project_id = project_id
16
25
 
@@ -50,7 +59,7 @@ class TrainValSplitsSelector:
50
59
  description=self.description,
51
60
  content=self.container,
52
61
  lock_message=self.lock_message,
53
- collapsable=app_options.get("collapsable", False),
62
+ collapsable=self.app_options.get("collapsable", False),
54
63
  )
55
64
  self.card.lock()
56
65
 
@@ -5,7 +5,6 @@ import supervisely.io.env as sly_env
5
5
  import supervisely.nn.training.gui.utils as gui_utils
6
6
  from supervisely import Api, logger
7
7
  from supervisely._utils import is_production
8
- from supervisely.api.api import ApiField
9
8
  from supervisely.app.widgets import (
10
9
  Card,
11
10
  Container,
@@ -34,13 +33,30 @@ class TrainingArtifacts:
34
33
  lock_message = "Artifacts will be available after training is completed"
35
34
 
36
35
  def __init__(self, api: Api, app_options: Dict[str, Any]):
36
+ # Init widgets
37
+ self.artifacts_thumbnail = None
38
+ self.artifacts_field = None
39
+ self.model_benchmark_report_thumbnail = None
40
+ self.model_benchmark_fail_text = None
41
+ self.model_benchmark_widgets = None
42
+ self.model_benchmark_report_field = None
43
+ self.pytorch_instruction = None
44
+ self.onnx_instruction = None
45
+ self.trt_instruction = None
46
+ self.inference_demo_field = None
47
+ self.validator_text = None
48
+ self.container = None
49
+ self.card = None
50
+ # -------------------------------- #
51
+
37
52
  self.display_widgets = []
53
+ self.app_options = app_options
54
+
38
55
  self.success_message_text = (
39
56
  "Training completed. Training artifacts were uploaded to Team Files. "
40
57
  "You can find and open tensorboard logs in the artifacts folder via the "
41
58
  "<a href='https://ecosystem.supervisely.com/apps/tensorboard-experiments-viewer' target='_blank'>Tensorboard Experiment Viewer</a> app."
42
59
  )
43
- self.app_options = app_options
44
60
 
45
61
  # GUI Components
46
62
  self.validator_text = Text("")
@@ -60,7 +76,7 @@ class TrainingArtifacts:
60
76
  self.display_widgets.extend([self.artifacts_field])
61
77
 
62
78
  # Optional Model Benchmark
63
- if app_options.get("model_benchmark", False):
79
+ if self.app_options.get("model_benchmark", False):
64
80
  self.model_benchmark_report_thumbnail = ReportThumbnail()
65
81
  self.model_benchmark_report_thumbnail.hide()
66
82
 
@@ -86,6 +102,8 @@ class TrainingArtifacts:
86
102
  # PyTorch, ONNX, TensorRT demo
87
103
  self.inference_demo_widgets = []
88
104
 
105
+ # Demo display works only for released apps
106
+ self.need_upload_demo = False
89
107
  model_demo = self.app_options.get("demo", None)
90
108
  if model_demo is not None:
91
109
  model_demo_path = model_demo.get("path", None)
@@ -111,7 +129,7 @@ class TrainingArtifacts:
111
129
  )
112
130
 
113
131
  if model_demo_gh_link is not None:
114
- gh_branch = "blob/main"
132
+ gh_branch = f"blob/{model_demo.get('branch', 'master')}"
115
133
  link_to_demo = f"{model_demo_gh_link}/{gh_branch}/{model_demo_path}"
116
134
 
117
135
  if model_demo_gh_link is not None and model_demo_path is not None:
@@ -186,6 +204,7 @@ class TrainingArtifacts:
186
204
  )
187
205
  self.inference_demo_field.hide()
188
206
  self.display_widgets.extend([self.inference_demo_field])
207
+ self.need_upload_demo = True
189
208
  # -------------------------------- #
190
209
 
191
210
  self.container = Container(self.display_widgets)
@@ -21,9 +21,21 @@ class TrainingLogs:
21
21
  lock_message = "Start training to unlock"
22
22
 
23
23
  def __init__(self, app_options: Dict[str, Any]):
24
+ # Init widgets
25
+ self.tensorboard_button = None
26
+ self.tensorboard_offline_button = None
27
+ self.logs_button = None
28
+ self.task_logs = None
29
+ self.progress_bar_main = None
30
+ self.progress_bar_secondary = None
31
+ self.validator_text = None
32
+ self.container = None
33
+ self.card = None
34
+ # -------------------------------- #
35
+
24
36
  self.display_widgets = []
25
- api = Api.from_env()
26
37
  self.app_options = app_options
38
+ api = Api.from_env()
27
39
 
28
40
  # GUI Components
29
41
  self.validator_text = Text("")
@@ -42,6 +54,7 @@ class TrainingLogs:
42
54
  self.tensorboard_link = f"{api.server_address}{sly_url_prefix}/tensorboard/"
43
55
  else:
44
56
  self.tensorboard_link = "http://localhost:8000/tensorboard"
57
+
45
58
  self.tensorboard_button = Button(
46
59
  "Open Tensorboard",
47
60
  button_type="info",
@@ -54,7 +67,6 @@ class TrainingLogs:
54
67
  self.display_widgets.extend([self.validator_text, self.tensorboard_button])
55
68
 
56
69
  # Offline session Tensorboard button
57
- self.tensorboard_offline_button = None
58
70
  if is_production():
59
71
  workspace_id = sly_env.workspace_id()
60
72
  app_name = "Tensorboard Experiments Viewer"
@@ -79,7 +91,7 @@ class TrainingLogs:
79
91
  )
80
92
 
81
93
  # Optional Show logs button
82
- if app_options.get("show_logs_in_gui", False):
94
+ if self.app_options.get("show_logs_in_gui", False):
83
95
  self.logs_button = Button(
84
96
  text="Show logs",
85
97
  plain=True,