supervisely 6.73.373__py3-none-any.whl → 6.73.375__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/run_app_button/run_app_button.py +22 -2
- supervisely/app/widgets/run_app_button/script.js +105 -45
- supervisely/app/widgets/run_app_button/template.html +5 -10
- supervisely/app/widgets/select_collection/__init__.py +0 -0
- supervisely/app/widgets/select_collection/select_collection.py +693 -0
- supervisely/app/widgets/select_collection/template.html +3 -0
- supervisely/app/widgets/train_val_splits/train_val_splits.py +111 -13
- supervisely/nn/training/gui/gui.py +28 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +133 -30
- supervisely/nn/training/gui/training_logs.py +4 -1
- supervisely/nn/training/gui/utils.py +23 -0
- supervisely/nn/training/train_app.py +47 -5
- supervisely/project/pointcloud_episode_project.py +16 -0
- supervisely/project/pointcloud_project.py +16 -0
- supervisely/project/project.py +57 -0
- supervisely/project/video_project.py +16 -0
- supervisely/project/volume_project.py +16 -0
- {supervisely-6.73.373.dist-info → supervisely-6.73.375.dist-info}/METADATA +1 -1
- {supervisely-6.73.373.dist-info → supervisely-6.73.375.dist-info}/RECORD +24 -21
- {supervisely-6.73.373.dist-info → supervisely-6.73.375.dist-info}/LICENSE +0 -0
- {supervisely-6.73.373.dist-info → supervisely-6.73.375.dist-info}/WHEEL +0 -0
- {supervisely-6.73.373.dist-info → supervisely-6.73.375.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.373.dist-info → supervisely-6.73.375.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
3
4
|
|
|
4
5
|
import supervisely as sly
|
|
@@ -17,6 +18,7 @@ from supervisely.app.widgets import (
|
|
|
17
18
|
from supervisely.app.widgets.random_splits_table.random_splits_table import (
|
|
18
19
|
RandomSplitsTable,
|
|
19
20
|
)
|
|
21
|
+
from supervisely.app.widgets.select_collection.select_collection import SelectCollection
|
|
20
22
|
from supervisely.app.widgets.select_dataset_tree.select_dataset_tree import (
|
|
21
23
|
SelectDatasetTree,
|
|
22
24
|
)
|
|
@@ -41,28 +43,27 @@ class TrainValSplits(Widget):
|
|
|
41
43
|
tags_splits: Optional[bool] = True,
|
|
42
44
|
datasets_splits: Optional[bool] = True,
|
|
43
45
|
widget_id: Optional[int] = None,
|
|
46
|
+
collections_splits: Optional[bool] = False,
|
|
44
47
|
):
|
|
45
48
|
self._project_id = project_id
|
|
46
49
|
self._project_fs = project_fs
|
|
47
50
|
|
|
48
|
-
if project_fs is not None and project_id is not None:
|
|
49
|
-
raise ValueError(
|
|
50
|
-
"You can not provide both project_id and project_fs parameters to TrainValSplits widget."
|
|
51
|
-
)
|
|
52
|
-
if project_fs is None and project_id is None:
|
|
53
|
-
raise ValueError(
|
|
54
|
-
"You should provide at least one of: project_id or project_fs parameters to TrainValSplits widget."
|
|
55
|
-
)
|
|
56
|
-
|
|
57
51
|
self._project_info = None
|
|
52
|
+
self._project_type = None
|
|
53
|
+
self._project_class = None
|
|
54
|
+
self._api = None
|
|
58
55
|
if project_id is not None:
|
|
59
56
|
self._api = Api()
|
|
60
57
|
self._project_info = self._api.project.get_info_by_id(
|
|
61
58
|
self._project_id, raise_error=True
|
|
62
59
|
)
|
|
63
60
|
|
|
64
|
-
|
|
65
|
-
|
|
61
|
+
if project_fs is not None:
|
|
62
|
+
self._project_type = project_fs.type
|
|
63
|
+
elif self._project_info is not None:
|
|
64
|
+
self._project_type = self._project_info.type
|
|
65
|
+
if self._project_type is not None:
|
|
66
|
+
self._project_class = get_project_class(self._project_type)
|
|
66
67
|
|
|
67
68
|
self._random_splits_table: RandomSplitsTable = None
|
|
68
69
|
self._train_tag_select: SelectTagMeta = None
|
|
@@ -70,6 +71,8 @@ class TrainValSplits(Widget):
|
|
|
70
71
|
self._untagged_select: SelectString = None
|
|
71
72
|
self._train_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
72
73
|
self._val_ds_select: Union[SelectDatasetTree, SelectString] = None
|
|
74
|
+
self._train_collections_select: SelectCollection = None
|
|
75
|
+
self._val_collections_select: SelectCollection = None
|
|
73
76
|
self._split_methods = []
|
|
74
77
|
|
|
75
78
|
contents = []
|
|
@@ -80,12 +83,18 @@ class TrainValSplits(Widget):
|
|
|
80
83
|
contents.append(self._get_random_content())
|
|
81
84
|
if tags_splits:
|
|
82
85
|
self._split_methods.append("Based on item tags")
|
|
83
|
-
tabs_descriptions.append(
|
|
86
|
+
tabs_descriptions.append(
|
|
87
|
+
f"{self._project_type.capitalize()} should have assigned train or val tag"
|
|
88
|
+
)
|
|
84
89
|
contents.append(self._get_tags_content())
|
|
85
90
|
if datasets_splits:
|
|
86
91
|
self._split_methods.append("Based on datasets")
|
|
87
92
|
tabs_descriptions.append("Select one or several datasets for every split")
|
|
88
93
|
contents.append(self._get_datasets_content())
|
|
94
|
+
if collections_splits:
|
|
95
|
+
self._split_methods.append("Based on collections")
|
|
96
|
+
tabs_descriptions.append("Select one or several collections for every split")
|
|
97
|
+
contents.append(self._get_collections_content())
|
|
89
98
|
if not self._split_methods:
|
|
90
99
|
raise ValueError(
|
|
91
100
|
"Any of split methods [random_splits, tags_splits, datasets_splits] must be specified in TrainValSplits."
|
|
@@ -216,6 +225,32 @@ class TrainValSplits(Widget):
|
|
|
216
225
|
widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
|
|
217
226
|
)
|
|
218
227
|
|
|
228
|
+
def _get_collections_content(self):
|
|
229
|
+
notification_box = NotificationBox(
|
|
230
|
+
title="Notice: How to make equal splits",
|
|
231
|
+
description="Choose the same collection(s) for train/validation to make splits equal. Can be used for debug and for tiny projects",
|
|
232
|
+
box_type="info",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self._train_collections_select = SelectCollection(multiselect=True, compact=True)
|
|
236
|
+
self._val_collections_select = SelectCollection(multiselect=True, compact=True)
|
|
237
|
+
if self._project_id is not None:
|
|
238
|
+
self._train_collections_select.set_project_id(self._project_id)
|
|
239
|
+
self._val_collections_select.set_project_id(self._project_id)
|
|
240
|
+
train_field = Field(
|
|
241
|
+
self._train_collections_select,
|
|
242
|
+
title="Train collection(s)",
|
|
243
|
+
description="all images in selected collection(s) are considered as training set",
|
|
244
|
+
)
|
|
245
|
+
val_field = Field(
|
|
246
|
+
self._val_collections_select,
|
|
247
|
+
title="Validation collection(s)",
|
|
248
|
+
description="all images in selected collection(s) are considered as validation set",
|
|
249
|
+
)
|
|
250
|
+
return Container(
|
|
251
|
+
widgets=[notification_box, train_field, val_field], direction="vertical", gap=5
|
|
252
|
+
)
|
|
253
|
+
|
|
219
254
|
def get_json_data(self):
|
|
220
255
|
return {}
|
|
221
256
|
|
|
@@ -223,6 +258,8 @@ class TrainValSplits(Widget):
|
|
|
223
258
|
return {}
|
|
224
259
|
|
|
225
260
|
def get_splits(self) -> Tuple[List[ItemInfo], List[ItemInfo]]:
|
|
261
|
+
if self._project_id is None and self._project_fs is None:
|
|
262
|
+
raise ValueError("Both project_id and project_fs are None.")
|
|
226
263
|
split_method = self._content.get_active_tab()
|
|
227
264
|
tmp_project_dir = None
|
|
228
265
|
train_set, val_set = [], []
|
|
@@ -276,18 +313,35 @@ class TrainValSplits(Widget):
|
|
|
276
313
|
train_set, val_set = self._project_class.get_train_val_splits_by_dataset(
|
|
277
314
|
project_dir, train_ds_names, val_ds_names
|
|
278
315
|
)
|
|
316
|
+
elif split_method == "Based on collections":
|
|
317
|
+
if self._project_id is None:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
"You can not use collections_splits parameter without project_id parameter."
|
|
320
|
+
)
|
|
321
|
+
train_collections = self._train_collections_select.get_selected_ids()
|
|
322
|
+
val_collections = self._val_collections_select.get_selected_ids()
|
|
323
|
+
|
|
324
|
+
train_set, val_set = self._project_class.get_train_val_splits_by_collections(
|
|
325
|
+
project_dir,
|
|
326
|
+
train_collections,
|
|
327
|
+
val_collections,
|
|
328
|
+
self._project_id,
|
|
329
|
+
self._api,
|
|
330
|
+
)
|
|
279
331
|
|
|
280
332
|
if tmp_project_dir is not None:
|
|
281
333
|
remove_dir(tmp_project_dir)
|
|
282
334
|
return train_set, val_set
|
|
283
335
|
|
|
284
|
-
def set_split_method(self, split_method: Literal["random", "tags", "datasets"]):
|
|
336
|
+
def set_split_method(self, split_method: Literal["random", "tags", "datasets", "collections"]):
|
|
285
337
|
if split_method == "random":
|
|
286
338
|
split_method = "Random"
|
|
287
339
|
elif split_method == "tags":
|
|
288
340
|
split_method = "Based on item tags"
|
|
289
341
|
elif split_method == "datasets":
|
|
290
342
|
split_method = "Based on datasets"
|
|
343
|
+
elif split_method == "collections":
|
|
344
|
+
split_method = "Based on collections"
|
|
291
345
|
self._content.set_active_tab(split_method)
|
|
292
346
|
StateJson().send_changes()
|
|
293
347
|
DataJson().send_changes()
|
|
@@ -337,6 +391,42 @@ class TrainValSplits(Widget):
|
|
|
337
391
|
def get_val_dataset_ids(self) -> List[int]:
|
|
338
392
|
return self._val_ds_select.get_selected_ids()
|
|
339
393
|
|
|
394
|
+
def set_project_id_for_collections(self, project_id: int):
|
|
395
|
+
if not isinstance(project_id, int):
|
|
396
|
+
raise ValueError("Project ID must be an integer.")
|
|
397
|
+
self._project_id = project_id
|
|
398
|
+
self._project_type = None
|
|
399
|
+
if self._api is None:
|
|
400
|
+
self._api = Api()
|
|
401
|
+
self._project_info = self._api.project.get_info_by_id(self._project_id, raise_error=True)
|
|
402
|
+
self._project_type = self._project_info.type
|
|
403
|
+
self._project_class = get_project_class(self._project_type)
|
|
404
|
+
if not self._train_collections_select or not self._val_collections_select:
|
|
405
|
+
raise ValueError("Collections select widgets are not initialized.")
|
|
406
|
+
self._train_collections_select.set_project_id(project_id)
|
|
407
|
+
self._val_collections_select.set_project_id(project_id)
|
|
408
|
+
|
|
409
|
+
def get_train_collections_ids(self) -> List[int]:
|
|
410
|
+
return self._train_collections_select.get_selected_ids() or []
|
|
411
|
+
|
|
412
|
+
def get_val_collections_ids(self) -> List[int]:
|
|
413
|
+
return self._val_collections_select.get_selected_ids() or []
|
|
414
|
+
|
|
415
|
+
def set_collections_splits(self, train_collections: List[int], val_collections: List[int]):
|
|
416
|
+
self._content.set_active_tab("Based on collections")
|
|
417
|
+
self.set_collections_splits_by_ids("train", train_collections)
|
|
418
|
+
self.set_collections_splits_by_ids("val", val_collections)
|
|
419
|
+
|
|
420
|
+
def set_collections_splits_by_ids(
|
|
421
|
+
self, split: Literal["train", "val"], collection_ids: List[int]
|
|
422
|
+
):
|
|
423
|
+
if split == "train":
|
|
424
|
+
self._train_collections_select.set_collections(collection_ids)
|
|
425
|
+
elif split == "val":
|
|
426
|
+
self._val_collections_select.set_collections(collection_ids)
|
|
427
|
+
else:
|
|
428
|
+
raise ValueError("Split value must be 'train' or 'val'")
|
|
429
|
+
|
|
340
430
|
def get_untagged_action(self) -> str:
|
|
341
431
|
return self._untagged_select.get_value()
|
|
342
432
|
|
|
@@ -355,6 +445,10 @@ class TrainValSplits(Widget):
|
|
|
355
445
|
if self._val_ds_select is not None:
|
|
356
446
|
self._val_ds_select.disable()
|
|
357
447
|
self._disabled = True
|
|
448
|
+
if self._train_collections_select is not None:
|
|
449
|
+
self._train_collections_select.disable()
|
|
450
|
+
if self._val_collections_select is not None:
|
|
451
|
+
self._val_collections_select.disable()
|
|
358
452
|
DataJson()[self.widget_id]["disabled"] = self._disabled
|
|
359
453
|
DataJson().send_changes()
|
|
360
454
|
|
|
@@ -373,5 +467,9 @@ class TrainValSplits(Widget):
|
|
|
373
467
|
if self._val_ds_select is not None:
|
|
374
468
|
self._val_ds_select.enable()
|
|
375
469
|
self._disabled = False
|
|
470
|
+
if self._train_collections_select is not None:
|
|
471
|
+
self._train_collections_select.enable()
|
|
472
|
+
if self._val_collections_select is not None:
|
|
473
|
+
self._val_collections_select.enable()
|
|
376
474
|
DataJson()[self.widget_id]["disabled"] = self._disabled
|
|
377
475
|
DataJson().send_changes()
|
|
@@ -815,6 +815,20 @@ class TrainGUI:
|
|
|
815
815
|
raise ValueError("split must be 'train' or 'val'")
|
|
816
816
|
if not isinstance(percent, int) or not 0 < percent < 100:
|
|
817
817
|
raise ValueError("percent must be an integer in range 1 to 99")
|
|
818
|
+
elif train_val_splits_settings.get("method") == "collections":
|
|
819
|
+
train_collections = train_val_splits_settings.get("train_collections", [])
|
|
820
|
+
val_collections = train_val_splits_settings.get("val_collections", [])
|
|
821
|
+
collection_ids = set()
|
|
822
|
+
for collection in self._api.entities_collection.get_list(self.project_id):
|
|
823
|
+
collection_ids.add(collection.id)
|
|
824
|
+
missing_collections_ids = set(train_collections + val_collections) - collection_ids
|
|
825
|
+
if missing_collections_ids:
|
|
826
|
+
missing_collections_text = ", ".join(
|
|
827
|
+
[str(collection_id) for collection_id in missing_collections_ids]
|
|
828
|
+
)
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f"Collections with ids: {missing_collections_text} not found in the project"
|
|
831
|
+
)
|
|
818
832
|
return app_state
|
|
819
833
|
|
|
820
834
|
def load_from_app_state(self, app_state: Union[str, dict]) -> None:
|
|
@@ -849,7 +863,8 @@ class TrainGUI:
|
|
|
849
863
|
"ONNXRuntime": True,
|
|
850
864
|
"TensorRT": True
|
|
851
865
|
},
|
|
852
|
-
}
|
|
866
|
+
},
|
|
867
|
+
"experiment_name": "my_experiment",
|
|
853
868
|
}
|
|
854
869
|
"""
|
|
855
870
|
if isinstance(app_state, str):
|
|
@@ -863,6 +878,7 @@ class TrainGUI:
|
|
|
863
878
|
tags_settings = app_state.get("tags", [])
|
|
864
879
|
model_settings = app_state["model"]
|
|
865
880
|
hyperparameters_settings = app_state["hyperparameters"]
|
|
881
|
+
experiment_name = app_state.get("experiment_name", None)
|
|
866
882
|
|
|
867
883
|
self._init_input(input_settings, options)
|
|
868
884
|
self._init_train_val_splits(train_val_splits_settings, options)
|
|
@@ -870,6 +886,8 @@ class TrainGUI:
|
|
|
870
886
|
self._init_tags(tags_settings, options)
|
|
871
887
|
self._init_model(model_settings, options)
|
|
872
888
|
self._init_hyperparameters(hyperparameters_settings, options)
|
|
889
|
+
if experiment_name is not None:
|
|
890
|
+
self.training_process.set_experiment_name(experiment_name)
|
|
873
891
|
|
|
874
892
|
def _init_input(self, input_settings: Union[dict, None], options: dict) -> None:
|
|
875
893
|
"""
|
|
@@ -938,6 +956,15 @@ class TrainGUI:
|
|
|
938
956
|
self.train_val_splits_selector.train_val_splits.set_datasets_splits(
|
|
939
957
|
train_datasets, val_datasets
|
|
940
958
|
)
|
|
959
|
+
elif split_method == "collections":
|
|
960
|
+
train_collections = train_val_splits_settings["train_collections"]
|
|
961
|
+
val_collections = train_val_splits_settings["val_collections"]
|
|
962
|
+
self.train_val_splits_selector.train_val_splits.set_project_id_for_collections(
|
|
963
|
+
self.project_id
|
|
964
|
+
)
|
|
965
|
+
self.train_val_splits_selector.train_val_splits.set_collections_splits(
|
|
966
|
+
train_collections, val_collections
|
|
967
|
+
)
|
|
941
968
|
self.train_val_splits_selector_cb()
|
|
942
969
|
|
|
943
970
|
def _init_classes(self, classes_settings: list, options: dict) -> None:
|
|
@@ -26,53 +26,80 @@ class TrainValSplitsSelector:
|
|
|
26
26
|
# GUI Components
|
|
27
27
|
split_methods = self.app_options.get("train_val_split_methods", [])
|
|
28
28
|
if len(split_methods) == 0:
|
|
29
|
-
split_methods = ["Random", "Based on tags", "Based on datasets"]
|
|
29
|
+
split_methods = ["Random", "Based on tags", "Based on datasets", "Based on collections"]
|
|
30
30
|
random_split = "Random" in split_methods
|
|
31
31
|
tag_split = "Based on tags" in split_methods
|
|
32
32
|
ds_split = "Based on datasets" in split_methods
|
|
33
|
+
coll_split = "Based on collections" in split_methods
|
|
33
34
|
|
|
34
|
-
self.train_val_splits = TrainValSplits(
|
|
35
|
+
self.train_val_splits = TrainValSplits(
|
|
36
|
+
project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# check for collections with "train" and "val" prefixes
|
|
40
|
+
all_collections = self.api.entities_collection.get_list(self.project_id)
|
|
41
|
+
train_collections = []
|
|
42
|
+
val_collections = []
|
|
43
|
+
collections_found = False
|
|
44
|
+
for collection in all_collections:
|
|
45
|
+
if collection.name.lower().startswith("train"):
|
|
46
|
+
train_collections.append(collection.id)
|
|
47
|
+
elif collection.name.lower().startswith("val"):
|
|
48
|
+
val_collections.append(collection.id)
|
|
49
|
+
|
|
50
|
+
if len(train_collections) > 0 and len(val_collections) > 0:
|
|
51
|
+
self.train_val_splits.set_collections_splits(train_collections, val_collections)
|
|
52
|
+
self.validator_text = Text(
|
|
53
|
+
"Train and val collections are detected", status="info"
|
|
54
|
+
)
|
|
55
|
+
self.validator_text.show()
|
|
56
|
+
collections_found = True
|
|
57
|
+
else:
|
|
58
|
+
self.validator_text = Text("")
|
|
59
|
+
self.validator_text.hide()
|
|
60
|
+
|
|
35
61
|
|
|
36
62
|
def _extend_with_nested(root_ds):
|
|
37
63
|
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
38
64
|
nested_ids = [ds.id for ds in nested]
|
|
39
65
|
return [root_ds.id] + nested_ids
|
|
40
66
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
67
|
+
if not collections_found:
|
|
68
|
+
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
69
|
+
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
70
|
+
ds_name = dataset.name.lower()
|
|
44
71
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
72
|
+
if ds_name in {"train", "training"}:
|
|
73
|
+
for _id in _extend_with_nested(dataset):
|
|
74
|
+
train_val_dataset_ids["train"].add(_id)
|
|
48
75
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
76
|
+
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
77
|
+
for _id in _extend_with_nested(dataset):
|
|
78
|
+
train_val_dataset_ids["val"].add(_id)
|
|
52
79
|
|
|
53
|
-
|
|
54
|
-
|
|
80
|
+
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
81
|
+
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
55
82
|
|
|
56
|
-
|
|
57
|
-
|
|
83
|
+
train_count = len(train_val_dataset_ids["train"])
|
|
84
|
+
val_count = len(train_val_dataset_ids["val"])
|
|
58
85
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
86
|
+
if train_count > 0 and val_count > 0:
|
|
87
|
+
self.train_val_splits.set_datasets_splits(
|
|
88
|
+
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
89
|
+
)
|
|
63
90
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
91
|
+
if train_count > 0 and val_count > 0:
|
|
92
|
+
if train_count == val_count == 1:
|
|
93
|
+
self.validator_text = Text("train and val datasets are detected", status="info")
|
|
94
|
+
else:
|
|
95
|
+
self.validator_text = Text(
|
|
96
|
+
"Multiple train and val datasets are detected. Check manually if selection is correct",
|
|
97
|
+
status="info",
|
|
98
|
+
)
|
|
99
|
+
self.validator_text.show()
|
|
67
100
|
else:
|
|
68
|
-
self.validator_text = Text(
|
|
69
|
-
|
|
70
|
-
status="info",
|
|
71
|
-
)
|
|
72
|
-
self.validator_text.show()
|
|
73
|
-
else:
|
|
74
|
-
self.validator_text = Text("")
|
|
75
|
-
self.validator_text.hide()
|
|
101
|
+
self.validator_text = Text("")
|
|
102
|
+
self.validator_text.hide()
|
|
76
103
|
|
|
77
104
|
self.button = Button("Select")
|
|
78
105
|
self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
|
|
@@ -238,6 +265,68 @@ class TrainValSplitsSelector:
|
|
|
238
265
|
self.validator_text.set("Train and val datasets are selected", status="success")
|
|
239
266
|
return True
|
|
240
267
|
|
|
268
|
+
def validate_based_on_collections():
|
|
269
|
+
train_collection_id = self.train_val_splits.get_train_collections_ids()
|
|
270
|
+
val_collection_id = self.train_val_splits.get_val_collections_ids()
|
|
271
|
+
if train_collection_id is None and val_collection_id is None:
|
|
272
|
+
self.validator_text.set("No collections are selected", status="error")
|
|
273
|
+
return False
|
|
274
|
+
if len(train_collection_id) == 0 or len(val_collection_id) == 0:
|
|
275
|
+
self.validator_text.set("Collections are not selected", status="error")
|
|
276
|
+
return False
|
|
277
|
+
if set(train_collection_id) == set(val_collection_id):
|
|
278
|
+
self.validator_text.set(
|
|
279
|
+
text=f"Same collections are selected for both train and val splits. {ensure_text} {warning_text}",
|
|
280
|
+
status="warning",
|
|
281
|
+
)
|
|
282
|
+
return True
|
|
283
|
+
from supervisely.api.entities_collection_api import CollectionTypeFilter
|
|
284
|
+
|
|
285
|
+
train_items = set()
|
|
286
|
+
empty_train_collections = []
|
|
287
|
+
for collection_id in train_collection_id:
|
|
288
|
+
items = self.api.entities_collection.get_items(
|
|
289
|
+
collection_id=collection_id,
|
|
290
|
+
project_id=self.project_id,
|
|
291
|
+
collection_type=CollectionTypeFilter.DEFAULT,
|
|
292
|
+
)
|
|
293
|
+
train_items.update([item.id for item in items])
|
|
294
|
+
if len(items) == 0:
|
|
295
|
+
empty_train_collections.append(collection_id)
|
|
296
|
+
val_items = set()
|
|
297
|
+
empty_val_collections = []
|
|
298
|
+
for collection_id in val_collection_id:
|
|
299
|
+
items = self.api.entities_collection.get_items(
|
|
300
|
+
collection_id=collection_id,
|
|
301
|
+
project_id=self.project_id,
|
|
302
|
+
collection_type=CollectionTypeFilter.DEFAULT,
|
|
303
|
+
)
|
|
304
|
+
val_items.update([item.id for item in items])
|
|
305
|
+
if len(items) == 0:
|
|
306
|
+
empty_val_collections.append(collection_id)
|
|
307
|
+
if len(train_items) == 0 and len(val_items) == 0:
|
|
308
|
+
self.validator_text.set(
|
|
309
|
+
text="All selected collections are empty. ",
|
|
310
|
+
status="error",
|
|
311
|
+
)
|
|
312
|
+
return False
|
|
313
|
+
if len(empty_train_collections) > 0 or len(empty_val_collections) > 0:
|
|
314
|
+
empty_collections_text = "Selected collections are empty. "
|
|
315
|
+
if len(empty_train_collections) > 0:
|
|
316
|
+
empty_collections_text += f"train: {', '.join(empty_train_collections)}. "
|
|
317
|
+
if len(empty_val_collections) > 0:
|
|
318
|
+
empty_collections_text += f"val: {', '.join(empty_val_collections)}. "
|
|
319
|
+
empty_collections_text += f"{ensure_text}"
|
|
320
|
+
self.validator_text.set(
|
|
321
|
+
text=empty_collections_text,
|
|
322
|
+
status="error",
|
|
323
|
+
)
|
|
324
|
+
return True
|
|
325
|
+
|
|
326
|
+
else:
|
|
327
|
+
self.validator_text.set("Train and val collections are selected", status="success")
|
|
328
|
+
return True
|
|
329
|
+
|
|
241
330
|
if split_method == "Random":
|
|
242
331
|
is_valid = validate_random_split()
|
|
243
332
|
|
|
@@ -246,6 +335,8 @@ class TrainValSplitsSelector:
|
|
|
246
335
|
|
|
247
336
|
elif split_method == "Based on datasets":
|
|
248
337
|
is_valid = validate_based_on_datasets()
|
|
338
|
+
elif split_method == "Based on collections":
|
|
339
|
+
is_valid = validate_based_on_collections()
|
|
249
340
|
|
|
250
341
|
# @TODO: handle button correctly if validation fails. Do not unlock next card until validation passes if returned False
|
|
251
342
|
self.validator_text.show()
|
|
@@ -268,3 +359,15 @@ class TrainValSplitsSelector:
|
|
|
268
359
|
|
|
269
360
|
def set_val_dataset_ids(self, dataset_ids: List[int]) -> None:
|
|
270
361
|
self.train_val_splits._val_ds_select.set_selected_ids(dataset_ids)
|
|
362
|
+
|
|
363
|
+
def get_train_collection_ids(self) -> List[int]:
|
|
364
|
+
return self.train_val_splits._train_collections_select.get_selected_ids()
|
|
365
|
+
|
|
366
|
+
def set_train_collection_ids(self, collection_ids: List[int]) -> None:
|
|
367
|
+
self.train_val_splits._train_collections_select.set_selected_ids(collection_ids)
|
|
368
|
+
|
|
369
|
+
def get_val_collection_ids(self) -> List[int]:
|
|
370
|
+
return self.train_val_splits._val_collections_select.get_selected_ids()
|
|
371
|
+
|
|
372
|
+
def set_val_collection_ids(self, collection_ids: List[int]) -> None:
|
|
373
|
+
self.train_val_splits._val_collections_select.set_selected_ids(collection_ids)
|
|
@@ -36,6 +36,7 @@ class TrainingLogs:
|
|
|
36
36
|
self.display_widgets = []
|
|
37
37
|
self.app_options = app_options
|
|
38
38
|
api = Api.from_env()
|
|
39
|
+
team_id = sly_env.team_id()
|
|
39
40
|
|
|
40
41
|
# GUI Components
|
|
41
42
|
self.validator_text = Text("")
|
|
@@ -73,6 +74,7 @@ class TrainingLogs:
|
|
|
73
74
|
module_info = gui_utils.get_module_info_by_name(api, app_name)
|
|
74
75
|
if module_info is not None:
|
|
75
76
|
self.tensorboard_offline_button = RunAppButton(
|
|
77
|
+
team_id=team_id,
|
|
76
78
|
workspace_id=workspace_id,
|
|
77
79
|
module_id=module_info["id"],
|
|
78
80
|
payload={},
|
|
@@ -81,9 +83,10 @@ class TrainingLogs:
|
|
|
81
83
|
plain=True,
|
|
82
84
|
icon="zmdi zmdi-chart",
|
|
83
85
|
available_in_offline=True,
|
|
84
|
-
visible_by_vue_field=
|
|
86
|
+
visible_by_vue_field=None,
|
|
85
87
|
)
|
|
86
88
|
self.tensorboard_offline_button.disable()
|
|
89
|
+
self.tensorboard_offline_button.hide()
|
|
87
90
|
self.display_widgets.extend([self.tensorboard_offline_button])
|
|
88
91
|
else:
|
|
89
92
|
logger.warning(
|
|
@@ -136,3 +136,26 @@ def get_module_info_by_name(api: Api, app_name: str) -> Union[Dict, None]:
|
|
|
136
136
|
if module["name"] == app_name:
|
|
137
137
|
app_info = api.app.get_info(module["id"])
|
|
138
138
|
return app_info
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def generate_task_check_function_js(folder: str) -> str:
|
|
142
|
+
"""
|
|
143
|
+
Returns JavaScript function code for checking existing tasks.
|
|
144
|
+
|
|
145
|
+
:param folder: Remote folder to check.
|
|
146
|
+
:type folder: str
|
|
147
|
+
:return: JavaScript function code for checking existing tasks.
|
|
148
|
+
:rtype: str
|
|
149
|
+
"""
|
|
150
|
+
escaped_folder = folder.replace("'", "\\'")
|
|
151
|
+
js_code = f"""
|
|
152
|
+
if (!task || !task.meta || !task.meta.params || !task.meta.params.state) {{
|
|
153
|
+
return false;
|
|
154
|
+
}}
|
|
155
|
+
const taskFolder = task.meta.params.state.slyFolder;
|
|
156
|
+
if (!taskFolder || typeof taskFolder !== 'string') {{
|
|
157
|
+
return false;
|
|
158
|
+
}}
|
|
159
|
+
return taskFolder === '{escaped_folder}';
|
|
160
|
+
"""
|
|
161
|
+
return js_code
|
|
@@ -10,6 +10,7 @@ import subprocess
|
|
|
10
10
|
from datetime import datetime
|
|
11
11
|
from os import getcwd, listdir, walk
|
|
12
12
|
from os.path import basename, dirname, exists, expanduser, isdir, isfile, join
|
|
13
|
+
from time import sleep
|
|
13
14
|
from typing import Any, Dict, List, Literal, Optional, Union
|
|
14
15
|
from urllib.request import urlopen
|
|
15
16
|
|
|
@@ -57,6 +58,7 @@ from supervisely.nn.inference import RuntimeType, SessionJSON
|
|
|
57
58
|
from supervisely.nn.inference.inference import Inference
|
|
58
59
|
from supervisely.nn.task_type import TaskType
|
|
59
60
|
from supervisely.nn.training.gui.gui import TrainGUI
|
|
61
|
+
from supervisely.nn.training.gui.utils import generate_task_check_function_js
|
|
60
62
|
from supervisely.nn.training.loggers import setup_train_logger, train_logger
|
|
61
63
|
from supervisely.nn.utils import ModelSource, _get_model_name
|
|
62
64
|
from supervisely.output import set_directory
|
|
@@ -206,16 +208,41 @@ class TrainApp:
|
|
|
206
208
|
def _train_from_api(response: Response, request: Request):
|
|
207
209
|
try:
|
|
208
210
|
state = request.state.state
|
|
211
|
+
wait = state.get("wait", True)
|
|
209
212
|
app_state = state["app_state"]
|
|
210
213
|
self.gui.load_from_app_state(app_state)
|
|
211
214
|
|
|
212
|
-
|
|
215
|
+
if wait:
|
|
216
|
+
self._wrapped_start_training()
|
|
217
|
+
else:
|
|
218
|
+
import threading
|
|
219
|
+
|
|
220
|
+
training_thread = threading.Thread(
|
|
221
|
+
target=self._wrapped_start_training,
|
|
222
|
+
daemon=True,
|
|
223
|
+
)
|
|
224
|
+
training_thread.start()
|
|
225
|
+
return {"result": "model training started"}
|
|
213
226
|
|
|
214
227
|
return {"result": "model was successfully trained"}
|
|
215
228
|
except Exception as e:
|
|
216
229
|
self.gui.training_process.start_button.loading = False
|
|
217
230
|
raise e
|
|
218
231
|
|
|
232
|
+
# # Get training status
|
|
233
|
+
# @self._server.post("/train_status")
|
|
234
|
+
# def _train_status(response: Response, request: Request):
|
|
235
|
+
# """Returns the current training status."""
|
|
236
|
+
# status = self.gui.training_process.validator_text.get_value()
|
|
237
|
+
# if status == "Training is in progress...":
|
|
238
|
+
# try:
|
|
239
|
+
# total_epochs = self.progress_bar_main.total
|
|
240
|
+
# current_epoch = self.progress_bar_main.current
|
|
241
|
+
# status += f" (Epoch {current_epoch}/{total_epochs})"
|
|
242
|
+
# except Exception:
|
|
243
|
+
# pass
|
|
244
|
+
# return {"status": status}
|
|
245
|
+
|
|
219
246
|
def _register_routes(self):
|
|
220
247
|
"""
|
|
221
248
|
Registers API routes for TensorBoard and training endpoints.
|
|
@@ -1868,6 +1895,13 @@ class TrainApp:
|
|
|
1868
1895
|
"val_datasets": self.gui.train_val_splits_selector.train_val_splits.get_val_dataset_ids(),
|
|
1869
1896
|
}
|
|
1870
1897
|
)
|
|
1898
|
+
elif split_method == "Based on collections":
|
|
1899
|
+
train_val_splits.update(
|
|
1900
|
+
{
|
|
1901
|
+
"train_collections": self.gui.train_val_splits_selector.get_train_collection_ids(),
|
|
1902
|
+
"val_collections": self.gui.train_val_splits_selector.get_val_collection_ids(),
|
|
1903
|
+
}
|
|
1904
|
+
)
|
|
1871
1905
|
return train_val_splits
|
|
1872
1906
|
|
|
1873
1907
|
def _get_model_config_for_app_state(self, experiment_info: Dict = None) -> Dict:
|
|
@@ -1949,11 +1983,15 @@ class TrainApp:
|
|
|
1949
1983
|
self.progress_bar_main.hide()
|
|
1950
1984
|
|
|
1951
1985
|
file_info = self._api.file.get_info_by_path(self.team_id, join(remote_dir, "open_app.lnk"))
|
|
1986
|
+
|
|
1952
1987
|
# Set offline tensorboard button payload
|
|
1953
1988
|
if is_production():
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
|
|
1989
|
+
remote_log_dir = join(remote_dir, "logs")
|
|
1990
|
+
tb_btn_payload = {"state": {"slyFolder": remote_log_dir}}
|
|
1991
|
+
self.gui.training_logs.tensorboard_offline_button.payload = tb_btn_payload
|
|
1992
|
+
self.gui.training_logs.tensorboard_offline_button.set_check_existing_task_cb(
|
|
1993
|
+
generate_task_check_function_js(remote_log_dir)
|
|
1994
|
+
)
|
|
1957
1995
|
self.gui.training_logs.tensorboard_offline_button.enable()
|
|
1958
1996
|
|
|
1959
1997
|
return remote_dir, file_info
|
|
@@ -2104,7 +2142,7 @@ class TrainApp:
|
|
|
2104
2142
|
]
|
|
2105
2143
|
task_type = experiment_info["task_type"]
|
|
2106
2144
|
if task_type not in supported_task_types:
|
|
2107
|
-
logger.
|
|
2145
|
+
logger.warning(
|
|
2108
2146
|
f"Task type: '{task_type}' is not supported for Model Benchmark. "
|
|
2109
2147
|
f"Supported tasks: {', '.join(task_type)}"
|
|
2110
2148
|
)
|
|
@@ -2608,6 +2646,10 @@ class TrainApp:
|
|
|
2608
2646
|
self.gui.training_process.start_button.loading = False
|
|
2609
2647
|
|
|
2610
2648
|
# Shutdown the app after training is finished
|
|
2649
|
+
|
|
2650
|
+
self.gui.training_logs.tensorboard_button.hide()
|
|
2651
|
+
self.gui.training_logs.tensorboard_offline_button.show()
|
|
2652
|
+
sleep(1) # wait for the button to be shown
|
|
2611
2653
|
self.app.shutdown()
|
|
2612
2654
|
except Exception as e:
|
|
2613
2655
|
message = f"Error occurred during finalizing and uploading training artifacts. {check_logs_text}"
|