supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +136 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/json_geometries_map.py +2 -0
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +9 -9
- supervisely/api/api.py +67 -43
- supervisely/api/app_api.py +72 -5
- supervisely/api/dataset_api.py +108 -33
- supervisely/api/entity_annotation/figure_api.py +113 -49
- supervisely/api/image_api.py +82 -0
- supervisely/api/module_api.py +10 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/pointcloud/pointcloud_api.py +38 -0
- supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
- supervisely/api/project_api.py +213 -6
- supervisely/api/task_api.py +11 -1
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +79 -1
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/api/volume/volume_api.py +38 -0
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +14 -6
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +88 -0
- supervisely/app/fastapi/subapp.py +175 -42
- supervisely/app/fastapi/templating.py +1 -1
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +11 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
- supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
- supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
- supervisely/app/widgets/fast_table/fast_table.py +713 -126
- supervisely/app/widgets/fast_table/script.js +492 -95
- supervisely/app/widgets/fast_table/style.css +54 -0
- supervisely/app/widgets/fast_table/template.html +45 -5
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +523 -0
- supervisely/app/widgets/heatmap/script.js +378 -0
- supervisely/app/widgets/heatmap/style.css +227 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/input_tag/input_tag.py +102 -15
- supervisely/app/widgets/input_tag_list/__init__.py +0 -0
- supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
- supervisely/app/widgets/input_tag_list/template.html +70 -0
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/app/widgets/transfer/style.css +3 -0
- supervisely/app/widgets/transfer/template.html +3 -1
- supervisely/app/widgets/transfer/transfer.py +48 -45
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
- supervisely/convert/video/video_converter.py +2 -2
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/io/env.py +161 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/cache.py +37 -17
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +953 -211
- supervisely/nn/inference/inference_request.py +15 -8
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
- supervisely/nn/inference/predict_app/gui/gui.py +915 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +399 -0
- supervisely/nn/inference/predict_app/predict_app.py +176 -0
- supervisely/nn/inference/session.py +47 -39
- supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +4 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/model_api.py +44 -22
- supervisely/nn/model/prediction.py +15 -1
- supervisely/nn/model/prediction_session.py +70 -14
- supervisely/nn/prediction_dto.py +7 -0
- supervisely/nn/tracker/__init__.py +6 -8
- supervisely/nn/tracker/base_tracker.py +54 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
- supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +273 -0
- supervisely/nn/tracker/calculate_metrics.py +264 -0
- supervisely/nn/tracker/utils.py +273 -0
- supervisely/nn/tracker/visualize.py +520 -0
- supervisely/nn/training/gui/gui.py +152 -49
- supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
- supervisely/nn/training/gui/training_artifacts.py +3 -1
- supervisely/nn/training/train_app.py +225 -46
- supervisely/project/pointcloud_episode_project.py +12 -8
- supervisely/project/pointcloud_project.py +12 -8
- supervisely/project/project.py +221 -75
- supervisely/template/experiment/experiment.html.jinja +105 -55
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- supervisely/versions.json +3 -1
- supervisely/video/sampling.py +42 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
- supervisely_lib/__init__.py +6 -1
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,7 @@ from typing import List
|
|
|
3
3
|
from supervisely import Api, Project
|
|
4
4
|
from supervisely.app.widgets import Button, Card, Container, Text, TrainValSplits
|
|
5
5
|
from supervisely.api.module_api import ApiField
|
|
6
|
+
from supervisely.api.entities_collection_api import EntitiesCollectionInfo
|
|
6
7
|
|
|
7
8
|
class TrainValSplitsSelector:
|
|
8
9
|
title = "Train / Val Splits"
|
|
@@ -18,6 +19,13 @@ class TrainValSplitsSelector:
|
|
|
18
19
|
self.card = None
|
|
19
20
|
# -------------------------------- #
|
|
20
21
|
|
|
22
|
+
# Automated Splits
|
|
23
|
+
self._all_train_collections = []
|
|
24
|
+
self._all_val_collections = []
|
|
25
|
+
self._latest_train_collection = None
|
|
26
|
+
self._latest_val_collection = None
|
|
27
|
+
# -------------------------------- #
|
|
28
|
+
|
|
21
29
|
self.display_widgets = []
|
|
22
30
|
self.app_options = app_options
|
|
23
31
|
self.api = api
|
|
@@ -32,75 +40,9 @@ class TrainValSplitsSelector:
|
|
|
32
40
|
ds_split = "Based on datasets" in split_methods
|
|
33
41
|
coll_split = "Based on collections" in split_methods
|
|
34
42
|
|
|
35
|
-
self.train_val_splits = TrainValSplits(
|
|
36
|
-
project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
# check for collections with "train" and "val" prefixes
|
|
40
|
-
all_collections = self.api.entities_collection.get_list(self.project_id)
|
|
41
|
-
train_collections = []
|
|
42
|
-
val_collections = []
|
|
43
|
-
collections_found = False
|
|
44
|
-
for collection in all_collections:
|
|
45
|
-
if collection.name.lower().startswith("train"):
|
|
46
|
-
train_collections.append(collection.id)
|
|
47
|
-
elif collection.name.lower().startswith("val"):
|
|
48
|
-
val_collections.append(collection.id)
|
|
49
|
-
|
|
50
|
-
if len(train_collections) > 0 and len(val_collections) > 0:
|
|
51
|
-
self.train_val_splits.set_collections_splits(train_collections, val_collections)
|
|
52
|
-
self.validator_text = Text(
|
|
53
|
-
"Train and val collections are detected", status="info"
|
|
54
|
-
)
|
|
55
|
-
self.validator_text.show()
|
|
56
|
-
collections_found = True
|
|
57
|
-
else:
|
|
58
|
-
self.validator_text = Text("")
|
|
59
|
-
self.validator_text.hide()
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def _extend_with_nested(root_ds):
|
|
63
|
-
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
64
|
-
nested_ids = [ds.id for ds in nested]
|
|
65
|
-
return [root_ds.id] + nested_ids
|
|
66
|
-
|
|
67
|
-
if not collections_found:
|
|
68
|
-
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
69
|
-
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
70
|
-
ds_name = dataset.name.lower()
|
|
71
|
-
|
|
72
|
-
if ds_name in {"train", "training"}:
|
|
73
|
-
for _id in _extend_with_nested(dataset):
|
|
74
|
-
train_val_dataset_ids["train"].add(_id)
|
|
75
|
-
|
|
76
|
-
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
77
|
-
for _id in _extend_with_nested(dataset):
|
|
78
|
-
train_val_dataset_ids["val"].add(_id)
|
|
79
|
-
|
|
80
|
-
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
81
|
-
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
82
|
-
|
|
83
|
-
train_count = len(train_val_dataset_ids["train"])
|
|
84
|
-
val_count = len(train_val_dataset_ids["val"])
|
|
85
|
-
|
|
86
|
-
if train_count > 0 and val_count > 0:
|
|
87
|
-
self.train_val_splits.set_datasets_splits(
|
|
88
|
-
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
if train_count > 0 and val_count > 0:
|
|
92
|
-
if train_count == val_count == 1:
|
|
93
|
-
self.validator_text = Text("train and val datasets are detected", status="info")
|
|
94
|
-
else:
|
|
95
|
-
self.validator_text = Text(
|
|
96
|
-
"Multiple train and val datasets are detected. Check manually if selection is correct",
|
|
97
|
-
status="info",
|
|
98
|
-
)
|
|
99
|
-
self.validator_text.show()
|
|
100
|
-
else:
|
|
101
|
-
self.validator_text = Text("")
|
|
102
|
-
self.validator_text.hide()
|
|
43
|
+
self.train_val_splits = TrainValSplits(project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split)
|
|
103
44
|
|
|
45
|
+
self._detect_splits(coll_split, ds_split)
|
|
104
46
|
self.button = Button("Select")
|
|
105
47
|
self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
|
|
106
48
|
# -------------------------------- #
|
|
@@ -115,6 +57,22 @@ class TrainValSplitsSelector:
|
|
|
115
57
|
)
|
|
116
58
|
self.card.lock()
|
|
117
59
|
|
|
60
|
+
@property
|
|
61
|
+
def all_train_collections(self) -> List[EntitiesCollectionInfo]:
|
|
62
|
+
return self._all_train_collections
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def all_val_collections(self) -> List[EntitiesCollectionInfo]:
|
|
66
|
+
return self._all_val_collections
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def latest_train_collection(self) -> EntitiesCollectionInfo:
|
|
70
|
+
return self._latest_train_collection
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def latest_val_collection(self) -> EntitiesCollectionInfo:
|
|
74
|
+
return self._latest_val_collection
|
|
75
|
+
|
|
118
76
|
@property
|
|
119
77
|
def widgets_to_disable(self) -> list:
|
|
120
78
|
return [self.train_val_splits]
|
|
@@ -222,7 +180,13 @@ class TrainValSplitsSelector:
|
|
|
222
180
|
return False
|
|
223
181
|
|
|
224
182
|
# Check if datasets are not empty
|
|
225
|
-
filters = [
|
|
183
|
+
filters = [
|
|
184
|
+
{
|
|
185
|
+
ApiField.FIELD: ApiField.ID,
|
|
186
|
+
ApiField.OPERATOR: "in",
|
|
187
|
+
ApiField.VALUE: train_dataset_id + val_dataset_id,
|
|
188
|
+
}
|
|
189
|
+
]
|
|
226
190
|
selected_datasets = self.api.dataset.get_list(self.project_id, filters, recursive=True)
|
|
227
191
|
datasets_count = {}
|
|
228
192
|
for dataset in selected_datasets:
|
|
@@ -313,10 +277,11 @@ class TrainValSplitsSelector:
|
|
|
313
277
|
return False
|
|
314
278
|
if len(empty_train_collections) > 0 or len(empty_val_collections) > 0:
|
|
315
279
|
empty_collections_text = "Selected collections are empty. "
|
|
280
|
+
# @TODO: Use collection names instead of ids
|
|
316
281
|
if len(empty_train_collections) > 0:
|
|
317
|
-
empty_collections_text += f"train: {', '.join(empty_train_collections)}. "
|
|
282
|
+
empty_collections_text += f"train: {', '.join([str(collection_id) for collection_id in empty_train_collections])}. "
|
|
318
283
|
if len(empty_val_collections) > 0:
|
|
319
|
-
empty_collections_text += f"val: {', '.join(empty_val_collections)}. "
|
|
284
|
+
empty_collections_text += f"val: {', '.join([str(collection_id) for collection_id in empty_val_collections])}. "
|
|
320
285
|
empty_collections_text += f"{ensure_text}"
|
|
321
286
|
self.validator_text.set(
|
|
322
287
|
text=empty_collections_text,
|
|
@@ -372,3 +337,111 @@ class TrainValSplitsSelector:
|
|
|
372
337
|
|
|
373
338
|
def set_val_collection_ids(self, collection_ids: List[int]) -> None:
|
|
374
339
|
self.train_val_splits._val_collections_select.set_selected_ids(collection_ids)
|
|
340
|
+
|
|
341
|
+
def _detect_splits(self, collections_split: bool, datasets_split: bool) -> bool:
|
|
342
|
+
"""Detect splits based on the selected method"""
|
|
343
|
+
self._parse_collections()
|
|
344
|
+
splits_found = False
|
|
345
|
+
if collections_split:
|
|
346
|
+
splits_found = self._detect_collections()
|
|
347
|
+
if not splits_found and datasets_split:
|
|
348
|
+
splits_found = self._detect_datasets()
|
|
349
|
+
return splits_found
|
|
350
|
+
|
|
351
|
+
def _parse_collections(self) -> None:
|
|
352
|
+
"""Parse collections with train and val prefixes and set them to train_val_splits variables"""
|
|
353
|
+
all_collections = self.api.entities_collection.get_list(self.project_id)
|
|
354
|
+
existing_train_collections = [
|
|
355
|
+
collection for collection in all_collections if collection.name.startswith("train_")
|
|
356
|
+
]
|
|
357
|
+
existing_val_collections = [
|
|
358
|
+
collection for collection in all_collections if collection.name.startswith("val_")
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
self._all_train_collections = existing_train_collections
|
|
362
|
+
self._all_val_collections = existing_val_collections
|
|
363
|
+
self._latest_train_collection = self._get_latest_collection(existing_train_collections, "train")
|
|
364
|
+
self._latest_val_collection = self._get_latest_collection(existing_val_collections, "val")
|
|
365
|
+
|
|
366
|
+
def _get_latest_collection(
|
|
367
|
+
self, collections: List[EntitiesCollectionInfo], expected_prefix: str
|
|
368
|
+
) -> EntitiesCollectionInfo:
|
|
369
|
+
curr_collection = None
|
|
370
|
+
curr_idx = 0
|
|
371
|
+
for collection in collections:
|
|
372
|
+
parts = collection.name.split("_")
|
|
373
|
+
if len(parts) == 2:
|
|
374
|
+
prefix = parts[0].lower()
|
|
375
|
+
if prefix == expected_prefix:
|
|
376
|
+
if parts[1].isdigit():
|
|
377
|
+
collection_idx = int(parts[1])
|
|
378
|
+
if collection_idx > curr_idx:
|
|
379
|
+
curr_idx = collection_idx
|
|
380
|
+
curr_collection = collection
|
|
381
|
+
return curr_collection
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _detect_collections(self) -> bool:
|
|
385
|
+
"""Find collections with train and val prefixes and set them to train_val_splits"""
|
|
386
|
+
|
|
387
|
+
collections_found = False
|
|
388
|
+
if self._latest_train_collection is not None and self._latest_val_collection is not None:
|
|
389
|
+
self.train_val_splits.set_collections_splits(
|
|
390
|
+
[self._latest_train_collection.id], [self._latest_val_collection.id]
|
|
391
|
+
)
|
|
392
|
+
self.validator_text = Text("Train and val collections are detected", status="info")
|
|
393
|
+
self.validator_text.show()
|
|
394
|
+
collections_found = True
|
|
395
|
+
else:
|
|
396
|
+
self.validator_text = Text("")
|
|
397
|
+
self.validator_text.hide()
|
|
398
|
+
collections_found = False
|
|
399
|
+
return collections_found
|
|
400
|
+
|
|
401
|
+
def _detect_datasets(self) -> bool:
|
|
402
|
+
"""Find datasets with train and val prefixes and set them to train_val_splits"""
|
|
403
|
+
|
|
404
|
+
def _extend_with_nested(root_ds):
|
|
405
|
+
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
406
|
+
nested_ids = [ds.id for ds in nested]
|
|
407
|
+
return [root_ds.id] + nested_ids
|
|
408
|
+
|
|
409
|
+
datasets_found = False
|
|
410
|
+
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
411
|
+
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
412
|
+
ds_name = dataset.name.lower()
|
|
413
|
+
|
|
414
|
+
if ds_name in {"train", "training"}:
|
|
415
|
+
for _id in _extend_with_nested(dataset):
|
|
416
|
+
train_val_dataset_ids["train"].add(_id)
|
|
417
|
+
|
|
418
|
+
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
419
|
+
for _id in _extend_with_nested(dataset):
|
|
420
|
+
train_val_dataset_ids["val"].add(_id)
|
|
421
|
+
|
|
422
|
+
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
423
|
+
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
424
|
+
|
|
425
|
+
train_count = len(train_val_dataset_ids["train"])
|
|
426
|
+
val_count = len(train_val_dataset_ids["val"])
|
|
427
|
+
|
|
428
|
+
if train_count > 0 and val_count > 0:
|
|
429
|
+
self.train_val_splits.set_datasets_splits(
|
|
430
|
+
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
431
|
+
)
|
|
432
|
+
datasets_found = True
|
|
433
|
+
|
|
434
|
+
if train_count > 0 and val_count > 0:
|
|
435
|
+
if train_count == val_count == 1:
|
|
436
|
+
message = "train and val datasets are detected"
|
|
437
|
+
else:
|
|
438
|
+
message = "Multiple train and val datasets are detected. Check manually if selection is correct"
|
|
439
|
+
|
|
440
|
+
self.validator_text = Text(message, status="info")
|
|
441
|
+
self.validator_text.show()
|
|
442
|
+
datasets_found = True
|
|
443
|
+
else:
|
|
444
|
+
self.validator_text = Text("")
|
|
445
|
+
self.validator_text.hide()
|
|
446
|
+
datasets_found = False
|
|
447
|
+
return datasets_found
|
|
@@ -64,7 +64,9 @@ class TrainingArtifacts:
|
|
|
64
64
|
self.display_widgets.extend([self.validator_text])
|
|
65
65
|
|
|
66
66
|
# Outputs
|
|
67
|
-
need_generate_report = self.app_options.get("generate_report",
|
|
67
|
+
need_generate_report = self.app_options.get("generate_report", False)
|
|
68
|
+
# ------------------------------------------------------------ #
|
|
69
|
+
|
|
68
70
|
if need_generate_report:
|
|
69
71
|
self.artifacts_thumbnail = ReportThumbnail(
|
|
70
72
|
title="Experiment Report",
|