supervisely 6.73.420__py3-none-any.whl → 6.73.422__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/api.py +10 -5
- supervisely/api/app_api.py +71 -4
- supervisely/api/module_api.py +4 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/project_api.py +35 -6
- supervisely/api/task_api.py +5 -1
- supervisely/app/widgets/__init__.py +8 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/deploy_model/__init__.py +0 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
- supervisely/app/widgets/fast_table/fast_table.py +402 -74
- supervisely/app/widgets/fast_table/script.js +364 -96
- supervisely/app/widgets/fast_table/style.css +24 -0
- supervisely/app/widgets/fast_table/template.html +43 -3
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +160 -94
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
- supervisely/nn/inference/predict_app/gui/gui.py +710 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +282 -0
- supervisely/nn/inference/predict_app/predict_app.py +184 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/prediction.py +2 -0
- supervisely/nn/model/prediction_session.py +20 -3
- supervisely/nn/training/gui/gui.py +131 -44
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
- supervisely/nn/training/gui/training_artifacts.py +0 -5
- supervisely/nn/training/train_app.py +161 -44
- supervisely/template/experiment/experiment.html.jinja +74 -17
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/METADATA +3 -1
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/RECORD +74 -56
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/LICENSE +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/WHEEL +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.422.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,7 @@ from typing import List
|
|
|
3
3
|
from supervisely import Api, Project
|
|
4
4
|
from supervisely.app.widgets import Button, Card, Container, Text, TrainValSplits
|
|
5
5
|
from supervisely.api.module_api import ApiField
|
|
6
|
+
from supervisely.api.entities_collection_api import EntitiesCollectionInfo
|
|
6
7
|
|
|
7
8
|
class TrainValSplitsSelector:
|
|
8
9
|
title = "Train / Val Splits"
|
|
@@ -18,6 +19,13 @@ class TrainValSplitsSelector:
|
|
|
18
19
|
self.card = None
|
|
19
20
|
# -------------------------------- #
|
|
20
21
|
|
|
22
|
+
# Automated Splits
|
|
23
|
+
self._all_train_collections = []
|
|
24
|
+
self._all_val_collections = []
|
|
25
|
+
self._latest_train_collection = None
|
|
26
|
+
self._latest_val_collection = None
|
|
27
|
+
# -------------------------------- #
|
|
28
|
+
|
|
21
29
|
self.display_widgets = []
|
|
22
30
|
self.app_options = app_options
|
|
23
31
|
self.api = api
|
|
@@ -32,75 +40,9 @@ class TrainValSplitsSelector:
|
|
|
32
40
|
ds_split = "Based on datasets" in split_methods
|
|
33
41
|
coll_split = "Based on collections" in split_methods
|
|
34
42
|
|
|
35
|
-
self.train_val_splits = TrainValSplits(
|
|
36
|
-
project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
# check for collections with "train" and "val" prefixes
|
|
40
|
-
all_collections = self.api.entities_collection.get_list(self.project_id)
|
|
41
|
-
train_collections = []
|
|
42
|
-
val_collections = []
|
|
43
|
-
collections_found = False
|
|
44
|
-
for collection in all_collections:
|
|
45
|
-
if collection.name.lower().startswith("train"):
|
|
46
|
-
train_collections.append(collection.id)
|
|
47
|
-
elif collection.name.lower().startswith("val"):
|
|
48
|
-
val_collections.append(collection.id)
|
|
49
|
-
|
|
50
|
-
if len(train_collections) > 0 and len(val_collections) > 0:
|
|
51
|
-
self.train_val_splits.set_collections_splits(train_collections, val_collections)
|
|
52
|
-
self.validator_text = Text(
|
|
53
|
-
"Train and val collections are detected", status="info"
|
|
54
|
-
)
|
|
55
|
-
self.validator_text.show()
|
|
56
|
-
collections_found = True
|
|
57
|
-
else:
|
|
58
|
-
self.validator_text = Text("")
|
|
59
|
-
self.validator_text.hide()
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def _extend_with_nested(root_ds):
|
|
63
|
-
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
64
|
-
nested_ids = [ds.id for ds in nested]
|
|
65
|
-
return [root_ds.id] + nested_ids
|
|
66
|
-
|
|
67
|
-
if not collections_found:
|
|
68
|
-
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
69
|
-
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
70
|
-
ds_name = dataset.name.lower()
|
|
71
|
-
|
|
72
|
-
if ds_name in {"train", "training"}:
|
|
73
|
-
for _id in _extend_with_nested(dataset):
|
|
74
|
-
train_val_dataset_ids["train"].add(_id)
|
|
75
|
-
|
|
76
|
-
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
77
|
-
for _id in _extend_with_nested(dataset):
|
|
78
|
-
train_val_dataset_ids["val"].add(_id)
|
|
79
|
-
|
|
80
|
-
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
81
|
-
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
82
|
-
|
|
83
|
-
train_count = len(train_val_dataset_ids["train"])
|
|
84
|
-
val_count = len(train_val_dataset_ids["val"])
|
|
85
|
-
|
|
86
|
-
if train_count > 0 and val_count > 0:
|
|
87
|
-
self.train_val_splits.set_datasets_splits(
|
|
88
|
-
train_val_dataset_ids["train"], train_val_dataset_ids["val"]
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
if train_count > 0 and val_count > 0:
|
|
92
|
-
if train_count == val_count == 1:
|
|
93
|
-
self.validator_text = Text("train and val datasets are detected", status="info")
|
|
94
|
-
else:
|
|
95
|
-
self.validator_text = Text(
|
|
96
|
-
"Multiple train and val datasets are detected. Check manually if selection is correct",
|
|
97
|
-
status="info",
|
|
98
|
-
)
|
|
99
|
-
self.validator_text.show()
|
|
100
|
-
else:
|
|
101
|
-
self.validator_text = Text("")
|
|
102
|
-
self.validator_text.hide()
|
|
43
|
+
self.train_val_splits = TrainValSplits(project_id, None, random_split, tag_split, ds_split, collections_splits=coll_split)
|
|
103
44
|
|
|
45
|
+
self._detect_splits(coll_split, ds_split)
|
|
104
46
|
self.button = Button("Select")
|
|
105
47
|
self.display_widgets.extend([self.train_val_splits, self.validator_text, self.button])
|
|
106
48
|
# -------------------------------- #
|
|
@@ -115,6 +57,22 @@ class TrainValSplitsSelector:
|
|
|
115
57
|
)
|
|
116
58
|
self.card.lock()
|
|
117
59
|
|
|
60
|
+
@property
|
|
61
|
+
def all_train_collections(self) -> List[EntitiesCollectionInfo]:
|
|
62
|
+
return self._all_train_collections
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def all_val_collections(self) -> List[EntitiesCollectionInfo]:
|
|
66
|
+
return self._all_val_collections
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def latest_train_collection(self) -> EntitiesCollectionInfo:
|
|
70
|
+
return self._latest_train_collection
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def latest_val_collection(self) -> EntitiesCollectionInfo:
|
|
74
|
+
return self._latest_val_collection
|
|
75
|
+
|
|
118
76
|
@property
|
|
119
77
|
def widgets_to_disable(self) -> list:
|
|
120
78
|
return [self.train_val_splits]
|
|
@@ -313,10 +271,11 @@ class TrainValSplitsSelector:
|
|
|
313
271
|
return False
|
|
314
272
|
if len(empty_train_collections) > 0 or len(empty_val_collections) > 0:
|
|
315
273
|
empty_collections_text = "Selected collections are empty. "
|
|
274
|
+
# @TODO: Use collection names instead of ids
|
|
316
275
|
if len(empty_train_collections) > 0:
|
|
317
|
-
empty_collections_text += f"train: {', '.join(empty_train_collections)}. "
|
|
276
|
+
empty_collections_text += f"train: {', '.join([str(collection_id) for collection_id in empty_train_collections])}. "
|
|
318
277
|
if len(empty_val_collections) > 0:
|
|
319
|
-
empty_collections_text += f"val: {', '.join(empty_val_collections)}. "
|
|
278
|
+
empty_collections_text += f"val: {', '.join([str(collection_id) for collection_id in empty_val_collections])}. "
|
|
320
279
|
empty_collections_text += f"{ensure_text}"
|
|
321
280
|
self.validator_text.set(
|
|
322
281
|
text=empty_collections_text,
|
|
@@ -372,3 +331,96 @@ class TrainValSplitsSelector:
|
|
|
372
331
|
|
|
373
332
|
def set_val_collection_ids(self, collection_ids: List[int]) -> None:
|
|
374
333
|
self.train_val_splits._val_collections_select.set_selected_ids(collection_ids)
|
|
334
|
+
|
|
335
|
+
def _detect_splits(self, collections_split: bool, datasets_split: bool) -> bool:
|
|
336
|
+
"""Detect splits based on the selected method"""
|
|
337
|
+
splits_found = False
|
|
338
|
+
if collections_split:
|
|
339
|
+
splits_found = self._detect_collections()
|
|
340
|
+
if not splits_found and datasets_split:
|
|
341
|
+
splits_found = self._detect_datasets()
|
|
342
|
+
return splits_found
|
|
343
|
+
|
|
344
|
+
def _detect_collections(self) -> bool:
|
|
345
|
+
"""Find collections with train and val prefixes and set them to train_val_splits"""
|
|
346
|
+
def _get_latest_collection(collections: List[EntitiesCollectionInfo]) -> EntitiesCollectionInfo:
|
|
347
|
+
curr_collection = None
|
|
348
|
+
curr_idx = 0
|
|
349
|
+
for collection in collections:
|
|
350
|
+
collection_idx = int(collection.name.rsplit('_', 1)[-1])
|
|
351
|
+
if collection_idx > curr_idx:
|
|
352
|
+
curr_idx = collection_idx
|
|
353
|
+
curr_collection = collection
|
|
354
|
+
return curr_collection
|
|
355
|
+
|
|
356
|
+
all_collections = self.api.entities_collection.get_list(self.project_id)
|
|
357
|
+
train_collections = []
|
|
358
|
+
val_collections = []
|
|
359
|
+
collections_found = False
|
|
360
|
+
for collection in all_collections:
|
|
361
|
+
if collection.name.lower().startswith("train_"):
|
|
362
|
+
train_collections.append(collection)
|
|
363
|
+
elif collection.name.lower().startswith("val_"):
|
|
364
|
+
val_collections.append(collection)
|
|
365
|
+
|
|
366
|
+
train_collection = _get_latest_collection(train_collections)
|
|
367
|
+
val_collection = _get_latest_collection(val_collections)
|
|
368
|
+
if train_collection is not None and val_collection is not None:
|
|
369
|
+
self.train_val_splits.set_collections_splits([train_collection.id], [val_collection.id])
|
|
370
|
+
self.validator_text = Text("Train and val collections are detected", status="info")
|
|
371
|
+
self.validator_text.show()
|
|
372
|
+
collections_found = True
|
|
373
|
+
self._all_train_collections = train_collections
|
|
374
|
+
self._all_val_collections = val_collections
|
|
375
|
+
self._latest_train_collection = train_collection
|
|
376
|
+
self._latest_val_collection = val_collection
|
|
377
|
+
else:
|
|
378
|
+
self.validator_text = Text("")
|
|
379
|
+
self.validator_text.hide()
|
|
380
|
+
collections_found = False
|
|
381
|
+
return collections_found
|
|
382
|
+
|
|
383
|
+
def _detect_datasets(self) -> bool:
|
|
384
|
+
"""Find datasets with train and val prefixes and set them to train_val_splits"""
|
|
385
|
+
def _extend_with_nested(root_ds):
|
|
386
|
+
nested = self.api.dataset.get_nested(self.project_id, root_ds.id)
|
|
387
|
+
nested_ids = [ds.id for ds in nested]
|
|
388
|
+
return [root_ds.id] + nested_ids
|
|
389
|
+
|
|
390
|
+
datasets_found = False
|
|
391
|
+
train_val_dataset_ids = {"train": set(), "val": set()}
|
|
392
|
+
for _, dataset in self.api.dataset.tree(self.project_id):
|
|
393
|
+
ds_name = dataset.name.lower()
|
|
394
|
+
|
|
395
|
+
if ds_name in {"train", "training"}:
|
|
396
|
+
for _id in _extend_with_nested(dataset):
|
|
397
|
+
train_val_dataset_ids["train"].add(_id)
|
|
398
|
+
|
|
399
|
+
elif ds_name in {"val", "validation", "test", "testing"}:
|
|
400
|
+
for _id in _extend_with_nested(dataset):
|
|
401
|
+
train_val_dataset_ids["val"].add(_id)
|
|
402
|
+
|
|
403
|
+
train_val_dataset_ids["train"] = list(train_val_dataset_ids["train"])
|
|
404
|
+
train_val_dataset_ids["val"] = list(train_val_dataset_ids["val"])
|
|
405
|
+
|
|
406
|
+
train_count = len(train_val_dataset_ids["train"])
|
|
407
|
+
val_count = len(train_val_dataset_ids["val"])
|
|
408
|
+
|
|
409
|
+
if train_count > 0 and val_count > 0:
|
|
410
|
+
self.train_val_splits.set_datasets_splits(train_val_dataset_ids["train"], train_val_dataset_ids["val"])
|
|
411
|
+
datasets_found = True
|
|
412
|
+
|
|
413
|
+
if train_count > 0 and val_count > 0:
|
|
414
|
+
if train_count == val_count == 1:
|
|
415
|
+
message = "train and val datasets are detected"
|
|
416
|
+
else:
|
|
417
|
+
message = "Multiple train and val datasets are detected. Check manually if selection is correct"
|
|
418
|
+
|
|
419
|
+
self.validator_text = Text(message, status="info")
|
|
420
|
+
self.validator_text.show()
|
|
421
|
+
datasets_found = True
|
|
422
|
+
else:
|
|
423
|
+
self.validator_text = Text("")
|
|
424
|
+
self.validator_text.hide()
|
|
425
|
+
datasets_found = False
|
|
426
|
+
return datasets_found
|
|
@@ -65,11 +65,6 @@ class TrainingArtifacts:
|
|
|
65
65
|
|
|
66
66
|
# Outputs
|
|
67
67
|
need_generate_report = self.app_options.get("generate_report", False)
|
|
68
|
-
|
|
69
|
-
# @TODO: temporary code to generate report for dev only
|
|
70
|
-
is_dev = "dev.internal" in api.server_address
|
|
71
|
-
if not is_dev:
|
|
72
|
-
need_generate_report = False
|
|
73
68
|
# ------------------------------------------------------------ #
|
|
74
69
|
|
|
75
70
|
if need_generate_report:
|
|
@@ -56,7 +56,7 @@ from supervisely.nn.benchmark import (
|
|
|
56
56
|
SemanticSegmentationEvaluator,
|
|
57
57
|
)
|
|
58
58
|
from supervisely.nn.inference import RuntimeType, SessionJSON
|
|
59
|
-
from supervisely.nn.inference.inference import Inference
|
|
59
|
+
from supervisely.nn.inference.inference import Inference, torch_load_safe
|
|
60
60
|
from supervisely.nn.task_type import TaskType
|
|
61
61
|
from supervisely.nn.training.gui.gui import TrainGUI
|
|
62
62
|
from supervisely.nn.training.gui.utils import generate_task_check_function_js
|
|
@@ -72,6 +72,7 @@ from supervisely.project.download import (
|
|
|
72
72
|
is_cached,
|
|
73
73
|
)
|
|
74
74
|
from supervisely.template.experiment.experiment_generator import ExperimentGenerator
|
|
75
|
+
from supervisely.api.entities_collection_api import EntitiesCollectionInfo
|
|
75
76
|
|
|
76
77
|
|
|
77
78
|
class TrainApp:
|
|
@@ -159,8 +160,14 @@ class TrainApp:
|
|
|
159
160
|
self.sly_project = None
|
|
160
161
|
# -------------------------- #
|
|
161
162
|
|
|
162
|
-
|
|
163
|
-
self.
|
|
163
|
+
# Train Val Splits
|
|
164
|
+
self._train_split = []
|
|
165
|
+
self._train_split_item_ids = set()
|
|
166
|
+
self._train_collection_id = None
|
|
167
|
+
|
|
168
|
+
self._val_split = []
|
|
169
|
+
self._val_split_item_ids = set()
|
|
170
|
+
self._val_collection_id = None
|
|
164
171
|
# -------------------------- #
|
|
165
172
|
|
|
166
173
|
# Input
|
|
@@ -232,19 +239,33 @@ class TrainApp:
|
|
|
232
239
|
self.gui.training_process.start_button.loading = False
|
|
233
240
|
raise e
|
|
234
241
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
242
|
+
@self._server.post("/train_status")
|
|
243
|
+
def _train_status(response: Response, request: Request):
|
|
244
|
+
"""Returns the current training status."""
|
|
245
|
+
status = self.gui.training_process.validator_text.get_value()
|
|
246
|
+
if status == "Training is in progress...":
|
|
247
|
+
try:
|
|
248
|
+
total_epochs = getattr(self.progress_bar_main, "total", None)
|
|
249
|
+
current_epoch = getattr(self.progress_bar_main, "current", None)
|
|
250
|
+
if total_epochs is not None and current_epoch is not None:
|
|
251
|
+
status += f" (Epoch {current_epoch}/{total_epochs})"
|
|
252
|
+
except Exception:
|
|
253
|
+
pass
|
|
254
|
+
return {"status": status}
|
|
255
|
+
|
|
256
|
+
# Read GUI State when launched from experiment modal
|
|
257
|
+
state = self.gui._extract_state_from_env()
|
|
258
|
+
logger.debug(f"State: {state}")
|
|
259
|
+
gui_state_raw = state.get("guiState")
|
|
260
|
+
if gui_state_raw is not None:
|
|
261
|
+
logger.info("Loading GUI from state")
|
|
262
|
+
logger.debug(f"GUI State: {gui_state_raw}")
|
|
263
|
+
try:
|
|
264
|
+
self.gui.load_from_app_state(gui_state_raw)
|
|
265
|
+
logger.info("Successfully loaded GUI from state")
|
|
266
|
+
except Exception as e:
|
|
267
|
+
raise e
|
|
268
|
+
# ----------------------------------------- #
|
|
248
269
|
|
|
249
270
|
def _register_routes(self):
|
|
250
271
|
"""
|
|
@@ -399,14 +420,14 @@ class TrainApp:
|
|
|
399
420
|
:rtype: str
|
|
400
421
|
"""
|
|
401
422
|
return self.gui.training_process.get_device()
|
|
402
|
-
|
|
423
|
+
|
|
403
424
|
@property
|
|
404
425
|
def base_checkpoint(self) -> str:
|
|
405
426
|
"""
|
|
406
427
|
Returns the name of the base checkpoint.
|
|
407
428
|
"""
|
|
408
429
|
return self.gui.model_selector.get_checkpoint_name()
|
|
409
|
-
|
|
430
|
+
|
|
410
431
|
@property
|
|
411
432
|
def base_checkpoint_link(self) -> str:
|
|
412
433
|
"""
|
|
@@ -596,13 +617,18 @@ class TrainApp:
|
|
|
596
617
|
if self.gui.classes_selector is not None:
|
|
597
618
|
if self.gui.classes_selector.is_convert_class_shapes_enabled():
|
|
598
619
|
self._convert_project_to_model_task()
|
|
620
|
+
|
|
599
621
|
# Step 4. Split Project
|
|
600
622
|
self._split_project()
|
|
601
623
|
# Step 5. Remove classes except selected
|
|
602
624
|
if self.sly_project.type == ProjectType.IMAGES.value:
|
|
603
625
|
self.sly_project.remove_classes_except(self.project_dir, self.classes, True)
|
|
604
626
|
self._read_project()
|
|
605
|
-
|
|
627
|
+
|
|
628
|
+
# Step 6. Create collections
|
|
629
|
+
self._create_collection_splits()
|
|
630
|
+
|
|
631
|
+
# Step 7. Download Model files
|
|
606
632
|
self._download_model()
|
|
607
633
|
|
|
608
634
|
def _finalize(self, experiment_info: dict) -> None:
|
|
@@ -847,7 +873,7 @@ class TrainApp:
|
|
|
847
873
|
"TensorRT": True
|
|
848
874
|
},
|
|
849
875
|
},
|
|
850
|
-
"experiment_name": "
|
|
876
|
+
"experiment_name": "My Experiment",
|
|
851
877
|
}
|
|
852
878
|
"""
|
|
853
879
|
self.gui.load_from_app_state(app_state)
|
|
@@ -1162,13 +1188,15 @@ class TrainApp:
|
|
|
1162
1188
|
self._train_val_split_file = None
|
|
1163
1189
|
self._train_split = []
|
|
1164
1190
|
self._val_split = []
|
|
1191
|
+
self._train_split_item_ids = set()
|
|
1192
|
+
self._val_split_item_ids = set()
|
|
1165
1193
|
return
|
|
1166
1194
|
|
|
1167
1195
|
# Load splits
|
|
1168
1196
|
self.gui.train_val_splits_selector.set_sly_project(self.sly_project)
|
|
1169
|
-
self._train_split, self._val_split = (
|
|
1170
|
-
|
|
1171
|
-
)
|
|
1197
|
+
self._train_split, self._val_split = self.gui.train_val_splits_selector.train_val_splits.get_splits()
|
|
1198
|
+
self._train_split_ids, self._val_split_ids = [], []
|
|
1199
|
+
self._train_split_item_ids, self._val_split_item_ids = set(), set()
|
|
1172
1200
|
|
|
1173
1201
|
# Prepare paths
|
|
1174
1202
|
project_split_path = join(self.work_dir, "splits")
|
|
@@ -1177,11 +1205,13 @@ class TrainApp:
|
|
|
1177
1205
|
"split_path": join(project_split_path, "train"),
|
|
1178
1206
|
"img_dir": join(project_split_path, "train", "img"),
|
|
1179
1207
|
"ann_dir": join(project_split_path, "train", "ann"),
|
|
1208
|
+
"img_info_dir": join(project_split_path, "train", "img_info"),
|
|
1180
1209
|
},
|
|
1181
1210
|
"val": {
|
|
1182
1211
|
"split_path": join(project_split_path, "val"),
|
|
1183
1212
|
"img_dir": join(project_split_path, "val", "img"),
|
|
1184
1213
|
"ann_dir": join(project_split_path, "val", "ann"),
|
|
1214
|
+
"img_info_dir": join(project_split_path, "val", "img_info"),
|
|
1185
1215
|
},
|
|
1186
1216
|
}
|
|
1187
1217
|
|
|
@@ -1199,7 +1229,7 @@ class TrainApp:
|
|
|
1199
1229
|
}
|
|
1200
1230
|
|
|
1201
1231
|
# Utility function to move files
|
|
1202
|
-
def move_files(split, paths, img_name_format, pbar):
|
|
1232
|
+
def move_files(split, split_name, paths, img_name_format, pbar):
|
|
1203
1233
|
"""
|
|
1204
1234
|
Move files to the appropriate directories.
|
|
1205
1235
|
"""
|
|
@@ -1208,6 +1238,19 @@ class TrainApp:
|
|
|
1208
1238
|
ann_name = f"{item_name}.json"
|
|
1209
1239
|
shutil.copy(item.img_path, join(paths["img_dir"], item_name))
|
|
1210
1240
|
shutil.copy(item.ann_path, join(paths["ann_dir"], ann_name))
|
|
1241
|
+
|
|
1242
|
+
# Move img_info
|
|
1243
|
+
img_info_name = f"{sly_fs.get_file_name_with_ext(item.img_path)}.json"
|
|
1244
|
+
img_info_path = join(dirname(dirname(item.img_path)), "img_info", img_info_name)
|
|
1245
|
+
# shutil.copy(img_info_path, join(paths["img_info_dir"], ann_name))
|
|
1246
|
+
|
|
1247
|
+
# Write split ids to img_info
|
|
1248
|
+
img_info = sly_json.load_json_file(img_info_path)
|
|
1249
|
+
if split_name == "train":
|
|
1250
|
+
self._train_split_item_ids.add(img_info["id"])
|
|
1251
|
+
else:
|
|
1252
|
+
self._val_split_item_ids.add(img_info["id"])
|
|
1253
|
+
|
|
1211
1254
|
pbar.update(1)
|
|
1212
1255
|
|
|
1213
1256
|
# Main split processing
|
|
@@ -1216,12 +1259,16 @@ class TrainApp:
|
|
|
1216
1259
|
) as main_pbar:
|
|
1217
1260
|
self.progress_bar_main.show()
|
|
1218
1261
|
for dataset in ["train", "val"]:
|
|
1219
|
-
|
|
1262
|
+
split_name = dataset
|
|
1263
|
+
if split_name == "train":
|
|
1264
|
+
split = self._train_split
|
|
1265
|
+
else:
|
|
1266
|
+
split = self._val_split
|
|
1220
1267
|
with self.progress_bar_secondary(
|
|
1221
1268
|
message=f"Preparing '{dataset}'", total=len(split)
|
|
1222
1269
|
) as second_pbar:
|
|
1223
1270
|
self.progress_bar_secondary.show()
|
|
1224
|
-
move_files(split, paths[dataset], image_name_formats[dataset], second_pbar)
|
|
1271
|
+
move_files(split, split_name, paths[dataset], image_name_formats[dataset], second_pbar)
|
|
1225
1272
|
main_pbar.update(1)
|
|
1226
1273
|
self.progress_bar_secondary.hide()
|
|
1227
1274
|
self.progress_bar_main.hide()
|
|
@@ -1241,10 +1288,7 @@ class TrainApp:
|
|
|
1241
1288
|
with self.progress_bar_main(message="Processing splits", total=2) as pbar:
|
|
1242
1289
|
self.progress_bar_main.show()
|
|
1243
1290
|
for dataset in ["train", "val"]:
|
|
1244
|
-
shutil.move(
|
|
1245
|
-
paths[dataset]["split_path"],
|
|
1246
|
-
train_ds_path if dataset == "train" else val_ds_path,
|
|
1247
|
-
)
|
|
1291
|
+
shutil.move(paths[dataset]["split_path"], train_ds_path if dataset == "train" else val_ds_path)
|
|
1248
1292
|
pbar.update(1)
|
|
1249
1293
|
self.progress_bar_main.hide()
|
|
1250
1294
|
|
|
@@ -1667,7 +1711,8 @@ class TrainApp:
|
|
|
1667
1711
|
try:
|
|
1668
1712
|
# pylint: disable=import-error
|
|
1669
1713
|
import torch
|
|
1670
|
-
|
|
1714
|
+
|
|
1715
|
+
state_dict = torch_load_safe(new_checkpoint_path)
|
|
1671
1716
|
state_dict["model_info"] = {
|
|
1672
1717
|
"task_id": self.task_id,
|
|
1673
1718
|
"model_name": experiment_info["model_name"],
|
|
@@ -1679,9 +1724,7 @@ class TrainApp:
|
|
|
1679
1724
|
state_dict["model_files"] = ckpt_files
|
|
1680
1725
|
torch.save(state_dict, new_checkpoint_path)
|
|
1681
1726
|
except Exception as e:
|
|
1682
|
-
logger.warning(
|
|
1683
|
-
f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}"
|
|
1684
|
-
)
|
|
1727
|
+
logger.warning(f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}")
|
|
1685
1728
|
continue
|
|
1686
1729
|
|
|
1687
1730
|
new_checkpoint_paths.append(new_checkpoint_path)
|
|
@@ -1817,7 +1860,6 @@ class TrainApp:
|
|
|
1817
1860
|
:type export_weights: dict
|
|
1818
1861
|
"""
|
|
1819
1862
|
logger.debug("Updating experiment info")
|
|
1820
|
-
|
|
1821
1863
|
experiment_info = {
|
|
1822
1864
|
"experiment_name": self.gui.training_process.get_experiment_name(),
|
|
1823
1865
|
"framework_name": self.framework_name,
|
|
@@ -1826,6 +1868,7 @@ class TrainApp:
|
|
|
1826
1868
|
"base_checkpoint_link": self.base_checkpoint_link,
|
|
1827
1869
|
"task_type": experiment_info["task_type"],
|
|
1828
1870
|
"project_id": self.project_info.id,
|
|
1871
|
+
"project_version": self.project_info.version,
|
|
1829
1872
|
"task_id": self.task_id,
|
|
1830
1873
|
"model_files": experiment_info["model_files"],
|
|
1831
1874
|
"checkpoints": experiment_info["checkpoints"],
|
|
@@ -1843,6 +1886,8 @@ class TrainApp:
|
|
|
1843
1886
|
"logs": {"type": "tensorboard", "link": f"{remote_dir}logs/"},
|
|
1844
1887
|
"device": self.gui.training_process.get_device_name(),
|
|
1845
1888
|
"training_duration": self._training_duration,
|
|
1889
|
+
"train_collection_id": self._train_collection_id,
|
|
1890
|
+
"val_collection_id": self._val_collection_id,
|
|
1846
1891
|
}
|
|
1847
1892
|
|
|
1848
1893
|
if self._has_splits_selector:
|
|
@@ -1997,13 +2042,7 @@ class TrainApp:
|
|
|
1997
2042
|
:rtype: tuple
|
|
1998
2043
|
"""
|
|
1999
2044
|
need_generate_report = self._app_options.get("generate_report", False)
|
|
2000
|
-
|
|
2001
|
-
is_dev = "dev.internal" in self._api.server_address
|
|
2002
|
-
if not is_dev:
|
|
2003
|
-
need_generate_report = False
|
|
2004
|
-
# ------------------------------------------------------------ #
|
|
2005
|
-
|
|
2006
|
-
if need_generate_report: # link to experiment page
|
|
2045
|
+
if need_generate_report: # link to experiment page
|
|
2007
2046
|
try:
|
|
2008
2047
|
output_file_info = self._generate_experiment_report(experiment_info, model_meta)
|
|
2009
2048
|
experiment_info["has_report"] = True
|
|
@@ -2011,7 +2050,7 @@ class TrainApp:
|
|
|
2011
2050
|
logger.error(f"Error generating experiment report: {e}")
|
|
2012
2051
|
output_file_info = session_link_file_info
|
|
2013
2052
|
experiment_info["has_report"] = False
|
|
2014
|
-
else:
|
|
2053
|
+
else: # link to artifacts directory
|
|
2015
2054
|
output_file_info = session_link_file_info
|
|
2016
2055
|
experiment_info["has_report"] = False
|
|
2017
2056
|
return output_file_info, experiment_info
|
|
@@ -3089,4 +3128,82 @@ class TrainApp:
|
|
|
3089
3128
|
# 4. Match splits with original project
|
|
3090
3129
|
gt_split_data = self._postprocess_splits(gt_project_info.id)
|
|
3091
3130
|
return gt_project_info.id, gt_split_data
|
|
3092
|
-
|
|
3131
|
+
|
|
3132
|
+
def _create_collection_splits(self):
|
|
3133
|
+
def _check_match(current_selected_collection_ids: List[int], all_split_collections: List[EntitiesCollectionInfo]):
|
|
3134
|
+
if len(current_selected_collection_ids) > 0:
|
|
3135
|
+
if len(current_selected_collection_ids) == 1:
|
|
3136
|
+
current_selected_collection_id = current_selected_collection_ids[0]
|
|
3137
|
+
for collection in all_split_collections:
|
|
3138
|
+
if collection.id == current_selected_collection_id:
|
|
3139
|
+
return True
|
|
3140
|
+
return False
|
|
3141
|
+
|
|
3142
|
+
# Case 1: Use existing collections for training. No need to create new collections
|
|
3143
|
+
split_method = self.gui.train_val_splits_selector.get_split_method()
|
|
3144
|
+
all_train_collections = self.gui.train_val_splits_selector.all_train_collections
|
|
3145
|
+
all_val_collections = self.gui.train_val_splits_selector.all_val_collections
|
|
3146
|
+
if split_method == "Based on collections":
|
|
3147
|
+
current_selected_train_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_train_collections_ids()
|
|
3148
|
+
train_match = _check_match(current_selected_train_collection_ids, all_train_collections)
|
|
3149
|
+
if train_match:
|
|
3150
|
+
current_selected_val_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_val_collections_ids()
|
|
3151
|
+
val_match = _check_match(current_selected_val_collection_ids, all_val_collections)
|
|
3152
|
+
if val_match:
|
|
3153
|
+
self._train_collection_id = current_selected_train_collection_ids[0]
|
|
3154
|
+
self._val_collection_id = current_selected_val_collection_ids[0]
|
|
3155
|
+
self._update_project_custom_data(self._train_collection_id, self._val_collection_id)
|
|
3156
|
+
return
|
|
3157
|
+
# ------------------------------------------------------------ #
|
|
3158
|
+
|
|
3159
|
+
# Case 2: Create new collections for selected train val splits. Need to create new collections
|
|
3160
|
+
item_type = self.project_info.type
|
|
3161
|
+
experiment_name = self.gui.training_process.get_experiment_name()
|
|
3162
|
+
|
|
3163
|
+
train_collection_idx = 1
|
|
3164
|
+
val_collection_idx = 1
|
|
3165
|
+
|
|
3166
|
+
# Get train collection with max idx
|
|
3167
|
+
if len(all_train_collections) > 0:
|
|
3168
|
+
train_collection_idx = max([int(collection.name.split("_")[1]) for collection in all_train_collections])
|
|
3169
|
+
train_collection_idx += 1
|
|
3170
|
+
# Get val collection with max idx
|
|
3171
|
+
if len(all_val_collections) > 0:
|
|
3172
|
+
val_collection_idx = max([int(collection.name.split("_")[1]) for collection in all_val_collections])
|
|
3173
|
+
val_collection_idx += 1
|
|
3174
|
+
# -------------------------------- #
|
|
3175
|
+
|
|
3176
|
+
# Create Train Collection
|
|
3177
|
+
train_img_ids = list(self._train_split_item_ids)
|
|
3178
|
+
train_collection_description = f"Collection with train {item_type} for experiment: {experiment_name}"
|
|
3179
|
+
train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx}", train_collection_description)
|
|
3180
|
+
train_collection_id = getattr(train_collection, "id", None)
|
|
3181
|
+
if train_collection_id is None:
|
|
3182
|
+
raise AttributeError("Train EntitiesCollectionInfo object does not have 'id' attribute")
|
|
3183
|
+
self._api.entities_collection.add_items(train_collection_id, train_img_ids)
|
|
3184
|
+
self._train_collection_id = train_collection_id
|
|
3185
|
+
|
|
3186
|
+
# Create Val Collection
|
|
3187
|
+
val_img_ids = list(self._val_split_item_ids)
|
|
3188
|
+
val_collection_description = f"Collection with val {item_type} for experiment: {experiment_name}"
|
|
3189
|
+
val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx}", val_collection_description)
|
|
3190
|
+
val_collection_id = getattr(val_collection, "id", None)
|
|
3191
|
+
if val_collection_id is None:
|
|
3192
|
+
raise AttributeError("Val EntitiesCollectionInfo object does not have 'id' attribute")
|
|
3193
|
+
self._api.entities_collection.add_items(val_collection_id, val_img_ids)
|
|
3194
|
+
self._val_collection_id = val_collection_id
|
|
3195
|
+
|
|
3196
|
+
# Update Project Custom Data
|
|
3197
|
+
self._update_project_custom_data(train_collection_id, val_collection_id)
|
|
3198
|
+
|
|
3199
|
+
def _update_project_custom_data(self, train_collection_id: int, val_collection_id: int):
|
|
3200
|
+
train_info = {
|
|
3201
|
+
"task_id": self.task_id,
|
|
3202
|
+
"framework_name": self.framework_name,
|
|
3203
|
+
"splits": {"train_collection": train_collection_id, "val_collection": val_collection_id}
|
|
3204
|
+
}
|
|
3205
|
+
custom_data = self._api.project.get_info_by_id(self.project_id).custom_data
|
|
3206
|
+
train_info_list = custom_data.get("train_info", [])
|
|
3207
|
+
train_info_list.append(train_info)
|
|
3208
|
+
custom_data.update({"train_info": train_info_list})
|
|
3209
|
+
self._api.project.update_custom_data(self.project_id, custom_data)
|