supervisely 6.73.325__py3-none-any.whl → 6.73.327__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. supervisely/annotation/annotation.py +1 -1
  2. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +17 -14
  3. supervisely/app/widgets/pretrained_models_selector/template.html +2 -1
  4. supervisely/convert/image/yolo/yolo_helper.py +95 -25
  5. supervisely/convert/volume/nii/nii_planes_volume_converter.py +54 -6
  6. supervisely/convert/volume/nii/nii_volume_converter.py +7 -7
  7. supervisely/convert/volume/nii/nii_volume_helper.py +49 -0
  8. supervisely/nn/inference/gui/serving_gui_template.py +2 -3
  9. supervisely/nn/inference/inference.py +33 -25
  10. supervisely/nn/training/gui/classes_selector.py +24 -19
  11. supervisely/nn/training/gui/gui.py +90 -37
  12. supervisely/nn/training/gui/hyperparameters_selector.py +32 -15
  13. supervisely/nn/training/gui/input_selector.py +13 -2
  14. supervisely/nn/training/gui/model_selector.py +16 -6
  15. supervisely/nn/training/gui/train_val_splits_selector.py +10 -1
  16. supervisely/nn/training/gui/training_artifacts.py +23 -4
  17. supervisely/nn/training/gui/training_logs.py +15 -3
  18. supervisely/nn/training/gui/training_process.py +14 -13
  19. supervisely/nn/training/train_app.py +59 -24
  20. supervisely/nn/utils.py +9 -0
  21. supervisely/project/project.py +16 -3
  22. supervisely/volume/volume.py +19 -21
  23. {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/METADATA +1 -1
  24. {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/RECORD +28 -28
  25. {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/LICENSE +0 -0
  26. {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/WHEEL +0 -0
  27. {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/entry_points.txt +0 -0
  28. {supervisely-6.73.325.dist-info → supervisely-6.73.327.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,7 @@ from supervisely.nn.utils import (
75
75
  ModelPrecision,
76
76
  ModelSource,
77
77
  RuntimeType,
78
+ _get_model_name,
78
79
  )
79
80
  from supervisely.project import ProjectType
80
81
  from supervisely.project.download import download_to_cache, read_from_cached_project
@@ -173,9 +174,7 @@ class Inference:
173
174
  self._use_gui = False
174
175
  deploy_params, need_download = self._get_deploy_params_from_args()
175
176
  if need_download:
176
- local_model_files = self._download_model_files(
177
- deploy_params["model_source"], deploy_params["model_files"], False
178
- )
177
+ local_model_files = self._download_model_files(deploy_params, False)
179
178
  deploy_params["model_files"] = local_model_files
180
179
  self._load_model_headless(**deploy_params)
181
180
 
@@ -210,14 +209,12 @@ class Inference:
210
209
  self.initialize_gui()
211
210
 
212
211
  def on_serve_callback(
213
- gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
212
+ gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
214
213
  ):
215
214
  Progress("Deploying model ...", 1)
216
215
  if isinstance(self.gui, GUI.ServingGUITemplate):
217
216
  deploy_params = self.get_params_from_gui()
218
- model_files = self._download_model_files(
219
- deploy_params["model_source"], deploy_params["model_files"]
220
- )
217
+ model_files = self._download_model_files(deploy_params)
221
218
  deploy_params["model_files"] = model_files
222
219
  self._load_model_headless(**deploy_params)
223
220
  elif isinstance(self.gui, GUI.ServingGUI):
@@ -230,7 +227,7 @@ class Inference:
230
227
  gui.show_deployed_model_info(self)
231
228
 
232
229
  def on_change_model_callback(
233
- gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
230
+ gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate],
234
231
  ):
235
232
  self.shutdown_model()
236
233
  if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
@@ -567,13 +564,23 @@ class Inference:
567
564
  def _checkpoints_cache_dir(self):
568
565
  return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
569
566
 
570
- def _download_model_files(
571
- self, model_source: str, model_files: List[str], log_progress: bool = True
572
- ) -> dict:
573
- if model_source == ModelSource.PRETRAINED:
574
- return self._download_pretrained_model(model_files, log_progress)
575
- elif model_source == ModelSource.CUSTOM:
576
- return self._download_custom_model(model_files, log_progress)
567
+ def _download_model_files(self, deploy_params: dict, log_progress: bool = True) -> dict:
568
+ if deploy_params["runtime"] != RuntimeType.PYTORCH:
569
+ export = deploy_params["model_info"].get("export", {})
570
+ export_model = export.get(deploy_params["runtime"], None)
571
+ if export_model is not None:
572
+ if sly_fs.get_file_name(export_model) == sly_fs.get_file_name(
573
+ deploy_params["model_files"]["checkpoint"]
574
+ ):
575
+ deploy_params["model_files"]["checkpoint"] = (
576
+ deploy_params["model_info"]["artifacts_dir"] + export_model
577
+ )
578
+ logger.info(f"Found model checkpoint for '{deploy_params['runtime']}'")
579
+
580
+ if deploy_params["model_source"] == ModelSource.PRETRAINED:
581
+ return self._download_pretrained_model(deploy_params["model_files"], log_progress)
582
+ elif deploy_params["model_source"] == ModelSource.CUSTOM:
583
+ return self._download_custom_model(deploy_params["model_files"], log_progress)
577
584
 
578
585
  def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
579
586
  """
@@ -2929,9 +2936,7 @@ class Inference:
2929
2936
  state = request.state.state
2930
2937
  deploy_params = state["deploy_params"]
2931
2938
  if isinstance(self.gui, GUI.ServingGUITemplate):
2932
- model_files = self._download_model_files(
2933
- deploy_params["model_source"], deploy_params["model_files"]
2934
- )
2939
+ model_files = self._download_model_files(deploy_params)
2935
2940
  deploy_params["model_files"] = model_files
2936
2941
  self._load_model_headless(**deploy_params)
2937
2942
  elif isinstance(self.gui, GUI.ServingGUI):
@@ -3061,7 +3066,7 @@ class Inference:
3061
3066
  raise ValueError("No pretrained models found.")
3062
3067
 
3063
3068
  model = self.pretrained_models[0]
3064
- model_name = model.get("meta", {}).get("model_name", None)
3069
+ model_name = _get_model_name(model)
3065
3070
  if model_name is None:
3066
3071
  raise ValueError("No model name found in the first pretrained model.")
3067
3072
 
@@ -3126,7 +3131,7 @@ class Inference:
3126
3131
  meta = m.get("meta", None)
3127
3132
  if meta is None:
3128
3133
  continue
3129
- model_name = meta.get("model_name", None)
3134
+ model_name = _get_model_name(m)
3130
3135
  if model_name is None:
3131
3136
  continue
3132
3137
  m_files = meta.get("model_files", None)
@@ -3135,7 +3140,7 @@ class Inference:
3135
3140
  checkpoint = m_files.get("checkpoint", None)
3136
3141
  if checkpoint is None:
3137
3142
  continue
3138
- if model == m["meta"]["model_name"]:
3143
+ if model.lower() == model_name.lower():
3139
3144
  model_info = m
3140
3145
  model_source = ModelSource.PRETRAINED
3141
3146
  model_files = {"checkpoint": checkpoint}
@@ -3153,8 +3158,6 @@ class Inference:
3153
3158
  model_meta_path = os.path.join(artifacts_dir, "model_meta.json")
3154
3159
  model_info["model_meta"] = self._load_json_file(model_meta_path)
3155
3160
  original_model_files = model_info.get("model_files")
3156
- if not original_model_files:
3157
- raise ValueError("Invalid 'experiment_info.json'. Missing 'model_files' key.")
3158
3161
  return model_info, original_model_files
3159
3162
 
3160
3163
  def _prepare_local_model_files(artifacts_dir, checkpoint_path, original_model_files):
@@ -3201,6 +3204,7 @@ class Inference:
3201
3204
  model_files = _prepare_local_model_files(
3202
3205
  artifacts_dir, checkpoint_path, original_model_files
3203
3206
  )
3207
+
3204
3208
  else:
3205
3209
  local_artifacts_dir = os.path.join(
3206
3210
  self.model_dir, "local_deploy", os.path.basename(artifacts_dir)
@@ -3298,7 +3302,11 @@ class Inference:
3298
3302
  if draw:
3299
3303
  raise ValueError("Draw visualization is not supported for project inference")
3300
3304
 
3301
- state = {"projectId": project_id, "dataset_ids": dataset_ids, "settings": settings}
3305
+ state = {
3306
+ "projectId": project_id,
3307
+ "dataset_ids": dataset_ids,
3308
+ "settings": settings,
3309
+ }
3302
3310
  if upload:
3303
3311
  source_project = api.project.get_info_by_id(project_id)
3304
3312
  workspace_id = source_project.workspace_id
@@ -3472,7 +3480,7 @@ class Inference:
3472
3480
  def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
3473
3481
  if model_source == ModelSource.PRETRAINED:
3474
3482
  checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
3475
- checkpoint_name = model_info["meta"]["model_name"]
3483
+ checkpoint_name = _get_model_name(model_info)
3476
3484
  else:
3477
3485
  checkpoint_name = sly_fs.get_file_name_with_ext(model_files["checkpoint"])
3478
3486
  checkpoint_url = os.path.join(
@@ -4,21 +4,28 @@ from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
4
4
 
5
5
  class ClassesSelector:
6
6
  title = "Classes Selector"
7
- description = (
8
- "Select classes that will be used for training. "
9
- "Supported shapes are Bitmap, Polygon, Rectangle."
10
- )
7
+ description = "Select classes that will be used for training"
11
8
  lock_message = "Select training and validation splits to unlock"
12
9
 
13
10
  def __init__(self, project_id: int, classes: list, app_options: dict = {}):
11
+ # Init widgets
12
+ self.qa_stats_text = None
13
+ self.classes_table = None
14
+ self.validator_text = None
15
+ self.button = None
16
+ self.container = None
17
+ self.card = None
18
+ # -------------------------------- #
19
+
14
20
  self.display_widgets = []
21
+ self.app_options = app_options
15
22
 
16
23
  # GUI Components
17
24
  if is_development() or is_debug_with_sly_net():
18
25
  qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
19
26
  else:
20
27
  qa_stats_link = f"/projects/{project_id}/stats/datasets"
21
- qa_stats_text = Text(
28
+ self.qa_stats_text = Text(
22
29
  text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
23
30
  )
24
31
 
@@ -32,7 +39,7 @@ class ClassesSelector:
32
39
  self.validator_text.hide()
33
40
  self.button = Button("Select")
34
41
  self.display_widgets.extend(
35
- [qa_stats_text, self.classes_table, self.validator_text, self.button]
42
+ [self.qa_stats_text, self.classes_table, self.validator_text, self.button]
36
43
  )
37
44
  # -------------------------------- #
38
45
 
@@ -42,7 +49,7 @@ class ClassesSelector:
42
49
  description=self.description,
43
50
  content=self.container,
44
51
  lock_message=self.lock_message,
45
- collapsable=app_options.get("collapsable", False),
52
+ collapsable=self.app_options.get("collapsable", False),
46
53
  )
47
54
  self.card.lock()
48
55
 
@@ -62,14 +69,14 @@ class ClassesSelector:
62
69
  def validate_step(self) -> bool:
63
70
  self.validator_text.hide()
64
71
 
65
- if len(self.classes_table.project_meta.obj_classes) == 0:
72
+ project_classes = self.classes_table.project_meta.obj_classes
73
+ if len(project_classes) == 0:
66
74
  self.validator_text.set(text="Project has no classes", status="error")
67
75
  self.validator_text.show()
68
76
  return False
69
77
 
70
78
  selected_classes = self.classes_table.get_selected_classes()
71
79
  table_data = self.classes_table._table_data
72
-
73
80
  empty_classes = [
74
81
  row[0]["data"]
75
82
  for row in table_data
@@ -78,23 +85,21 @@ class ClassesSelector:
78
85
 
79
86
  n_classes = len(selected_classes)
80
87
  if n_classes == 0:
81
- self.validator_text.set(text="Please select at least one class", status="error")
88
+ message = "Please select at least one class"
89
+ status = "error"
82
90
  else:
83
- warning_text = ""
91
+ class_text = "class" if n_classes == 1 else "classes"
92
+ message = f"Selected {n_classes} {class_text}"
84
93
  status = "success"
85
94
  if empty_classes:
86
95
  intersections = set(selected_classes).intersection(empty_classes)
87
96
  if intersections:
88
- warning_text = (
89
- f". Selected class has no annotations: {', '.join(intersections)}"
90
- if len(intersections) == 1
91
- else f". Selected classes have no annotations: {', '.join(intersections)}"
97
+ class_text = "class" if len(intersections) == 1 else "classes"
98
+ message += (
99
+ f". Selected {class_text} have no annotations: {', '.join(intersections)}"
92
100
  )
93
101
  status = "warning"
94
102
 
95
- class_text = "class" if n_classes == 1 else "classes"
96
- self.validator_text.set(
97
- text=f"Selected {n_classes} {class_text}{warning_text}", status=status
98
- )
103
+ self.validator_text.set(text=message, status=status)
99
104
  self.validator_text.show()
100
105
  return n_classes > 0
@@ -6,11 +6,15 @@ training workflows in Supervisely.
6
6
  """
7
7
 
8
8
  from os import environ
9
+ from typing import Union
9
10
 
10
11
  import supervisely.io.env as sly_env
12
+ import supervisely.io.json as sly_json
11
13
  from supervisely import Api, ProjectMeta
12
14
  from supervisely._utils import is_production
13
15
  from supervisely.app.widgets import Stepper, Widget
16
+ from supervisely.geometry.bitmap import Bitmap
17
+ from supervisely.geometry.graph import GraphNodes
14
18
  from supervisely.geometry.polygon import Polygon
15
19
  from supervisely.geometry.rectangle import Rectangle
16
20
  from supervisely.nn.task_type import TaskType
@@ -63,7 +67,7 @@ class TrainGUI:
63
67
  self.models = models
64
68
  self.hyperparameters = hyperparameters
65
69
  self.app_options = app_options
66
- self.collapsable = app_options.get("collapsable", False)
70
+ self.collapsable = self.app_options.get("collapsable", False)
67
71
  self.need_convert_shapes_for_bm = False
68
72
 
69
73
  self.team_id = sly_env.team_id(raise_not_found=False)
@@ -142,33 +146,73 @@ class TrainGUI:
142
146
  self.training_process.set_experiment_name(experiment_name)
143
147
 
144
148
  def need_convert_class_shapes() -> bool:
145
- if not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
146
- self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
147
- self.need_convert_shapes_for_bm = False
148
- else:
149
- task_type = self.model_selector.get_selected_task_type()
150
-
151
- def _need_convert(shape):
152
- if task_type == TaskType.OBJECT_DETECTION:
153
- return shape != Rectangle.geometry_name()
154
- elif task_type in [
155
- TaskType.INSTANCE_SEGMENTATION,
156
- TaskType.SEMANTIC_SEGMENTATION,
157
- ]:
158
- return shape == Polygon.geometry_name()
159
- return
160
-
161
- data = self.classes_selector.classes_table._table_data
162
- selected_classes = set(self.classes_selector.classes_table.get_selected_classes())
163
- empty = set(r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0)
164
- need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
165
-
166
- if need_convert.intersection(selected_classes - empty):
167
- self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
168
- self.need_convert_shapes_for_bm = True
169
- else:
149
+ if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
150
+ if not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked():
170
151
  self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
171
152
  self.need_convert_shapes_for_bm = False
153
+ else:
154
+ task_type = self.model_selector.get_selected_task_type()
155
+
156
+ def _need_convert(shape):
157
+ if task_type == TaskType.OBJECT_DETECTION:
158
+ return shape != Rectangle.geometry_name()
159
+ elif task_type in [
160
+ TaskType.INSTANCE_SEGMENTATION,
161
+ TaskType.SEMANTIC_SEGMENTATION,
162
+ ]:
163
+ return shape == Polygon.geometry_name()
164
+ return
165
+
166
+ data = self.classes_selector.classes_table._table_data
167
+ selected_classes = set(
168
+ self.classes_selector.classes_table.get_selected_classes()
169
+ )
170
+ empty = set(
171
+ r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
172
+ )
173
+ need_convert = set(r[0]["data"] for r in data if _need_convert(r[1]["data"]))
174
+
175
+ if need_convert.intersection(selected_classes - empty):
176
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.show()
177
+ self.need_convert_shapes_for_bm = True
178
+ else:
179
+ self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
180
+ self.need_convert_shapes_for_bm = False
181
+ else:
182
+ self.need_convert_shapes_for_bm = False
183
+
184
+ def validate_class_shape_for_model_task():
185
+ task_type = self.model_selector.get_selected_task_type()
186
+ classes = self.classes_selector.get_selected_classes()
187
+
188
+ required_geometries = {
189
+ TaskType.INSTANCE_SEGMENTATION: {Polygon, Bitmap},
190
+ TaskType.SEMANTIC_SEGMENTATION: {Polygon, Bitmap},
191
+ TaskType.POSE_ESTIMATION: {GraphNodes},
192
+ }
193
+ task_specific_texts = {
194
+ TaskType.INSTANCE_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
195
+ TaskType.SEMANTIC_SEGMENTATION: "Only polygon and bitmap shapes are supported for segmentation task",
196
+ TaskType.POSE_ESTIMATION: "Only keypoint (graph) shape is supported for pose estimation task",
197
+ }
198
+
199
+ if task_type not in required_geometries:
200
+ return
201
+
202
+ wrong_shape_classes = [
203
+ class_name
204
+ for class_name in classes
205
+ if self.project_meta.get_obj_class(class_name).geometry_type
206
+ not in required_geometries[task_type]
207
+ ]
208
+
209
+ if wrong_shape_classes:
210
+ specific_text = task_specific_texts[task_type]
211
+ message_text = f"Model task type is {task_type}. {specific_text}. Selected classes have wrong shapes for the model task: {', '.join(wrong_shape_classes)}"
212
+ self.model_selector.validator_text.set(
213
+ text=message_text,
214
+ status="warning",
215
+ )
172
216
 
173
217
  # ------------------------------------------------- #
174
218
 
@@ -201,7 +245,11 @@ class TrainGUI:
201
245
  callback=self.hyperparameters_selector_cb,
202
246
  validation_text=self.model_selector.validator_text,
203
247
  validation_func=self.model_selector.validate_step,
204
- on_select_click=[set_experiment_name, need_convert_class_shapes],
248
+ on_select_click=[
249
+ set_experiment_name,
250
+ need_convert_class_shapes,
251
+ validate_class_shape_for_model_task,
252
+ ],
205
253
  collapse_card=(self.model_selector.card, self.collapsable),
206
254
  )
207
255
 
@@ -299,17 +347,19 @@ class TrainGUI:
299
347
  # ------------------------------------------------- #
300
348
 
301
349
  # Other Buttons
302
- if app_options.get("show_logs_in_gui", False):
350
+ if self.app_options.get("show_logs_in_gui", False):
303
351
 
304
352
  @self.training_logs.logs_button.click
305
353
  def show_logs():
306
354
  self.training_logs.toggle_logs()
307
355
 
308
356
  # Other handlers
309
- @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
310
- def show_mb_speedtest(is_checked: bool):
311
- self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
312
- need_convert_class_shapes()
357
+ if self.hyperparameters_selector.run_model_benchmark_checkbox is not None:
358
+
359
+ @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
360
+ def show_mb_speedtest(is_checked: bool):
361
+ self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
362
+ need_convert_class_shapes()
313
363
 
314
364
  # ------------------------------------------------- #
315
365
 
@@ -361,7 +411,6 @@ class TrainGUI:
361
411
  raise ValueError("app_state must be a dictionary")
362
412
 
363
413
  required_keys = {
364
- "input": ["project_id"],
365
414
  "train_val_split": ["method"],
366
415
  "classes": list,
367
416
  "model": ["source"],
@@ -453,7 +502,7 @@ class TrainGUI:
453
502
  raise ValueError("percent must be an integer in range 1 to 99")
454
503
  return app_state
455
504
 
456
- def load_from_app_state(self, app_state: dict) -> None:
505
+ def load_from_app_state(self, app_state: Union[str, dict]) -> None:
457
506
  """
458
507
  Load the GUI state from app state dictionary.
459
508
 
@@ -463,7 +512,6 @@ class TrainGUI:
463
512
  app_state example:
464
513
 
465
514
  app_state = {
466
- "input": {"project_id": 43192},
467
515
  "train_val_split": {
468
516
  "method": "random",
469
517
  "split": "train",
@@ -489,10 +537,13 @@ class TrainGUI:
489
537
  }
490
538
  }
491
539
  """
540
+ if isinstance(app_state, str):
541
+ app_state = sly_json.load_json_file(app_state)
542
+
492
543
  app_state = self.validate_app_state(app_state)
493
544
 
494
545
  options = app_state.get("options", {})
495
- input_settings = app_state["input"]
546
+ input_settings = app_state.get("input")
496
547
  train_val_splits_settings = app_state["train_val_split"]
497
548
  classes_settings = app_state["classes"]
498
549
  model_settings = app_state["model"]
@@ -504,7 +555,7 @@ class TrainGUI:
504
555
  self._init_model(model_settings)
505
556
  self._init_hyperparameters(hyperparameters_settings, options)
506
557
 
507
- def _init_input(self, input_settings: dict, options: dict) -> None:
558
+ def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
508
559
  """
509
560
  Initialize the input selector with the given settings.
510
561
 
@@ -604,6 +655,8 @@ class TrainGUI:
604
655
  )
605
656
  self.hyperparameters_selector.set_speedtest_checkbox_value(
606
657
  model_benchmark_settings["speed_test"]
658
+ if model_benchmark_settings["enable"]
659
+ else False
607
660
  )
608
661
  export_weights_settings = options.get("export", None)
609
662
  if export_weights_settings is not None:
@@ -9,7 +9,6 @@ from supervisely.app.widgets import (
9
9
  Field,
10
10
  Text,
11
11
  )
12
- from supervisely.nn.utils import RuntimeType
13
12
 
14
13
 
15
14
  class HyperparametersSelector:
@@ -18,6 +17,22 @@ class HyperparametersSelector:
18
17
  lock_message = "Select model to unlock"
19
18
 
20
19
  def __init__(self, hyperparameters: dict, app_options: dict = {}):
20
+ # Init widgets
21
+ self.editor = None
22
+ self.run_model_benchmark_checkbox = None
23
+ self.run_speedtest_checkbox = None
24
+ self.model_benchmark_field = None
25
+ self.model_benchmark_learn_more = None
26
+ self.model_benchmark_auto_convert_warning = None
27
+ self.export_onnx_checkbox = None
28
+ self.export_tensorrt_checkbox = None
29
+ self.export_field = None
30
+ self.validator_text = None
31
+ self.button = None
32
+ self.container = None
33
+ self.card = None
34
+ # -------------------------------- #
35
+
21
36
  self.display_widgets = []
22
37
  self.app_options = app_options
23
38
 
@@ -28,7 +43,7 @@ class HyperparametersSelector:
28
43
  self.display_widgets.extend([self.editor])
29
44
 
30
45
  # Optional Model Benchmark
31
- if app_options.get("model_benchmark", True):
46
+ if self.app_options.get("model_benchmark", True):
32
47
  # Model Benchmark
33
48
  self.run_model_benchmark_checkbox = Checkbox(
34
49
  content="Run Model Benchmark evaluation", checked=True
@@ -37,7 +52,7 @@ class HyperparametersSelector:
37
52
 
38
53
  self.model_benchmark_field = Field(
39
54
  title="Model Evaluation Benchmark",
40
- description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
55
+ description="Generate evaluation dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
41
56
  content=Container([self.run_model_benchmark_checkbox, self.run_speedtest_checkbox]),
42
57
  )
43
58
  docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
@@ -60,8 +75,8 @@ class HyperparametersSelector:
60
75
  # -------------------------------- #
61
76
 
62
77
  # Optional Export Weights
63
- export_onnx_supported = app_options.get("export_onnx_supported", False)
64
- export_tensorrt_supported = app_options.get("export_tensorrt_supported", False)
78
+ export_onnx_supported = self.app_options.get("export_onnx_supported", False)
79
+ export_tensorrt_supported = self.app_options.get("export_tensorrt_supported", False)
65
80
 
66
81
  onnx_name = "ONNX"
67
82
  tensorrt_name = "TensorRT engine"
@@ -120,26 +135,28 @@ class HyperparametersSelector:
120
135
  return self.editor.get_value()
121
136
 
122
137
  def get_model_benchmark_checkbox_value(self) -> bool:
123
- if self.app_options.get("model_benchmark", True):
138
+ if self.run_model_benchmark_checkbox is not None:
124
139
  return self.run_model_benchmark_checkbox.is_checked()
125
140
  return False
126
141
 
127
142
  def set_model_benchmark_checkbox_value(self, is_checked: bool) -> bool:
128
- if is_checked:
129
- self.run_model_benchmark_checkbox.check()
130
- else:
131
- self.run_model_benchmark_checkbox.uncheck()
143
+ if self.run_model_benchmark_checkbox is not None:
144
+ if is_checked:
145
+ self.run_model_benchmark_checkbox.check()
146
+ else:
147
+ self.run_model_benchmark_checkbox.uncheck()
132
148
 
133
149
  def get_speedtest_checkbox_value(self) -> bool:
134
- if self.app_options.get("model_benchmark", True):
150
+ if self.run_speedtest_checkbox is not None:
135
151
  return self.run_speedtest_checkbox.is_checked()
136
152
  return False
137
153
 
138
154
  def set_speedtest_checkbox_value(self, is_checked: bool) -> bool:
139
- if is_checked:
140
- self.run_speedtest_checkbox.check()
141
- else:
142
- self.run_speedtest_checkbox.uncheck()
155
+ if self.run_speedtest_checkbox is not None:
156
+ if is_checked:
157
+ self.run_speedtest_checkbox.check()
158
+ else:
159
+ self.run_speedtest_checkbox.uncheck()
143
160
 
144
161
  def toggle_mb_speedtest(self, is_checked: bool) -> None:
145
162
  if is_checked:
@@ -4,7 +4,6 @@ from supervisely.app.widgets import (
4
4
  Card,
5
5
  Checkbox,
6
6
  Container,
7
- Field,
8
7
  ProjectThumbnail,
9
8
  Text,
10
9
  )
@@ -17,7 +16,19 @@ class InputSelector:
17
16
  lock_message = None
18
17
 
19
18
  def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
19
+ # Init widgets
20
+ self.project_thumbnail = None
21
+ self.use_cache_text = None
22
+ self.use_cache_checkbox = None
23
+ self.validator_text = None
24
+ self.button = None
25
+ self.container = None
26
+ self.card = None
27
+ # -------------------------------- #
28
+
20
29
  self.display_widgets = []
30
+ self.app_options = app_options
31
+
21
32
  self.project_id = project_info.id
22
33
  self.project_info = project_info
23
34
 
@@ -49,7 +60,7 @@ class InputSelector:
49
60
  title=self.title,
50
61
  description=self.description,
51
62
  content=self.container,
52
- collapsable=app_options.get("collapsable", False),
63
+ collapsable=self.app_options.get("collapsable", False),
53
64
  )
54
65
 
55
66
  @property
@@ -14,7 +14,7 @@ from supervisely.app.widgets import (
14
14
  )
15
15
  from supervisely.nn.artifacts.utils import FrameworkMapper
16
16
  from supervisely.nn.experiments import get_experiment_infos
17
- from supervisely.nn.utils import ModelSource
17
+ from supervisely.nn.utils import ModelSource, _get_model_name
18
18
 
19
19
 
20
20
  class ModelSelector:
@@ -23,15 +23,26 @@ class ModelSelector:
23
23
  lock_message = "Select classes to unlock"
24
24
 
25
25
  def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
26
+ # Init widgets
27
+ self.pretrained_models_table = None
28
+ self.experiment_selector = None
29
+ self.model_source_tabs = None
30
+ self.validator_text = None
31
+ self.button = None
32
+ self.container = None
33
+ self.card = None
34
+ # -------------------------------- #
35
+
26
36
  self.display_widgets = []
37
+ self.app_options = app_options
38
+
27
39
  self.team_id = sly_env.team_id()
28
40
  self.models = models
29
41
 
30
42
  # GUI Components
31
43
  self.pretrained_models_table = PretrainedModelsSelector(self.models)
32
-
33
44
  experiment_infos = get_experiment_infos(api, self.team_id, framework)
34
- if app_options.get("legacy_checkpoints", False):
45
+ if self.app_options.get("legacy_checkpoints", False):
35
46
  try:
36
47
  framework_cls = FrameworkMapper.get_framework_cls(framework, self.team_id)
37
48
  legacy_experiment_infos = framework_cls.get_list_experiment_info()
@@ -61,7 +72,7 @@ class ModelSelector:
61
72
  description=self.description,
62
73
  content=self.container,
63
74
  lock_message=self.lock_message,
64
- collapsable=app_options.get("collapsable", False),
75
+ collapsable=self.app_options.get("collapsable", False),
65
76
  )
66
77
  self.card.lock()
67
78
 
@@ -82,8 +93,7 @@ class ModelSelector:
82
93
  def get_model_name(self) -> str:
83
94
  if self.get_model_source() == ModelSource.PRETRAINED:
84
95
  selected_row = self.pretrained_models_table.get_selected_row()
85
- model_meta = selected_row.get("meta", {})
86
- model_name = model_meta.get("model_name", None)
96
+ model_name = _get_model_name(selected_row)
87
97
  else:
88
98
  selected_row = self.experiment_selector.get_selected_experiment_info()
89
99
  model_name = selected_row.get("model_name", None)