supervisely 6.73.243__py3-none-any.whl → 6.73.245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (56) hide show
  1. supervisely/__init__.py +1 -1
  2. supervisely/_utils.py +18 -0
  3. supervisely/app/widgets/__init__.py +1 -0
  4. supervisely/app/widgets/card/card.py +3 -0
  5. supervisely/app/widgets/classes_table/classes_table.py +15 -1
  6. supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
  7. supervisely/app/widgets/custom_models_selector/template.html +1 -1
  8. supervisely/app/widgets/experiment_selector/__init__.py +0 -0
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
  10. supervisely/app/widgets/experiment_selector/style.css +27 -0
  11. supervisely/app/widgets/experiment_selector/template.html +82 -0
  12. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
  13. supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
  15. supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
  16. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  17. supervisely/nn/__init__.py +3 -1
  18. supervisely/nn/artifacts/artifacts.py +10 -0
  19. supervisely/nn/artifacts/detectron2.py +2 -0
  20. supervisely/nn/artifacts/hrda.py +3 -0
  21. supervisely/nn/artifacts/mmclassification.py +2 -0
  22. supervisely/nn/artifacts/mmdetection.py +6 -3
  23. supervisely/nn/artifacts/mmsegmentation.py +2 -0
  24. supervisely/nn/artifacts/ritm.py +3 -1
  25. supervisely/nn/artifacts/rtdetr.py +2 -0
  26. supervisely/nn/artifacts/unet.py +2 -0
  27. supervisely/nn/artifacts/yolov5.py +3 -0
  28. supervisely/nn/artifacts/yolov8.py +7 -1
  29. supervisely/nn/experiments.py +113 -0
  30. supervisely/nn/inference/gui/__init__.py +3 -1
  31. supervisely/nn/inference/gui/gui.py +31 -232
  32. supervisely/nn/inference/gui/serving_gui.py +223 -0
  33. supervisely/nn/inference/gui/serving_gui_template.py +240 -0
  34. supervisely/nn/inference/inference.py +225 -24
  35. supervisely/nn/training/__init__.py +0 -0
  36. supervisely/nn/training/gui/__init__.py +1 -0
  37. supervisely/nn/training/gui/classes_selector.py +100 -0
  38. supervisely/nn/training/gui/gui.py +539 -0
  39. supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
  40. supervisely/nn/training/gui/input_selector.py +70 -0
  41. supervisely/nn/training/gui/model_selector.py +95 -0
  42. supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
  43. supervisely/nn/training/gui/training_logs.py +93 -0
  44. supervisely/nn/training/gui/training_process.py +114 -0
  45. supervisely/nn/training/gui/utils.py +128 -0
  46. supervisely/nn/training/loggers/__init__.py +0 -0
  47. supervisely/nn/training/loggers/base_train_logger.py +58 -0
  48. supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
  49. supervisely/nn/training/train_app.py +2038 -0
  50. supervisely/nn/utils.py +5 -0
  51. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/METADATA +3 -1
  52. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/RECORD +56 -34
  53. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/LICENSE +0 -0
  54. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/WHEEL +0 -0
  55. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/entry_points.txt +0 -0
  56. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,539 @@
