supervisely 6.73.374__py3-none-any.whl → 6.73.375__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/select_collection/__init__.py +0 -0
- supervisely/app/widgets/select_collection/select_collection.py +693 -0
- supervisely/app/widgets/select_collection/template.html +3 -0
- supervisely/app/widgets/train_val_splits/train_val_splits.py +111 -13
- supervisely/nn/training/gui/gui.py +28 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +133 -30
- supervisely/nn/training/train_app.py +34 -2
- supervisely/project/pointcloud_episode_project.py +16 -0
- supervisely/project/pointcloud_project.py +16 -0
- supervisely/project/project.py +57 -0
- supervisely/project/video_project.py +16 -0
- supervisely/project/volume_project.py +16 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.375.dist-info}/METADATA +1 -1
- {supervisely-6.73.374.dist-info → supervisely-6.73.375.dist-info}/RECORD +19 -16
- {supervisely-6.73.374.dist-info → supervisely-6.73.375.dist-info}/LICENSE +0 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.375.dist-info}/WHEEL +0 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.375.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.374.dist-info → supervisely-6.73.375.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
3
4
|
|
|
4
5
|
import supervisely as sly
|
|
@@ -17,6 +18,7 @@ from supervisely.app.widgets import (
|
|
|
17
18
|
from supervisely.app.widgets.random_splits_table.random_splits_table import (
|
|
18
19
|
RandomSplitsTable,
|
|
19
20
|
)
|
|
21
|
+
from supervisely.app.widgets.select_collection.select_collection import SelectCollection
|
|
20
22
|
from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import (
|
|
21
23
|
SelectDatasetTree,
|
|
22
24
|
)
|
|
@@ -41,28 +43,27 @@ class TrainValSplits(Widget):
|
|
|
41
43
|
tags_splits: Optional[bool] = True,
|
|
42
44
|
datasets_splits: Optional[bool] = True,
|
|
43
45
|
widget_id: Optional[int] = None,
|
|
46
|
+
collections_splits: Optional[bool] = False,
|
|
44
47
|
):
|
|
45
48
|
self._project_id = project_id
|
|
46
49
|
self._project_fs = project_fs
|
|
47
50
|
|
|
48
|
-
if project_fs is not None and project_id is not None:
|
|
49
|
-
raise ValueError(
|
|
50
|
-
"You can not provide both project_id and project_fs parameters to TrainValSplits widget."
|
|
51
|
-
)
|
|
52
|
-
if project_fs is None and project_id is None:
|
|
53
|
-
raise ValueError(
|
|
54
|
-
"You should provide at least one of: project_id or project_fs parameters to TrainValSplits widget."
|
|
55
|
-
)
|
|
56
|
-
|
|
57
51
|
self._project_info = None
|
|
52
|
+
self._project_type = None
|
|
53
|
+
self._project_class = None
|
|
54
|
+
self._api = None
|
|
58
55
|
if project_id is not None:
|
|
59
56
|
self._api = Api()
|
|
60
57
|
self._project_info = self._api.project.get_info_by_id(
|
|
61
58
|
self._project_id, raise_error=True
|
|
62
59
|
)
|
|
63
60
|
|
|
64
|
-
|
|
65
|
-
|
|
61
|
+
if project_fs is not None:
|
|
62
|
+
self._project_type = project_fs.type
|
|
63
|
+
elif self._project_info is not None:
|
|
64
|
+
self._project_type = self._project_info.type
|
|
65
|
+
if self._project_type is not None:
|
|
66
|
+
self._project_class = get_project_class(self._project_type)
|
|
66
67
|
|
|
67
68
|
self._random_splits_table: RandomSplitsTable = None
|
|
68
69
|
self._train_tag_select: SelectTagMeta = None
|
|
@@ -70,6 +71,8 @@ class TrainValSplits(Widget):
|
|
|
70
71
|
self._untagged_select: SelectString = None
|
|
71
72
|
self._train_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
72
73
|
self._val_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
74
|
+
self._train_collections_select: SelectCollection = None
|
|
75
|
+
self._val_collections_select: SelectCollection = None
|
|
73
76
|
self._split_methods = []
|
|
74
77
|
|
|
75
78
|
contents = []
|
|
@@ -80,12 +83,18 @@ class TrainValSplits(Widget):
|
|
|
80
83
|
contents.append(self._get_random_content())
|
|
81
84
|
if tags_splits:
|
|
82
85
|
self._split_methods.append("Based on item tags")
|
|
83
|
-
tabs_descriptions.append(
|
|
86
|
+
tabs_descriptions.append(
|
|
87
|
+
f"{self._project_type.capitalize()} should have assigned train or val tag"
|
|
88
|
+
)
|
|
84
89
|
contents.append(self._get_tags_content())
|
|
85
90
|
if datasets_splits:
|
|
86
91
|
self._split_methods.append("Based on datasets")
|
|
87
92
|
tabs_descriptions.append("Select one or several datasets for every split")
|
|
88
93
|
contents.append(self._get_datasets_content())
|
|
94
|
+
if collections_splits:
|
|
95
|
+
self._split_methods.append("Based on collections")
|
|
96
|
+
tabs_descriptions.append("Select one or several collections for every split")
|
|
97
|
+
contents.append(self._get_collections_content())
|
|
89
98
|
if not self._split_methods:
|
|
90
99
|
raise ValueError(
|
|
91
100
|
"Any of split methods [random_splits, tags_splits, datasets_splits] must be specified in TrainValSplits."
|
|
@@ -216,6 +225,32 @@ class TrainValSplits(Widget):
|
|
|
216
225
|
widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
|
|
217
226
|
)
|
|
218
227
|
|
|
228
|
+
def _get_collections_content(self):
|
|
229
|
+
notification_box = NotificationBox(
|
|
230
|
+
title="Notice: How to make equal splits",
|
|
231
|
+
description="Choose the same collection(s) for train/validation to make splits equal. Can be used for debug and for tiny projects",
|
|
232
|
+
box_type="info",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self._train_collections_select = SelectCollection(multiselect=True, compact=True)
|
|
236
|
+
self._val_collections_select = SelectCollection(multiselect=True, compact=True)
|
|
237
|
+
if self._project_id is not None:
|
|
238
|
+
self._train_collections_select.set_project_id(self._project_id)
|
|
239
|
+
self._val_collections_select.set_project_id(self._project_id)
|
|
240
|
+
train_field = Field(
|
|
241
|
+
self._train_collections_select,
|
|
242
|
+
title="Train collection(s)",
|
|
243
|
+
description="all images in selected collection(s) are considered as training set",
|
|
244
|
+
)
|
|
245
|
+
val_field = Field(
|
|
246
|
+
self._val_collections_select,
|
|
247
|
+
title="Validation collection(s)",
|
|
248
|
+
description="all images in selected collection(s) are considered as validation set",
|
|
249
|
+
)
|
|
250
|
+
return Container(
|
|
251
|
+
widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
|
|
252
|
+
)
|
|
253
|
+
|
|
219
254
|
def get_json_data(self):
|
|
220
255
|
return {}
|
|
221
256
|
|
|
@@ -223,6 +258,8 @@ class TrainValSplits(Widget):
|
|
|
223
258
|
return {}
|
|
224
259
|
|
|
225
260
|
def get_splits(self) -> Tuple[List[ItemInfo], List[ItemInfo]]:
|
|
261
|
+
if self._project_id is None and self._project_fs is None:
|
|
262
|
+
raise ValueError("Both project_id and project_fs are None.")
|
|
226
263
|
split_method = self._content.get_active_tab()
|
|
227
264
|
tmp_project_dir = None
|
|
228
265
|
train_set, val_set = [], []
|
|
@@ -276,18 +313,35 @@ class TrainValSplits(Widget):
|
|
|
276
313
|
train_set, val_set = self._project_class.get_train_val_splits_by_dataset(
|
|
277
314
|
project_dir, train_ds_names, val_ds_names
|
|
278
315
|
)
|
|
316
|
+
elif split_method == "Based on collections":
|
|
317
|
+
if self._project_id is None:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
"You can not use collections_splits parameter without project_id parameter."
|
|
320
|
+
)
|
|
321
|
+
train_collections = self._train_collections_select.get_selected_ids()
|
|
322
|
+
val_collections = self._val_collections_select.get_selected_ids()
|
|
323
|
+
|
|
324
|
+
train_set, val_set = self._project_class.get_train_val_splits_by_collections(
|
|
325
|
+
project_dir,
|
|
326
|
+
train_collections,
|
|
327
|
+
val_collections,
|
|
328
|
+
self._project_id,
|
|
329
|
+
self._api,
|
|
330
|
+
)
|
|
279
331
|
|
|
280
332
|
if tmp_project_dir is not None:
|
|
281
333
|
remove_dir(tmp_project_dir)
|
|
282
334
|
return train_set, val_set
|
|
283
335
|
|
|
284
|
-
def set_split_method(self, split_method: Literal["random", "tags", "datasets"]):
|
|
336
|
+
def set_split_method(self, split_method: Literal["random", "tags", "datasets", "collections"]):
|
|
285
337
|
if split_method == "random":
|
|
286
338
|
split_method = "Random"
|
|
287
339
|
elif split_method == "tags":
|
|
288
340
|
split_method = "Based on item tags"
|
|
289
341
|
elif split_method == "datasets":
|
|
290
342
|
split_method = "Based on datasets"
|
|
343
|
+
elif split_method == "collections":
|
|
344
|
+
split_method = "Based on collections"
|
|
291
345
|
self._content.set_active_tab(split_method)
|
|
292
346
|
StateJson().send_changes()
|
|
293
347
|
DataJson().send_changes()
|
|
@@ -337,6 +391,42 @@ class TrainValSplits(Widget):
|
|
|
337
391
|
def get_val_dataset_ids(self) -> List[int]:
|
|
338
392
|
return self._val_ds_select.get_selected_ids()
|
|
339
393
|
|
|
394
|
+
def set_project_id_for_collections(self, project_id: int):
|
|
395
|
+
if not isinstance(project_id, int):
|
|
396
|
+
raise ValueError("Project ID must be an integer.")
|
|
397
|
+
self._project_id = project_id
|
|
398
|
+
self._project_type = None
|
|
399
|
+
if self._api is None:
|
|
400
|
+
self._api = Api()
|
|
401
|
+
self._project_info = self._api.project.get_info_by_id(self._project_id, raise_error=True)
|
|
402
|
+
self._project_type = self._project_info.type
|
|
403
|
+
self._project_class = get_project_class(self._project_type)
|
|
404
|
+
if not self._train_collections_select or not self._val_collections_select:
|
|
405
|
+
raise ValueError("Collections select widgets are not initialized.")
|
|
406
|
+
self._train_collections_select.set_project_id(project_id)
|
|
407
|
+
self._val_collections_select.set_project_id(project_id)
|
|
408
|
+
|
|
409
|
+
def get_train_collections_ids(self) -> List[int]:
|
|
410
|
+
return self._train_collections_select.get_selected_ids() or []
|
|
411
|
+
|
|
412
|
+
def get_val_collections_ids(self) -> List[int]:
|
|
413
|
+
return self._val_collections_select.get_selected_ids() or []
|
|
414
|
+
|
|
415
|
+
def set_collections_splits(self, train_collections: List[int], val_collections: List[int]):
|
|
416
|
+
self._content.set_active_tab("Based on collections")
|
|
417
|
+
self.set_collections_splits_by_ids("train", train_collections)
|
|
418
|
+
self.set_collections_splits_by_ids("val", val_collections)
|
|
419
|
+
|
|
420
|
+
def set_collections_splits_by_ids(
|
|
421
|
+
self, split: Literal["train", "val"], collection_ids: List[int]
|
|
422
|
+
):
|
|
423
|
+
if split == "train":
|
|
424
|
+
self._train_collections_select.set_collections(collection_ids)
|
|
425
|
+
elif split == "val":
|
|
426
|
+
self._val_collections_select.set_collections(collection_ids)
|
|
427
|
+
else:
|
|
428
|
+
raise ValueError("Split value must be 'train' or 'val'")
|
|
429
|
+
|
|
340
430
|
def get_untagged_action(self) -> str:
|
|
341
431
|
return self._untagged_select.get_value()
|
|
342
432
|
|
|
@@ -355,6 +445,10 @@ class TrainValSplits(Widget):
|
|
|
355
445
|
if self._val_ds_select is not None:
|
|
356
446
|
self._val_ds_select.disable()
|
|
357
447
|
self._disabled = True
|
|
448
|
+
if self._train_collections_select is not None:
|
|
449
|
+
self._train_collections_select.disable()
|
|
450
|
+
if self._val_collections_select is not None:
|
|
451
|
+
self._val_collections_select.disable()
|
|
358
452
|
DataJson()[self.widget_id]["disabled"] = self._disabled
|
|
359
453
|
DataJson().send_changes()
|
|
360
454
|
|
|
@@ -373,5 +467,9 @@ class TrainValSplits(Widget):
|
|
|
373
467
|
if self._val_ds_select is not None:
|
|
374
468
|
self._val_ds_select.enable()
|
|
375
469
|
self._disabled = False
|
|
470
|
+
if self._train_collections_select is not None:
|
|
471
|
+
self._train_collections_select.enable()
|
|
472
|
+
if self._val_collections_select is not None:
|
|
473
|
+
self._val_collections_select.enable()
|
|
376
474
|
DataJson()[self.widget_id]["disabled"] = self._disabled
|
|
377
475
|
DataJson().send_changes()
|
|
@@ -815,6 +815,20 @@ class TrainGUI:
|
|
|
815
815
|
raise ValueError("split must be 'train' or 'val'")
|
|
816
816
|
if not isinstance(percent, int) or not 0 < percent < 100:
|
|
817
817
|
raise ValueError("percent must be an integer in range 1 to 99")
|
|
818
|
+
elif train_val_splits_settings.get("method") == "collections":
|
|
819
|
+
train_collections = train_val_splits_settings.get("train_collections", [])
|
|
820
|
+
val_collections = train_val_splits_settings.get("val_collections", [])
|
|
821
|
+
collection_ids = set()
|
|
822
|
+
for collection in self._api.entities_collection.get_list(self.project_id):
|
|
823
|
+
collection_ids.add(collection.id)
|
|
824
|
+
missing_collections_ids = set(train_collections + val_collections) - collection_ids
|
|
825
|
+
if missing_collections_ids:
|
|
826
|
+
missing_collections_text = ", ".join(
|
|
827
|
+
[str(collection_id) for collection_id in missing_collections_ids]
|
|
828
|
+
)
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f"Collections with ids: {missing_collections_text} not found in the project"
|
|
831
|
+
)
|
|
818
832
|
return app_state
|
|
819
833
|
|
|
820
834
|
def load_from_app_state(self, app_state: Union[str, dict]) -> None:
|
|
@@ -849,7 +863,8 @@ class TrainGUI:
|
|
|
849
863
|
"ONNXRuntime": True,
|
|
850
864
|
"TensorRT": True
|
|
851
865
|
},
|
|
852
|
-
}
|
|
866
|
+
},
|
|
867
|
+
"experiment_name": "my_experiment",
|
|
853
868
|
}
|
|
854
869
|
"""
|
|
855
870
|
if isinstance(app_state, str):
|
|
@@ -863,6 +878,7 @@ class TrainGUI:
|
|
|
863
878
|
tags_settings = app_state.get("tags", [])
|
|
864
879
|
model_settings = app_state["model"]
|
|
865
880
|
hyperparameters_settings = app_state["hyperparameters"]
|
|
881
|
+
experiment_name = app_state.get("experiment_name", None)
|
|
866
882
|
|
|
867
883
|
self._init_input(input_settings, options)
|
|
868
884
|
self._init_train_val_splits(train_val_splits_settings, options)
|
|
@@ -870,6 +886,8 @@ class TrainGUI:
|
|
|
870
886
|
self._init_tags(tags_settings, options)
|
|
871
887
|
self._init_model(model_settings, options)
|
|
872
888
|
self._init_hyperparameters(hyperparameters_settings, options)
|
|
889
|
+
if experiment_name is not None:
|
|
890
|
+
self.training_process.set_experiment_name(experiment_name)
|
|
873
891
|
|
|
874
892
|
def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
|
|
875
893
|
"""
|
|
@@ -938,6 +956,15 @@ class TrainGUI:
|
|
|
938
956
|
self.train_val_splits_selector.train_val_splits.set_datasets_splits(
|
|
939
957
|
train_datasets, val_datasets
|
|
940
958
|
)
|
|
959
|
+
elif split_method == "collections":
|
|
960
|
+
train_collections = train_val_splits_settings["train_collections"]
|
|
961
|
+
val_collections = train_val_splits_settings["val_collections"]
|
|
962
|
+
self.train_val_splits_selector.train_val_splits.set_project_id_for_collections(
|
|
963
|
+
self.project_id
|
|
964
|
+
)
|
|
965
|
+
self.train_val_splits_selector.train_val_splits.set_collections_splits(
|
|
966
|
+
train_collections, val_collections
|
|
967
|
+
)
|
|
941
968
|
self.train_val_splits_selector_cb()
|
|
942
969
|
|
|
943
970
|
def _init_classes(self, classes_settings: list, options: dict) -> None:
|
|
@@ -26,53 +26,80 @@ class TrainValSplitsSelector:
|
|
|
26
26
|
# GUI Components
|
|
27
27
|
split_methods = self.app_options.get("train_val_split_methods", [])
|
|
28
28
|
if len(split_methods) == 0:
|
|
29
|
-
split_methods = ["Random", "Based on tags", "Based on datasets"]
|
|
29
|
+
split_methods = ["Random", "Based on tags", "Based on datasets", "Based on collections"]
|
|
30
30
|
random_split = "Random" in split_methods
|
|
31
31
|
tag_split = "Based on tags" in split_methods
|
|
32
32
|
ds_split = "Based on datasets" in split_methods
|
|
33
|
+
coll_split = "Based on collections" in split_methods
|
|
33
34
|
|
|
34
|
-
self.train_val_splits = TrainValSplits(
|
|
35
|
+
self.train_val_splits = TrainValSplits(
|
|
36
|
+
project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# check for collections with "train" and "val" prefixes
|
|
40
|
+
all_collections = self.api.entities_collection.get_list(self.project_id)
|
|
41
|
+
train_collections = []
|
|
42
|
+
val_collections = []
|
|
43
|
+
collections_found = False
|
|
44
|
+
for collection in all_collections:
|
|
45
|
+
if collection.name.lower().startswith("train"):
|
|
46
|
+
train_collections.append(collection.id)
|
|
47
|
+
elif collection.name.lower().startswith("val"):
|
|
48
|
+
val_collections.append(collection.id)
|
|
49
|
+
|
|
50
|
+
if len(train_collections) > 0 and len(val_collections) > 0:
|
|
51
|
+
self.train_val_splits.set_collections_splits(train_collections, val_collections)
|
|
52
|
+
self.validator_text = Text(
|
|
53
|
+
"Train and val collections are detected", status="info"
|
|
54
|
+
)
|
|
55
|
+
self.validator_text.show()
|
|
56
|
+
collections_found = True
|
|
57
|
+
else:
|
|
58
|
+
self.validator_text = Text("")
|
|
59
|
+
self.validator_text.hide()
|
|
60
|
+
|
|
35
61
|
|
|
36
62
|
def _extend_with_nested(root_ds):
|
|
37
63
|
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
38
64
|
nested_ids = [ds.id for ds in nested]
|
|
39
65
|
return [root_ds.id] + nested_ids
|
|
40
66
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
67
|
+
if not collections_found:
|
|
68
|
+
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
69
|
+
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
70
|
+
ds_name = dataset.name.lower()
|
|
44
71
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
72
|
+
if ds_name in {"train", "training"}:
|
|
73
|
+
for _id in _extend_with_nested(dataset):
|
|
74
|
+
train_val_dataset_ids["train"].add(_id)
|
|
48
75
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
76
|
+
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
77
|
+
for _id in _extend_with_nested(dataset):
|
|
78
|
+
train_val_dataset_ids["val"].add(_id)
|
|
52
79
|
|
|
53
|
-
|
|
54
|
-
|
|
80
|
+
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
81
|
+
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
55
82
|
|
|
56
|
-
|
|
57
|
-
|
|
83
|
+
train_count = len(train_val_dataset_ids["train"])
|
|
84
|
+
val_count = len(train_val_dataset_ids["val"])
|
|
58
85
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
86
|
+
if train_count > 0 and val_count > 0:
|
|
87
|
+
self.train_val_splits.set_datasets_splits(
|
|
88
|
+
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
89
|
+
)
|
|
63
90
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
91
|
+
if train_count > 0 and val_count > 0:
|
|
92
|
+
if train_count == val_count == 1:
|
|
93
|
+
self.validator_text = Text("train and val datasets are detected", status="info")
|
|
94
|
+
else:
|
|
95
|
+
self.validator_text = Text(
|
|
96
|
+
"Multiple train and val datasets are detected. Check manually if selection is correct",
|
|
97
|
+
status="info",
|
|
98
|
+
)
|
|
99
|
+
self.validator_text.show()
|
|
67
100
|
else:
|
|
68
|
-
self.validator_text = Text(
|
|
69
|
-
|
|
70
|
-
status="info",
|
|
71
|
-
)
|
|
72
|
-
self.validator_text.show()
|
|
73
|
-
else:
|
|
74
|
-
self.validator_text = Text("")
|
|
75
|
-
self.validator_text.hide()
|
|
101
|
+
self.validator_text = Text("")
|
|
102
|
+
self.validator_text.hide()
|
|
76
103
|
|
|
77
104
|
self.button = Button("Select")
|
|
78
105
|
self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
|
|
@@ -238,6 +265,68 @@ class TrainValSplitsSelector:
|
|
|
238
265
|
self.validator_text.set("Train and val datasets are selected", status="success")
|
|
239
266
|
return True
|
|
240
267
|
|
|
268
|
+
def validate_based_on_collections():
|
|
269
|
+
train_collection_id = self.train_val_splits.get_train_collections_ids()
|
|
270
|
+
val_collection_id = self.train_val_splits.get_val_collections_ids()
|
|
271
|
+
if train_collection_id is None and val_collection_id is None:
|
|
272
|
+
self.validator_text.set("No collections are selected", status="error")
|
|
273
|
+
return False
|
|
274
|
+
if len(train_collection_id) == 0 or len(val_collection_id) == 0:
|
|
275
|
+
self.validator_text.set("Collections are not selected", status="error")
|
|
276
|
+
return False
|
|
277
|
+
if set(train_collection_id) == set(val_collection_id):
|
|
278
|
+
self.validator_text.set(
|
|
279
|
+
text=f"Same collections are selected for both train and val splits. {ensure_text} {warning_text}",
|
|
280
|
+
status="warning",
|
|
281
|
+
)
|
|
282
|
+
return True
|
|
283
|
+
from supervisely.api.entities_collection_api import CollectionTypeFilter
|
|
284
|
+
|
|
285
|
+
train_items = set()
|
|
286
|
+
empty_train_collections = []
|
|
287
|
+
for collection_id in train_collection_id:
|
|
288
|
+
items = self.api.entities_collection.get_items(
|
|
289
|
+
collection_id=collection_id,
|
|
290
|
+
project_id=self.project_id,
|
|
291
|
+
collection_type=CollectionTypeFilter.DEFAULT,
|
|
292
|
+
)
|
|
293
|
+
train_items.update([item.id for item in items])
|
|
294
|
+
if len(items) == 0:
|
|
295
|
+
empty_train_collections.append(collection_id)
|
|
296
|
+
val_items = set()
|
|
297
|
+
empty_val_collections = []
|
|
298
|
+
for collection_id in val_collection_id:
|
|
299
|
+
items = self.api.entities_collection.get_items(
|
|
300
|
+
collection_id=collection_id,
|
|
301
|
+
project_id=self.project_id,
|
|
302
|
+
collection_type=CollectionTypeFilter.DEFAULT,
|
|
303
|
+
)
|
|
304
|
+
val_items.update([item.id for item in items])
|
|
305
|
+
if len(items) == 0:
|
|
306
|
+
empty_val_collections.append(collection_id)
|
|
307
|
+
if len(train_items) == 0 and len(val_items) == 0:
|
|
308
|
+
self.validator_text.set(
|
|
309
|
+
text="All selected collections are empty. ",
|
|
310
|
+
status="error",
|
|
311
|
+
)
|
|
312
|
+
return False
|
|
313
|
+
if len(empty_train_collections) > 0 or len(empty_val_collections) > 0:
|
|
314
|
+
empty_collections_text = "Selected collections are empty. "
|
|
315
|
+
if len(empty_train_collections) > 0:
|
|
316
|
+
empty_collections_text += f"train: {', '.join(empty_train_collections)}. "
|
|
317
|
+
if len(empty_val_collections) > 0:
|
|
318
|
+
empty_collections_text += f"val: {', '.join(empty_val_collections)}. "
|
|
319
|
+
empty_collections_text += f"{ensure_text}"
|
|
320
|
+
self.validator_text.set(
|
|
321
|
+
text=empty_collections_text,
|
|
322
|
+
status="error",
|
|
323
|
+
)
|
|
324
|
+
return True
|
|
325
|
+
|
|
326
|
+
else:
|
|
327
|
+
self.validator_text.set("Train and val collections are selected", status="success")
|
|
328
|
+
return True
|
|
329
|
+
|
|
241
330
|
if split_method == "Random":
|
|
242
331
|
is_valid = validate_random_split()
|
|
243
332
|
|
|
@@ -246,6 +335,8 @@ class TrainValSplitsSelector:
|
|
|
246
335
|
|
|
247
336
|
elif split_method == "Based on datasets":
|
|
248
337
|
is_valid = validate_based_on_datasets()
|
|
338
|
+
elif split_method == "Based on collections":
|
|
339
|
+
is_valid = validate_based_on_collections()
|
|
249
340
|
|
|
250
341
|
# @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
|
|
251
342
|
self.validator_text.show()
|
|
@@ -268,3 +359,15 @@ class TrainValSplitsSelector:
|
|
|
268
359
|
|
|
269
360
|
def set_val_dataset_ids(self, dataset_ids: List[int]) -> None:
|
|
270
361
|
self.train_val_splits._val_ds_select.set_selected_ids(dataset_ids)
|
|
362
|
+
|
|
363
|
+
def get_train_collection_ids(self) -> List[int]:
|
|
364
|
+
return self.train_val_splits._train_collections_select.get_selected_ids()
|
|
365
|
+
|
|
366
|
+
def set_train_collection_ids(self, collection_ids: List[int]) -> None:
|
|
367
|
+
self.train_val_splits._train_collections_select.set_selected_ids(collection_ids)
|
|
368
|
+
|
|
369
|
+
def get_val_collection_ids(self) -> List[int]:
|
|
370
|
+
return self.train_val_splits._val_collections_select.get_selected_ids()
|
|
371
|
+
|
|
372
|
+
def set_val_collection_ids(self, collection_ids: List[int]) -> None:
|
|
373
|
+
self.train_val_splits._val_collections_select.set_selected_ids(collection_ids)
|
|
@@ -208,16 +208,41 @@ class TrainApp:
|
|
|
208
208
|
def _train_from_api(response: Response, request: Request):
|
|
209
209
|
try:
|
|
210
210
|
state = request.state.state
|
|
211
|
+
wait = state.get("wait", True)
|
|
211
212
|
app_state = state["app_state"]
|
|
212
213
|
self.gui.load_from_app_state(app_state)
|
|
213
214
|
|
|
214
|
-
|
|
215
|
+
if wait:
|
|
216
|
+
self._wrapped_start_training()
|
|
217
|
+
else:
|
|
218
|
+
import threading
|
|
219
|
+
|
|
220
|
+
training_thread = threading.Thread(
|
|
221
|
+
target=self._wrapped_start_training,
|
|
222
|
+
daemon=True,
|
|
223
|
+
)
|
|
224
|
+
training_thread.start()
|
|
225
|
+
return {"result": "model training started"}
|
|
215
226
|
|
|
216
227
|
return {"result": "model was successfully trained"}
|
|
217
228
|
except Exception as e:
|
|
218
229
|
self.gui.training_process.start_button.loading = False
|
|
219
230
|
raise e
|
|
220
231
|
|
|
232
|
+
# # Get training status
|
|
233
|
+
# @self._server.post("/train_status")
|
|
234
|
+
# def _train_status(response: Response, request: Request):
|
|
235
|
+
# """Returns the current training status."""
|
|
236
|
+
# status = self.gui.training_process.validator_text.get_value()
|
|
237
|
+
# if status == "Training is in progress...":
|
|
238
|
+
# try:
|
|
239
|
+
# total_epochs = self.progress_bar_main.total
|
|
240
|
+
# current_epoch = self.progress_bar_main.current
|
|
241
|
+
# status += f" (Epoch {current_epoch}/{total_epochs})"
|
|
242
|
+
# except Exception:
|
|
243
|
+
# pass
|
|
244
|
+
# return {"status": status}
|
|
245
|
+
|
|
221
246
|
def _register_routes(self):
|
|
222
247
|
"""
|
|
223
248
|
Registers API routes for TensorBoard and training endpoints.
|
|
@@ -1870,6 +1895,13 @@ class TrainApp:
|
|
|
1870
1895
|
"val_datasets": self.gui.train_val_splits_selector.train_val_splits.get_val_dataset_ids(),
|
|
1871
1896
|
}
|
|
1872
1897
|
)
|
|
1898
|
+
elif split_method == "Based on collections":
|
|
1899
|
+
train_val_splits.update(
|
|
1900
|
+
{
|
|
1901
|
+
"train_collections": self.gui.train_val_splits_selector.get_train_collection_ids(),
|
|
1902
|
+
"val_collections": self.gui.train_val_splits_selector.get_val_collection_ids(),
|
|
1903
|
+
}
|
|
1904
|
+
)
|
|
1873
1905
|
return train_val_splits
|
|
1874
1906
|
|
|
1875
1907
|
def _get_model_config_for_app_state(self, experiment_info: Dict = None) -> Dict:
|
|
@@ -2110,7 +2142,7 @@ class TrainApp:
|
|
|
2110
2142
|
]
|
|
2111
2143
|
task_type = experiment_info["task_type"]
|
|
2112
2144
|
if task_type not in supported_task_types:
|
|
2113
|
-
logger.
|
|
2145
|
+
logger.warning(
|
|
2114
2146
|
f"Task type: '{task_type}' is not supported for Model Benchmark. "
|
|
2115
2147
|
f"Supported tasks: {', '.join(task_type)}"
|
|
2116
2148
|
)
|
|
@@ -495,6 +495,22 @@ class PointcloudEpisodeProject(PointcloudProject):
|
|
|
495
495
|
_add_items_to_list(project, val_datasets, val_items)
|
|
496
496
|
return train_items, val_items
|
|
497
497
|
|
|
498
|
+
@staticmethod
|
|
499
|
+
def get_train_val_splits_by_collections(
|
|
500
|
+
project_dir: str,
|
|
501
|
+
train_collections: List[int],
|
|
502
|
+
val_collections: List[int],
|
|
503
|
+
project_id: int,
|
|
504
|
+
api: Api,
|
|
505
|
+
) -> None:
|
|
506
|
+
"""
|
|
507
|
+
Not available for PointcloudEpisodeProject class.
|
|
508
|
+
:raises: :class:`NotImplementedError` in all cases.
|
|
509
|
+
"""
|
|
510
|
+
raise NotImplementedError(
|
|
511
|
+
f"Static method 'get_train_val_splits_by_collections()' is not supported for PointcloudEpisodeProject class now."
|
|
512
|
+
)
|
|
513
|
+
|
|
498
514
|
@staticmethod
|
|
499
515
|
def download(
|
|
500
516
|
api: Api,
|
|
@@ -725,6 +725,22 @@ class PointcloudProject(VideoProject):
|
|
|
725
725
|
_add_items_to_list(project, val_datasets, val_items)
|
|
726
726
|
return train_items, val_items
|
|
727
727
|
|
|
728
|
+
@staticmethod
|
|
729
|
+
def get_train_val_splits_by_collections(
|
|
730
|
+
project_dir: str,
|
|
731
|
+
train_collections: List[int],
|
|
732
|
+
val_collections: List[int],
|
|
733
|
+
project_id: int,
|
|
734
|
+
api: Api,
|
|
735
|
+
) -> None:
|
|
736
|
+
"""
|
|
737
|
+
Not available for PointcloudProject class.
|
|
738
|
+
:raises: :class:`NotImplementedError` in all cases.
|
|
739
|
+
"""
|
|
740
|
+
raise NotImplementedError(
|
|
741
|
+
f"Static method 'get_train_val_splits_by_collections()' is not supported for PointcloudProject class now."
|
|
742
|
+
)
|
|
743
|
+
|
|
728
744
|
@staticmethod
|
|
729
745
|
def download(
|
|
730
746
|
api: Api,
|
supervisely/project/project.py
CHANGED
|
@@ -3277,6 +3277,63 @@ class Project:
|
|
|
3277
3277
|
_add_items_to_list(project, val_datasets, val_items)
|
|
3278
3278
|
return train_items, val_items
|
|
3279
3279
|
|
|
3280
|
+
@staticmethod
|
|
3281
|
+
def get_train_val_splits_by_collections(
|
|
3282
|
+
project_dir: str,
|
|
3283
|
+
train_collections: List[int],
|
|
3284
|
+
val_collections: List[int],
|
|
3285
|
+
project_id: int,
|
|
3286
|
+
api: Api,
|
|
3287
|
+
) -> Tuple[List[ItemInfo], List[ItemInfo]]:
|
|
3288
|
+
"""
|
|
3289
|
+
Get train and val items information from project by given train and val collections IDs.
|
|
3290
|
+
|
|
3291
|
+
:param project_dir: Path to project directory.
|
|
3292
|
+
:type project_dir: :class:`str`
|
|
3293
|
+
:param train_collections: List of train collections IDs.
|
|
3294
|
+
:type train_collections: :class:`list` [ :class:`int` ]
|
|
3295
|
+
:param val_collections: List of val collections IDs.
|
|
3296
|
+
:type val_collections: :class:`list` [ :class:`int` ]
|
|
3297
|
+
:param project_id: Project ID.
|
|
3298
|
+
:type project_id: :class:`int`
|
|
3299
|
+
:param api: Supervisely API address and token.
|
|
3300
|
+
:type api: :class:`Api<supervisely.api.api.Api>`
|
|
3301
|
+
:raises: :class:`KeyError` if collection ID not found in project
|
|
3302
|
+
:return: Tuple with lists of train items information and val items information
|
|
3303
|
+
:rtype: :class:`list` [ :class:`ItemInfo<ItemInfo>` ], :class:`list` [ :class:`ItemInfo<ItemInfo>` ]
|
|
3304
|
+
"""
|
|
3305
|
+
from supervisely.api.entities_collection_api import CollectionTypeFilter
|
|
3306
|
+
|
|
3307
|
+
project = Project(project_dir, OpenMode.READ)
|
|
3308
|
+
|
|
3309
|
+
ds_id_to_name = {}
|
|
3310
|
+
for parents, ds_info in api.dataset.tree(project_id):
|
|
3311
|
+
full_name = "/".join(parents + [ds_info.name])
|
|
3312
|
+
ds_id_to_name[ds_info.id] = full_name
|
|
3313
|
+
|
|
3314
|
+
train_items = []
|
|
3315
|
+
val_items = []
|
|
3316
|
+
|
|
3317
|
+
for collection_ids, items_dict in [
|
|
3318
|
+
(train_collections, train_items),
|
|
3319
|
+
(val_collections, val_items),
|
|
3320
|
+
]:
|
|
3321
|
+
for collection_id in collection_ids:
|
|
3322
|
+
collection_items = api.entities_collection.get_items(
|
|
3323
|
+
collection_id=collection_id,
|
|
3324
|
+
project_id=project_id,
|
|
3325
|
+
collection_type=CollectionTypeFilter.DEFAULT,
|
|
3326
|
+
)
|
|
3327
|
+
for item in collection_items:
|
|
3328
|
+
ds_name = ds_id_to_name.get(item.dataset_id)
|
|
3329
|
+
ds = project.datasets.get(ds_name)
|
|
3330
|
+
img_path, ann_path = ds.get_item_paths(item.name)
|
|
3331
|
+
info = ItemInfo(ds_name, item.name, img_path, ann_path)
|
|
3332
|
+
items_dict.append(info)
|
|
3333
|
+
|
|
3334
|
+
return train_items, val_items
|
|
3335
|
+
|
|
3336
|
+
|
|
3280
3337
|
@staticmethod
|
|
3281
3338
|
def download(
|
|
3282
3339
|
api: Api,
|