supervisely 6.73.374__py3-none-any.whl → 6.73.376__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/nn/neural_network_api.py +1 -1
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/agent_selector/agent_selector.py +6 -0
- supervisely/app/widgets/agent_selector/template.html +2 -0
- supervisely/app/widgets/button/button.py +28 -1
- supervisely/app/widgets/button/template.html +1 -1
- supervisely/app/widgets/card/card.py +4 -0
- supervisely/app/widgets/card/template.html +1 -1
- supervisely/app/widgets/classes_table/classes_table.py +3 -1
- supervisely/app/widgets/fast_table/fast_table.py +16 -0
- supervisely/app/widgets/fast_table/script.js +6 -2
- supervisely/app/widgets/fast_table/template.html +1 -0
- supervisely/app/widgets/random_splits_table/random_splits_table.py +2 -0
- supervisely/app/widgets/select_collection/__init__.py +0 -0
- supervisely/app/widgets/select_collection/select_collection.py +693 -0
- supervisely/app/widgets/select_collection/template.html +3 -0
- supervisely/app/widgets/train_val_splits/train_val_splits.py +111 -13
- supervisely/nn/training/gui/gui.py +28 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +133 -30
- supervisely/nn/training/train_app.py +44 -16
- supervisely/project/pointcloud_episode_project.py +16 -0
- supervisely/project/pointcloud_project.py +16 -0
- supervisely/project/project.py +57 -0
- supervisely/project/video_project.py +16 -0
- supervisely/project/volume_project.py +16 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.376.dist-info}/METADATA +1 -1
- {supervisely-6.73.374.dist-info → supervisely-6.73.376.dist-info}/RECORD +31 -28
- {supervisely-6.73.374.dist-info → supervisely-6.73.376.dist-info}/LICENSE +0 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.376.dist-info}/WHEEL +0 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.376.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.376.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
3
4
|
|
|
4
5
|
import supervisely as sly
|
|
@@ -17,6 +18,7 @@ from supervisely.app.widgets import (
|
|
|
17
18
|
from supervisely.app.widgets.random_splits_table.random_splits_table import (
|
|
18
19
|
RandomSplitsTable,
|
|
19
20
|
)
|
|
21
|
+
from supervisely.app.widgets.select_collection.select_collection import SelectCollection
|
|
20
22
|
from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import (
|
|
21
23
|
SelectDatasetTree,
|
|
22
24
|
)
|
|
@@ -41,28 +43,27 @@ class TrainValSplits(Widget):
|
|
|
41
43
|
tags_splits: Optional[bool] = True,
|
|
42
44
|
datasets_splits: Optional[bool] = True,
|
|
43
45
|
widget_id: Optional[int] = None,
|
|
46
|
+
collections_splits: Optional[bool] = False,
|
|
44
47
|
):
|
|
45
48
|
self._project_id = project_id
|
|
46
49
|
self._project_fs = project_fs
|
|
47
50
|
|
|
48
|
-
if project_fs is not None and project_id is not None:
|
|
49
|
-
raise ValueError(
|
|
50
|
-
"You can not provide both project_id and project_fs parameters to TrainValSplits widget."
|
|
51
|
-
)
|
|
52
|
-
if project_fs is None and project_id is None:
|
|
53
|
-
raise ValueError(
|
|
54
|
-
"You should provide at least one of: project_id or project_fs parameters to TrainValSplits widget."
|
|
55
|
-
)
|
|
56
|
-
|
|
57
51
|
self._project_info = None
|
|
52
|
+
self._project_type = None
|
|
53
|
+
self._project_class = None
|
|
54
|
+
self._api = None
|
|
58
55
|
if project_id is not None:
|
|
59
56
|
self._api = Api()
|
|
60
57
|
self._project_info = self._api.project.get_info_by_id(
|
|
61
58
|
self._project_id, raise_error=True
|
|
62
59
|
)
|
|
63
60
|
|
|
64
|
-
|
|
65
|
-
|
|
61
|
+
if project_fs is not None:
|
|
62
|
+
self._project_type = project_fs.type
|
|
63
|
+
elif self._project_info is not None:
|
|
64
|
+
self._project_type = self._project_info.type
|
|
65
|
+
if self._project_type is not None:
|
|
66
|
+
self._project_class = get_project_class(self._project_type)
|
|
66
67
|
|
|
67
68
|
self._random_splits_table: RandomSplitsTable = None
|
|
68
69
|
self._train_tag_select: SelectTagMeta = None
|
|
@@ -70,6 +71,8 @@ class TrainValSplits(Widget):
|
|
|
70
71
|
self._untagged_select: SelectString = None
|
|
71
72
|
self._train_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
72
73
|
self._val_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
74
|
+
self._train_collections_select: SelectCollection = None
|
|
75
|
+
self._val_collections_select: SelectCollection = None
|
|
73
76
|
self._split_methods = []
|
|
74
77
|
|
|
75
78
|
contents = []
|
|
@@ -80,12 +83,18 @@ class TrainValSplits(Widget):
|
|
|
80
83
|
contents.append(self._get_random_content())
|
|
81
84
|
if tags_splits:
|
|
82
85
|
self._split_methods.append("Based on item tags")
|
|
83
|
-
tabs_descriptions.append(
|
|
86
|
+
tabs_descriptions.append(
|
|
87
|
+
f"{self._project_type.capitalize()} should have assigned train or val tag"
|
|
88
|
+
)
|
|
84
89
|
contents.append(self._get_tags_content())
|
|
85
90
|
if datasets_splits:
|
|
86
91
|
self._split_methods.append("Based on datasets")
|
|
87
92
|
tabs_descriptions.append("Select one or several datasets for every split")
|
|
88
93
|
contents.append(self._get_datasets_content())
|
|
94
|
+
if collections_splits:
|
|
95
|
+
self._split_methods.append("Based on collections")
|
|
96
|
+
tabs_descriptions.append("Select one or several collections for every split")
|
|
97
|
+
contents.append(self._get_collections_content())
|
|
89
98
|
if not self._split_methods:
|
|
90
99
|
raise ValueError(
|
|
91
100
|
"Any of split methods [random_splits, tags_splits, datasets_splits] must be specified in TrainValSplits."
|
|
@@ -216,6 +225,32 @@ class TrainValSplits(Widget):
|
|
|
216
225
|
widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
|
|
217
226
|
)
|
|
218
227
|
|
|
228
|
+
def _get_collections_content(self):
|
|
229
|
+
notification_box = NotificationBox(
|
|
230
|
+
title="Notice: How to make equal splits",
|
|
231
|
+
description="Choose the same collection(s) for train/validation to make splits equal. Can be used for debug and for tiny projects",
|
|
232
|
+
box_type="info",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self._train_collections_select = SelectCollection(multiselect=True, compact=True)
|
|
236
|
+
self._val_collections_select = SelectCollection(multiselect=True, compact=True)
|
|
237
|
+
if self._project_id is not None:
|
|
238
|
+
self._train_collections_select.set_project_id(self._project_id)
|
|
239
|
+
self._val_collections_select.set_project_id(self._project_id)
|
|
240
|
+
train_field = Field(
|
|
241
|
+
self._train_collections_select,
|
|
242
|
+
title="Train collection(s)",
|
|
243
|
+
description="all images in selected collection(s) are considered as training set",
|
|
244
|
+
)
|
|
245
|
+
val_field = Field(
|
|
246
|
+
self._val_collections_select,
|
|
247
|
+
title="Validation collection(s)",
|
|
248
|
+
description="all images in selected collection(s) are considered as validation set",
|
|
249
|
+
)
|
|
250
|
+
return Container(
|
|
251
|
+
widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
|
|
252
|
+
)
|
|
253
|
+
|
|
219
254
|
def get_json_data(self):
|
|
220
255
|
return {}
|
|
221
256
|
|
|
@@ -223,6 +258,8 @@ class TrainValSplits(Widget):
|
|
|
223
258
|
return {}
|
|
224
259
|
|
|
225
260
|
def get_splits(self) -> Tuple[List[ItemInfo], List[ItemInfo]]:
|
|
261
|
+
if self._project_id is None and self._project_fs is None:
|
|
262
|
+
raise ValueError("Both project_id and project_fs are None.")
|
|
226
263
|
split_method = self._content.get_active_tab()
|
|
227
264
|
tmp_project_dir = None
|
|
228
265
|
train_set, val_set = [], []
|
|
@@ -276,18 +313,35 @@ class TrainValSplits(Widget):
|
|
|
276
313
|
train_set, val_set = self._project_class.get_train_val_splits_by_dataset(
|
|
277
314
|
project_dir, train_ds_names, val_ds_names
|
|
278
315
|
)
|
|
316
|
+
elif split_method == "Based on collections":
|
|
317
|
+
if self._project_id is None:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
"You can not use collections_splits parameter without project_id parameter."
|
|
320
|
+
)
|
|
321
|
+
train_collections = self._train_collections_select.get_selected_ids()
|
|
322
|
+
val_collections = self._val_collections_select.get_selected_ids()
|
|
323
|
+
|
|
324
|
+
train_set, val_set = self._project_class.get_train_val_splits_by_collections(
|
|
325
|
+
project_dir,
|
|
326
|
+
train_collections,
|
|
327
|
+
val_collections,
|
|
328
|
+
self._project_id,
|
|
329
|
+
self._api,
|
|
330
|
+
)
|
|
279
331
|
|
|
280
332
|
if tmp_project_dir is not None:
|
|
281
333
|
remove_dir(tmp_project_dir)
|
|
282
334
|
return train_set, val_set
|
|
283
335
|
|
|
284
|
-
def set_split_method(self, split_method: Literal["random", "tags", "datasets"]):
|
|
336
|
+
def set_split_method(self, split_method: Literal["random", "tags", "datasets", "collections"]):
|
|
285
337
|
if split_method == "random":
|
|
286
338
|
split_method = "Random"
|
|
287
339
|
elif split_method == "tags":
|
|
288
340
|
split_method = "Based on item tags"
|
|
289
341
|
elif split_method == "datasets":
|
|
290
342
|
split_method = "Based on datasets"
|
|
343
|
+
elif split_method == "collections":
|
|
344
|
+
split_method = "Based on collections"
|
|
291
345
|
self._content.set_active_tab(split_method)
|
|
292
346
|
StateJson().send_changes()
|
|
293
347
|
DataJson().send_changes()
|
|
@@ -337,6 +391,42 @@ class TrainValSplits(Widget):
|
|
|
337
391
|
def get_val_dataset_ids(self) -> List[int]:
|
|
338
392
|
return self._val_ds_select.get_selected_ids()
|
|
339
393
|
|
|
394
|
+
def set_project_id_for_collections(self, project_id: int):
|
|
395
|
+
if not isinstance(project_id, int):
|
|
396
|
+
raise ValueError("Project ID must be an integer.")
|
|
397
|
+
self._project_id = project_id
|
|
398
|
+
self._project_type = None
|
|
399
|
+
if self._api is None:
|
|
400
|
+
self._api = Api()
|
|
401
|
+
self._project_info = self._api.project.get_info_by_id(self._project_id, raise_error=True)
|
|
402
|
+
self._project_type = self._project_info.type
|
|
403
|
+
self._project_class = get_project_class(self._project_type)
|
|
404
|
+
if not self._train_collections_select or not self._val_collections_select:
|
|
405
|
+
raise ValueError("Collections select widgets are not initialized.")
|
|
406
|
+
self._train_collections_select.set_project_id(project_id)
|
|
407
|
+
self._val_collections_select.set_project_id(project_id)
|
|
408
|
+
|
|
409
|
+
def get_train_collections_ids(self) -> List[int]:
|
|
410
|
+
return self._train_collections_select.get_selected_ids() or []
|
|
411
|
+
|
|
412
|
+
def get_val_collections_ids(self) -> List[int]:
|
|
413
|
+
return self._val_collections_select.get_selected_ids() or []
|
|
414
|
+
|
|
415
|
+
def set_collections_splits(self, train_collections: List[int], val_collections: List[int]):
|
|
416
|
+
self._content.set_active_tab("Based on collections")
|
|
417
|
+
self.set_collections_splits_by_ids("train", train_collections)
|
|
418
|
+
self.set_collections_splits_by_ids("val", val_collections)
|
|
419
|
+
|
|
420
|
+
def set_collections_splits_by_ids(
|
|
421
|
+
self, split: Literal["train", "val"], collection_ids: List[int]
|
|
422
|
+
):
|
|
423
|
+
if split == "train":
|
|
424
|
+
self._train_collections_select.set_collections(collection_ids)
|
|
425
|
+
elif split == "val":
|
|
426
|
+
self._val_collections_select.set_collections(collection_ids)
|
|
427
|
+
else:
|
|
428
|
+
raise ValueError("Split value must be 'train' or 'val'")
|
|
429
|
+
|
|
340
430
|
def get_untagged_action(self) -> str:
|
|
341
431
|
return self._untagged_select.get_value()
|
|
342
432
|
|
|
@@ -355,6 +445,10 @@ class TrainValSplits(Widget):
|
|
|
355
445
|
if self._val_ds_select is not None:
|
|
356
446
|
self._val_ds_select.disable()
|
|
357
447
|
self._disabled = True
|
|
448
|
+
if self._train_collections_select is not None:
|
|
449
|
+
self._train_collections_select.disable()
|
|
450
|
+
if self._val_collections_select is not None:
|
|
451
|
+
self._val_collections_select.disable()
|
|
358
452
|
DataJson()[self.widget_id]["disabled"] = self._disabled
|
|
359
453
|
DataJson().send_changes()
|
|
360
454
|
|
|
@@ -373,5 +467,9 @@ class TrainValSplits(Widget):
|
|
|
373
467
|
if self._val_ds_select is not None:
|
|
374
468
|
self._val_ds_select.enable()
|
|
375
469
|
self._disabled = False
|
|
470
|
+
if self._train_collections_select is not None:
|
|
471
|
+
self._train_collections_select.enable()
|
|
472
|
+
if self._val_collections_select is not None:
|
|
473
|
+
self._val_collections_select.enable()
|
|
376
474
|
DataJson()[self.widget_id]["disabled"] = self._disabled
|
|
377
475
|
DataJson().send_changes()
|
|
@@ -815,6 +815,20 @@ class TrainGUI:
|
|
|
815
815
|
raise ValueError("split must be 'train' or 'val'")
|
|
816
816
|
if not isinstance(percent, int) or not 0 < percent < 100:
|
|
817
817
|
raise ValueError("percent must be an integer in range 1 to 99")
|
|
818
|
+
elif train_val_splits_settings.get("method") == "collections":
|
|
819
|
+
train_collections = train_val_splits_settings.get("train_collections", [])
|
|
820
|
+
val_collections = train_val_splits_settings.get("val_collections", [])
|
|
821
|
+
collection_ids = set()
|
|
822
|
+
for collection in self._api.entities_collection.get_list(self.project_id):
|
|
823
|
+
collection_ids.add(collection.id)
|
|
824
|
+
missing_collections_ids = set(train_collections + val_collections) - collection_ids
|
|
825
|
+
if missing_collections_ids:
|
|
826
|
+
missing_collections_text = ", ".join(
|
|
827
|
+
[str(collection_id) for collection_id in missing_collections_ids]
|
|
828
|
+
)
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f"Collections with ids: {missing_collections_text} not found in the project"
|
|
831
|
+
)
|
|
818
832
|
return app_state
|
|
819
833
|
|
|
820
834
|
def load_from_app_state(self, app_state: Union[str, dict]) -> None:
|
|
@@ -849,7 +863,8 @@ class TrainGUI:
|
|
|
849
863
|
"ONNXRuntime": True,
|
|
850
864
|
"TensorRT": True
|
|
851
865
|
},
|
|
852
|
-
}
|
|
866
|
+
},
|
|
867
|
+
"experiment_name": "my_experiment",
|
|
853
868
|
}
|
|
854
869
|
"""
|
|
855
870
|
if isinstance(app_state, str):
|
|
@@ -863,6 +878,7 @@ class TrainGUI:
|
|
|
863
878
|
tags_settings = app_state.get("tags", [])
|
|
864
879
|
model_settings = app_state["model"]
|
|
865
880
|
hyperparameters_settings = app_state["hyperparameters"]
|
|
881
|
+
experiment_name = app_state.get("experiment_name", None)
|
|
866
882
|
|
|
867
883
|
self._init_input(input_settings, options)
|
|
868
884
|
self._init_train_val_splits(train_val_splits_settings, options)
|
|
@@ -870,6 +886,8 @@ class TrainGUI:
|
|
|
870
886
|
self._init_tags(tags_settings, options)
|
|
871
887
|
self._init_model(model_settings, options)
|
|
872
888
|
self._init_hyperparameters(hyperparameters_settings, options)
|
|
889
|
+
if experiment_name is not None:
|
|
890
|
+
self.training_process.set_experiment_name(experiment_name)
|
|
873
891
|
|
|
874
892
|
def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
|
|
875
893
|
"""
|
|
@@ -938,6 +956,15 @@ class TrainGUI:
|
|
|
938
956
|
self.train_val_splits_selector.train_val_splits.set_datasets_splits(
|
|
939
957
|
train_datasets, val_datasets
|
|
940
958
|
)
|
|
959
|
+
elif split_method == "collections":
|
|
960
|
+
train_collections = train_val_splits_settings["train_collections"]
|
|
961
|
+
val_collections = train_val_splits_settings["val_collections"]
|
|
962
|
+
self.train_val_splits_selector.train_val_splits.set_project_id_for_collections(
|
|
963
|
+
self.project_id
|
|
964
|
+
)
|
|
965
|
+
self.train_val_splits_selector.train_val_splits.set_collections_splits(
|
|
966
|
+
train_collections, val_collections
|
|
967
|
+
)
|
|
941
968
|
self.train_val_splits_selector_cb()
|
|
942
969
|
|
|
943
970
|
def _init_classes(self, classes_settings: list, options: dict) -> None:
|
|
@@ -26,53 +26,80 @@ class TrainValSplitsSelector:
|
|
|
26
26
|
# GUI Components
|
|
27
27
|
split_methods = self.app_options.get("train_val_split_methods", [])
|
|
28
28
|
if len(split_methods) == 0:
|
|
29
|
-
split_methods = ["Random", "Based on tags", "Based on datasets"]
|
|
29
|
+
split_methods = ["Random", "Based on tags", "Based on datasets", "Based on collections"]
|
|
30
30
|
random_split = "Random" in split_methods
|
|
31
31
|
tag_split = "Based on tags" in split_methods
|
|
32
32
|
ds_split = "Based on datasets" in split_methods
|
|
33
|
+
coll_split = "Based on collections" in split_methods
|
|
33
34
|
|
|
34
|
-
self.train_val_splits = TrainValSplits(
|
|
35
|
+
self.train_val_splits = TrainValSplits(
|
|
36
|
+
project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# check for collections with "train" and "val" prefixes
|
|
40
|
+
all_collections = self.api.entities_collection.get_list(self.project_id)
|
|
41
|
+
train_collections = []
|
|
42
|
+
val_collections = []
|
|
43
|
+
collections_found = False
|
|
44
|
+
for collection in all_collections:
|
|
45
|
+
if collection.name.lower().startswith("train"):
|
|
46
|
+
train_collections.append(collection.id)
|
|
47
|
+
elif collection.name.lower().startswith("val"):
|
|
48
|
+
val_collections.append(collection.id)
|
|
49
|
+
|
|
50
|
+
if len(train_collections) > 0 and len(val_collections) > 0:
|
|
51
|
+
self.train_val_splits.set_collections_splits(train_collections, val_collections)
|
|
52
|
+
self.validator_text = Text(
|
|
53
|
+
"Train and val collections are detected", status="info"
|
|
54
|
+
)
|
|
55
|
+
self.validator_text.show()
|
|
56
|
+
collections_found = True
|
|
57
|
+
else:
|
|
58
|
+
self.validator_text = Text("")
|
|
59
|
+
self.validator_text.hide()
|
|
60
|
+
|
|
35
61
|
|
|
36
62
|
def _extend_with_nested(root_ds):
|
|
37
63
|
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
38
64
|
nested_ids = [ds.id for ds in nested]
|
|
39
65
|
return [root_ds.id] + nested_ids
|
|
40
66
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
67
|
+
if not collections_found:
|
|
68
|
+
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
69
|
+
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
70
|
+
ds_name = dataset.name.lower()
|
|
44
71
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
72
|
+
if ds_name in {"train", "training"}:
|
|
73
|
+
for _id in _extend_with_nested(dataset):
|
|
74
|
+
train_val_dataset_ids["train"].add(_id)
|
|
48
75
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
76
|
+
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
77
|
+
for _id in _extend_with_nested(dataset):
|
|
78
|
+
train_val_dataset_ids["val"].add(_id)
|
|
52
79
|
|
|
53
|
-
|
|
54
|
-
|
|
80
|
+
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
81
|
+
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
55
82
|
|
|
56
|
-
|
|
57
|
-
|
|
83
|
+
train_count = len(train_val_dataset_ids["train"])
|
|
84
|
+
val_count = len(train_val_dataset_ids["val"])
|
|
58
85
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
86
|
+
if train_count > 0 and val_count > 0:
|
|
87
|
+
self.train_val_splits.set_datasets_splits(
|
|
88
|
+
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
89
|
+
)
|
|
63
90
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
91
|
+
if train_count > 0 and val_count > 0:
|
|
92
|
+
if train_count == val_count == 1:
|
|
93
|
+
self.validator_text = Text("train and val datasets are detected", status="info")
|
|
94
|
+
else:
|
|
95
|
+
self.validator_text = Text(
|
|
96
|
+
"Multiple train and val datasets are detected. Check manually if selection is correct",
|
|
97
|
+
status="info",
|
|
98
|
+
)
|
|
99
|
+
self.validator_text.show()
|
|
67
100
|
else:
|
|
68
|
-
self.validator_text = Text(
|
|
69
|
-
|
|
70
|
-
status="info",
|
|
71
|
-
)
|
|
72
|
-
self.validator_text.show()
|
|
73
|
-
else:
|
|
74
|
-
self.validator_text = Text("")
|
|
75
|
-
self.validator_text.hide()
|
|
101
|
+
self.validator_text = Text("")
|
|
102
|
+
self.validator_text.hide()
|
|
76
103
|
|
|
77
104
|
self.button = Button("Select")
|
|
78
105
|
self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
|
|
@@ -238,6 +265,68 @@ class TrainValSplitsSelector:
|
|
|
238
265
|
self.validator_text.set("Train and val datasets are selected", status="success")
|
|
239
266
|
return True
|
|
240
267
|
|
|
268
|
+
def validate_based_on_collections():
|
|
269
|
+
train_collection_id = self.train_val_splits.get_train_collections_ids()
|
|
270
|
+
val_collection_id = self.train_val_splits.get_val_collections_ids()
|
|
271
|
+
if train_collection_id is None and val_collection_id is None:
|
|
272
|
+
self.validator_text.set("No collections are selected", status="error")
|
|
273
|
+
return False
|
|
274
|
+
if len(train_collection_id) == 0 or len(val_collection_id) == 0:
|
|
275
|
+
self.validator_text.set("Collections are not selected", status="error")
|
|
276
|
+
return False
|
|
277
|
+
if set(train_collection_id) == set(val_collection_id):
|
|
278
|
+
self.validator_text.set(
|
|
279
|
+
text=f"Same collections are selected for both train and val splits. {ensure_text} {warning_text}",
|
|
280
|
+
status="warning",
|
|
281
|
+
)
|
|
282
|
+
return True
|
|
283
|
+
from supervisely.api.entities_collection_api import CollectionTypeFilter
|
|
284
|
+
|
|
285
|
+
train_items = set()
|
|
286
|
+
empty_train_collections = []
|
|
287
|
+
for collection_id in train_collection_id:
|
|
288
|
+
items = self.api.entities_collection.get_items(
|
|
289
|
+
collection_id=collection_id,
|
|
290
|
+
project_id=self.project_id,
|
|
291
|
+
collection_type=CollectionTypeFilter.DEFAULT,
|
|
292
|
+
)
|
|
293
|
+
train_items.update([item.id for item in items])
|
|
294
|
+
if len(items) == 0:
|
|
295
|
+
empty_train_collections.append(collection_id)
|
|
296
|
+
val_items = set()
|
|
297
|
+
empty_val_collections = []
|
|
298
|
+
for collection_id in val_collection_id:
|
|
299
|
+
items = self.api.entities_collection.get_items(
|
|
300
|
+
collection_id=collection_id,
|
|
301
|
+
project_id=self.project_id,
|
|
302
|
+
collection_type=CollectionTypeFilter.DEFAULT,
|
|
303
|
+
)
|
|
304
|
+
val_items.update([item.id for item in items])
|
|
305
|
+
if len(items) == 0:
|
|
306
|
+
empty_val_collections.append(collection_id)
|
|
307
|
+
if len(train_items) == 0 and len(val_items) == 0:
|
|
308
|
+
self.validator_text.set(
|
|
309
|
+
text="All selected collections are empty. ",
|
|
310
|
+
status="error",
|
|
311
|
+
)
|
|
312
|
+
return False
|
|
313
|
+
if len(empty_train_collections) > 0 or len(empty_val_collections) > 0:
|
|
314
|
+
empty_collections_text = "Selected collections are empty. "
|
|
315
|
+
if len(empty_train_collections) > 0:
|
|
316
|
+
empty_collections_text += f"train: {', '.join(empty_train_collections)}. "
|
|
317
|
+
if len(empty_val_collections) > 0:
|
|
318
|
+
empty_collections_text += f"val: {', '.join(empty_val_collections)}. "
|
|
319
|
+
empty_collections_text += f"{ensure_text}"
|
|
320
|
+
self.validator_text.set(
|
|
321
|
+
text=empty_collections_text,
|
|
322
|
+
status="error",
|
|
323
|
+
)
|
|
324
|
+
return True
|
|
325
|
+
|
|
326
|
+
else:
|
|
327
|
+
self.validator_text.set("Train and val collections are selected", status="success")
|
|
328
|
+
return True
|
|
329
|
+
|
|
241
330
|
if split_method == "Random":
|
|
242
331
|
is_valid = validate_random_split()
|
|
243
332
|
|
|
@@ -246,6 +335,8 @@ class TrainValSplitsSelector:
|
|
|
246
335
|
|
|
247
336
|
elif split_method == "Based on datasets":
|
|
248
337
|
is_valid = validate_based_on_datasets()
|
|
338
|
+
elif split_method == "Based on collections":
|
|
339
|
+
is_valid = validate_based_on_collections()
|
|
249
340
|
|
|
250
341
|
# @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
|
|
251
342
|
self.validator_text.show()
|
|
@@ -268,3 +359,15 @@ class TrainValSplitsSelector:
|
|
|
268
359
|
|
|
269
360
|
def set_val_dataset_ids(self, dataset_ids: List[int]) -> None:
|
|
270
361
|
self.train_val_splits._val_ds_select.set_selected_ids(dataset_ids)
|
|
362
|
+
|
|
363
|
+
def get_train_collection_ids(self) -> List[int]:
|
|
364
|
+
return self.train_val_splits._train_collections_select.get_selected_ids()
|
|
365
|
+
|
|
366
|
+
def set_train_collection_ids(self, collection_ids: List[int]) -> None:
|
|
367
|
+
self.train_val_splits._train_collections_select.set_selected_ids(collection_ids)
|
|
368
|
+
|
|
369
|
+
def get_val_collection_ids(self) -> List[int]:
|
|
370
|
+
return self.train_val_splits._val_collections_select.get_selected_ids()
|
|
371
|
+
|
|
372
|
+
def set_val_collection_ids(self, collection_ids: List[int]) -> None:
|
|
373
|
+
self.train_val_splits._val_collections_select.set_selected_ids(collection_ids)
|
|
@@ -44,7 +44,7 @@ from supervisely import (
|
|
|
44
44
|
)
|
|
45
45
|
from supervisely._utils import abs_url, get_filename_from_headers
|
|
46
46
|
from supervisely.api.file_api import FileInfo
|
|
47
|
-
from supervisely.app import get_synced_data_dir
|
|
47
|
+
from supervisely.app import get_synced_data_dir, show_dialog
|
|
48
48
|
from supervisely.app.widgets import Progress
|
|
49
49
|
from supervisely.nn.benchmark import (
|
|
50
50
|
InstanceSegmentationBenchmark,
|
|
@@ -208,16 +208,41 @@ class TrainApp:
|
|
|
208
208
|
def _train_from_api(response: Response, request: Request):
|
|
209
209
|
try:
|
|
210
210
|
state = request.state.state
|
|
211
|
+
wait = state.get("wait", True)
|
|
211
212
|
app_state = state["app_state"]
|
|
212
213
|
self.gui.load_from_app_state(app_state)
|
|
213
214
|
|
|
214
|
-
|
|
215
|
+
if wait:
|
|
216
|
+
self._wrapped_start_training()
|
|
217
|
+
else:
|
|
218
|
+
import threading
|
|
219
|
+
|
|
220
|
+
training_thread = threading.Thread(
|
|
221
|
+
target=self._wrapped_start_training,
|
|
222
|
+
daemon=True,
|
|
223
|
+
)
|
|
224
|
+
training_thread.start()
|
|
225
|
+
return {"result": "model training started"}
|
|
215
226
|
|
|
216
227
|
return {"result": "model was successfully trained"}
|
|
217
228
|
except Exception as e:
|
|
218
229
|
self.gui.training_process.start_button.loading = False
|
|
219
230
|
raise e
|
|
220
231
|
|
|
232
|
+
# # Get training status
|
|
233
|
+
# @self._server.post("/train_status")
|
|
234
|
+
# def _train_status(response: Response, request: Request):
|
|
235
|
+
# """Returns the current training status."""
|
|
236
|
+
# status = self.gui.training_process.validator_text.get_value()
|
|
237
|
+
# if status == "Training is in progress...":
|
|
238
|
+
# try:
|
|
239
|
+
# total_epochs = self.progress_bar_main.total
|
|
240
|
+
# current_epoch = self.progress_bar_main.current
|
|
241
|
+
# status += f" (Epoch {current_epoch}/{total_epochs})"
|
|
242
|
+
# except Exception:
|
|
243
|
+
# pass
|
|
244
|
+
# return {"status": status}
|
|
245
|
+
|
|
221
246
|
def _register_routes(self):
|
|
222
247
|
"""
|
|
223
248
|
Registers API routes for TensorBoard and training endpoints.
|
|
@@ -1870,6 +1895,13 @@ class TrainApp:
|
|
|
1870
1895
|
"val_datasets": self.gui.train_val_splits_selector.train_val_splits.get_val_dataset_ids(),
|
|
1871
1896
|
}
|
|
1872
1897
|
)
|
|
1898
|
+
elif split_method == "Based on collections":
|
|
1899
|
+
train_val_splits.update(
|
|
1900
|
+
{
|
|
1901
|
+
"train_collections": self.gui.train_val_splits_selector.get_train_collection_ids(),
|
|
1902
|
+
"val_collections": self.gui.train_val_splits_selector.get_val_collection_ids(),
|
|
1903
|
+
}
|
|
1904
|
+
)
|
|
1873
1905
|
return train_val_splits
|
|
1874
1906
|
|
|
1875
1907
|
def _get_model_config_for_app_state(self, experiment_info: Dict = None) -> Dict:
|
|
@@ -2110,7 +2142,7 @@ class TrainApp:
|
|
|
2110
2142
|
]
|
|
2111
2143
|
task_type = experiment_info["task_type"]
|
|
2112
2144
|
if task_type not in supported_task_types:
|
|
2113
|
-
logger.
|
|
2145
|
+
logger.warning(
|
|
2114
2146
|
f"Task type: '{task_type}' is not supported for Model Benchmark. "
|
|
2115
2147
|
f"Supported tasks: {', '.join(task_type)}"
|
|
2116
2148
|
)
|
|
@@ -2570,9 +2602,8 @@ class TrainApp:
|
|
|
2570
2602
|
except Exception as e:
|
|
2571
2603
|
message = f"Error occurred during training initialization. {check_logs_text}"
|
|
2572
2604
|
self._show_error(message, e)
|
|
2573
|
-
self._restore_train_widgets_state_on_error()
|
|
2574
2605
|
self._set_ws_progress_status("reset")
|
|
2575
|
-
|
|
2606
|
+
raise e
|
|
2576
2607
|
|
|
2577
2608
|
try:
|
|
2578
2609
|
self._set_text_status("preparing")
|
|
@@ -2581,9 +2612,8 @@ class TrainApp:
|
|
|
2581
2612
|
except Exception as e:
|
|
2582
2613
|
message = f"Error occurred during data preparation. {check_logs_text}"
|
|
2583
2614
|
self._show_error(message, e)
|
|
2584
|
-
self._restore_train_widgets_state_on_error()
|
|
2585
2615
|
self._set_ws_progress_status("reset")
|
|
2586
|
-
|
|
2616
|
+
raise e
|
|
2587
2617
|
|
|
2588
2618
|
try:
|
|
2589
2619
|
self._set_text_status("training")
|
|
@@ -2597,15 +2627,13 @@ class TrainApp:
|
|
|
2597
2627
|
"Please check input data and hyperparameters."
|
|
2598
2628
|
)
|
|
2599
2629
|
self._show_error(message, e)
|
|
2600
|
-
self._restore_train_widgets_state_on_error()
|
|
2601
2630
|
self._set_ws_progress_status("reset")
|
|
2602
2631
|
return
|
|
2603
2632
|
except Exception as e:
|
|
2604
2633
|
message = f"Error occurred during training. {check_logs_text}"
|
|
2605
2634
|
self._show_error(message, e)
|
|
2606
|
-
self._restore_train_widgets_state_on_error()
|
|
2607
2635
|
self._set_ws_progress_status("reset")
|
|
2608
|
-
|
|
2636
|
+
raise e
|
|
2609
2637
|
|
|
2610
2638
|
try:
|
|
2611
2639
|
self._set_text_status("finalizing")
|
|
@@ -2613,18 +2641,17 @@ class TrainApp:
|
|
|
2613
2641
|
self._finalize(experiment_info)
|
|
2614
2642
|
self.gui.training_process.start_button.loading = False
|
|
2615
2643
|
|
|
2616
|
-
|
|
2644
|
+
if is_production() and self.gui.training_logs.tensorboard_offline_button is not None:
|
|
2645
|
+
self.gui.training_logs.tensorboard_button.hide()
|
|
2646
|
+
self.gui.training_logs.tensorboard_offline_button.show()
|
|
2617
2647
|
|
|
2618
|
-
|
|
2619
|
-
self.gui.training_logs.tensorboard_offline_button.show()
|
|
2620
|
-
sleep(1) # wait for the button to be shown
|
|
2648
|
+
sleep(1)
|
|
2621
2649
|
self.app.shutdown()
|
|
2622
2650
|
except Exception as e:
|
|
2623
2651
|
message = f"Error occurred during finalizing and uploading training artifacts. {check_logs_text}"
|
|
2624
2652
|
self._show_error(message, e)
|
|
2625
|
-
self._restore_train_widgets_state_on_error()
|
|
2626
2653
|
self._set_ws_progress_status("reset")
|
|
2627
|
-
|
|
2654
|
+
raise e
|
|
2628
2655
|
|
|
2629
2656
|
def _show_error(self, message: str, e=None):
|
|
2630
2657
|
if e is not None:
|
|
@@ -2635,6 +2662,7 @@ class TrainApp:
|
|
|
2635
2662
|
self.gui.training_process.validator_text.show()
|
|
2636
2663
|
self.gui.training_process.start_button.loading = False
|
|
2637
2664
|
self._restore_train_widgets_state_on_error()
|
|
2665
|
+
show_dialog(title="Error", description=message, status="error")
|
|
2638
2666
|
|
|
2639
2667
|
def _set_train_widgets_state_on_start(self):
|
|
2640
2668
|
self.gui.disable_select_buttons()
|
|
@@ -495,6 +495,22 @@ class PointcloudEpisodeProject(PointcloudProject):
|
|
|
495
495
|
_add_items_to_list(project, val_datasets, val_items)
|
|
496
496
|
return train_items, val_items
|
|
497
497
|
|
|
498
|
+
@staticmethod
|
|
499
|
+
def get_train_val_splits_by_collections(
|
|
500
|
+
project_dir: str,
|
|
501
|
+
train_collections: List[int],
|
|
502
|
+
val_collections: List[int],
|
|
503
|
+
project_id: int,
|
|
504
|
+
api: Api,
|
|
505
|
+
) -> None:
|
|
506
|
+
"""
|
|
507
|
+
Not available for PointcloudEpisodeProject class.
|
|
508
|
+
:raises: :class:`NotImplementedError` in all cases.
|
|
509
|
+
"""
|
|
510
|
+
raise NotImplementedError(
|
|
511
|
+
f"Static method 'get_train_val_splits_by_collections()' is not supported for PointcloudEpisodeProject class now."
|
|
512
|
+
)
|
|
513
|
+
|
|
498
514
|
@staticmethod
|
|
499
515
|
def download(
|
|
500
516
|
api: Api,
|