supervisely 6.73.242__py3-none-any.whl → 6.73.244__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +18 -0
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/card/card.py +3 -0
- supervisely/app/widgets/classes_table/classes_table.py +15 -1
- supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
- supervisely/app/widgets/custom_models_selector/template.html +1 -1
- supervisely/app/widgets/experiment_selector/__init__.py +0 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
- supervisely/app/widgets/experiment_selector/style.css +27 -0
- supervisely/app/widgets/experiment_selector/template.html +82 -0
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
- supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
- supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/__init__.py +3 -1
- supervisely/nn/artifacts/artifacts.py +10 -0
- supervisely/nn/artifacts/detectron2.py +2 -0
- supervisely/nn/artifacts/hrda.py +3 -0
- supervisely/nn/artifacts/mmclassification.py +2 -0
- supervisely/nn/artifacts/mmdetection.py +6 -3
- supervisely/nn/artifacts/mmsegmentation.py +2 -0
- supervisely/nn/artifacts/ritm.py +3 -1
- supervisely/nn/artifacts/rtdetr.py +2 -0
- supervisely/nn/artifacts/unet.py +2 -0
- supervisely/nn/artifacts/yolov5.py +3 -0
- supervisely/nn/artifacts/yolov8.py +7 -1
- supervisely/nn/experiments.py +113 -0
- supervisely/nn/inference/gui/__init__.py +3 -1
- supervisely/nn/inference/gui/gui.py +31 -232
- supervisely/nn/inference/gui/serving_gui.py +223 -0
- supervisely/nn/inference/gui/serving_gui_template.py +240 -0
- supervisely/nn/inference/inference.py +225 -24
- supervisely/nn/training/__init__.py +0 -0
- supervisely/nn/training/gui/__init__.py +1 -0
- supervisely/nn/training/gui/classes_selector.py +100 -0
- supervisely/nn/training/gui/gui.py +539 -0
- supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
- supervisely/nn/training/gui/input_selector.py +70 -0
- supervisely/nn/training/gui/model_selector.py +95 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
- supervisely/nn/training/gui/training_logs.py +93 -0
- supervisely/nn/training/gui/training_process.py +114 -0
- supervisely/nn/training/gui/utils.py +128 -0
- supervisely/nn/training/loggers/__init__.py +0 -0
- supervisely/nn/training/loggers/base_train_logger.py +58 -0
- supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
- supervisely/nn/training/train_app.py +2038 -0
- supervisely/nn/utils.py +5 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/RECORD +57 -35
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from supervisely.api.project_api import ProjectInfo
|
|
2
|
+
from supervisely.app.widgets import (
|
|
3
|
+
Button,
|
|
4
|
+
Card,
|
|
5
|
+
Checkbox,
|
|
6
|
+
Container,
|
|
7
|
+
Field,
|
|
8
|
+
ProjectThumbnail,
|
|
9
|
+
Text,
|
|
10
|
+
)
|
|
11
|
+
from supervisely.project.download import is_cached
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InputSelector:
|
|
15
|
+
title = "Input project"
|
|
16
|
+
description = "Selected project from which images and annotations will be downloaded"
|
|
17
|
+
lock_message = None
|
|
18
|
+
|
|
19
|
+
def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
|
|
20
|
+
self.project_id = project_info.id
|
|
21
|
+
self.project_info = project_info
|
|
22
|
+
|
|
23
|
+
self.project_thumbnail = ProjectThumbnail(self.project_info)
|
|
24
|
+
|
|
25
|
+
if is_cached(self.project_id):
|
|
26
|
+
_text = "Use cached data stored on the agent to optimize project download"
|
|
27
|
+
else:
|
|
28
|
+
_text = "Cache data on the agent to optimize project download for future trainings"
|
|
29
|
+
self.use_cache_text = Text(_text)
|
|
30
|
+
self.use_cache_checkbox = Checkbox(self.use_cache_text, checked=True)
|
|
31
|
+
|
|
32
|
+
self.validator_text = Text("")
|
|
33
|
+
self.validator_text.hide()
|
|
34
|
+
self.button = Button("Select")
|
|
35
|
+
container = Container(
|
|
36
|
+
widgets=[
|
|
37
|
+
self.project_thumbnail,
|
|
38
|
+
self.use_cache_checkbox,
|
|
39
|
+
self.validator_text,
|
|
40
|
+
self.button,
|
|
41
|
+
]
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
self.card = Card(
|
|
45
|
+
title=self.title,
|
|
46
|
+
description=self.description,
|
|
47
|
+
content=container,
|
|
48
|
+
collapsable=app_options.get("collapsable", False),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def widgets_to_disable(self) -> list:
|
|
53
|
+
return [
|
|
54
|
+
self.use_cache_checkbox,
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
def get_project_id(self) -> int:
|
|
58
|
+
return self.project_id
|
|
59
|
+
|
|
60
|
+
def set_cache(self, value: bool) -> None:
|
|
61
|
+
if value:
|
|
62
|
+
self.use_cache_checkbox.check()
|
|
63
|
+
else:
|
|
64
|
+
self.use_cache_checkbox.uncheck()
|
|
65
|
+
|
|
66
|
+
def get_cache_value(self) -> bool:
|
|
67
|
+
return self.use_cache_checkbox.is_checked()
|
|
68
|
+
|
|
69
|
+
def validate_step(self) -> bool:
|
|
70
|
+
return True
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import supervisely.io.env as sly_env
|
|
4
|
+
from supervisely.api.api import Api
|
|
5
|
+
from supervisely.app.widgets import (
|
|
6
|
+
Button,
|
|
7
|
+
Card,
|
|
8
|
+
Container,
|
|
9
|
+
ExperimentSelector,
|
|
10
|
+
PretrainedModelsSelector,
|
|
11
|
+
RadioTabs,
|
|
12
|
+
Text,
|
|
13
|
+
)
|
|
14
|
+
from supervisely.nn.experiments import get_experiment_infos
|
|
15
|
+
from supervisely.nn.utils import ModelSource
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelSelector:
|
|
19
|
+
title = "Select Model"
|
|
20
|
+
description = "Select a model for training"
|
|
21
|
+
lock_message = "Select classes to unlock"
|
|
22
|
+
|
|
23
|
+
def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
|
|
24
|
+
self.team_id = sly_env.team_id() # get from project id
|
|
25
|
+
self.models = models
|
|
26
|
+
|
|
27
|
+
# Pretrained models
|
|
28
|
+
self.pretrained_models_table = PretrainedModelsSelector(self.models)
|
|
29
|
+
|
|
30
|
+
experiment_infos = get_experiment_infos(api, self.team_id, framework)
|
|
31
|
+
self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
|
|
32
|
+
# Model source tabs
|
|
33
|
+
self.model_source_tabs = RadioTabs(
|
|
34
|
+
titles=[ModelSource.PRETRAINED, ModelSource.CUSTOM],
|
|
35
|
+
descriptions=[
|
|
36
|
+
"Publicly available models",
|
|
37
|
+
"Models trained by you in Supervisely",
|
|
38
|
+
],
|
|
39
|
+
contents=[self.pretrained_models_table, self.experiment_selector],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.validator_text = Text("")
|
|
43
|
+
self.validator_text.hide()
|
|
44
|
+
self.button = Button("Select")
|
|
45
|
+
container = Container([self.model_source_tabs, self.validator_text, self.button])
|
|
46
|
+
self.card = Card(
|
|
47
|
+
title=self.title,
|
|
48
|
+
description=self.description,
|
|
49
|
+
content=container,
|
|
50
|
+
lock_message=self.lock_message,
|
|
51
|
+
collapsable=app_options.get("collapsable", False),
|
|
52
|
+
)
|
|
53
|
+
self.card.lock()
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def widgets_to_disable(self) -> list:
|
|
57
|
+
return [
|
|
58
|
+
self.model_source_tabs,
|
|
59
|
+
self.pretrained_models_table,
|
|
60
|
+
self.experiment_selector,
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
def get_model_source(self) -> str:
|
|
64
|
+
return self.model_source_tabs.get_active_tab()
|
|
65
|
+
|
|
66
|
+
def set_model_source(self, source: Literal["Pretrained models", "Custom models"]) -> None:
|
|
67
|
+
self.model_source_tabs.set_active_tab(source)
|
|
68
|
+
|
|
69
|
+
def get_model_name(self) -> str:
|
|
70
|
+
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
71
|
+
selected_row = self.pretrained_models_table.get_selected_row()
|
|
72
|
+
model_meta = selected_row.get("meta", {})
|
|
73
|
+
model_name = model_meta.get("model_name", None)
|
|
74
|
+
else:
|
|
75
|
+
selected_row = self.experiment_selector.get_selected_experiment_info()
|
|
76
|
+
model_name = selected_row.get("model_name", None)
|
|
77
|
+
return model_name
|
|
78
|
+
|
|
79
|
+
def get_model_info(self) -> dict:
|
|
80
|
+
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
81
|
+
return self.pretrained_models_table.get_selected_row()
|
|
82
|
+
else:
|
|
83
|
+
return self.experiment_selector.get_selected_experiment_info()
|
|
84
|
+
|
|
85
|
+
def validate_step(self) -> bool:
|
|
86
|
+
self.validator_text.hide()
|
|
87
|
+
model_info = self.get_model_info()
|
|
88
|
+
if model_info is None or model_info == {}:
|
|
89
|
+
self.validator_text.set(text="Model is not selected", status="error")
|
|
90
|
+
self.validator_text.show()
|
|
91
|
+
return False
|
|
92
|
+
else:
|
|
93
|
+
self.validator_text.set(text="Model is selected", status="success")
|
|
94
|
+
self.validator_text.show()
|
|
95
|
+
return True
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from supervisely import Api, Project
|
|
4
|
+
from supervisely.app.widgets import Button, Card, Container, Text, TrainValSplits
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TrainValSplitsSelector:
|
|
8
|
+
title = "Train / Val Splits"
|
|
9
|
+
description = "Select train and val splits for training"
|
|
10
|
+
lock_message = "Select input options to unlock"
|
|
11
|
+
|
|
12
|
+
def __init__(self, api: Api, project_id: int, app_options: dict = {}):
|
|
13
|
+
self.api = api
|
|
14
|
+
self.project_id = project_id
|
|
15
|
+
self.train_val_splits = TrainValSplits(project_id)
|
|
16
|
+
|
|
17
|
+
train_val_dataset_ids = {"train": [], "val": []}
|
|
18
|
+
for _, dataset in api.dataset.tree(project_id):
|
|
19
|
+
if dataset.name.lower() == "train" or dataset.name.lower() == "training":
|
|
20
|
+
if dataset.items_count > 0:
|
|
21
|
+
train_val_dataset_ids["train"].append(dataset.id)
|
|
22
|
+
elif dataset.name.lower() == "val" or dataset.name.lower() == "validation":
|
|
23
|
+
if dataset.items_count > 0:
|
|
24
|
+
train_val_dataset_ids["val"].append(dataset.id)
|
|
25
|
+
|
|
26
|
+
# Check nested dataset names
|
|
27
|
+
train_count = len(train_val_dataset_ids["train"])
|
|
28
|
+
val_count = len(train_val_dataset_ids["val"])
|
|
29
|
+
if train_count > 0 and val_count > 0:
|
|
30
|
+
self.train_val_splits.set_datasets_splits(
|
|
31
|
+
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if train_count > 0 and val_count > 0:
|
|
35
|
+
self.validator_text = Text("Train and val datasets are detected", status="info")
|
|
36
|
+
self.validator_text.show()
|
|
37
|
+
else:
|
|
38
|
+
self.validator_text = Text("")
|
|
39
|
+
self.validator_text.hide()
|
|
40
|
+
|
|
41
|
+
self.button = Button("Select")
|
|
42
|
+
container = Container(
|
|
43
|
+
[
|
|
44
|
+
self.train_val_splits,
|
|
45
|
+
self.validator_text,
|
|
46
|
+
self.button,
|
|
47
|
+
]
|
|
48
|
+
)
|
|
49
|
+
self.card = Card(
|
|
50
|
+
title=self.title,
|
|
51
|
+
description=self.description,
|
|
52
|
+
content=container,
|
|
53
|
+
lock_message=self.lock_message,
|
|
54
|
+
collapsable=app_options.get("collapsable", False),
|
|
55
|
+
)
|
|
56
|
+
self.card.lock()
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def widgets_to_disable(self) -> list:
|
|
60
|
+
return [self.train_val_splits]
|
|
61
|
+
|
|
62
|
+
def validate_step(self) -> bool:
|
|
63
|
+
split_method = self.train_val_splits.get_split_method()
|
|
64
|
+
warning_text = "Using the same data for training and validation leads to overfitting, poor generalization and biased model selection."
|
|
65
|
+
ensure_text = "Ensure this is intentional."
|
|
66
|
+
|
|
67
|
+
if split_method == "Random":
|
|
68
|
+
train_ratio = self.train_val_splits.get_train_split_percent()
|
|
69
|
+
val_ratio = self.train_val_splits.get_val_split_percent()
|
|
70
|
+
|
|
71
|
+
# Define common warning text
|
|
72
|
+
ensure_text_random_split = (
|
|
73
|
+
"Consider reallocating to ensure efficient learning and validation."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Validate splits
|
|
77
|
+
if train_ratio == val_ratio:
|
|
78
|
+
self.validator_text.set(
|
|
79
|
+
text="Train and validation splits are equal (50:50). This is inefficient for standard training. "
|
|
80
|
+
f"{ensure_text}",
|
|
81
|
+
status="warning",
|
|
82
|
+
)
|
|
83
|
+
elif train_ratio > 90:
|
|
84
|
+
self.validator_text.set(
|
|
85
|
+
text="Training split exceeds 90%. This may leave insufficient data for validation. Ensure you have enough data for validation.",
|
|
86
|
+
status="warning",
|
|
87
|
+
)
|
|
88
|
+
elif val_ratio > train_ratio:
|
|
89
|
+
self.validator_text.set(
|
|
90
|
+
text=f"Validation split is larger than the training split. {ensure_text_random_split}",
|
|
91
|
+
status="warning",
|
|
92
|
+
)
|
|
93
|
+
elif train_ratio < 70:
|
|
94
|
+
self.validator_text.set(
|
|
95
|
+
text="Training split is below 70%. This may limit the model's learning capability. "
|
|
96
|
+
f"{ensure_text_random_split}",
|
|
97
|
+
status="warning",
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
self.validator_text.set(
|
|
101
|
+
text="Train and validation splits are selected.",
|
|
102
|
+
status="success",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
elif split_method == "Based on tags":
|
|
106
|
+
train_tag = self.train_val_splits.get_train_tag()
|
|
107
|
+
val_tag = self.train_val_splits.get_val_tag()
|
|
108
|
+
|
|
109
|
+
# Check if tags are present on any item in the project
|
|
110
|
+
stats = self.api.project.get_stats(self.project_id)
|
|
111
|
+
tags_count = {}
|
|
112
|
+
for item in stats["imageTags"]["items"]:
|
|
113
|
+
tag_name = item["tagMeta"]["name"]
|
|
114
|
+
tag_total = item["total"]
|
|
115
|
+
tags_count[tag_name] = tag_total
|
|
116
|
+
|
|
117
|
+
for object_tags in stats["objectTags"]["items"]:
|
|
118
|
+
tag_name = object_tags["tagMeta"]["name"]
|
|
119
|
+
tag_total = object_tags["total"]
|
|
120
|
+
if tag_name in tags_count:
|
|
121
|
+
tags_count[tag_name] += tag_total
|
|
122
|
+
else:
|
|
123
|
+
tags_count[tag_name] = tag_total
|
|
124
|
+
|
|
125
|
+
# @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
|
|
126
|
+
if tags_count[train_tag] == 0:
|
|
127
|
+
self.validator_text.set(
|
|
128
|
+
text=f"Train tag '{train_tag}' is not present in any images. {ensure_text}",
|
|
129
|
+
status="error",
|
|
130
|
+
)
|
|
131
|
+
elif tags_count[val_tag] == 0:
|
|
132
|
+
self.validator_text.set(
|
|
133
|
+
text=f"Val tag '{val_tag}' is not present in any images. {ensure_text}",
|
|
134
|
+
status="error",
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
elif train_tag == val_tag:
|
|
138
|
+
self.validator_text.set(
|
|
139
|
+
text=f"Train and val tags are the same. {ensure_text} {warning_text}",
|
|
140
|
+
status="warning",
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
self.validator_text.set("Train and val tags are selected", status="success")
|
|
144
|
+
|
|
145
|
+
elif split_method == "Based on datasets":
|
|
146
|
+
train_dataset_id = self.get_train_dataset_ids()
|
|
147
|
+
val_dataset_id = self.get_val_dataset_ids()
|
|
148
|
+
|
|
149
|
+
# Check if datasets are not empty
|
|
150
|
+
stats = self.api.project.get_stats(self.project_id)
|
|
151
|
+
datasets_count = {}
|
|
152
|
+
for dataset in stats["images"]["datasets"]:
|
|
153
|
+
datasets_count[dataset["id"]] = {
|
|
154
|
+
"name": dataset["name"],
|
|
155
|
+
"total": dataset["imagesInDataset"],
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
empty_dataset_names = []
|
|
159
|
+
for dataset_id in train_dataset_id + val_dataset_id:
|
|
160
|
+
if datasets_count[dataset_id]["total"] == 0:
|
|
161
|
+
empty_dataset_names.append(datasets_count[dataset_id]["name"])
|
|
162
|
+
|
|
163
|
+
if len(empty_dataset_names) > 0:
|
|
164
|
+
if len(empty_dataset_names) == 1:
|
|
165
|
+
empty_ds_text = f"Selected dataset: {', '.join(empty_dataset_names)} is empty. {ensure_text}"
|
|
166
|
+
else:
|
|
167
|
+
empty_ds_text = f"Selected datasets: {', '.join(empty_dataset_names)} are empty. {ensure_text}"
|
|
168
|
+
|
|
169
|
+
self.validator_text.set(
|
|
170
|
+
text=empty_ds_text,
|
|
171
|
+
status="error",
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
elif train_dataset_id == val_dataset_id:
|
|
175
|
+
self.validator_text.set(
|
|
176
|
+
text=f"Same datasets are selected for both train and val splits. {ensure_text} {warning_text}",
|
|
177
|
+
status="warning",
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
self.validator_text.set("Train and val datasets are selected", status="success")
|
|
181
|
+
self.validator_text.show()
|
|
182
|
+
return True
|
|
183
|
+
|
|
184
|
+
def set_sly_project(self, project: Project) -> None:
|
|
185
|
+
self.train_val_splits._project_fs = project
|
|
186
|
+
|
|
187
|
+
def get_split_method(self) -> str:
|
|
188
|
+
return self.train_val_splits.get_split_method()
|
|
189
|
+
|
|
190
|
+
def get_train_dataset_ids(self) -> List[int]:
|
|
191
|
+
return self.train_val_splits._train_ds_select.get_selected_ids()
|
|
192
|
+
|
|
193
|
+
def set_train_dataset_ids(self, dataset_ids: List[int]) -> None:
|
|
194
|
+
self.train_val_splits._train_ds_select.set_selected_ids(dataset_ids)
|
|
195
|
+
|
|
196
|
+
def get_val_dataset_ids(self) -> List[int]:
|
|
197
|
+
return self.train_val_splits._val_ds_select.get_selected_ids()
|
|
198
|
+
|
|
199
|
+
def set_val_dataset_ids(self, dataset_ids: List[int]) -> None:
|
|
200
|
+
self.train_val_splits._val_ds_select.set_selected_ids(dataset_ids)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from supervisely import Api
|
|
4
|
+
from supervisely._utils import is_production
|
|
5
|
+
from supervisely.app.widgets import Button, Card, Container, Progress, TaskLogs, Text
|
|
6
|
+
from supervisely.io.env import task_id as get_task_id
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TrainingLogs:
|
|
10
|
+
title = "Training Logs"
|
|
11
|
+
description = "Track training progress"
|
|
12
|
+
lock_message = "Start training to unlock"
|
|
13
|
+
|
|
14
|
+
def __init__(self, app_options: Dict[str, Any]):
|
|
15
|
+
api = Api.from_env()
|
|
16
|
+
self.app_options = app_options
|
|
17
|
+
|
|
18
|
+
self.progress_bar_main = Progress(hide_on_finish=False)
|
|
19
|
+
self.progress_bar_main.hide()
|
|
20
|
+
|
|
21
|
+
self.progress_bar_secondary = Progress(hide_on_finish=False)
|
|
22
|
+
self.progress_bar_secondary.hide()
|
|
23
|
+
|
|
24
|
+
if is_production():
|
|
25
|
+
task_id = get_task_id(raise_not_found=False)
|
|
26
|
+
else:
|
|
27
|
+
task_id = None
|
|
28
|
+
|
|
29
|
+
# Tensorboard button
|
|
30
|
+
if is_production():
|
|
31
|
+
task_info = api.task.get_info_by_id(task_id)
|
|
32
|
+
session_token = task_info["meta"]["sessionToken"]
|
|
33
|
+
sly_url_prefix = f"/net/{session_token}"
|
|
34
|
+
self.tensorboard_link = f"{api.server_address}{sly_url_prefix}/tensorboard/"
|
|
35
|
+
else:
|
|
36
|
+
self.tensorboard_link = "http://localhost:8000/tensorboard"
|
|
37
|
+
self.tensorboard_button = Button(
|
|
38
|
+
"Open Tensorboard",
|
|
39
|
+
button_type="info",
|
|
40
|
+
plain=True,
|
|
41
|
+
icon="zmdi zmdi-chart",
|
|
42
|
+
link=self.tensorboard_link,
|
|
43
|
+
)
|
|
44
|
+
self.tensorboard_button.disable()
|
|
45
|
+
|
|
46
|
+
self.validator_text = Text("")
|
|
47
|
+
self.validator_text.hide()
|
|
48
|
+
|
|
49
|
+
container_widgets = [
|
|
50
|
+
self.validator_text,
|
|
51
|
+
self.tensorboard_button,
|
|
52
|
+
self.progress_bar_main,
|
|
53
|
+
self.progress_bar_secondary,
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
if app_options.get("show_logs_in_gui", False):
|
|
57
|
+
self.logs_button = Button(
|
|
58
|
+
text="Show logs",
|
|
59
|
+
plain=True,
|
|
60
|
+
button_size="mini",
|
|
61
|
+
icon="zmdi zmdi-caret-down-circle",
|
|
62
|
+
)
|
|
63
|
+
self.task_logs = TaskLogs(task_id)
|
|
64
|
+
self.task_logs.hide()
|
|
65
|
+
logs_container = Container([self.logs_button, self.task_logs])
|
|
66
|
+
container_widgets.insert(2, logs_container)
|
|
67
|
+
|
|
68
|
+
container = Container(container_widgets)
|
|
69
|
+
|
|
70
|
+
self.card = Card(
|
|
71
|
+
title=self.title,
|
|
72
|
+
description=self.description,
|
|
73
|
+
content=container,
|
|
74
|
+
lock_message=self.lock_message,
|
|
75
|
+
)
|
|
76
|
+
self.card.lock()
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def widgets_to_disable(self) -> list:
|
|
80
|
+
return []
|
|
81
|
+
|
|
82
|
+
def validate_step(self) -> bool:
|
|
83
|
+
return True
|
|
84
|
+
|
|
85
|
+
def toggle_logs(self):
|
|
86
|
+
if self.task_logs.is_hidden():
|
|
87
|
+
self.task_logs.show()
|
|
88
|
+
self.logs_button.text = "Hide logs"
|
|
89
|
+
self.logs_button.icon = "zmdi zmdi-caret-up-circle"
|
|
90
|
+
else:
|
|
91
|
+
self.task_logs.hide()
|
|
92
|
+
self.logs_button.text = "Show logs"
|
|
93
|
+
self.logs_button.icon = "zmdi zmdi-caret-down-circle"
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from supervisely import Api
|
|
4
|
+
from supervisely.app.widgets import (
|
|
5
|
+
Button,
|
|
6
|
+
Card,
|
|
7
|
+
Container,
|
|
8
|
+
DoneLabel,
|
|
9
|
+
Empty,
|
|
10
|
+
Field,
|
|
11
|
+
FolderThumbnail,
|
|
12
|
+
Input,
|
|
13
|
+
ReportThumbnail,
|
|
14
|
+
SelectCudaDevice,
|
|
15
|
+
Text,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TrainingProcess:
|
|
20
|
+
title = "Training Process"
|
|
21
|
+
description = "Manage training process"
|
|
22
|
+
lock_message = "Select hyperparametrs to unlock"
|
|
23
|
+
|
|
24
|
+
def __init__(self, app_options: Dict[str, Any]):
|
|
25
|
+
self.app_options = app_options
|
|
26
|
+
self.experiment_name_input = Input("Enter experiment name")
|
|
27
|
+
self.experiment_name_field = Field(
|
|
28
|
+
title="Experiment name",
|
|
29
|
+
description="Experiment name will be saved to experiment_info.json",
|
|
30
|
+
content=self.experiment_name_input,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.success_message_text = (
|
|
34
|
+
"Training completed. Training artifacts were uploaded to Team Files. "
|
|
35
|
+
"You can find and open tensorboard logs in the artifacts folder via the "
|
|
36
|
+
"<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
|
|
37
|
+
)
|
|
38
|
+
self.success_message = DoneLabel(text=self.success_message_text)
|
|
39
|
+
self.success_message.hide()
|
|
40
|
+
|
|
41
|
+
self.artifacts_thumbnail = FolderThumbnail()
|
|
42
|
+
self.artifacts_thumbnail.hide()
|
|
43
|
+
|
|
44
|
+
self.model_benchmark_report_thumbnail = ReportThumbnail()
|
|
45
|
+
self.model_benchmark_report_thumbnail.hide()
|
|
46
|
+
|
|
47
|
+
self.model_benchmark_report_text = Text(status="info", text="Creating report on model...")
|
|
48
|
+
self.model_benchmark_report_text.hide()
|
|
49
|
+
|
|
50
|
+
self.validator_text = Text("")
|
|
51
|
+
self.validator_text.hide()
|
|
52
|
+
self.start_button = Button("Start")
|
|
53
|
+
self.stop_button = Button("Stop", button_type="danger")
|
|
54
|
+
self.stop_button.hide() # @TODO: implement stop and hide stop button until training starts
|
|
55
|
+
|
|
56
|
+
button_container = Container(
|
|
57
|
+
[self.start_button, self.stop_button, Empty()],
|
|
58
|
+
"horizontal",
|
|
59
|
+
overflow="wrap",
|
|
60
|
+
fractions=[1, 1, 10],
|
|
61
|
+
gap=1,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
container_widgets = [
|
|
65
|
+
self.experiment_name_field,
|
|
66
|
+
button_container,
|
|
67
|
+
self.validator_text,
|
|
68
|
+
self.artifacts_thumbnail,
|
|
69
|
+
self.model_benchmark_report_thumbnail,
|
|
70
|
+
self.model_benchmark_report_text,
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
if self.app_options.get("device_selector", False):
|
|
74
|
+
self.select_device = SelectCudaDevice()
|
|
75
|
+
self.select_cuda_device_field = Field(
|
|
76
|
+
title="Select CUDA device",
|
|
77
|
+
description="The device on which the model will be trained",
|
|
78
|
+
content=self.select_device,
|
|
79
|
+
)
|
|
80
|
+
container_widgets.insert(1, self.select_cuda_device_field)
|
|
81
|
+
|
|
82
|
+
container = Container(container_widgets)
|
|
83
|
+
|
|
84
|
+
self.card = Card(
|
|
85
|
+
title=self.title,
|
|
86
|
+
description=self.description,
|
|
87
|
+
content=container,
|
|
88
|
+
lock_message=self.lock_message,
|
|
89
|
+
)
|
|
90
|
+
self.card.lock()
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def widgets_to_disable(self) -> list:
|
|
94
|
+
widgets = [self.experiment_name_input]
|
|
95
|
+
if self.app_options.get("device_selector", False):
|
|
96
|
+
widgets.append(self.experiment_name_input)
|
|
97
|
+
return widgets
|
|
98
|
+
|
|
99
|
+
return []
|
|
100
|
+
|
|
101
|
+
def validate_step(self) -> bool:
|
|
102
|
+
return True
|
|
103
|
+
|
|
104
|
+
def get_device(self) -> str:
|
|
105
|
+
if self.app_options.get("device_selector", False):
|
|
106
|
+
return self.select_device.get_device()
|
|
107
|
+
else:
|
|
108
|
+
return "cuda:0"
|
|
109
|
+
|
|
110
|
+
def get_experiment_name(self) -> str:
|
|
111
|
+
return self.experiment_name_input.get_value()
|
|
112
|
+
|
|
113
|
+
def set_experiment_name(self, experiment_name) -> None:
|
|
114
|
+
self.experiment_name_input.set_value(experiment_name)
|