supervisely 6.73.243__py3-none-any.whl → 6.73.245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (56) hide show
  1. supervisely/__init__.py +1 -1
  2. supervisely/_utils.py +18 -0
  3. supervisely/app/widgets/__init__.py +1 -0
  4. supervisely/app/widgets/card/card.py +3 -0
  5. supervisely/app/widgets/classes_table/classes_table.py +15 -1
  6. supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
  7. supervisely/app/widgets/custom_models_selector/template.html +1 -1
  8. supervisely/app/widgets/experiment_selector/__init__.py +0 -0
  9. supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
  10. supervisely/app/widgets/experiment_selector/style.css +27 -0
  11. supervisely/app/widgets/experiment_selector/template.html +82 -0
  12. supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
  13. supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
  14. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
  15. supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
  16. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  17. supervisely/nn/__init__.py +3 -1
  18. supervisely/nn/artifacts/artifacts.py +10 -0
  19. supervisely/nn/artifacts/detectron2.py +2 -0
  20. supervisely/nn/artifacts/hrda.py +3 -0
  21. supervisely/nn/artifacts/mmclassification.py +2 -0
  22. supervisely/nn/artifacts/mmdetection.py +6 -3
  23. supervisely/nn/artifacts/mmsegmentation.py +2 -0
  24. supervisely/nn/artifacts/ritm.py +3 -1
  25. supervisely/nn/artifacts/rtdetr.py +2 -0
  26. supervisely/nn/artifacts/unet.py +2 -0
  27. supervisely/nn/artifacts/yolov5.py +3 -0
  28. supervisely/nn/artifacts/yolov8.py +7 -1
  29. supervisely/nn/experiments.py +113 -0
  30. supervisely/nn/inference/gui/__init__.py +3 -1
  31. supervisely/nn/inference/gui/gui.py +31 -232
  32. supervisely/nn/inference/gui/serving_gui.py +223 -0
  33. supervisely/nn/inference/gui/serving_gui_template.py +240 -0
  34. supervisely/nn/inference/inference.py +225 -24
  35. supervisely/nn/training/__init__.py +0 -0
  36. supervisely/nn/training/gui/__init__.py +1 -0
  37. supervisely/nn/training/gui/classes_selector.py +100 -0
  38. supervisely/nn/training/gui/gui.py +539 -0
  39. supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
  40. supervisely/nn/training/gui/input_selector.py +70 -0
  41. supervisely/nn/training/gui/model_selector.py +95 -0
  42. supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
  43. supervisely/nn/training/gui/training_logs.py +93 -0
  44. supervisely/nn/training/gui/training_process.py +114 -0
  45. supervisely/nn/training/gui/utils.py +128 -0
  46. supervisely/nn/training/loggers/__init__.py +0 -0
  47. supervisely/nn/training/loggers/base_train_logger.py +58 -0
  48. supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
  49. supervisely/nn/training/train_app.py +2038 -0
  50. supervisely/nn/utils.py +5 -0
  51. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/METADATA +3 -1
  52. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/RECORD +56 -34
  53. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/LICENSE +0 -0
  54. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/WHEEL +0 -0
  55. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/entry_points.txt +0 -0
  56. {supervisely-6.73.243.dist-info → supervisely-6.73.245.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,70 @@
1
+ from supervisely.api.project_api import ProjectInfo
2
+ from supervisely.app.widgets import (
3
+ Button,
4
+ Card,
5
+ Checkbox,
6
+ Container,
7
+ Field,
8
+ ProjectThumbnail,
9
+ Text,
10
+ )
11
+ from supervisely.project.download import is_cached
12
+
13
+
14
+ class InputSelector:
15
+ title = "Input project"
16
+ description = "Selected project from which images and annotations will be downloaded"
17
+ lock_message = None
18
+
19
+ def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
20
+ self.project_id = project_info.id
21
+ self.project_info = project_info
22
+
23
+ self.project_thumbnail = ProjectThumbnail(self.project_info)
24
+
25
+ if is_cached(self.project_id):
26
+ _text = "Use cached data stored on the agent to optimize project download"
27
+ else:
28
+ _text = "Cache data on the agent to optimize project download for future trainings"
29
+ self.use_cache_text = Text(_text)
30
+ self.use_cache_checkbox = Checkbox(self.use_cache_text, checked=True)
31
+
32
+ self.validator_text = Text("")
33
+ self.validator_text.hide()
34
+ self.button = Button("Select")
35
+ container = Container(
36
+ widgets=[
37
+ self.project_thumbnail,
38
+ self.use_cache_checkbox,
39
+ self.validator_text,
40
+ self.button,
41
+ ]
42
+ )
43
+
44
+ self.card = Card(
45
+ title=self.title,
46
+ description=self.description,
47
+ content=container,
48
+ collapsable=app_options.get("collapsable", False),
49
+ )
50
+
51
+ @property
52
+ def widgets_to_disable(self) -> list:
53
+ return [
54
+ self.use_cache_checkbox,
55
+ ]
56
+
57
+ def get_project_id(self) -> int:
58
+ return self.project_id
59
+
60
+ def set_cache(self, value: bool) -> None:
61
+ if value:
62
+ self.use_cache_checkbox.check()
63
+ else:
64
+ self.use_cache_checkbox.uncheck()
65
+
66
+ def get_cache_value(self) -> bool:
67
+ return self.use_cache_checkbox.is_checked()
68
+
69
+ def validate_step(self) -> bool:
70
+ return True
@@ -0,0 +1,95 @@
1
+ from typing import Literal
2
+
3
+ import supervisely.io.env as sly_env
4
+ from supervisely.api.api import Api
5
+ from supervisely.app.widgets import (
6
+ Button,
7
+ Card,
8
+ Container,
9
+ ExperimentSelector,
10
+ PretrainedModelsSelector,
11
+ RadioTabs,
12
+ Text,
13
+ )
14
+ from supervisely.nn.experiments import get_experiment_infos
15
+ from supervisely.nn.utils import ModelSource
16
+
17
+
18
+ class ModelSelector:
19
+ title = "Select Model"
20
+ description = "Select a model for training"
21
+ lock_message = "Select classes to unlock"
22
+
23
+ def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
24
+ self.team_id = sly_env.team_id() # get from project id
25
+ self.models = models
26
+
27
+ # Pretrained models
28
+ self.pretrained_models_table = PretrainedModelsSelector(self.models)
29
+
30
+ experiment_infos = get_experiment_infos(api, self.team_id, framework)
31
+ self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
32
+ # Model source tabs
33
+ self.model_source_tabs = RadioTabs(
34
+ titles=[ModelSource.PRETRAINED, ModelSource.CUSTOM],
35
+ descriptions=[
36
+ "Publicly available models",
37
+ "Models trained by you in Supervisely",
38
+ ],
39
+ contents=[self.pretrained_models_table, self.experiment_selector],
40
+ )
41
+
42
+ self.validator_text = Text("")
43
+ self.validator_text.hide()
44
+ self.button = Button("Select")
45
+ container = Container([self.model_source_tabs, self.validator_text, self.button])
46
+ self.card = Card(
47
+ title=self.title,
48
+ description=self.description,
49
+ content=container,
50
+ lock_message=self.lock_message,
51
+ collapsable=app_options.get("collapsable", False),
52
+ )
53
+ self.card.lock()
54
+
55
+ @property
56
+ def widgets_to_disable(self) -> list:
57
+ return [
58
+ self.model_source_tabs,
59
+ self.pretrained_models_table,
60
+ self.experiment_selector,
61
+ ]
62
+
63
+ def get_model_source(self) -> str:
64
+ return self.model_source_tabs.get_active_tab()
65
+
66
+ def set_model_source(self, source: Literal["Pretrained models", "Custom models"]) -> None:
67
+ self.model_source_tabs.set_active_tab(source)
68
+
69
+ def get_model_name(self) -> str:
70
+ if self.get_model_source() == ModelSource.PRETRAINED:
71
+ selected_row = self.pretrained_models_table.get_selected_row()
72
+ model_meta = selected_row.get("meta", {})
73
+ model_name = model_meta.get("model_name", None)
74
+ else:
75
+ selected_row = self.experiment_selector.get_selected_experiment_info()
76
+ model_name = selected_row.get("model_name", None)
77
+ return model_name
78
+
79
+ def get_model_info(self) -> dict:
80
+ if self.get_model_source() == ModelSource.PRETRAINED:
81
+ return self.pretrained_models_table.get_selected_row()
82
+ else:
83
+ return self.experiment_selector.get_selected_experiment_info()
84
+
85
+ def validate_step(self) -> bool:
86
+ self.validator_text.hide()
87
+ model_info = self.get_model_info()
88
+ if model_info is None or model_info == {}:
89
+ self.validator_text.set(text="Model is not selected", status="error")
90
+ self.validator_text.show()
91
+ return False
92
+ else:
93
+ self.validator_text.set(text="Model is selected", status="success")
94
+ self.validator_text.show()
95
+ return True
@@ -0,0 +1,200 @@
1
+ from typing import List
2
+
3
+ from supervisely import Api, Project
4
+ from supervisely.app.widgets import Button, Card, Container, Text, TrainValSplits
5
+
6
+
7
+ class TrainValSplitsSelector:
8
+ title = "Train / Val Splits"
9
+ description = "Select train and val splits for training"
10
+ lock_message = "Select input options to unlock"
11
+
12
+ def __init__(self, api: Api, project_id: int, app_options: dict = {}):
13
+ self.api = api
14
+ self.project_id = project_id
15
+ self.train_val_splits = TrainValSplits(project_id)
16
+
17
+ train_val_dataset_ids = {"train": [], "val": []}
18
+ for _, dataset in api.dataset.tree(project_id):
19
+ if dataset.name.lower() == "train" or dataset.name.lower() == "training":
20
+ if dataset.items_count > 0:
21
+ train_val_dataset_ids["train"].append(dataset.id)
22
+ elif dataset.name.lower() == "val" or dataset.name.lower() == "validation":
23
+ if dataset.items_count > 0:
24
+ train_val_dataset_ids["val"].append(dataset.id)
25
+
26
+ # Check nested dataset names
27
+ train_count = len(train_val_dataset_ids["train"])
28
+ val_count = len(train_val_dataset_ids["val"])
29
+ if train_count > 0 and val_count > 0:
30
+ self.train_val_splits.set_datasets_splits(
31
+ train_val_dataset_ids["train"], train_val_dataset_ids["val"]
32
+ )
33
+
34
+ if train_count > 0 and val_count > 0:
35
+ self.validator_text = Text("Train and val datasets are detected", status="info")
36
+ self.validator_text.show()
37
+ else:
38
+ self.validator_text = Text("")
39
+ self.validator_text.hide()
40
+
41
+ self.button = Button("Select")
42
+ container = Container(
43
+ [
44
+ self.train_val_splits,
45
+ self.validator_text,
46
+ self.button,
47
+ ]
48
+ )
49
+ self.card = Card(
50
+ title=self.title,
51
+ description=self.description,
52
+ content=container,
53
+ lock_message=self.lock_message,
54
+ collapsable=app_options.get("collapsable", False),
55
+ )
56
+ self.card.lock()
57
+
58
+ @property
59
+ def widgets_to_disable(self) -> list:
60
+ return [self.train_val_splits]
61
+
62
+ def validate_step(self) -> bool:
63
+ split_method = self.train_val_splits.get_split_method()
64
+ warning_text = "Using the same data for training and validation leads to overfitting, poor generalization and biased model selection."
65
+ ensure_text = "Ensure this is intentional."
66
+
67
+ if split_method == "Random":
68
+ train_ratio = self.train_val_splits.get_train_split_percent()
69
+ val_ratio = self.train_val_splits.get_val_split_percent()
70
+
71
+ # Define common warning text
72
+ ensure_text_random_split = (
73
+ "Consider reallocating to ensure efficient learning and validation."
74
+ )
75
+
76
+ # Validate splits
77
+ if train_ratio == val_ratio:
78
+ self.validator_text.set(
79
+ text="Train and validation splits are equal (50:50). This is inefficient for standard training. "
80
+ f"{ensure_text}",
81
+ status="warning",
82
+ )
83
+ elif train_ratio > 90:
84
+ self.validator_text.set(
85
+ text="Training split exceeds 90%. This may leave insufficient data for validation. Ensure you have enough data for validation.",
86
+ status="warning",
87
+ )
88
+ elif val_ratio > train_ratio:
89
+ self.validator_text.set(
90
+ text=f"Validation split is larger than the training split. {ensure_text_random_split}",
91
+ status="warning",
92
+ )
93
+ elif train_ratio < 70:
94
+ self.validator_text.set(
95
+ text="Training split is below 70%. This may limit the model's learning capability. "
96
+ f"{ensure_text_random_split}",
97
+ status="warning",
98
+ )
99
+ else:
100
+ self.validator_text.set(
101
+ text="Train and validation splits are selected.",
102
+ status="success",
103
+ )
104
+
105
+ elif split_method == "Based on tags":
106
+ train_tag = self.train_val_splits.get_train_tag()
107
+ val_tag = self.train_val_splits.get_val_tag()
108
+
109
+ # Check if tags are present on any item in the project
110
+ stats = self.api.project.get_stats(self.project_id)
111
+ tags_count = {}
112
+ for item in stats["imageTags"]["items"]:
113
+ tag_name = item["tagMeta"]["name"]
114
+ tag_total = item["total"]
115
+ tags_count[tag_name] = tag_total
116
+
117
+ for object_tags in stats["objectTags"]["items"]:
118
+ tag_name = object_tags["tagMeta"]["name"]
119
+ tag_total = object_tags["total"]
120
+ if tag_name in tags_count:
121
+ tags_count[tag_name] += tag_total
122
+ else:
123
+ tags_count[tag_name] = tag_total
124
+
125
+ # @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
126
+ if tags_count[train_tag] == 0:
127
+ self.validator_text.set(
128
+ text=f"Train tag '{train_tag}' is not present in any images. {ensure_text}",
129
+ status="error",
130
+ )
131
+ elif tags_count[val_tag] == 0:
132
+ self.validator_text.set(
133
+ text=f"Val tag '{val_tag}' is not present in any images. {ensure_text}",
134
+ status="error",
135
+ )
136
+
137
+ elif train_tag == val_tag:
138
+ self.validator_text.set(
139
+ text=f"Train and val tags are the same. {ensure_text} {warning_text}",
140
+ status="warning",
141
+ )
142
+ else:
143
+ self.validator_text.set("Train and val tags are selected", status="success")
144
+
145
+ elif split_method == "Based on datasets":
146
+ train_dataset_id = self.get_train_dataset_ids()
147
+ val_dataset_id = self.get_val_dataset_ids()
148
+
149
+ # Check if datasets are not empty
150
+ stats = self.api.project.get_stats(self.project_id)
151
+ datasets_count = {}
152
+ for dataset in stats["images"]["datasets"]:
153
+ datasets_count[dataset["id"]] = {
154
+ "name": dataset["name"],
155
+ "total": dataset["imagesInDataset"],
156
+ }
157
+
158
+ empty_dataset_names = []
159
+ for dataset_id in train_dataset_id + val_dataset_id:
160
+ if datasets_count[dataset_id]["total"] == 0:
161
+ empty_dataset_names.append(datasets_count[dataset_id]["name"])
162
+
163
+ if len(empty_dataset_names) > 0:
164
+ if len(empty_dataset_names) == 1:
165
+ empty_ds_text = f"Selected dataset: {', '.join(empty_dataset_names)} is empty. {ensure_text}"
166
+ else:
167
+ empty_ds_text = f"Selected datasets: {', '.join(empty_dataset_names)} are empty. {ensure_text}"
168
+
169
+ self.validator_text.set(
170
+ text=empty_ds_text,
171
+ status="error",
172
+ )
173
+
174
+ elif train_dataset_id == val_dataset_id:
175
+ self.validator_text.set(
176
+ text=f"Same datasets are selected for both train and val splits. {ensure_text} {warning_text}",
177
+ status="warning",
178
+ )
179
+ else:
180
+ self.validator_text.set("Train and val datasets are selected", status="success")
181
+ self.validator_text.show()
182
+ return True
183
+
184
+ def set_sly_project(self, project: Project) -> None:
185
+ self.train_val_splits._project_fs = project
186
+
187
+ def get_split_method(self) -> str:
188
+ return self.train_val_splits.get_split_method()
189
+
190
+ def get_train_dataset_ids(self) -> List[int]:
191
+ return self.train_val_splits._train_ds_select.get_selected_ids()
192
+
193
+ def set_train_dataset_ids(self, dataset_ids: List[int]) -> None:
194
+ self.train_val_splits._train_ds_select.set_selected_ids(dataset_ids)
195
+
196
+ def get_val_dataset_ids(self) -> List[int]:
197
+ return self.train_val_splits._val_ds_select.get_selected_ids()
198
+
199
+ def set_val_dataset_ids(self, dataset_ids: List[int]) -> None:
200
+ self.train_val_splits._val_ds_select.set_selected_ids(dataset_ids)
@@ -0,0 +1,93 @@
1
+ from typing import Any, Dict
2
+
3
+ from supervisely import Api
4
+ from supervisely._utils import is_production
5
+ from supervisely.app.widgets import Button, Card, Container, Progress, TaskLogs, Text
6
+ from supervisely.io.env import task_id as get_task_id
7
+
8
+
9
+ class TrainingLogs:
10
+ title = "Training Logs"
11
+ description = "Track training progress"
12
+ lock_message = "Start training to unlock"
13
+
14
+ def __init__(self, app_options: Dict[str, Any]):
15
+ api = Api.from_env()
16
+ self.app_options = app_options
17
+
18
+ self.progress_bar_main = Progress(hide_on_finish=False)
19
+ self.progress_bar_main.hide()
20
+
21
+ self.progress_bar_secondary = Progress(hide_on_finish=False)
22
+ self.progress_bar_secondary.hide()
23
+
24
+ if is_production():
25
+ task_id = get_task_id(raise_not_found=False)
26
+ else:
27
+ task_id = None
28
+
29
+ # Tensorboard button
30
+ if is_production():
31
+ task_info = api.task.get_info_by_id(task_id)
32
+ session_token = task_info["meta"]["sessionToken"]
33
+ sly_url_prefix = f"/net/{session_token}"
34
+ self.tensorboard_link = f"{api.server_address}{sly_url_prefix}/tensorboard/"
35
+ else:
36
+ self.tensorboard_link = "http://localhost:8000/tensorboard"
37
+ self.tensorboard_button = Button(
38
+ "Open Tensorboard",
39
+ button_type="info",
40
+ plain=True,
41
+ icon="zmdi zmdi-chart",
42
+ link=self.tensorboard_link,
43
+ )
44
+ self.tensorboard_button.disable()
45
+
46
+ self.validator_text = Text("")
47
+ self.validator_text.hide()
48
+
49
+ container_widgets = [
50
+ self.validator_text,
51
+ self.tensorboard_button,
52
+ self.progress_bar_main,
53
+ self.progress_bar_secondary,
54
+ ]
55
+
56
+ if app_options.get("show_logs_in_gui", False):
57
+ self.logs_button = Button(
58
+ text="Show logs",
59
+ plain=True,
60
+ button_size="mini",
61
+ icon="zmdi zmdi-caret-down-circle",
62
+ )
63
+ self.task_logs = TaskLogs(task_id)
64
+ self.task_logs.hide()
65
+ logs_container = Container([self.logs_button, self.task_logs])
66
+ container_widgets.insert(2, logs_container)
67
+
68
+ container = Container(container_widgets)
69
+
70
+ self.card = Card(
71
+ title=self.title,
72
+ description=self.description,
73
+ content=container,
74
+ lock_message=self.lock_message,
75
+ )
76
+ self.card.lock()
77
+
78
+ @property
79
+ def widgets_to_disable(self) -> list:
80
+ return []
81
+
82
+ def validate_step(self) -> bool:
83
+ return True
84
+
85
+ def toggle_logs(self):
86
+ if self.task_logs.is_hidden():
87
+ self.task_logs.show()
88
+ self.logs_button.text = "Hide logs"
89
+ self.logs_button.icon = "zmdi zmdi-caret-up-circle"
90
+ else:
91
+ self.task_logs.hide()
92
+ self.logs_button.text = "Show logs"
93
+ self.logs_button.icon = "zmdi zmdi-caret-down-circle"
@@ -0,0 +1,114 @@
1
+ from typing import Any, Dict
2
+
3
+ from supervisely import Api
4
+ from supervisely.app.widgets import (
5
+ Button,
6
+ Card,
7
+ Container,
8
+ DoneLabel,
9
+ Empty,
10
+ Field,
11
+ FolderThumbnail,
12
+ Input,
13
+ ReportThumbnail,
14
+ SelectCudaDevice,
15
+ Text,
16
+ )
17
+
18
+
19
+ class TrainingProcess:
20
+ title = "Training Process"
21
+ description = "Manage training process"
22
+ lock_message = "Select hyperparametrs to unlock"
23
+
24
+ def __init__(self, app_options: Dict[str, Any]):
25
+ self.app_options = app_options
26
+ self.experiment_name_input = Input("Enter experiment name")
27
+ self.experiment_name_field = Field(
28
+ title="Experiment name",
29
+ description="Experiment name will be saved to experiment_info.json",
30
+ content=self.experiment_name_input,
31
+ )
32
+
33
+ self.success_message_text = (
34
+ "Training completed. Training artifacts were uploaded to Team Files. "
35
+ "You can find and open tensorboard logs in the artifacts folder via the "
36
+ "<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
37
+ )
38
+ self.success_message = DoneLabel(text=self.success_message_text)
39
+ self.success_message.hide()
40
+
41
+ self.artifacts_thumbnail = FolderThumbnail()
42
+ self.artifacts_thumbnail.hide()
43
+
44
+ self.model_benchmark_report_thumbnail = ReportThumbnail()
45
+ self.model_benchmark_report_thumbnail.hide()
46
+
47
+ self.model_benchmark_report_text = Text(status="info", text="Creating report on model...")
48
+ self.model_benchmark_report_text.hide()
49
+
50
+ self.validator_text = Text("")
51
+ self.validator_text.hide()
52
+ self.start_button = Button("Start")
53
+ self.stop_button = Button("Stop", button_type="danger")
54
+ self.stop_button.hide() # @TODO: implement stop and hide stop button until training starts
55
+
56
+ button_container = Container(
57
+ [self.start_button, self.stop_button, Empty()],
58
+ "horizontal",
59
+ overflow="wrap",
60
+ fractions=[1, 1, 10],
61
+ gap=1,
62
+ )
63
+
64
+ container_widgets = [
65
+ self.experiment_name_field,
66
+ button_container,
67
+ self.validator_text,
68
+ self.artifacts_thumbnail,
69
+ self.model_benchmark_report_thumbnail,
70
+ self.model_benchmark_report_text,
71
+ ]
72
+
73
+ if self.app_options.get("device_selector", False):
74
+ self.select_device = SelectCudaDevice()
75
+ self.select_cuda_device_field = Field(
76
+ title="Select CUDA device",
77
+ description="The device on which the model will be trained",
78
+ content=self.select_device,
79
+ )
80
+ container_widgets.insert(1, self.select_cuda_device_field)
81
+
82
+ container = Container(container_widgets)
83
+
84
+ self.card = Card(
85
+ title=self.title,
86
+ description=self.description,
87
+ content=container,
88
+ lock_message=self.lock_message,
89
+ )
90
+ self.card.lock()
91
+
92
+ @property
93
+ def widgets_to_disable(self) -> list:
94
+ widgets = [self.experiment_name_input]
95
+ if self.app_options.get("device_selector", False):
96
+ widgets.append(self.experiment_name_input)
97
+ return widgets
98
+
99
+ return []
100
+
101
+ def validate_step(self) -> bool:
102
+ return True
103
+
104
+ def get_device(self) -> str:
105
+ if self.app_options.get("device_selector", False):
106
+ return self.select_device.get_device()
107
+ else:
108
+ return "cuda:0"
109
+
110
+ def get_experiment_name(self) -> str:
111
+ return self.experiment_name_input.get_value()
112
+
113
+ def set_experiment_name(self, experiment_name) -> None:
114
+ self.experiment_name_input.set_value(experiment_name)