supervisely 6.73.358__py3-none-any.whl → 6.73.360__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/app/widgets/project_thumbnail/project_thumbnail.py +3 -2
- supervisely/app/widgets/random_splits_table/random_splits_table.py +13 -2
- supervisely/app/widgets/random_splits_table/template.html +2 -2
- supervisely/app/widgets/select_app_session/select_app_session.py +3 -0
- supervisely/app/widgets/train_val_splits/train_val_splits.py +36 -24
- supervisely/nn/training/gui/gui.py +551 -186
- supervisely/nn/training/gui/input_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +26 -6
- supervisely/nn/training/gui/tags_selector.py +105 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +80 -18
- supervisely/nn/training/train_app.py +139 -43
- {supervisely-6.73.358.dist-info → supervisely-6.73.360.dist-info}/METADATA +80 -59
- {supervisely-6.73.358.dist-info → supervisely-6.73.360.dist-info}/RECORD +17 -16
- {supervisely-6.73.358.dist-info → supervisely-6.73.360.dist-info}/LICENSE +0 -0
- {supervisely-6.73.358.dist-info → supervisely-6.73.360.dist-info}/WHEEL +0 -0
- {supervisely-6.73.358.dist-info → supervisely-6.73.360.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.358.dist-info → supervisely-6.73.360.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@ from supervisely.project.download import is_cached
|
|
|
12
12
|
|
|
13
13
|
class InputSelector:
|
|
14
14
|
title = "Input project"
|
|
15
|
-
description = "Selected project from which
|
|
15
|
+
description = "Selected project from which items and annotations will be downloaded"
|
|
16
16
|
lock_message = None
|
|
17
17
|
|
|
18
18
|
def __init__(self, project_info: ProjectInfo, app_options: dict = {}):
|
|
@@ -36,6 +36,13 @@ class ModelSelector:
|
|
|
36
36
|
self.display_widgets = []
|
|
37
37
|
self.app_options = app_options
|
|
38
38
|
|
|
39
|
+
model_selector_opts = self.app_options.get("model_selector", {})
|
|
40
|
+
if not isinstance(model_selector_opts, dict):
|
|
41
|
+
model_selector_opts = {}
|
|
42
|
+
|
|
43
|
+
self.show_pretrained = True
|
|
44
|
+
self.show_custom = model_selector_opts.get("show_custom", True)
|
|
45
|
+
|
|
39
46
|
self.team_id = sly_env.team_id()
|
|
40
47
|
self.models = models
|
|
41
48
|
|
|
@@ -51,15 +58,28 @@ class ModelSelector:
|
|
|
51
58
|
logger.warning(f"Legacy checkpoints are not available for '{framework}'")
|
|
52
59
|
|
|
53
60
|
self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
|
|
61
|
+
|
|
62
|
+
tab_titles = []
|
|
63
|
+
tab_descriptions = []
|
|
64
|
+
tab_contents = []
|
|
65
|
+
if self.show_pretrained:
|
|
66
|
+
tab_titles.append(ModelSource.PRETRAINED)
|
|
67
|
+
tab_descriptions.append("Publicly available models")
|
|
68
|
+
tab_contents.append(self.pretrained_models_table)
|
|
69
|
+
if self.show_custom:
|
|
70
|
+
tab_titles.append(ModelSource.CUSTOM)
|
|
71
|
+
tab_descriptions.append("Models trained by you in Supervisely")
|
|
72
|
+
tab_contents.append(self.experiment_selector)
|
|
73
|
+
|
|
54
74
|
self.model_source_tabs = RadioTabs(
|
|
55
|
-
titles=
|
|
56
|
-
descriptions=
|
|
57
|
-
|
|
58
|
-
"Models trained by you in Supervisely",
|
|
59
|
-
],
|
|
60
|
-
contents=[self.pretrained_models_table, self.experiment_selector],
|
|
75
|
+
titles=tab_titles,
|
|
76
|
+
descriptions=tab_descriptions,
|
|
77
|
+
contents=tab_contents,
|
|
61
78
|
)
|
|
62
79
|
|
|
80
|
+
if len(tab_titles) > 0:
|
|
81
|
+
self.model_source_tabs.set_active_tab(tab_titles[0])
|
|
82
|
+
|
|
63
83
|
self.validator_text = Text("")
|
|
64
84
|
self.validator_text.hide()
|
|
65
85
|
self.button = Button("Select")
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from supervisely._utils import abs_url, is_debug_with_sly_net, is_development
|
|
2
|
+
from supervisely.app.widgets import Button, Card, Container, TagsTable, Text
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TagsSelector:
|
|
6
|
+
title = "Tags Selector"
|
|
7
|
+
description = "Select tags that will be used for training"
|
|
8
|
+
lock_message = "Select training and validation splits to unlock"
|
|
9
|
+
|
|
10
|
+
def __init__(self, project_id: int, tags: list, app_options: dict = {}):
|
|
11
|
+
# Init widgets
|
|
12
|
+
self.qa_stats_text = None
|
|
13
|
+
self.tags_table = None
|
|
14
|
+
self.validator_text = None
|
|
15
|
+
self.button = None
|
|
16
|
+
self.container = None
|
|
17
|
+
self.card = None
|
|
18
|
+
# -------------------------------- #
|
|
19
|
+
|
|
20
|
+
self.display_widgets = []
|
|
21
|
+
self.app_options = app_options
|
|
22
|
+
|
|
23
|
+
# GUI Components
|
|
24
|
+
if is_development() or is_debug_with_sly_net():
|
|
25
|
+
qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
|
|
26
|
+
else:
|
|
27
|
+
qa_stats_link = f"/projects/{project_id}/stats/datasets"
|
|
28
|
+
self.qa_stats_text = Text(
|
|
29
|
+
text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
self.tags_table = TagsTable(project_id=project_id)
|
|
33
|
+
if len(tags) > 0:
|
|
34
|
+
self.tags_table.select_tags(tags)
|
|
35
|
+
else:
|
|
36
|
+
self.tags_table.select_all()
|
|
37
|
+
|
|
38
|
+
self.validator_text = Text("")
|
|
39
|
+
self.validator_text.hide()
|
|
40
|
+
self.button = Button("Select")
|
|
41
|
+
self.display_widgets.extend(
|
|
42
|
+
[self.qa_stats_text, self.tags_table, self.validator_text, self.button]
|
|
43
|
+
)
|
|
44
|
+
# -------------------------------- #
|
|
45
|
+
|
|
46
|
+
self.container = Container(self.display_widgets)
|
|
47
|
+
self.card = Card(
|
|
48
|
+
title=self.title,
|
|
49
|
+
description=self.description,
|
|
50
|
+
content=self.container,
|
|
51
|
+
lock_message=self.lock_message,
|
|
52
|
+
collapsable=self.app_options.get("collapsable", False),
|
|
53
|
+
)
|
|
54
|
+
self.card.lock()
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def widgets_to_disable(self) -> list:
|
|
58
|
+
return [self.tags_table]
|
|
59
|
+
|
|
60
|
+
def get_selected_tags(self) -> list:
|
|
61
|
+
return self.tags_table.get_selected_tags()
|
|
62
|
+
|
|
63
|
+
def set_tags(self, tags) -> None:
|
|
64
|
+
self.tags_table.select_tags(tags)
|
|
65
|
+
|
|
66
|
+
def select_all_tags(self) -> None:
|
|
67
|
+
self.tags_table.select_all()
|
|
68
|
+
|
|
69
|
+
def validate_step(self) -> bool:
|
|
70
|
+
self.validator_text.hide()
|
|
71
|
+
|
|
72
|
+
project_tags = self.tags_table.project_meta.tag_metas
|
|
73
|
+
if len(project_tags) == 0:
|
|
74
|
+
self.validator_text.set(text="Project has no tags", status="error")
|
|
75
|
+
self.validator_text.show()
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
selected_tags = self.tags_table.get_selected_tags()
|
|
79
|
+
table_data = self.tags_table._table_data
|
|
80
|
+
empty_tags = [
|
|
81
|
+
row[0]["data"]
|
|
82
|
+
for row in table_data
|
|
83
|
+
if row[0]["data"] in selected_tags and row[2]["data"] == 0 and row[3]["data"] == 0
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
n_tags = len(selected_tags)
|
|
87
|
+
if n_tags == 0:
|
|
88
|
+
message = "Please select at least one tag"
|
|
89
|
+
status = "error"
|
|
90
|
+
else:
|
|
91
|
+
tag_text = "tag" if n_tags == 1 else "tags"
|
|
92
|
+
message = f"Selected {n_tags} {tag_text}"
|
|
93
|
+
status = "success"
|
|
94
|
+
if empty_tags:
|
|
95
|
+
intersections = set(selected_tags).intersection(empty_tags)
|
|
96
|
+
if intersections:
|
|
97
|
+
tag_text = "tag" if len(intersections) == 1 else "tags"
|
|
98
|
+
message += (
|
|
99
|
+
f". Selected {tag_text} have no annotations: {', '.join(intersections)}"
|
|
100
|
+
)
|
|
101
|
+
status = "warning"
|
|
102
|
+
|
|
103
|
+
self.validator_text.set(text=message, status=status)
|
|
104
|
+
self.validator_text.show()
|
|
105
|
+
return n_tags > 0
|
|
@@ -24,26 +24,51 @@ class TrainValSplitsSelector:
|
|
|
24
24
|
self.project_id = project_id
|
|
25
25
|
|
|
26
26
|
# GUI Components
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
27
|
+
split_methods = self.app_options.get("train_val_split_methods", [])
|
|
28
|
+
if len(split_methods) == 0:
|
|
29
|
+
split_methods = ["Random", "Based on tags", "Based on datasets"]
|
|
30
|
+
random_split = "Random" in split_methods
|
|
31
|
+
tag_split = "Based on tags" in split_methods
|
|
32
|
+
ds_split = "Based on datasets" in split_methods
|
|
33
|
+
|
|
34
|
+
self.train_val_splits = TrainValSplits(project_id, None, random_split, tag_split, ds_split)
|
|
35
|
+
|
|
36
|
+
def _extend_with_nested(root_ds):
|
|
37
|
+
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
38
|
+
nested_ids = [ds.id for ds in nested]
|
|
39
|
+
return [root_ds.id] + nested_ids
|
|
40
|
+
|
|
41
|
+
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
42
|
+
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
43
|
+
ds_name = dataset.name.lower()
|
|
44
|
+
|
|
45
|
+
if ds_name in {"train", "training"}:
|
|
46
|
+
for _id in _extend_with_nested(dataset):
|
|
47
|
+
train_val_dataset_ids["train"].add(_id)
|
|
48
|
+
|
|
49
|
+
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
50
|
+
for _id in _extend_with_nested(dataset):
|
|
51
|
+
train_val_dataset_ids["val"].add(_id)
|
|
52
|
+
|
|
53
|
+
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
54
|
+
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
55
|
+
|
|
38
56
|
train_count = len(train_val_dataset_ids["train"])
|
|
39
57
|
val_count = len(train_val_dataset_ids["val"])
|
|
58
|
+
|
|
40
59
|
if train_count > 0 and val_count > 0:
|
|
41
60
|
self.train_val_splits.set_datasets_splits(
|
|
42
61
|
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
43
62
|
)
|
|
44
63
|
|
|
45
64
|
if train_count > 0 and val_count > 0:
|
|
46
|
-
|
|
65
|
+
if train_count == val_count == 1:
|
|
66
|
+
self.validator_text = Text("train and val datasets are detected", status="info")
|
|
67
|
+
else:
|
|
68
|
+
self.validator_text = Text(
|
|
69
|
+
"Multiple train and val datasets are detected. Check manually if selection is correct",
|
|
70
|
+
status="info",
|
|
71
|
+
)
|
|
47
72
|
self.validator_text.show()
|
|
48
73
|
else:
|
|
49
74
|
self.validator_text = Text("")
|
|
@@ -71,8 +96,9 @@ class TrainValSplitsSelector:
|
|
|
71
96
|
split_method = self.train_val_splits.get_split_method()
|
|
72
97
|
warning_text = "Using the same data for training and validation leads to overfitting, poor generalization and biased model selection."
|
|
73
98
|
ensure_text = "Ensure this is intentional."
|
|
99
|
+
is_valid = False
|
|
74
100
|
|
|
75
|
-
|
|
101
|
+
def validate_random_split():
|
|
76
102
|
train_ratio = self.train_val_splits.get_train_split_percent()
|
|
77
103
|
val_ratio = self.train_val_splits.get_val_split_percent()
|
|
78
104
|
|
|
@@ -109,8 +135,9 @@ class TrainValSplitsSelector:
|
|
|
109
135
|
text="Train and validation splits are selected.",
|
|
110
136
|
status="success",
|
|
111
137
|
)
|
|
138
|
+
return True
|
|
112
139
|
|
|
113
|
-
|
|
140
|
+
def validate_based_on_tags():
|
|
114
141
|
train_tag = self.train_val_splits.get_train_tag()
|
|
115
142
|
val_tag = self.train_val_splits.get_val_tag()
|
|
116
143
|
|
|
@@ -130,29 +157,42 @@ class TrainValSplitsSelector:
|
|
|
130
157
|
else:
|
|
131
158
|
tags_count[tag_name] = tag_total
|
|
132
159
|
|
|
133
|
-
# @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
|
|
134
160
|
if tags_count[train_tag] == 0:
|
|
135
161
|
self.validator_text.set(
|
|
136
162
|
text=f"Train tag '{train_tag}' is not present in any images. {ensure_text}",
|
|
137
163
|
status="error",
|
|
138
164
|
)
|
|
165
|
+
return False
|
|
139
166
|
elif tags_count[val_tag] == 0:
|
|
140
167
|
self.validator_text.set(
|
|
141
168
|
text=f"Val tag '{val_tag}' is not present in any images. {ensure_text}",
|
|
142
169
|
status="error",
|
|
143
170
|
)
|
|
144
|
-
|
|
171
|
+
return False
|
|
145
172
|
elif train_tag == val_tag:
|
|
146
173
|
self.validator_text.set(
|
|
147
174
|
text=f"Train and val tags are the same. {ensure_text} {warning_text}",
|
|
148
175
|
status="warning",
|
|
149
176
|
)
|
|
177
|
+
return True
|
|
150
178
|
else:
|
|
151
179
|
self.validator_text.set("Train and val tags are selected", status="success")
|
|
180
|
+
return True
|
|
152
181
|
|
|
153
|
-
|
|
182
|
+
def validate_based_on_datasets():
|
|
154
183
|
train_dataset_id = self.get_train_dataset_ids()
|
|
155
184
|
val_dataset_id = self.get_val_dataset_ids()
|
|
185
|
+
if train_dataset_id is None and val_dataset_id is None:
|
|
186
|
+
self.validator_text.set("No datasets are selected", status="error")
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
if train_dataset_id is None:
|
|
190
|
+
self.validator_text.set("No train dataset is selected", status="error")
|
|
191
|
+
return False
|
|
192
|
+
|
|
193
|
+
if val_dataset_id is None:
|
|
194
|
+
self.validator_text.set("No val dataset is selected", status="error")
|
|
195
|
+
return False
|
|
156
196
|
|
|
157
197
|
# Check if datasets are not empty
|
|
158
198
|
stats = self.api.project.get_stats(self.project_id)
|
|
@@ -169,6 +209,14 @@ class TrainValSplitsSelector:
|
|
|
169
209
|
empty_dataset_names.append(datasets_count[dataset_id]["name"])
|
|
170
210
|
|
|
171
211
|
if len(empty_dataset_names) > 0:
|
|
212
|
+
if len(empty_dataset_names) == len(train_dataset_id + val_dataset_id):
|
|
213
|
+
empty_ds_text = f"All selected datasets are empty. {ensure_text}"
|
|
214
|
+
self.validator_text.set(
|
|
215
|
+
text=empty_ds_text,
|
|
216
|
+
status="error",
|
|
217
|
+
)
|
|
218
|
+
return False
|
|
219
|
+
|
|
172
220
|
if len(empty_dataset_names) == 1:
|
|
173
221
|
empty_ds_text = f"Selected dataset: {', '.join(empty_dataset_names)} is empty. {ensure_text}"
|
|
174
222
|
else:
|
|
@@ -178,16 +226,30 @@ class TrainValSplitsSelector:
|
|
|
178
226
|
text=empty_ds_text,
|
|
179
227
|
status="error",
|
|
180
228
|
)
|
|
229
|
+
return True
|
|
181
230
|
|
|
182
231
|
elif train_dataset_id == val_dataset_id:
|
|
183
232
|
self.validator_text.set(
|
|
184
233
|
text=f"Same datasets are selected for both train and val splits. {ensure_text} {warning_text}",
|
|
185
234
|
status="warning",
|
|
186
235
|
)
|
|
236
|
+
return True
|
|
187
237
|
else:
|
|
188
238
|
self.validator_text.set("Train and val datasets are selected", status="success")
|
|
239
|
+
return True
|
|
240
|
+
|
|
241
|
+
if split_method == "Random":
|
|
242
|
+
is_valid = validate_random_split()
|
|
243
|
+
|
|
244
|
+
elif split_method == "Based on tags":
|
|
245
|
+
is_valid = validate_based_on_tags()
|
|
246
|
+
|
|
247
|
+
elif split_method == "Based on datasets":
|
|
248
|
+
is_valid = validate_based_on_datasets()
|
|
249
|
+
|
|
250
|
+
# @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
|
|
189
251
|
self.validator_text.show()
|
|
190
|
-
return
|
|
252
|
+
return is_valid
|
|
191
253
|
|
|
192
254
|
def set_sly_project(self, project: Project) -> None:
|
|
193
255
|
self.train_val_splits._project_fs = project
|