supervisely 6.73.249__py3-none-any.whl → 6.73.251__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/api.py +19 -3
- supervisely/app/widgets/experiment_selector/experiment_selector.py +16 -8
- supervisely/nn/benchmark/base_benchmark.py +17 -2
- supervisely/nn/benchmark/base_evaluator.py +28 -6
- supervisely/nn/benchmark/instance_segmentation/benchmark.py +1 -1
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +14 -0
- supervisely/nn/benchmark/object_detection/benchmark.py +1 -1
- supervisely/nn/benchmark/object_detection/evaluator.py +43 -13
- supervisely/nn/benchmark/object_detection/metric_provider.py +7 -0
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +33 -7
- supervisely/nn/benchmark/utils/detection/utlis.py +6 -4
- supervisely/nn/experiments.py +23 -16
- supervisely/nn/inference/gui/serving_gui_template.py +2 -35
- supervisely/nn/inference/inference.py +71 -8
- supervisely/nn/training/__init__.py +2 -0
- supervisely/nn/training/gui/classes_selector.py +14 -14
- supervisely/nn/training/gui/gui.py +28 -13
- supervisely/nn/training/gui/hyperparameters_selector.py +90 -41
- supervisely/nn/training/gui/input_selector.py +8 -6
- supervisely/nn/training/gui/model_selector.py +7 -5
- supervisely/nn/training/gui/train_val_splits_selector.py +8 -9
- supervisely/nn/training/gui/training_logs.py +17 -17
- supervisely/nn/training/gui/training_process.py +41 -36
- supervisely/nn/training/loggers/__init__.py +22 -0
- supervisely/nn/training/loggers/base_train_logger.py +8 -5
- supervisely/nn/training/loggers/tensorboard_logger.py +4 -11
- supervisely/nn/training/train_app.py +276 -90
- supervisely/project/project.py +6 -0
- {supervisely-6.73.249.dist-info → supervisely-6.73.251.dist-info}/METADATA +8 -3
- {supervisely-6.73.249.dist-info → supervisely-6.73.251.dist-info}/RECORD +34 -34
- {supervisely-6.73.249.dist-info → supervisely-6.73.251.dist-info}/LICENSE +0 -0
- {supervisely-6.73.249.dist-info → supervisely-6.73.251.dist-info}/WHEEL +0 -0
- {supervisely-6.73.249.dist-info → supervisely-6.73.251.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.249.dist-info → supervisely-6.73.251.dist-info}/top_level.txt +0 -0
|
@@ -90,7 +90,10 @@ class Inference:
|
|
|
90
90
|
"""Path to file with list of models"""
|
|
91
91
|
APP_OPTIONS: str = None
|
|
92
92
|
"""Path to file with app options"""
|
|
93
|
-
DEFAULT_BATCH_SIZE = 16
|
|
93
|
+
DEFAULT_BATCH_SIZE: str = 16
|
|
94
|
+
"""Default batch size for inference"""
|
|
95
|
+
INFERENCE_SETTINGS: str = None
|
|
96
|
+
"""Path to file with custom inference settings"""
|
|
94
97
|
|
|
95
98
|
def __init__(
|
|
96
99
|
self,
|
|
@@ -125,8 +128,12 @@ class Inference:
|
|
|
125
128
|
self._autostart_delay_time = 5 * 60 # 5 min
|
|
126
129
|
self._tracker = None
|
|
127
130
|
self._hardware: str = None
|
|
131
|
+
self.pretrained_models = self._load_models_json(self.MODELS) if self.MODELS else None
|
|
128
132
|
if custom_inference_settings is None:
|
|
129
|
-
|
|
133
|
+
if self.INFERENCE_SETTINGS is not None:
|
|
134
|
+
custom_inference_settings = self.INFERENCE_SETTINGS
|
|
135
|
+
else:
|
|
136
|
+
custom_inference_settings = {}
|
|
130
137
|
if isinstance(custom_inference_settings, str):
|
|
131
138
|
if fs.file_exists(custom_inference_settings):
|
|
132
139
|
with open(custom_inference_settings, "r") as f:
|
|
@@ -153,7 +160,7 @@ class Inference:
|
|
|
153
160
|
if self.FRAMEWORK_NAME is None:
|
|
154
161
|
raise ValueError("FRAMEWORK_NAME is not defined")
|
|
155
162
|
self._gui = GUI.ServingGUITemplate(
|
|
156
|
-
self.FRAMEWORK_NAME, self.
|
|
163
|
+
self.FRAMEWORK_NAME, self.pretrained_models, self.APP_OPTIONS
|
|
157
164
|
)
|
|
158
165
|
self._user_layout = self._gui.widgets
|
|
159
166
|
self._user_layout_card = self._gui.card
|
|
@@ -239,6 +246,38 @@ class Inference:
|
|
|
239
246
|
)
|
|
240
247
|
device = "cpu"
|
|
241
248
|
|
|
249
|
+
def _load_models_json(self, models: str) -> List[Dict[str, Any]]:
|
|
250
|
+
"""
|
|
251
|
+
Loads models from the provided file or list of model configurations.
|
|
252
|
+
"""
|
|
253
|
+
if isinstance(models, str):
|
|
254
|
+
if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
|
|
255
|
+
models = sly_json.load_json_file(models)
|
|
256
|
+
else:
|
|
257
|
+
raise ValueError("File not found or invalid file format.")
|
|
258
|
+
else:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"Invalid models file. Please provide a valid '.json' file with list of model configurations."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if not isinstance(models, list):
|
|
264
|
+
raise ValueError("models parameters must be a list of dicts")
|
|
265
|
+
for item in models:
|
|
266
|
+
if not isinstance(item, dict):
|
|
267
|
+
raise ValueError(f"Each item in models must be a dict.")
|
|
268
|
+
model_meta = item.get("meta")
|
|
269
|
+
if model_meta is None:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
"Model metadata not found. Please update provided models parameter to include key 'meta'."
|
|
272
|
+
)
|
|
273
|
+
model_files = model_meta.get("model_files")
|
|
274
|
+
if model_files is None:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
"Model files not found in model metadata. "
|
|
277
|
+
"Please update provided models oarameter to include key 'model_files' in 'meta' key."
|
|
278
|
+
)
|
|
279
|
+
return models
|
|
280
|
+
|
|
242
281
|
def get_ui(self) -> Widget:
|
|
243
282
|
if not self._use_gui:
|
|
244
283
|
return None
|
|
@@ -487,6 +526,9 @@ class Inference:
|
|
|
487
526
|
def load_model_meta(self, model_tab: str, local_weights_path: str):
|
|
488
527
|
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
489
528
|
|
|
529
|
+
def _checkpoints_cache_dir(self):
|
|
530
|
+
return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
|
|
531
|
+
|
|
490
532
|
def _download_model_files(self, model_source: str, model_files: List[str]) -> dict:
|
|
491
533
|
if model_source == ModelSource.PRETRAINED:
|
|
492
534
|
return self._download_pretrained_model(model_files)
|
|
@@ -498,17 +540,28 @@ class Inference:
|
|
|
498
540
|
Downloads the pretrained model data.
|
|
499
541
|
"""
|
|
500
542
|
local_model_files = {}
|
|
543
|
+
cache_dir = self._checkpoints_cache_dir()
|
|
501
544
|
|
|
502
545
|
for file in model_files:
|
|
503
546
|
file_url = model_files[file]
|
|
504
|
-
|
|
547
|
+
file_name = sly_fs.get_file_name_with_ext(file_url)
|
|
505
548
|
if file_url.startswith("http"):
|
|
506
549
|
with urlopen(file_url) as f:
|
|
507
550
|
file_size = f.length
|
|
508
551
|
file_name = get_filename_from_headers(file_url)
|
|
509
|
-
file_path = os.path.join(self.model_dir, file_name)
|
|
510
552
|
if file_name is None:
|
|
511
553
|
file_name = file
|
|
554
|
+
file_path = os.path.join(self.model_dir, file_name)
|
|
555
|
+
cached_path = os.path.join(cache_dir, file_name)
|
|
556
|
+
if os.path.exists(cached_path):
|
|
557
|
+
local_model_files[file] = cached_path
|
|
558
|
+
logger.debug(f"Model: '{file_name}' was found in checkpoint cache")
|
|
559
|
+
continue
|
|
560
|
+
if os.path.exists(file_path):
|
|
561
|
+
local_model_files[file] = file_path
|
|
562
|
+
logger.debug(f"Model: '{file_name}' was found in model dir")
|
|
563
|
+
continue
|
|
564
|
+
|
|
512
565
|
with self.gui.download_progress(
|
|
513
566
|
message=f"Downloading: '{file_name}'",
|
|
514
567
|
total=file_size,
|
|
@@ -614,13 +667,23 @@ class Inference:
|
|
|
614
667
|
model_files = deploy_params.get("model_files", {})
|
|
615
668
|
if model_info:
|
|
616
669
|
checkpoint_name = os.path.basename(model_files.get("checkpoint"))
|
|
670
|
+
checkpoint_file_path = os.path.join(
|
|
671
|
+
model_info.get("artifacts_dir"), "checkpoints", checkpoint_name
|
|
672
|
+
)
|
|
673
|
+
checkpoint_file_info = self.api.file.get_info_by_path(
|
|
674
|
+
env.team_id(), checkpoint_file_path
|
|
675
|
+
)
|
|
676
|
+
if checkpoint_file_info is None:
|
|
677
|
+
checkpoint_url = None
|
|
678
|
+
else:
|
|
679
|
+
checkpoint_url = self.api.file.get_url(checkpoint_file_info.id)
|
|
680
|
+
|
|
617
681
|
self.checkpoint_info = CheckpointInfo(
|
|
618
682
|
checkpoint_name=checkpoint_name,
|
|
619
683
|
model_name=model_info.get("model_name"),
|
|
620
684
|
architecture=model_info.get("framework_name"),
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
),
|
|
685
|
+
checkpoint_url=checkpoint_url,
|
|
686
|
+
custom_checkpoint_path=checkpoint_file_path,
|
|
624
687
|
model_source=ModelSource.CUSTOM,
|
|
625
688
|
)
|
|
626
689
|
|
|
@@ -11,36 +11,36 @@ class ClassesSelector:
|
|
|
11
11
|
lock_message = "Select training and validation splits to unlock"
|
|
12
12
|
|
|
13
13
|
def __init__(self, project_id: int, classes: list, app_options: dict = {}):
|
|
14
|
-
self.
|
|
15
|
-
if len(classes) > 0:
|
|
16
|
-
self.classes_table.select_classes(classes) # from app options
|
|
17
|
-
else:
|
|
18
|
-
self.classes_table.select_all()
|
|
14
|
+
self.display_widgets = []
|
|
19
15
|
|
|
16
|
+
# GUI Components
|
|
20
17
|
if is_development() or is_debug_with_sly_net():
|
|
21
18
|
qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
|
|
22
19
|
else:
|
|
23
20
|
qa_stats_link = f"/projects/{project_id}/stats/datasets"
|
|
24
|
-
|
|
25
21
|
qa_stats_text = Text(
|
|
26
22
|
text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
|
|
27
23
|
)
|
|
28
24
|
|
|
25
|
+
self.classes_table = ClassesTable(project_id=project_id)
|
|
26
|
+
if len(classes) > 0:
|
|
27
|
+
self.classes_table.select_classes(classes)
|
|
28
|
+
else:
|
|
29
|
+
self.classes_table.select_all()
|
|
30
|
+
|
|
29
31
|
self.validator_text = Text("")
|
|
30
32
|
self.validator_text.hide()
|
|
31
33
|
self.button = Button("Select")
|
|
32
|
-
|
|
33
|
-
[
|
|
34
|
-
qa_stats_text,
|
|
35
|
-
self.classes_table,
|
|
36
|
-
self.validator_text,
|
|
37
|
-
self.button,
|
|
38
|
-
]
|
|
34
|
+
self.display_widgets.extend(
|
|
35
|
+
[qa_stats_text, self.classes_table, self.validator_text, self.button]
|
|
39
36
|
)
|
|
37
|
+
# -------------------------------- #
|
|
38
|
+
|
|
39
|
+
self.container = Container(self.display_widgets)
|
|
40
40
|
self.card = Card(
|
|
41
41
|
title=self.title,
|
|
42
42
|
description=self.description,
|
|
43
|
-
content=container,
|
|
43
|
+
content=self.container,
|
|
44
44
|
lock_message=self.lock_message,
|
|
45
45
|
collapsable=app_options.get("collapsable", False),
|
|
46
46
|
)
|
|
@@ -6,7 +6,7 @@ training workflows in Supervisely.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import supervisely.io.env as sly_env
|
|
9
|
-
from supervisely import Api
|
|
9
|
+
from supervisely import Api, ProjectMeta
|
|
10
10
|
from supervisely._utils import is_production
|
|
11
11
|
from supervisely.app.widgets import Stepper, Widget
|
|
12
12
|
from supervisely.nn.training.gui.classes_selector import ClassesSelector
|
|
@@ -17,7 +17,7 @@ from supervisely.nn.training.gui.train_val_splits_selector import TrainValSplits
|
|
|
17
17
|
from supervisely.nn.training.gui.training_logs import TrainingLogs
|
|
18
18
|
from supervisely.nn.training.gui.training_process import TrainingProcess
|
|
19
19
|
from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
|
|
20
|
-
from supervisely.nn.utils import ModelSource
|
|
20
|
+
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class TrainGUI:
|
|
@@ -62,6 +62,7 @@ class TrainGUI:
|
|
|
62
62
|
self.workspace_id = sly_env.workspace_id()
|
|
63
63
|
self.project_id = sly_env.project_id() # from app options?
|
|
64
64
|
self.project_info = self._api.project.get_info_by_id(self.project_id)
|
|
65
|
+
self.project_meta = ProjectMeta.from_json(self._api.project.get_meta(self.project_id))
|
|
65
66
|
|
|
66
67
|
# 1. Project selection + Train/val split
|
|
67
68
|
self.input_selector = InputSelector(self.project_info, self.app_options)
|
|
@@ -399,7 +400,7 @@ class TrainGUI:
|
|
|
399
400
|
|
|
400
401
|
app_state = {
|
|
401
402
|
"input": {"project_id": 43192},
|
|
402
|
-
"
|
|
403
|
+
"train_val_split": {
|
|
403
404
|
"method": "random",
|
|
404
405
|
"split": "train",
|
|
405
406
|
"percent": 90
|
|
@@ -415,13 +416,18 @@ class TrainGUI:
|
|
|
415
416
|
"enable": True,
|
|
416
417
|
"speed_test": True
|
|
417
418
|
},
|
|
418
|
-
"cache_project": True
|
|
419
|
+
"cache_project": True,
|
|
420
|
+
"export": {
|
|
421
|
+
"enable": True,
|
|
422
|
+
"ONNXRuntime": True,
|
|
423
|
+
"TensorRT": True
|
|
424
|
+
},
|
|
419
425
|
}
|
|
420
426
|
}
|
|
421
427
|
"""
|
|
422
428
|
app_state = self.validate_app_state(app_state)
|
|
423
429
|
|
|
424
|
-
options = app_state
|
|
430
|
+
options = app_state.get("options", {})
|
|
425
431
|
input_settings = app_state["input"]
|
|
426
432
|
train_val_splits_settings = app_state["train_val_split"]
|
|
427
433
|
classes_settings = app_state["classes"]
|
|
@@ -444,7 +450,7 @@ class TrainGUI:
|
|
|
444
450
|
:type options: dict
|
|
445
451
|
"""
|
|
446
452
|
# Set Input
|
|
447
|
-
self.input_selector.set_cache(options
|
|
453
|
+
self.input_selector.set_cache(options.get("cache_project", True))
|
|
448
454
|
self.input_selector_cb()
|
|
449
455
|
# ----------------------------------------- #
|
|
450
456
|
|
|
@@ -527,13 +533,22 @@ class TrainGUI:
|
|
|
527
533
|
"""
|
|
528
534
|
self.hyperparameters_selector.set_hyperparameters(hyperparameters_settings)
|
|
529
535
|
|
|
530
|
-
model_benchmark_settings = options
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
536
|
+
model_benchmark_settings = options.get("model_benchmark", None)
|
|
537
|
+
if model_benchmark_settings is not None:
|
|
538
|
+
self.hyperparameters_selector.set_model_benchmark_checkbox_value(
|
|
539
|
+
model_benchmark_settings["enable"]
|
|
540
|
+
)
|
|
541
|
+
self.hyperparameters_selector.set_speedtest_checkbox_value(
|
|
542
|
+
model_benchmark_settings["speed_test"]
|
|
543
|
+
)
|
|
544
|
+
export_weights_settings = options.get("export", None)
|
|
545
|
+
if export_weights_settings is not None:
|
|
546
|
+
self.hyperparameters_selector.set_export_onnx_checkbox_value(
|
|
547
|
+
export_weights_settings.get(RuntimeType.ONNXRUNTIME, False)
|
|
548
|
+
)
|
|
549
|
+
self.hyperparameters_selector.set_export_tensorrt_checkbox_value(
|
|
550
|
+
export_weights_settings.get(RuntimeType.TENSORRT, False)
|
|
551
|
+
)
|
|
537
552
|
self.hyperparameters_selector_cb()
|
|
538
553
|
|
|
539
554
|
# ----------------------------------------- #
|
|
@@ -9,6 +9,7 @@ from supervisely.app.widgets import (
|
|
|
9
9
|
Field,
|
|
10
10
|
Text,
|
|
11
11
|
)
|
|
12
|
+
from supervisely.nn.utils import RuntimeType
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class HyperparametersSelector:
|
|
@@ -17,55 +18,75 @@ class HyperparametersSelector:
|
|
|
17
18
|
lock_message = "Select model to unlock"
|
|
18
19
|
|
|
19
20
|
def __init__(self, hyperparameters: dict, app_options: dict = {}):
|
|
21
|
+
self.display_widgets = []
|
|
20
22
|
self.app_options = app_options
|
|
23
|
+
|
|
24
|
+
# GUI Components
|
|
21
25
|
self.editor = Editor(
|
|
22
26
|
hyperparameters, height_lines=50, language_mode="yaml", auto_format=True
|
|
23
27
|
)
|
|
28
|
+
self.display_widgets.extend([self.editor])
|
|
24
29
|
|
|
25
|
-
# Model Benchmark
|
|
26
|
-
self.run_model_benchmark_checkbox = Checkbox(
|
|
27
|
-
content="Run Model Benchmark evaluation", checked=True
|
|
28
|
-
)
|
|
29
|
-
self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=True)
|
|
30
|
-
|
|
31
|
-
self.model_benchmark_field = Field(
|
|
32
|
-
Container(
|
|
33
|
-
widgets=[
|
|
34
|
-
self.run_model_benchmark_checkbox,
|
|
35
|
-
self.run_speedtest_checkbox,
|
|
36
|
-
]
|
|
37
|
-
),
|
|
38
|
-
title="Model Evaluation Benchmark",
|
|
39
|
-
description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
|
|
40
|
-
)
|
|
41
|
-
docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
|
|
42
|
-
self.model_benchmark_learn_more = Text(
|
|
43
|
-
f"Learn more about Model Benchmark in the {docs_link}.", status="info"
|
|
44
|
-
)
|
|
45
|
-
|
|
30
|
+
# Optional Model Benchmark
|
|
46
31
|
if app_options.get("model_benchmark", True):
|
|
47
|
-
|
|
48
|
-
self.
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
self.
|
|
32
|
+
# Model Benchmark
|
|
33
|
+
self.run_model_benchmark_checkbox = Checkbox(
|
|
34
|
+
content="Run Model Benchmark evaluation", checked=True
|
|
35
|
+
)
|
|
36
|
+
self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=True)
|
|
37
|
+
|
|
38
|
+
self.model_benchmark_field = Field(
|
|
39
|
+
title="Model Evaluation Benchmark",
|
|
40
|
+
description=f"Generate evalutaion dashboard with visualizations and detailed analysis of the model performance after training. The best checkpoint will be used for evaluation. You can also run speed test to evaluate model inference speed.",
|
|
41
|
+
content=Container([self.run_model_benchmark_checkbox, self.run_speedtest_checkbox]),
|
|
42
|
+
)
|
|
43
|
+
docs_link = '<a href="https://docs.supervisely.com/neural-networks/model-evaluation-benchmark/" target="_blank">documentation</a>'
|
|
44
|
+
self.model_benchmark_learn_more = Text(
|
|
45
|
+
f"Learn more about Model Benchmark in the {docs_link}.", status="info"
|
|
46
|
+
)
|
|
47
|
+
self.display_widgets.extend(
|
|
48
|
+
[self.model_benchmark_field, self.model_benchmark_learn_more]
|
|
49
|
+
)
|
|
50
|
+
# -------------------------------- #
|
|
51
|
+
|
|
52
|
+
# Optional Export Weights
|
|
53
|
+
export_onnx_supported = app_options.get("export_onnx_supported", False)
|
|
54
|
+
export_tensorrt_supported = app_options.get("export_tensorrt_supported", False)
|
|
55
|
+
|
|
56
|
+
onnx_name = "ONNX"
|
|
57
|
+
tensorrt_name = "TensorRT engine"
|
|
58
|
+
export_runtimes = []
|
|
59
|
+
export_runtime_names = []
|
|
60
|
+
if export_onnx_supported:
|
|
61
|
+
self.export_onnx_checkbox = Checkbox(content=f"Export to {onnx_name}")
|
|
62
|
+
export_runtimes.append(self.export_onnx_checkbox)
|
|
63
|
+
export_runtime_names.append(onnx_name)
|
|
64
|
+
if export_tensorrt_supported:
|
|
65
|
+
self.export_tensorrt_checkbox = Checkbox(content=f"Export to {tensorrt_name}")
|
|
66
|
+
export_runtimes.append(self.export_tensorrt_checkbox)
|
|
67
|
+
export_runtime_names.append(tensorrt_name)
|
|
68
|
+
if export_onnx_supported or export_tensorrt_supported:
|
|
69
|
+
export_field_description = ", ".join(export_runtime_names)
|
|
70
|
+
runtime_container = Container(export_runtimes)
|
|
71
|
+
self.export_field = Field(
|
|
72
|
+
title="Export model",
|
|
73
|
+
description=f"Export best checkpoint to the following formats: {export_field_description}.",
|
|
74
|
+
content=runtime_container,
|
|
75
|
+
)
|
|
76
|
+
self.display_widgets.extend([self.export_field])
|
|
77
|
+
# -------------------------------- #
|
|
52
78
|
|
|
53
79
|
self.validator_text = Text("")
|
|
54
80
|
self.validator_text.hide()
|
|
55
81
|
self.button = Button("Select")
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
self.model_benchmark_learn_more,
|
|
61
|
-
self.validator_text,
|
|
62
|
-
self.button,
|
|
63
|
-
]
|
|
64
|
-
)
|
|
82
|
+
self.display_widgets.extend([self.validator_text, self.button])
|
|
83
|
+
# -------------------------------- #
|
|
84
|
+
|
|
85
|
+
self.container = Container(self.display_widgets)
|
|
65
86
|
self.card = Card(
|
|
66
87
|
title=self.title,
|
|
67
88
|
description=self.description,
|
|
68
|
-
content=container,
|
|
89
|
+
content=self.container,
|
|
69
90
|
lock_message=self.lock_message,
|
|
70
91
|
collapsable=app_options.get("collapsable", False),
|
|
71
92
|
)
|
|
@@ -73,11 +94,14 @@ class HyperparametersSelector:
|
|
|
73
94
|
|
|
74
95
|
@property
|
|
75
96
|
def widgets_to_disable(self) -> list:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
self.run_model_benchmark_checkbox,
|
|
79
|
-
|
|
80
|
-
|
|
97
|
+
widgets = [self.editor]
|
|
98
|
+
if self.app_options.get("model_benchmark", True):
|
|
99
|
+
widgets.extend([self.run_model_benchmark_checkbox, self.run_speedtest_checkbox])
|
|
100
|
+
if self.app_options.get("export_onnx_supported", False):
|
|
101
|
+
widgets.append(self.export_onnx_checkbox)
|
|
102
|
+
if self.app_options.get("export_tensorrt_supported", False):
|
|
103
|
+
widgets.append(self.export_tensorrt_checkbox)
|
|
104
|
+
return widgets
|
|
81
105
|
|
|
82
106
|
def set_hyperparameters(self, hyperparameters: Union[str, dict]) -> None:
|
|
83
107
|
self.editor.set_text(hyperparameters)
|
|
@@ -113,5 +137,30 @@ class HyperparametersSelector:
|
|
|
113
137
|
else:
|
|
114
138
|
self.run_speedtest_checkbox.hide()
|
|
115
139
|
|
|
140
|
+
def get_export_onnx_checkbox_value(self) -> bool:
|
|
141
|
+
if self.app_options.get("export_onnx_supported", False):
|
|
142
|
+
return self.export_onnx_checkbox.is_checked()
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
def set_export_onnx_checkbox_value(self, value: bool) -> None:
|
|
146
|
+
if value:
|
|
147
|
+
self.export_onnx_checkbox.check()
|
|
148
|
+
else:
|
|
149
|
+
self.export_onnx_checkbox.uncheck()
|
|
150
|
+
|
|
151
|
+
def get_export_tensorrt_checkbox_value(self) -> bool:
|
|
152
|
+
if self.app_options.get("export_tensorrt_supported", False):
|
|
153
|
+
return self.export_tensorrt_checkbox.is_checked()
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
def set_export_tensorrt_checkbox_value(self, value: bool) -> None:
|
|
157
|
+
if value:
|
|
158
|
+
self.export_tensorrt_checkbox.check()
|
|
159
|
+
else:
|
|
160
|
+
self.export_tensorrt_checkbox.uncheck()
|
|
161
|
+
|
|
162
|
+
def is_export_required(self) -> bool:
|
|
163
|
+
return self.get_export_onnx_checkbox_value() or self.get_export_tensorrt_checkbox_value()
|
|
164
|
+
|
|
116
165
|
def validate_step(self) -> bool:
|
|
117
166
|
return True
|
|
@@ -17,9 +17,11 @@ class InputSelector:
|
|
|
17
17
|
lock_message = None
|
|
18
18
|
|
|
19
19
|
def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
|
|
20
|
+
self.display_widgets = []
|
|
20
21
|
self.project_id = project_info.id
|
|
21
22
|
self.project_info = project_info
|
|
22
23
|
|
|
24
|
+
# GUI Components
|
|
23
25
|
self.project_thumbnail = ProjectThumbnail(self.project_info)
|
|
24
26
|
|
|
25
27
|
if is_cached(self.project_id):
|
|
@@ -32,27 +34,27 @@ class InputSelector:
|
|
|
32
34
|
self.validator_text = Text("")
|
|
33
35
|
self.validator_text.hide()
|
|
34
36
|
self.button = Button("Select")
|
|
35
|
-
|
|
36
|
-
|
|
37
|
+
self.display_widgets.extend(
|
|
38
|
+
[
|
|
37
39
|
self.project_thumbnail,
|
|
38
40
|
self.use_cache_checkbox,
|
|
39
41
|
self.validator_text,
|
|
40
42
|
self.button,
|
|
41
43
|
]
|
|
42
44
|
)
|
|
45
|
+
# -------------------------------- #
|
|
43
46
|
|
|
47
|
+
self.container = Container(self.display_widgets)
|
|
44
48
|
self.card = Card(
|
|
45
49
|
title=self.title,
|
|
46
50
|
description=self.description,
|
|
47
|
-
content=container,
|
|
51
|
+
content=self.container,
|
|
48
52
|
collapsable=app_options.get("collapsable", False),
|
|
49
53
|
)
|
|
50
54
|
|
|
51
55
|
@property
|
|
52
56
|
def widgets_to_disable(self) -> list:
|
|
53
|
-
return [
|
|
54
|
-
self.use_cache_checkbox,
|
|
55
|
-
]
|
|
57
|
+
return [self.use_cache_checkbox]
|
|
56
58
|
|
|
57
59
|
def get_project_id(self) -> int:
|
|
58
60
|
return self.project_id
|
|
@@ -21,15 +21,14 @@ class ModelSelector:
|
|
|
21
21
|
lock_message = "Select classes to unlock"
|
|
22
22
|
|
|
23
23
|
def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
|
|
24
|
+
self.display_widgets = []
|
|
24
25
|
self.team_id = sly_env.team_id() # get from project id
|
|
25
26
|
self.models = models
|
|
26
27
|
|
|
27
|
-
#
|
|
28
|
+
# GUI Components
|
|
28
29
|
self.pretrained_models_table = PretrainedModelsSelector(self.models)
|
|
29
|
-
|
|
30
30
|
experiment_infos = get_experiment_infos(api, self.team_id, framework)
|
|
31
31
|
self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
|
|
32
|
-
# Model source tabs
|
|
33
32
|
self.model_source_tabs = RadioTabs(
|
|
34
33
|
titles=[ModelSource.PRETRAINED, ModelSource.CUSTOM],
|
|
35
34
|
descriptions=[
|
|
@@ -42,11 +41,14 @@ class ModelSelector:
|
|
|
42
41
|
self.validator_text = Text("")
|
|
43
42
|
self.validator_text.hide()
|
|
44
43
|
self.button = Button("Select")
|
|
45
|
-
|
|
44
|
+
self.display_widgets.extend([self.model_source_tabs, self.validator_text, self.button])
|
|
45
|
+
# -------------------------------- #
|
|
46
|
+
|
|
47
|
+
self.container = Container(self.display_widgets)
|
|
46
48
|
self.card = Card(
|
|
47
49
|
title=self.title,
|
|
48
50
|
description=self.description,
|
|
49
|
-
content=container,
|
|
51
|
+
content=self.container,
|
|
50
52
|
lock_message=self.lock_message,
|
|
51
53
|
collapsable=app_options.get("collapsable", False),
|
|
52
54
|
)
|
|
@@ -10,10 +10,12 @@ class TrainValSplitsSelector:
|
|
|
10
10
|
lock_message = "Select input options to unlock"
|
|
11
11
|
|
|
12
12
|
def __init__(self, api: Api, project_id: int, app_options: dict = {}):
|
|
13
|
+
self.display_widgets = []
|
|
13
14
|
self.api = api
|
|
14
15
|
self.project_id = project_id
|
|
15
|
-
self.train_val_splits = TrainValSplits(project_id)
|
|
16
16
|
|
|
17
|
+
# GUI Components
|
|
18
|
+
self.train_val_splits = TrainValSplits(project_id)
|
|
17
19
|
train_val_dataset_ids = {"train": [], "val": []}
|
|
18
20
|
for _, dataset in api.dataset.tree(project_id):
|
|
19
21
|
if dataset.name.lower() == "train" or dataset.name.lower() == "training":
|
|
@@ -39,17 +41,14 @@ class TrainValSplitsSelector:
|
|
|
39
41
|
self.validator_text.hide()
|
|
40
42
|
|
|
41
43
|
self.button = Button("Select")
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
self.button,
|
|
47
|
-
]
|
|
48
|
-
)
|
|
44
|
+
self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
|
|
45
|
+
# -------------------------------- #
|
|
46
|
+
|
|
47
|
+
self.container = Container(self.display_widgets)
|
|
49
48
|
self.card = Card(
|
|
50
49
|
title=self.title,
|
|
51
50
|
description=self.description,
|
|
52
|
-
content=container,
|
|
51
|
+
content=self.container,
|
|
53
52
|
lock_message=self.lock_message,
|
|
54
53
|
collapsable=app_options.get("collapsable", False),
|
|
55
54
|
)
|
|
@@ -12,14 +12,13 @@ class TrainingLogs:
|
|
|
12
12
|
lock_message = "Start training to unlock"
|
|
13
13
|
|
|
14
14
|
def __init__(self, app_options: Dict[str, Any]):
|
|
15
|
+
self.display_widgets = []
|
|
15
16
|
api = Api.from_env()
|
|
16
17
|
self.app_options = app_options
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
self.
|
|
20
|
-
|
|
21
|
-
self.progress_bar_secondary = Progress(hide_on_finish=False)
|
|
22
|
-
self.progress_bar_secondary.hide()
|
|
19
|
+
# GUI Components
|
|
20
|
+
self.validator_text = Text("")
|
|
21
|
+
self.validator_text.hide()
|
|
23
22
|
|
|
24
23
|
if is_production():
|
|
25
24
|
task_id = get_task_id(raise_not_found=False)
|
|
@@ -43,16 +42,9 @@ class TrainingLogs:
|
|
|
43
42
|
)
|
|
44
43
|
self.tensorboard_button.disable()
|
|
45
44
|
|
|
46
|
-
self.validator_text
|
|
47
|
-
self.validator_text.hide()
|
|
48
|
-
|
|
49
|
-
container_widgets = [
|
|
50
|
-
self.validator_text,
|
|
51
|
-
self.tensorboard_button,
|
|
52
|
-
self.progress_bar_main,
|
|
53
|
-
self.progress_bar_secondary,
|
|
54
|
-
]
|
|
45
|
+
self.display_widgets.extend([self.validator_text, self.tensorboard_button])
|
|
55
46
|
|
|
47
|
+
# Optional Show logs button
|
|
56
48
|
if app_options.get("show_logs_in_gui", False):
|
|
57
49
|
self.logs_button = Button(
|
|
58
50
|
text="Show logs",
|
|
@@ -63,14 +55,22 @@ class TrainingLogs:
|
|
|
63
55
|
self.task_logs = TaskLogs(task_id)
|
|
64
56
|
self.task_logs.hide()
|
|
65
57
|
logs_container = Container([self.logs_button, self.task_logs])
|
|
66
|
-
|
|
58
|
+
self.display_widgets.extend([logs_container])
|
|
59
|
+
# -------------------------------- #
|
|
67
60
|
|
|
68
|
-
|
|
61
|
+
# Progress bars
|
|
62
|
+
self.progress_bar_main = Progress(hide_on_finish=False)
|
|
63
|
+
self.progress_bar_main.hide()
|
|
64
|
+
self.progress_bar_secondary = Progress(hide_on_finish=False)
|
|
65
|
+
self.progress_bar_secondary.hide()
|
|
66
|
+
self.display_widgets.extend([self.progress_bar_main, self.progress_bar_secondary])
|
|
67
|
+
# -------------------------------- #
|
|
69
68
|
|
|
69
|
+
self.container = Container(self.display_widgets)
|
|
70
70
|
self.card = Card(
|
|
71
71
|
title=self.title,
|
|
72
72
|
description=self.description,
|
|
73
|
-
content=container,
|
|
73
|
+
content=self.container,
|
|
74
74
|
lock_message=self.lock_message,
|
|
75
75
|
)
|
|
76
76
|
self.card.lock()
|