supervisely 6.73.242__py3-none-any.whl → 6.73.244__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +18 -0
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/card/card.py +3 -0
- supervisely/app/widgets/classes_table/classes_table.py +15 -1
- supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
- supervisely/app/widgets/custom_models_selector/template.html +1 -1
- supervisely/app/widgets/experiment_selector/__init__.py +0 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
- supervisely/app/widgets/experiment_selector/style.css +27 -0
- supervisely/app/widgets/experiment_selector/template.html +82 -0
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
- supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
- supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/__init__.py +3 -1
- supervisely/nn/artifacts/artifacts.py +10 -0
- supervisely/nn/artifacts/detectron2.py +2 -0
- supervisely/nn/artifacts/hrda.py +3 -0
- supervisely/nn/artifacts/mmclassification.py +2 -0
- supervisely/nn/artifacts/mmdetection.py +6 -3
- supervisely/nn/artifacts/mmsegmentation.py +2 -0
- supervisely/nn/artifacts/ritm.py +3 -1
- supervisely/nn/artifacts/rtdetr.py +2 -0
- supervisely/nn/artifacts/unet.py +2 -0
- supervisely/nn/artifacts/yolov5.py +3 -0
- supervisely/nn/artifacts/yolov8.py +7 -1
- supervisely/nn/experiments.py +113 -0
- supervisely/nn/inference/gui/__init__.py +3 -1
- supervisely/nn/inference/gui/gui.py +31 -232
- supervisely/nn/inference/gui/serving_gui.py +223 -0
- supervisely/nn/inference/gui/serving_gui_template.py +240 -0
- supervisely/nn/inference/inference.py +225 -24
- supervisely/nn/training/__init__.py +0 -0
- supervisely/nn/training/gui/__init__.py +1 -0
- supervisely/nn/training/gui/classes_selector.py +100 -0
- supervisely/nn/training/gui/gui.py +539 -0
- supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
- supervisely/nn/training/gui/input_selector.py +70 -0
- supervisely/nn/training/gui/model_selector.py +95 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
- supervisely/nn/training/gui/training_logs.py +93 -0
- supervisely/nn/training/gui/training_process.py +114 -0
- supervisely/nn/training/gui/utils.py +128 -0
- supervisely/nn/training/loggers/__init__.py +0 -0
- supervisely/nn/training/loggers/base_train_logger.py +58 -0
- supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
- supervisely/nn/training/train_app.py +2038 -0
- supervisely/nn/utils.py +5 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/RECORD +57 -35
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GUI module for training application.
|
|
3
|
+
|
|
4
|
+
This module provides the `TrainGUI` class that handles the graphical user interface (GUI) for managing
|
|
5
|
+
training workflows in Supervisely.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import supervisely.io.env as sly_env
|
|
9
|
+
from supervisely import Api
|
|
10
|
+
from supervisely._utils import is_production
|
|
11
|
+
from supervisely.app.widgets import Stepper, Widget
|
|
12
|
+
from supervisely.nn.training.gui.classes_selector import ClassesSelector
|
|
13
|
+
from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
|
|
14
|
+
from supervisely.nn.training.gui.input_selector import InputSelector
|
|
15
|
+
from supervisely.nn.training.gui.model_selector import ModelSelector
|
|
16
|
+
from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplitsSelector
|
|
17
|
+
from supervisely.nn.training.gui.training_logs import TrainingLogs
|
|
18
|
+
from supervisely.nn.training.gui.training_process import TrainingProcess
|
|
19
|
+
from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
|
|
20
|
+
from supervisely.nn.utils import ModelSource
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TrainGUI:
|
|
24
|
+
"""
|
|
25
|
+
A class representing the GUI for training workflows.
|
|
26
|
+
|
|
27
|
+
This class sets up and manages GUI components such as project selection,
|
|
28
|
+
train/validation split selection, model selection, hyperparameters selection,
|
|
29
|
+
and the training process.
|
|
30
|
+
|
|
31
|
+
:param framework_name: Name of the ML framework being used.
|
|
32
|
+
:type framework_name: str
|
|
33
|
+
:param models: List of available models.
|
|
34
|
+
:type models: list
|
|
35
|
+
:param hyperparameters: Hyperparameters for training.
|
|
36
|
+
:type hyperparameters: dict
|
|
37
|
+
:param app_options: Application options for customization.
|
|
38
|
+
:type app_options: dict, optional
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
framework_name: str,
|
|
44
|
+
models: list,
|
|
45
|
+
hyperparameters: dict,
|
|
46
|
+
app_options: dict = None,
|
|
47
|
+
):
|
|
48
|
+
self._api = Api.from_env()
|
|
49
|
+
|
|
50
|
+
if is_production():
|
|
51
|
+
self.task_id = sly_env.task_id()
|
|
52
|
+
else:
|
|
53
|
+
self.task_id = "debug-session"
|
|
54
|
+
|
|
55
|
+
self.framework_name = framework_name
|
|
56
|
+
self.models = models
|
|
57
|
+
self.hyperparameters = hyperparameters
|
|
58
|
+
self.app_options = app_options
|
|
59
|
+
self.collapsable = app_options.get("collapsable", False)
|
|
60
|
+
|
|
61
|
+
self.team_id = sly_env.team_id()
|
|
62
|
+
self.workspace_id = sly_env.workspace_id()
|
|
63
|
+
self.project_id = sly_env.project_id() # from app options?
|
|
64
|
+
self.project_info = self._api.project.get_info_by_id(self.project_id)
|
|
65
|
+
|
|
66
|
+
# 1. Project selection + Train/val split
|
|
67
|
+
self.input_selector = InputSelector(self.project_info, self.app_options)
|
|
68
|
+
# 2. Select train val splits
|
|
69
|
+
self.train_val_splits_selector = TrainValSplitsSelector(
|
|
70
|
+
self._api, self.project_id, self.app_options
|
|
71
|
+
)
|
|
72
|
+
# 3. Select classes
|
|
73
|
+
self.classes_selector = ClassesSelector(self.project_id, [], self.app_options)
|
|
74
|
+
# 4. Model selection
|
|
75
|
+
self.model_selector = ModelSelector(
|
|
76
|
+
self._api, self.framework_name, self.models, self.app_options
|
|
77
|
+
)
|
|
78
|
+
# 5. Training parameters (yaml), scheduler preview
|
|
79
|
+
self.hyperparameters_selector = HyperparametersSelector(
|
|
80
|
+
self.hyperparameters, self.app_options
|
|
81
|
+
)
|
|
82
|
+
# 6. Start Train
|
|
83
|
+
self.training_process = TrainingProcess(self.app_options)
|
|
84
|
+
|
|
85
|
+
# 7. Training logs
|
|
86
|
+
self.training_logs = TrainingLogs(self.app_options)
|
|
87
|
+
|
|
88
|
+
# Stepper layout
|
|
89
|
+
self.stepper = Stepper(
|
|
90
|
+
widgets=[
|
|
91
|
+
self.input_selector.card,
|
|
92
|
+
self.train_val_splits_selector.card,
|
|
93
|
+
self.classes_selector.card,
|
|
94
|
+
self.model_selector.card,
|
|
95
|
+
self.hyperparameters_selector.card,
|
|
96
|
+
self.training_process.card,
|
|
97
|
+
self.training_logs.card,
|
|
98
|
+
],
|
|
99
|
+
)
|
|
100
|
+
# ------------------------------------------------- #
|
|
101
|
+
|
|
102
|
+
# Button utils
|
|
103
|
+
def update_classes_table():
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
def disable_hyperparams_editor():
|
|
107
|
+
if self.hyperparameters_selector.editor.readonly:
|
|
108
|
+
self.hyperparameters_selector.editor.readonly = False
|
|
109
|
+
else:
|
|
110
|
+
self.hyperparameters_selector.editor.readonly = True
|
|
111
|
+
|
|
112
|
+
def set_experiment_name():
|
|
113
|
+
model_name = self.model_selector.get_model_name()
|
|
114
|
+
if model_name is None:
|
|
115
|
+
experiment_name = "Enter experiment name"
|
|
116
|
+
else:
|
|
117
|
+
experiment_name = f"{self.task_id}_{self.project_info.name}_{model_name}"
|
|
118
|
+
|
|
119
|
+
if experiment_name == self.training_process.get_experiment_name():
|
|
120
|
+
return
|
|
121
|
+
self.training_process.set_experiment_name(experiment_name)
|
|
122
|
+
|
|
123
|
+
# ------------------------------------------------- #
|
|
124
|
+
|
|
125
|
+
# Wrappers
|
|
126
|
+
self.training_process_cb = wrap_button_click(
|
|
127
|
+
button=self.hyperparameters_selector.button,
|
|
128
|
+
cards_to_unlock=[self.training_logs.card],
|
|
129
|
+
widgets_to_disable=self.training_process.widgets_to_disable,
|
|
130
|
+
callback=None,
|
|
131
|
+
validation_text=self.training_process.validator_text,
|
|
132
|
+
validation_func=self.training_process.validate_step,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.hyperparameters_selector_cb = wrap_button_click(
|
|
136
|
+
button=self.hyperparameters_selector.button,
|
|
137
|
+
cards_to_unlock=[self.training_process.card],
|
|
138
|
+
widgets_to_disable=self.hyperparameters_selector.widgets_to_disable,
|
|
139
|
+
callback=self.training_process_cb,
|
|
140
|
+
validation_text=self.hyperparameters_selector.validator_text,
|
|
141
|
+
validation_func=self.hyperparameters_selector.validate_step,
|
|
142
|
+
on_select_click=[disable_hyperparams_editor],
|
|
143
|
+
on_reselect_click=[disable_hyperparams_editor],
|
|
144
|
+
collapse_card=(self.hyperparameters_selector.card, self.collapsable),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
self.model_selector_cb = wrap_button_click(
|
|
148
|
+
button=self.model_selector.button,
|
|
149
|
+
cards_to_unlock=[self.hyperparameters_selector.card],
|
|
150
|
+
widgets_to_disable=self.model_selector.widgets_to_disable,
|
|
151
|
+
callback=self.hyperparameters_selector_cb,
|
|
152
|
+
validation_text=self.model_selector.validator_text,
|
|
153
|
+
validation_func=self.model_selector.validate_step,
|
|
154
|
+
on_select_click=[set_experiment_name],
|
|
155
|
+
collapse_card=(self.model_selector.card, self.collapsable),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.classes_selector_cb = wrap_button_click(
|
|
159
|
+
button=self.classes_selector.button,
|
|
160
|
+
cards_to_unlock=[self.model_selector.card],
|
|
161
|
+
widgets_to_disable=self.classes_selector.widgets_to_disable,
|
|
162
|
+
callback=self.model_selector_cb,
|
|
163
|
+
validation_text=self.classes_selector.validator_text,
|
|
164
|
+
validation_func=self.classes_selector.validate_step,
|
|
165
|
+
collapse_card=(self.classes_selector.card, self.collapsable),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.train_val_splits_selector_cb = wrap_button_click(
|
|
169
|
+
button=self.train_val_splits_selector.button,
|
|
170
|
+
cards_to_unlock=[self.classes_selector.card],
|
|
171
|
+
widgets_to_disable=self.train_val_splits_selector.widgets_to_disable,
|
|
172
|
+
callback=self.classes_selector_cb,
|
|
173
|
+
validation_text=self.train_val_splits_selector.validator_text,
|
|
174
|
+
validation_func=self.train_val_splits_selector.validate_step,
|
|
175
|
+
collapse_card=(self.train_val_splits_selector.card, self.collapsable),
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
self.input_selector_cb = wrap_button_click(
|
|
179
|
+
button=self.input_selector.button,
|
|
180
|
+
cards_to_unlock=[self.train_val_splits_selector.card],
|
|
181
|
+
widgets_to_disable=self.input_selector.widgets_to_disable,
|
|
182
|
+
callback=self.train_val_splits_selector_cb,
|
|
183
|
+
validation_text=self.input_selector.validator_text,
|
|
184
|
+
validation_func=self.input_selector.validate_step,
|
|
185
|
+
on_select_click=[update_classes_table],
|
|
186
|
+
collapse_card=(self.input_selector.card, self.collapsable),
|
|
187
|
+
)
|
|
188
|
+
# ------------------------------------------------- #
|
|
189
|
+
|
|
190
|
+
# Main Buttons
|
|
191
|
+
|
|
192
|
+
# Define outside. Used by user in app
|
|
193
|
+
# @self.training_process.start_button.click
|
|
194
|
+
# def start_training():
|
|
195
|
+
# pass
|
|
196
|
+
|
|
197
|
+
# @self.training_process.stop_button.click
|
|
198
|
+
# def stop_training():
|
|
199
|
+
# pass
|
|
200
|
+
|
|
201
|
+
# ------------------------------------------------- #
|
|
202
|
+
|
|
203
|
+
# Select Buttons
|
|
204
|
+
@self.hyperparameters_selector.button.click
|
|
205
|
+
def select_hyperparameters():
|
|
206
|
+
self.hyperparameters_selector_cb()
|
|
207
|
+
set_stepper_step(
|
|
208
|
+
self.stepper,
|
|
209
|
+
self.hyperparameters_selector.button,
|
|
210
|
+
next_pos=6,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
@self.model_selector.button.click
|
|
214
|
+
def select_model():
|
|
215
|
+
self.model_selector_cb()
|
|
216
|
+
set_stepper_step(
|
|
217
|
+
self.stepper,
|
|
218
|
+
self.model_selector.button,
|
|
219
|
+
next_pos=5,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
@self.classes_selector.button.click
|
|
223
|
+
def select_classes():
|
|
224
|
+
self.classes_selector_cb()
|
|
225
|
+
set_stepper_step(
|
|
226
|
+
self.stepper,
|
|
227
|
+
self.classes_selector.button,
|
|
228
|
+
next_pos=4,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@self.train_val_splits_selector.button.click
|
|
232
|
+
def select_train_val_splits():
|
|
233
|
+
self.train_val_splits_selector_cb()
|
|
234
|
+
set_stepper_step(
|
|
235
|
+
self.stepper,
|
|
236
|
+
self.train_val_splits_selector.button,
|
|
237
|
+
next_pos=3,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
@self.input_selector.button.click
|
|
241
|
+
def select_input():
|
|
242
|
+
self.input_selector_cb()
|
|
243
|
+
set_stepper_step(
|
|
244
|
+
self.stepper,
|
|
245
|
+
self.input_selector.button,
|
|
246
|
+
next_pos=2,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# ------------------------------------------------- #
|
|
250
|
+
|
|
251
|
+
# Other Buttons
|
|
252
|
+
if app_options.get("show_logs_in_gui", False):
|
|
253
|
+
|
|
254
|
+
@self.training_logs.logs_button.click
|
|
255
|
+
def show_logs():
|
|
256
|
+
self.training_logs.toggle_logs()
|
|
257
|
+
|
|
258
|
+
# Other handlers
|
|
259
|
+
@self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
|
|
260
|
+
def show_mb_speedtest(is_checked: bool):
|
|
261
|
+
self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
|
|
262
|
+
|
|
263
|
+
# ------------------------------------------------- #
|
|
264
|
+
|
|
265
|
+
self.layout: Widget = self.stepper
|
|
266
|
+
|
|
267
|
+
def enable_select_buttons(self):
|
|
268
|
+
"""
|
|
269
|
+
Makes all select buttons in the GUI available for interaction.
|
|
270
|
+
"""
|
|
271
|
+
self.input_selector.button.enable()
|
|
272
|
+
self.train_val_splits_selector.button.enable()
|
|
273
|
+
self.classes_selector.button.enable()
|
|
274
|
+
self.model_selector.button.enable()
|
|
275
|
+
self.hyperparameters_selector.button.enable()
|
|
276
|
+
|
|
277
|
+
def disable_select_buttons(self):
|
|
278
|
+
"""
|
|
279
|
+
Makes all select buttons in the GUI unavailable for interaction.
|
|
280
|
+
"""
|
|
281
|
+
self.input_selector.button.disable()
|
|
282
|
+
self.train_val_splits_selector.button.disable()
|
|
283
|
+
self.classes_selector.button.disable()
|
|
284
|
+
self.model_selector.button.disable()
|
|
285
|
+
self.hyperparameters_selector.button.disable()
|
|
286
|
+
|
|
287
|
+
# Set GUI from config
|
|
288
|
+
def validate_app_state(self, app_state: dict) -> dict:
|
|
289
|
+
"""
|
|
290
|
+
Validate the app state dictionary.
|
|
291
|
+
|
|
292
|
+
:param app_state: The app state dictionary.
|
|
293
|
+
:type app_state: dict
|
|
294
|
+
"""
|
|
295
|
+
if not isinstance(app_state, dict):
|
|
296
|
+
raise ValueError("app_state must be a dictionary")
|
|
297
|
+
|
|
298
|
+
required_keys = {
|
|
299
|
+
"input": ["project_id"],
|
|
300
|
+
"train_val_split": ["method"],
|
|
301
|
+
"classes": list,
|
|
302
|
+
"model": ["source"],
|
|
303
|
+
"hyperparameters": (dict, str), # Allowing dict or str for hyperparameters
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
for key, subkeys_or_type in required_keys.items():
|
|
307
|
+
if key not in app_state:
|
|
308
|
+
raise KeyError(f"Missing required key in app_state: {key}")
|
|
309
|
+
|
|
310
|
+
if isinstance(subkeys_or_type, list):
|
|
311
|
+
for subkey in subkeys_or_type:
|
|
312
|
+
if subkey not in app_state[key]:
|
|
313
|
+
raise KeyError(f"Missing required key in app_state['{key}']: {subkey}")
|
|
314
|
+
elif not isinstance(app_state[key], subkeys_or_type):
|
|
315
|
+
valid_types = (
|
|
316
|
+
" or ".join([t.__name__ for t in subkeys_or_type])
|
|
317
|
+
if isinstance(subkeys_or_type, tuple)
|
|
318
|
+
else subkeys_or_type.__name__
|
|
319
|
+
)
|
|
320
|
+
raise ValueError(f"app_state['{key}'] must be of type {valid_types}")
|
|
321
|
+
|
|
322
|
+
model = app_state["model"]
|
|
323
|
+
if model["source"] == "Pretrained models":
|
|
324
|
+
if "model_name" not in model:
|
|
325
|
+
raise KeyError("Missing required key in app_state['model']: model_name")
|
|
326
|
+
elif model["source"] == "Custom models":
|
|
327
|
+
custom_keys = ["task_id", "checkpoint"]
|
|
328
|
+
for key in custom_keys:
|
|
329
|
+
if key not in model:
|
|
330
|
+
raise KeyError(f"Missing required key in app_state['model']: {key}")
|
|
331
|
+
|
|
332
|
+
options = app_state.setdefault(
|
|
333
|
+
"options",
|
|
334
|
+
{
|
|
335
|
+
"model_benchmark": {
|
|
336
|
+
"enable": True,
|
|
337
|
+
"speed_test": True,
|
|
338
|
+
},
|
|
339
|
+
"cache_project": True,
|
|
340
|
+
},
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if not isinstance(options, dict):
|
|
344
|
+
raise ValueError("app_state['options'] must be a dictionary")
|
|
345
|
+
|
|
346
|
+
model_benchmark = options.setdefault(
|
|
347
|
+
"model_benchmark", {"enable": True, "speed_test": True}
|
|
348
|
+
)
|
|
349
|
+
if not isinstance(model_benchmark, dict):
|
|
350
|
+
raise ValueError("app_state['options']['model_benchmark'] must be a dictionary")
|
|
351
|
+
model_benchmark.setdefault("enable", True)
|
|
352
|
+
model_benchmark.setdefault("speed_test", True)
|
|
353
|
+
|
|
354
|
+
if not isinstance(options.get("cache_project"), bool):
|
|
355
|
+
raise ValueError("app_state['options']['cache_project'] must be a boolean")
|
|
356
|
+
|
|
357
|
+
# Check train val splits
|
|
358
|
+
train_val_splits_settings = app_state.get("train_val_split")
|
|
359
|
+
if train_val_splits_settings.get("method") == "datasets":
|
|
360
|
+
dataset_ids = []
|
|
361
|
+
for parents, dataset in self._api.dataset.tree(self.project_id):
|
|
362
|
+
dataset_ids.append(dataset.id)
|
|
363
|
+
|
|
364
|
+
train_datasets = train_val_splits_settings.get("train_datasets", [])
|
|
365
|
+
val_datasets = train_val_splits_settings.get("val_datasets", [])
|
|
366
|
+
|
|
367
|
+
missing_datasets_ids = []
|
|
368
|
+
for ds_id in train_datasets + val_datasets:
|
|
369
|
+
if ds_id not in dataset_ids:
|
|
370
|
+
missing_datasets_ids.append(ds_id)
|
|
371
|
+
|
|
372
|
+
if len(missing_datasets_ids) > 0:
|
|
373
|
+
missing_datasets_text = ", ".join([str(ds_id) for ds_id in missing_datasets_ids])
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"Datasets with ids: {missing_datasets_text} not found in the project"
|
|
376
|
+
)
|
|
377
|
+
elif train_val_splits_settings.get("method") == "tags":
|
|
378
|
+
train_tag = train_val_splits_settings.get("train_tag")
|
|
379
|
+
val_tag = train_val_splits_settings.get("val_tag")
|
|
380
|
+
if not train_tag or not val_tag:
|
|
381
|
+
raise ValueError("train_tag and val_tag must be specified in tags split method")
|
|
382
|
+
elif train_val_splits_settings.get("method") == "random":
|
|
383
|
+
split = train_val_splits_settings.get("split")
|
|
384
|
+
percent = train_val_splits_settings.get("percent")
|
|
385
|
+
if split not in ["train", "val"]:
|
|
386
|
+
raise ValueError("split must be 'train' or 'val'")
|
|
387
|
+
if not isinstance(percent, int) or not 0 < percent < 100:
|
|
388
|
+
raise ValueError("percent must be an integer in range 1 to 99")
|
|
389
|
+
return app_state
|
|
390
|
+
|
|
391
|
+
def load_from_app_state(self, app_state: dict) -> None:
|
|
392
|
+
"""
|
|
393
|
+
Load the GUI state from app state dictionary.
|
|
394
|
+
|
|
395
|
+
:param app_state: The state dictionary.
|
|
396
|
+
:type app_state: dict
|
|
397
|
+
|
|
398
|
+
app_state example:
|
|
399
|
+
|
|
400
|
+
app_state = {
|
|
401
|
+
"input": {"project_id": 43192},
|
|
402
|
+
"train_val_splits": {
|
|
403
|
+
"method": "random",
|
|
404
|
+
"split": "train",
|
|
405
|
+
"percent": 90
|
|
406
|
+
},
|
|
407
|
+
"classes": ["apple"],
|
|
408
|
+
"model": {
|
|
409
|
+
"source": "Pretrained models",
|
|
410
|
+
"model_name": "rtdetr_r50vd_coco_objects365"
|
|
411
|
+
},
|
|
412
|
+
"hyperparameters": hyperparameters, # yaml string
|
|
413
|
+
"options": {
|
|
414
|
+
"model_benchmark": {
|
|
415
|
+
"enable": True,
|
|
416
|
+
"speed_test": True
|
|
417
|
+
},
|
|
418
|
+
"cache_project": True
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
"""
|
|
422
|
+
app_state = self.validate_app_state(app_state)
|
|
423
|
+
|
|
424
|
+
options = app_state["options"]
|
|
425
|
+
input_settings = app_state["input"]
|
|
426
|
+
train_val_splits_settings = app_state["train_val_split"]
|
|
427
|
+
classes_settings = app_state["classes"]
|
|
428
|
+
model_settings = app_state["model"]
|
|
429
|
+
hyperparameters_settings = app_state["hyperparameters"]
|
|
430
|
+
|
|
431
|
+
self._init_input(input_settings, options)
|
|
432
|
+
self._init_classes(classes_settings)
|
|
433
|
+
self._init_train_val_splits(train_val_splits_settings)
|
|
434
|
+
self._init_model(model_settings)
|
|
435
|
+
self._init_hyperparameters(hyperparameters_settings, options)
|
|
436
|
+
|
|
437
|
+
def _init_input(self, input_settings: dict, options: dict) -> None:
|
|
438
|
+
"""
|
|
439
|
+
Initialize the input selector with the given settings.
|
|
440
|
+
|
|
441
|
+
:param input_settings: The input settings.
|
|
442
|
+
:type input_settings: dict
|
|
443
|
+
:param options: The application options.
|
|
444
|
+
:type options: dict
|
|
445
|
+
"""
|
|
446
|
+
# Set Input
|
|
447
|
+
self.input_selector.set_cache(options["cache_project"])
|
|
448
|
+
self.input_selector_cb()
|
|
449
|
+
# ----------------------------------------- #
|
|
450
|
+
|
|
451
|
+
def _init_train_val_splits(self, train_val_splits_settings: dict) -> None:
|
|
452
|
+
"""
|
|
453
|
+
Initialize the train/val splits selector with the given settings.
|
|
454
|
+
|
|
455
|
+
:param train_val_splits_settings: The train/val splits settings.
|
|
456
|
+
:type train_val_splits_settings: dict
|
|
457
|
+
"""
|
|
458
|
+
split_method = train_val_splits_settings["method"]
|
|
459
|
+
if split_method == "random":
|
|
460
|
+
split = train_val_splits_settings["split"]
|
|
461
|
+
percent = train_val_splits_settings["percent"]
|
|
462
|
+
self.train_val_splits_selector.train_val_splits.set_random_splits(split, percent)
|
|
463
|
+
elif split_method == "tags":
|
|
464
|
+
train_tag = train_val_splits_settings["train_tag"]
|
|
465
|
+
val_tag = train_val_splits_settings["val_tag"]
|
|
466
|
+
untagged_action = train_val_splits_settings["untagged_action"]
|
|
467
|
+
self.train_val_splits_selector.train_val_splits.set_tags_splits(
|
|
468
|
+
train_tag, val_tag, untagged_action
|
|
469
|
+
)
|
|
470
|
+
elif split_method == "datasets":
|
|
471
|
+
train_datasets = train_val_splits_settings["train_datasets"]
|
|
472
|
+
val_datasets = train_val_splits_settings["val_datasets"]
|
|
473
|
+
self.train_val_splits_selector.train_val_splits.set_datasets_splits(
|
|
474
|
+
train_datasets, val_datasets
|
|
475
|
+
)
|
|
476
|
+
self.train_val_splits_selector_cb()
|
|
477
|
+
|
|
478
|
+
def _init_classes(self, classes_settings: list) -> None:
|
|
479
|
+
"""
|
|
480
|
+
Initialize the classes selector with the given settings.
|
|
481
|
+
|
|
482
|
+
:param classes_settings: The classes settings.
|
|
483
|
+
:type classes_settings: list
|
|
484
|
+
"""
|
|
485
|
+
# Set Classes
|
|
486
|
+
self.classes_selector.set_classes(classes_settings)
|
|
487
|
+
self.classes_selector_cb()
|
|
488
|
+
# ----------------------------------------- #
|
|
489
|
+
|
|
490
|
+
def _init_model(self, model_settings: dict) -> None:
|
|
491
|
+
"""
|
|
492
|
+
Initialize the model selector with the given settings.
|
|
493
|
+
|
|
494
|
+
:param model_settings: The model settings.
|
|
495
|
+
:type model_settings: dict
|
|
496
|
+
"""
|
|
497
|
+
|
|
498
|
+
# Pretrained
|
|
499
|
+
if model_settings["source"] == ModelSource.PRETRAINED:
|
|
500
|
+
self.model_selector.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
501
|
+
self.model_selector.pretrained_models_table.set_by_model_name(
|
|
502
|
+
model_settings["model_name"]
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Custom
|
|
506
|
+
elif model_settings["source"] == ModelSource.CUSTOM:
|
|
507
|
+
self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
508
|
+
self.model_selector.experiment_selector.set_by_task_id(model_settings["task_id"])
|
|
509
|
+
active_row = self.model_selector.experiment_selector.get_selected_row()
|
|
510
|
+
if model_settings["checkpoint"] not in active_row.checkpoints_names:
|
|
511
|
+
raise ValueError(
|
|
512
|
+
f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
active_row.set_selected_checkpoint_by_name(model_settings["checkpoint"])
|
|
516
|
+
self.model_selector_cb()
|
|
517
|
+
# ----------------------------------------- #
|
|
518
|
+
|
|
519
|
+
def _init_hyperparameters(self, hyperparameters_settings: dict, options: dict) -> None:
|
|
520
|
+
"""
|
|
521
|
+
Initialize the hyperparameters selector with the given settings.
|
|
522
|
+
|
|
523
|
+
:param hyperparameters_settings: The hyperparameters settings.
|
|
524
|
+
:type hyperparameters_settings: dict
|
|
525
|
+
:param options: The application options.
|
|
526
|
+
:type options: dict
|
|
527
|
+
"""
|
|
528
|
+
self.hyperparameters_selector.set_hyperparameters(hyperparameters_settings)
|
|
529
|
+
|
|
530
|
+
model_benchmark_settings = options["model_benchmark"]
|
|
531
|
+
self.hyperparameters_selector.set_model_benchmark_checkbox_value(
|
|
532
|
+
model_benchmark_settings["enable"]
|
|
533
|
+
)
|
|
534
|
+
self.hyperparameters_selector.set_speedtest_checkbox_value(
|
|
535
|
+
model_benchmark_settings["speed_test"]
|
|
536
|
+
)
|
|
537
|
+
self.hyperparameters_selector_cb()
|
|
538
|
+
|
|
539
|
+
# ----------------------------------------- #
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from supervisely.app.widgets import (
|
|
4
|
+
Button,
|
|
5
|
+
Card,
|
|
6
|
+
Checkbox,
|
|
7
|
+
Container,
|
|
8
|
+
Editor,
|
|
9
|
+
Field,
|
|
10
|
+
Text,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HyperparametersSelector:
|
|
15
|
+
title = "Hyperparameters"
|
|
16
|
+
description = "Set hyperparameters for training"
|
|
17
|
+
lock_message = "Select model to unlock"
|
|
18
|
+
|
|
19
|
+
def __init__(self, hyperparameters: dict, app_options: dict = {}):
|
|
20
|
+
self.app_options = app_options
|
|
21
|
+
self.editor = Editor(
|
|
22
|
+
hyperparameters, height_lines=50, language_mode="yaml", auto_format=True
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Model Benchmark
|
|
26
|
+
self.run_model_benchmark_checkbox = Checkbox(
|
|
27
|
+
content="Run Model Benchmark evaluation", checked=True
|
|
28
|
+
)
|
|
29
|
+
self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=True)
|
|
30
|
+
|
|
31
|
+
self.model_benchmark_field = Field(
|
|
32
|
+
Container(
|
|
33
|
+
widgets=[
|
|
34
|
+
self.run_model_benchmark_checkbox,
|
|
35
|
+
self.run_speedtest_checkbox,
|
|
36
|
+
]
|
|
37
|
+
),
|
|
38
|
+
title="Model Evaluation Benchmark",
|
|
39
|
+
description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
|
|
40
|
+
)
|
|
41
|
+
docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
|
|
42
|
+
self.model_benchmark_learn_more = Text(
|
|
43
|
+
f"Learn more about Model Benchmark in the {docs_link}.", status="info"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if app_options.get("model_benchmark", True):
|
|
47
|
+
self.model_benchmark_field.show()
|
|
48
|
+
self.model_benchmark_learn_more.show()
|
|
49
|
+
else:
|
|
50
|
+
self.model_benchmark_field.hide()
|
|
51
|
+
self.model_benchmark_learn_more.hide()
|
|
52
|
+
|
|
53
|
+
self.validator_text = Text("")
|
|
54
|
+
self.validator_text.hide()
|
|
55
|
+
self.button = Button("Select")
|
|
56
|
+
container = Container(
|
|
57
|
+
[
|
|
58
|
+
self.editor,
|
|
59
|
+
self.model_benchmark_field,
|
|
60
|
+
self.model_benchmark_learn_more,
|
|
61
|
+
self.validator_text,
|
|
62
|
+
self.button,
|
|
63
|
+
]
|
|
64
|
+
)
|
|
65
|
+
self.card = Card(
|
|
66
|
+
title=self.title,
|
|
67
|
+
description=self.description,
|
|
68
|
+
content=container,
|
|
69
|
+
lock_message=self.lock_message,
|
|
70
|
+
collapsable=app_options.get("collapsable", False),
|
|
71
|
+
)
|
|
72
|
+
self.card.lock()
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def widgets_to_disable(self) -> list:
|
|
76
|
+
return [
|
|
77
|
+
self.editor,
|
|
78
|
+
self.run_model_benchmark_checkbox,
|
|
79
|
+
self.run_speedtest_checkbox,
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
def set_hyperparameters(self, hyperparameters: Union[str, dict]) -> None:
|
|
83
|
+
self.editor.set_text(hyperparameters)
|
|
84
|
+
|
|
85
|
+
def get_hyperparameters(self) -> dict:
|
|
86
|
+
return self.editor.get_value()
|
|
87
|
+
|
|
88
|
+
def get_model_benchmark_checkbox_value(self) -> bool:
|
|
89
|
+
if self.app_options.get("model_benchmark", True):
|
|
90
|
+
return self.run_model_benchmark_checkbox.is_checked()
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
def set_model_benchmark_checkbox_value(self, is_checked: bool) -> bool:
|
|
94
|
+
if is_checked:
|
|
95
|
+
self.run_model_benchmark_checkbox.check()
|
|
96
|
+
else:
|
|
97
|
+
self.run_model_benchmark_checkbox.uncheck()
|
|
98
|
+
|
|
99
|
+
def get_speedtest_checkbox_value(self) -> bool:
|
|
100
|
+
if self.app_options.get("model_benchmark", True):
|
|
101
|
+
return self.run_speedtest_checkbox.is_checked()
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
def set_speedtest_checkbox_value(self, is_checked: bool) -> bool:
|
|
105
|
+
if is_checked:
|
|
106
|
+
self.run_speedtest_checkbox.check()
|
|
107
|
+
else:
|
|
108
|
+
self.run_speedtest_checkbox.uncheck()
|
|
109
|
+
|
|
110
|
+
def toggle_mb_speedtest(self, is_checked: bool) -> None:
|
|
111
|
+
if is_checked:
|
|
112
|
+
self.run_speedtest_checkbox.show()
|
|
113
|
+
else:
|
|
114
|
+
self.run_speedtest_checkbox.hide()
|
|
115
|
+
|
|
116
|
+
def validate_step(self) -> bool:
|
|
117
|
+
return True
|