1
+ """
2
+ GUI module for training application.
3
+
4
+ This module provides the `TrainGUI` class that handles the graphical user interface (GUI) for managing
5
+ training workflows in Supervisely.
6
+ """
7
+
8
+ import supervisely.io.env as sly_env
9
+ from supervisely import Api
10
+ from supervisely._utils import is_production
11
+ from supervisely.app.widgets import Stepper, Widget
12
+ from supervisely.nn.training.gui.classes_selector import ClassesSelector
13
+ from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
14
+ from supervisely.nn.training.gui.input_selector import InputSelector
15
+ from supervisely.nn.training.gui.model_selector import ModelSelector
16
+ from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplitsSelector
17
+ from supervisely.nn.training.gui.training_logs import TrainingLogs
18
+ from supervisely.nn.training.gui.training_process import TrainingProcess
19
+ from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
20
+ from supervisely.nn.utils import ModelSource
21
+
22
+
23
+ class TrainGUI:
24
+ """
25
+ A class representing the GUI for training workflows.
26
+
27
+ This class sets up and manages GUI components such as project selection,
28
+ train/validation split selection, model selection, hyperparameters selection,
29
+ and the training process.
30
+
31
+ :param framework_name: Name of the ML framework being used.
32
+ :type framework_name: str
33
+ :param models: List of available models.
34
+ :type models: list
35
+ :param hyperparameters: Hyperparameters for training.
36
+ :type hyperparameters: dict
37
+ :param app_options: Application options for customization.
38
+ :type app_options: dict, optional
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ framework_name: str,
44
+ models: list,
45
+ hyperparameters: dict,
46
+ app_options: dict = None,
47
+ ):
48
+ self._api = Api.from_env()
49
+
50
+ if is_production():
51
+ self.task_id = sly_env.task_id()
52
+ else:
53
+ self.task_id = "debug-session"
54
+
55
+ self.framework_name = framework_name
56
+ self.models = models
57
+ self.hyperparameters = hyperparameters
58
+ self.app_options = app_options
59
+ self.collapsable = app_options.get("collapsable", False)
60
+
61
+ self.team_id = sly_env.team_id()
62
+ self.workspace_id = sly_env.workspace_id()
63
+ self.project_id = sly_env.project_id() # from app options?
64
+ self.project_info = self._api.project.get_info_by_id(self.project_id)
65
+
66
+ # 1. Project selection + Train/val split
67
+ self.input_selector = InputSelector(self.project_info, self.app_options)
68
+ # 2. Select train val splits
69
+ self.train_val_splits_selector = TrainValSplitsSelector(
70
+ self._api, self.project_id, self.app_options
71
+ )
72
+ # 3. Select classes
73
+ self.classes_selector = ClassesSelector(self.project_id, [], self.app_options)
74
+ # 4. Model selection
75
+ self.model_selector = ModelSelector(
76
+ self._api, self.framework_name, self.models, self.app_options
77
+ )
78
+ # 5. Training parameters (yaml), scheduler preview
79
+ self.hyperparameters_selector = HyperparametersSelector(
80
+ self.hyperparameters, self.app_options
81
+ )
82
+ # 6. Start Train
83
+ self.training_process = TrainingProcess(self.app_options)
84
+
85
+ # 7. Training logs
86
+ self.training_logs = TrainingLogs(self.app_options)
87
+
88
+ # Stepper layout
89
+ self.stepper = Stepper(
90
+ widgets=[
91
+ self.input_selector.card,
92
+ self.train_val_splits_selector.card,
93
+ self.classes_selector.card,
94
+ self.model_selector.card,
95
+ self.hyperparameters_selector.card,
96
+ self.training_process.card,
97
+ self.training_logs.card,
98
+ ],
99
+ )
100
+ # ------------------------------------------------- #
101
+
102
+ # Button utils
103
+ def update_classes_table():
104
+ pass
105
+
106
+ def disable_hyperparams_editor():
107
+ if self.hyperparameters_selector.editor.readonly:
108
+ self.hyperparameters_selector.editor.readonly = False
109
+ else:
110
+ self.hyperparameters_selector.editor.readonly = True
111
+
112
+ def set_experiment_name():
113
+ model_name = self.model_selector.get_model_name()
114
+ if model_name is None:
115
+ experiment_name = "Enter experiment name"
116
+ else:
117
+ experiment_name = f"{self.task_id}_{self.project_info.name}_{model_name}"
118
+
119
+ if experiment_name == self.training_process.get_experiment_name():
120
+ return
121
+ self.training_process.set_experiment_name(experiment_name)
122
+
123
+ # ------------------------------------------------- #
124
+
125
+ # Wrappers
126
+ self.training_process_cb = wrap_button_click(
127
+ button=self.hyperparameters_selector.button,
128
+ cards_to_unlock=[self.training_logs.card],
129
+ widgets_to_disable=self.training_process.widgets_to_disable,
130
+ callback=None,
131
+ validation_text=self.training_process.validator_text,
132
+ validation_func=self.training_process.validate_step,
133
+ )
134
+
135
+ self.hyperparameters_selector_cb = wrap_button_click(
136
+ button=self.hyperparameters_selector.button,
137
+ cards_to_unlock=[self.training_process.card],
138
+ widgets_to_disable=self.hyperparameters_selector.widgets_to_disable,
139
+ callback=self.training_process_cb,
140
+ validation_text=self.hyperparameters_selector.validator_text,
141
+ validation_func=self.hyperparameters_selector.validate_step,
142
+ on_select_click=[disable_hyperparams_editor],
143
+ on_reselect_click=[disable_hyperparams_editor],
144
+ collapse_card=(self.hyperparameters_selector.card, self.collapsable),
145
+ )
146
+
147
+ self.model_selector_cb = wrap_button_click(
148
+ button=self.model_selector.button,
149
+ cards_to_unlock=[self.hyperparameters_selector.card],
150
+ widgets_to_disable=self.model_selector.widgets_to_disable,
151
+ callback=self.hyperparameters_selector_cb,
152
+ validation_text=self.model_selector.validator_text,
153
+ validation_func=self.model_selector.validate_step,
154
+ on_select_click=[set_experiment_name],
155
+ collapse_card=(self.model_selector.card, self.collapsable),
156
+ )
157
+
158
+ self.classes_selector_cb = wrap_button_click(
159
+ button=self.classes_selector.button,
160
+ cards_to_unlock=[self.model_selector.card],
161
+ widgets_to_disable=self.classes_selector.widgets_to_disable,
162
+ callback=self.model_selector_cb,
163
+ validation_text=self.classes_selector.validator_text,
164
+ validation_func=self.classes_selector.validate_step,
165
+ collapse_card=(self.classes_selector.card, self.collapsable),
166
+ )
167
+
168
+ self.train_val_splits_selector_cb = wrap_button_click(
169
+ button=self.train_val_splits_selector.button,
170
+ cards_to_unlock=[self.classes_selector.card],
171
+ widgets_to_disable=self.train_val_splits_selector.widgets_to_disable,
172
+ callback=self.classes_selector_cb,
173
+ validation_text=self.train_val_splits_selector.validator_text,
174
+ validation_func=self.train_val_splits_selector.validate_step,
175
+ collapse_card=(self.train_val_splits_selector.card, self.collapsable),
176
+ )
177
+
178
+ self.input_selector_cb = wrap_button_click(
179
+ button=self.input_selector.button,
180
+ cards_to_unlock=[self.train_val_splits_selector.card],
181
+ widgets_to_disable=self.input_selector.widgets_to_disable,
182
+ callback=self.train_val_splits_selector_cb,
183
+ validation_text=self.input_selector.validator_text,
184
+ validation_func=self.input_selector.validate_step,
185
+ on_select_click=[update_classes_table],
186
+ collapse_card=(self.input_selector.card, self.collapsable),
187
+ )
188
+ # ------------------------------------------------- #
189
+
190
+ # Main Buttons
191
+
192
+ # Define outside. Used by user in app
193
+ # @self.training_process.start_button.click
194
+ # def start_training():
195
+ # pass
196
+
197
+ # @self.training_process.stop_button.click
198
+ # def stop_training():
199
+ # pass
200
+
201
+ # ------------------------------------------------- #
202
+
203
+ # Select Buttons
204
+ @self.hyperparameters_selector.button.click
205
+ def select_hyperparameters():
206
+ self.hyperparameters_selector_cb()
207
+ set_stepper_step(
208
+ self.stepper,
209
+ self.hyperparameters_selector.button,
210
+ next_pos=6,
211
+ )
212
+
213
+ @self.model_selector.button.click
214
+ def select_model():
215
+ self.model_selector_cb()
216
+ set_stepper_step(
217
+ self.stepper,
218
+ self.model_selector.button,
219
+ next_pos=5,
220
+ )
221
+
222
+ @self.classes_selector.button.click
223
+ def select_classes():
224
+ self.classes_selector_cb()
225
+ set_stepper_step(
226
+ self.stepper,
227
+ self.classes_selector.button,
228
+ next_pos=4,
229
+ )
230
+
231
+ @self.train_val_splits_selector.button.click
232
+ def select_train_val_splits():
233
+ self.train_val_splits_selector_cb()
234
+ set_stepper_step(
235
+ self.stepper,
236
+ self.train_val_splits_selector.button,
237
+ next_pos=3,
238
+ )
239
+
240
+ @self.input_selector.button.click
241
+ def select_input():
242
+ self.input_selector_cb()
243
+ set_stepper_step(
244
+ self.stepper,
245
+ self.input_selector.button,
246
+ next_pos=2,
247
+ )
248
+
249
+ # ------------------------------------------------- #
250
+
251
+ # Other Buttons
252
+ if app_options.get("show_logs_in_gui", False):
253
+
254
+ @self.training_logs.logs_button.click
255
+ def show_logs():
256
+ self.training_logs.toggle_logs()
257
+
258
+ # Other handlers
259
+ @self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
260
+ def show_mb_speedtest(is_checked: bool):
261
+ self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
262
+
263
+ # ------------------------------------------------- #
264
+
265
+ self.layout: Widget = self.stepper
266
+
267
+ def enable_select_buttons(self):
268
+ """
269
+ Makes all select buttons in the GUI available for interaction.
270
+ """
271
+ self.input_selector.button.enable()
272
+ self.train_val_splits_selector.button.enable()
273
+ self.classes_selector.button.enable()
274
+ self.model_selector.button.enable()
275
+ self.hyperparameters_selector.button.enable()
276
+
277
+ def disable_select_buttons(self):
278
+ """
279
+ Makes all select buttons in the GUI unavailable for interaction.
280
+ """
281
+ self.input_selector.button.disable()
282
+ self.train_val_splits_selector.button.disable()
283
+ self.classes_selector.button.disable()
284
+ self.model_selector.button.disable()
285
+ self.hyperparameters_selector.button.disable()
286
+
287
+ # Set GUI from config
288
+ def validate_app_state(self, app_state: dict) -> dict:
289
+ """
290
+ Validate the app state dictionary.
291
+
292
+ :param app_state: The app state dictionary.
293
+ :type app_state: dict
294
+ """
295
+ if not isinstance(app_state, dict):
296
+ raise ValueError("app_state must be a dictionary")
297
+
298
+ required_keys = {
299
+ "input": ["project_id"],
300
+ "train_val_split": ["method"],
301
+ "classes": list,
302
+ "model": ["source"],
303
+ "hyperparameters": (dict, str), # Allowing dict or str for hyperparameters
304
+ }
305
+
306
+ for key, subkeys_or_type in required_keys.items():
307
+ if key not in app_state:
308
+ raise KeyError(f"Missing required key in app_state: {key}")
309
+
310
+ if isinstance(subkeys_or_type, list):
311
+ for subkey in subkeys_or_type:
312
+ if subkey not in app_state[key]:
313
+ raise KeyError(f"Missing required key in app_state['{key}']: {subkey}")
314
+ elif not isinstance(app_state[key], subkeys_or_type):
315
+ valid_types = (
316
+ " or ".join([t.__name__ for t in subkeys_or_type])
317
+ if isinstance(subkeys_or_type, tuple)
318
+ else subkeys_or_type.__name__
319
+ )
320
+ raise ValueError(f"app_state['{key}'] must be of type {valid_types}")
321
+
322
+ model = app_state["model"]
323
+ if model["source"] == "Pretrained models":
324
+ if "model_name" not in model:
325
+ raise KeyError("Missing required key in app_state['model']: model_name")
326
+ elif model["source"] == "Custom models":
327
+ custom_keys = ["task_id", "checkpoint"]
328
+ for key in custom_keys:
329
+ if key not in model:
330
+ raise KeyError(f"Missing required key in app_state['model']: {key}")
331
+
332
+ options = app_state.setdefault(
333
+ "options",
334
+ {
335
+ "model_benchmark": {
336
+ "enable": True,
337
+ "speed_test": True,
338
+ },
339
+ "cache_project": True,
340
+ },
341
+ )
342
+
343
+ if not isinstance(options, dict):
344
+ raise ValueError("app_state['options'] must be a dictionary")
345
+
346
+ model_benchmark = options.setdefault(
347
+ "model_benchmark", {"enable": True, "speed_test": True}
348
+ )
349
+ if not isinstance(model_benchmark, dict):
350
+ raise ValueError("app_state['options']['model_benchmark'] must be a dictionary")
351
+ model_benchmark.setdefault("enable", True)
352
+ model_benchmark.setdefault("speed_test", True)
353
+
354
+ if not isinstance(options.get("cache_project"), bool):
355
+ raise ValueError("app_state['options']['cache_project'] must be a boolean")
356
+
357
+ # Check train val splits
358
+ train_val_splits_settings = app_state.get("train_val_split")
359
+ if train_val_splits_settings.get("method") == "datasets":
360
+ dataset_ids = []
361
+ for parents, dataset in self._api.dataset.tree(self.project_id):
362
+ dataset_ids.append(dataset.id)
363
+
364
+ train_datasets = train_val_splits_settings.get("train_datasets", [])
365
+ val_datasets = train_val_splits_settings.get("val_datasets", [])
366
+
367
+ missing_datasets_ids = []
368
+ for ds_id in train_datasets + val_datasets:
369
+ if ds_id not in dataset_ids:
370
+ missing_datasets_ids.append(ds_id)
371
+
372
+ if len(missing_datasets_ids) > 0:
373
+ missing_datasets_text = ", ".join([str(ds_id) for ds_id in missing_datasets_ids])
374
+ raise ValueError(
375
+ f"Datasets with ids: {missing_datasets_text} not found in the project"
376
+ )
377
+ elif train_val_splits_settings.get("method") == "tags":
378
+ train_tag = train_val_splits_settings.get("train_tag")
379
+ val_tag = train_val_splits_settings.get("val_tag")
380
+ if not train_tag or not val_tag:
381
+ raise ValueError("train_tag and val_tag must be specified in tags split method")
382
+ elif train_val_splits_settings.get("method") == "random":
383
+ split = train_val_splits_settings.get("split")
384
+ percent = train_val_splits_settings.get("percent")
385
+ if split not in ["train", "val"]:
386
+ raise ValueError("split must be 'train' or 'val'")
387
+ if not isinstance(percent, int) or not 0 < percent < 100:
388
+ raise ValueError("percent must be an integer in range 1 to 99")
389
+ return app_state
390
+
391
+ def load_from_app_state(self, app_state: dict) -> None:
392
+ """
393
+ Load the GUI state from app state dictionary.
394
+
395
+ :param app_state: The state dictionary.
396
+ :type app_state: dict
397
+
398
+ app_state example:
399
+
400
+ app_state = {
401
+ "input": {"project_id": 43192},
402
+ "train_val_splits": {
403
+ "method": "random",
404
+ "split": "train",
405
+ "percent": 90
406
+ },
407
+ "classes": ["apple"],
408
+ "model": {
409
+ "source": "Pretrained models",
410
+ "model_name": "rtdetr_r50vd_coco_objects365"
411
+ },
412
+ "hyperparameters": hyperparameters, # yaml string
413
+ "options": {
414
+ "model_benchmark": {
415
+ "enable": True,
416
+ "speed_test": True
417
+ },
418
+ "cache_project": True
419
+ }
420
+ }
421
+ """
422
+ app_state = self.validate_app_state(app_state)
423
+
424
+ options = app_state["options"]
425
+ input_settings = app_state["input"]
426
+ train_val_splits_settings = app_state["train_val_split"]
427
+ classes_settings = app_state["classes"]
428
+ model_settings = app_state["model"]
429
+ hyperparameters_settings = app_state["hyperparameters"]
430
+
431
+ self._init_input(input_settings, options)
432
+ self._init_classes(classes_settings)
433
+ self._init_train_val_splits(train_val_splits_settings)
434
+ self._init_model(model_settings)
435
+ self._init_hyperparameters(hyperparameters_settings, options)
436
+
437
+ def _init_input(self, input_settings: dict, options: dict) -> None:
438
+ """
439
+ Initialize the input selector with the given settings.
440
+
441
+ :param input_settings: The input settings.
442
+ :type input_settings: dict
443
+ :param options: The application options.
444
+ :type options: dict
445
+ """
446
+ # Set Input
447
+ self.input_selector.set_cache(options["cache_project"])
448
+ self.input_selector_cb()
449
+ # ----------------------------------------- #
450
+
451
+ def _init_train_val_splits(self, train_val_splits_settings: dict) -> None:
452
+ """
453
+ Initialize the train/val splits selector with the given settings.
454
+
455
+ :param train_val_splits_settings: The train/val splits settings.
456
+ :type train_val_splits_settings: dict
457
+ """
458
+ split_method = train_val_splits_settings["method"]
459
+ if split_method == "random":
460
+ split = train_val_splits_settings["split"]
461
+ percent = train_val_splits_settings["percent"]
462
+ self.train_val_splits_selector.train_val_splits.set_random_splits(split, percent)
463
+ elif split_method == "tags":
464
+ train_tag = train_val_splits_settings["train_tag"]
465
+ val_tag = train_val_splits_settings["val_tag"]
466
+ untagged_action = train_val_splits_settings["untagged_action"]
467
+ self.train_val_splits_selector.train_val_splits.set_tags_splits(
468
+ train_tag, val_tag, untagged_action
469
+ )
470
+ elif split_method == "datasets":
471
+ train_datasets = train_val_splits_settings["train_datasets"]
472
+ val_datasets = train_val_splits_settings["val_datasets"]
473
+ self.train_val_splits_selector.train_val_splits.set_datasets_splits(
474
+ train_datasets, val_datasets
475
+ )
476
+ self.train_val_splits_selector_cb()
477
+
478
+ def _init_classes(self, classes_settings: list) -> None:
479
+ """
480
+ Initialize the classes selector with the given settings.
481
+
482
+ :param classes_settings: The classes settings.
483
+ :type classes_settings: list
484
+ """
485
+ # Set Classes
486
+ self.classes_selector.set_classes(classes_settings)
487
+ self.classes_selector_cb()
488
+ # ----------------------------------------- #
489
+
490
+ def _init_model(self, model_settings: dict) -> None:
491
+ """
492
+ Initialize the model selector with the given settings.
493
+
494
+ :param model_settings: The model settings.
495
+ :type model_settings: dict
496
+ """
497
+
498
+ # Pretrained
499
+ if model_settings["source"] == ModelSource.PRETRAINED:
500
+ self.model_selector.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
501
+ self.model_selector.pretrained_models_table.set_by_model_name(
502
+ model_settings["model_name"]
503
+ )
504
+
505
+ # Custom
506
+ elif model_settings["source"] == ModelSource.CUSTOM:
507
+ self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
508
+ self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
509
+ active_row = self.model_selector.experiment_selector.get_selected_row()
510
+ if model_settings["checkpoint"] not in active_row.checkpoints_names:
511
+ raise ValueError(
512
+ f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
513
+ )
514
+
515
+ active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
516
+ self.model_selector_cb()
517
+ # ----------------------------------------- #
518
+
519
+ def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict) -> None:
520
+ """
521
+ Initialize the hyperparameters selector with the given settings.
522
+
523
+ :param hyperparameters_settings: The hyperparameters settings.
524
+ :type hyperparameters_settings: dict
525
+ :param options: The application options.
526
+ :type options: dict
527
+ """
528
+ self.hyperparameters_selector.set_hyperparameters(hyperparameters_settings)
529
+
530
+ model_benchmark_settings = options["model_benchmark"]
531
+ self.hyperparameters_selector.set_model_benchmark_checkbox_value(
532
+ model_benchmark_settings["enable"]
533
+ )
534
+ self.hyperparameters_selector.set_speedtest_checkbox_value(
535
+ model_benchmark_settings["speed_test"]
536
+ )
537
+ self.hyperparameters_selector_cb()
538
+
539
+ # ----------------------------------------- #
@@ -0,0 +1,117 @@
1
+ from typing import Union
2
+
3
+ from supervisely.app.widgets import (
4
+ Button,
5
+ Card,
6
+ Checkbox,
7
+ Container,
8
+ Editor,
9
+ Field,
10
+ Text,
11
+ )
12
+
13
+
14
+ class HyperparametersSelector:
15
+ title = "Hyperparameters"
16
+ description = "Set hyperparameters for training"
17
+ lock_message = "Select model to unlock"
18
+
19
+ def __init__(self, hyperparameters: dict, app_options: dict = {}):
20
+ self.app_options = app_options
21
+ self.editor = Editor(
22
+ hyperparameters, height_lines=50, language_mode="yaml", auto_format=True
23
+ )
24
+
25
+ # Model Benchmark
26
+ self.run_model_benchmark_checkbox = Checkbox(
27
+ content="Run Model Benchmark evaluation", checked=True
28
+ )
29
+ self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=True)
30
+
31
+ self.model_benchmark_field = Field(
32
+ Container(
33
+ widgets=[
34
+ self.run_model_benchmark_checkbox,
35
+ self.run_speedtest_checkbox,
36
+ ]
37
+ ),
38
+ title="Model Evaluation Benchmark",
39
+ description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
40
+ )
41
+ docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
42
+ self.model_benchmark_learn_more = Text(
43
+ f"Learn more about Model Benchmark in the {docs_link}.", status="info"
44
+ )
45
+
46
+ if app_options.get("model_benchmark", True):
47
+ self.model_benchmark_field.show()
48
+ self.model_benchmark_learn_more.show()
49
+ else:
50
+ self.model_benchmark_field.hide()
51
+ self.model_benchmark_learn_more.hide()
52
+
53
+ self.validator_text = Text("")
54
+ self.validator_text.hide()
55
+ self.button = Button("Select")
56
+ container = Container(
57
+ [
58
+ self.editor,
59
+ self.model_benchmark_field,
60
+ self.model_benchmark_learn_more,
61
+ self.validator_text,
62
+ self.button,
63
+ ]
64
+ )
65
+ self.card = Card(
66
+ title=self.title,
67
+ description=self.description,
68
+ content=container,
69
+ lock_message=self.lock_message,
70
+ collapsable=app_options.get("collapsable", False),
71
+ )
72
+ self.card.lock()
73
+
74
+ @property
75
+ def widgets_to_disable(self) -> list:
76
+ return [
77
+ self.editor,
78
+ self.run_model_benchmark_checkbox,
79
+ self.run_speedtest_checkbox,
80
+ ]
81
+
82
+ def set_hyperparameters(self, hyperparameters: Union[str, dict]) -> None:
83
+ self.editor.set_text(hyperparameters)
84
+
85
+ def get_hyperparameters(self) -> dict:
86
+ return self.editor.get_value()
87
+
88
+ def get_model_benchmark_checkbox_value(self) -> bool:
89
+ if self.app_options.get("model_benchmark", True):
90
+ return self.run_model_benchmark_checkbox.is_checked()
91
+ return False
92
+
93
+ def set_model_benchmark_checkbox_value(self, is_checked: bool) -> bool:
94
+ if is_checked:
95
+ self.run_model_benchmark_checkbox.check()
96
+ else:
97
+ self.run_model_benchmark_checkbox.uncheck()
98
+
99
+ def get_speedtest_checkbox_value(self) -> bool:
100
+ if self.app_options.get("model_benchmark", True):
101
+ return self.run_speedtest_checkbox.is_checked()
102
+ return False
103
+
104
+ def set_speedtest_checkbox_value(self, is_checked: bool) -> bool:
105
+ if is_checked:
106
+ self.run_speedtest_checkbox.check()
107
+ else:
108
+ self.run_speedtest_checkbox.uncheck()
109
+
110
+ def toggle_mb_speedtest(self, is_checked: bool) -> None:
111
+ if is_checked:
112
+ self.run_speedtest_checkbox.show()
113
+ else:
114
+ self.run_speedtest_checkbox.hide()
115
+
116
+ def validate_step(self) -> bool:
117
+ return True