supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +136 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/json_geometries_map.py +2 -0
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +9 -9
- supervisely/api/api.py +67 -43
- supervisely/api/app_api.py +72 -5
- supervisely/api/dataset_api.py +108 -33
- supervisely/api/entity_annotation/figure_api.py +113 -49
- supervisely/api/image_api.py +82 -0
- supervisely/api/module_api.py +10 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/pointcloud/pointcloud_api.py +38 -0
- supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
- supervisely/api/project_api.py +213 -6
- supervisely/api/task_api.py +11 -1
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +79 -1
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/api/volume/volume_api.py +38 -0
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +14 -6
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +88 -0
- supervisely/app/fastapi/subapp.py +175 -42
- supervisely/app/fastapi/templating.py +1 -1
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +11 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
- supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
- supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
- supervisely/app/widgets/fast_table/fast_table.py +713 -126
- supervisely/app/widgets/fast_table/script.js +492 -95
- supervisely/app/widgets/fast_table/style.css +54 -0
- supervisely/app/widgets/fast_table/template.html +45 -5
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +523 -0
- supervisely/app/widgets/heatmap/script.js +378 -0
- supervisely/app/widgets/heatmap/style.css +227 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/input_tag/input_tag.py +102 -15
- supervisely/app/widgets/input_tag_list/__init__.py +0 -0
- supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
- supervisely/app/widgets/input_tag_list/template.html +70 -0
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/app/widgets/transfer/style.css +3 -0
- supervisely/app/widgets/transfer/template.html +3 -1
- supervisely/app/widgets/transfer/transfer.py +48 -45
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
- supervisely/convert/video/video_converter.py +2 -2
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/io/env.py +161 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/cache.py +37 -17
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +953 -211
- supervisely/nn/inference/inference_request.py +15 -8
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
- supervisely/nn/inference/predict_app/gui/gui.py +915 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +399 -0
- supervisely/nn/inference/predict_app/predict_app.py +176 -0
- supervisely/nn/inference/session.py +47 -39
- supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +4 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/model_api.py +44 -22
- supervisely/nn/model/prediction.py +15 -1
- supervisely/nn/model/prediction_session.py +70 -14
- supervisely/nn/prediction_dto.py +7 -0
- supervisely/nn/tracker/__init__.py +6 -8
- supervisely/nn/tracker/base_tracker.py +54 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
- supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +273 -0
- supervisely/nn/tracker/calculate_metrics.py +264 -0
- supervisely/nn/tracker/utils.py +273 -0
- supervisely/nn/tracker/visualize.py +520 -0
- supervisely/nn/training/gui/gui.py +152 -49
- supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
- supervisely/nn/training/gui/training_artifacts.py +3 -1
- supervisely/nn/training/train_app.py +225 -46
- supervisely/project/pointcloud_episode_project.py +12 -8
- supervisely/project/pointcloud_project.py +12 -8
- supervisely/project/project.py +221 -75
- supervisely/template/experiment/experiment.html.jinja +105 -55
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- supervisely/versions.json +3 -1
- supervisely/video/sampling.py +42 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
- supervisely_lib/__init__.py +6 -1
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
|
@@ -56,7 +56,7 @@ from supervisely.nn.benchmark import (
|
|
|
56
56
|
SemanticSegmentationEvaluator,
|
|
57
57
|
)
|
|
58
58
|
from supervisely.nn.inference import RuntimeType, SessionJSON
|
|
59
|
-
from supervisely.nn.inference.inference import Inference
|
|
59
|
+
from supervisely.nn.inference.inference import Inference, torch_load_safe
|
|
60
60
|
from supervisely.nn.task_type import TaskType
|
|
61
61
|
from supervisely.nn.training.gui.gui import TrainGUI
|
|
62
62
|
from supervisely.nn.training.gui.utils import generate_task_check_function_js
|
|
@@ -72,6 +72,7 @@ from supervisely.project.download import (
|
|
|
72
72
|
is_cached,
|
|
73
73
|
)
|
|
74
74
|
from supervisely.template.experiment.experiment_generator import ExperimentGenerator
|
|
75
|
+
from supervisely.api.entities_collection_api import EntitiesCollectionInfo
|
|
75
76
|
|
|
76
77
|
|
|
77
78
|
class TrainApp:
|
|
@@ -159,8 +160,14 @@ class TrainApp:
|
|
|
159
160
|
self.sly_project = None
|
|
160
161
|
# -------------------------- #
|
|
161
162
|
|
|
162
|
-
|
|
163
|
-
self.
|
|
163
|
+
# Train Val Splits
|
|
164
|
+
self._train_split = []
|
|
165
|
+
self._train_split_item_ids = set()
|
|
166
|
+
self._train_collection_id = None
|
|
167
|
+
|
|
168
|
+
self._val_split = []
|
|
169
|
+
self._val_split_item_ids = set()
|
|
170
|
+
self._val_collection_id = None
|
|
164
171
|
# -------------------------- #
|
|
165
172
|
|
|
166
173
|
# Input
|
|
@@ -232,19 +239,33 @@ class TrainApp:
|
|
|
232
239
|
self.gui.training_process.start_button.loading = False
|
|
233
240
|
raise e
|
|
234
241
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
242
|
+
@self._server.post("/train_status")
|
|
243
|
+
def _train_status(response: Response, request: Request):
|
|
244
|
+
"""Returns the current training status."""
|
|
245
|
+
status = self.gui.training_process.validator_text.get_value()
|
|
246
|
+
if status == "Training is in progress...":
|
|
247
|
+
try:
|
|
248
|
+
total_epochs = getattr(self.progress_bar_main, "total", None)
|
|
249
|
+
current_epoch = getattr(self.progress_bar_main, "current", None)
|
|
250
|
+
if total_epochs is not None and current_epoch is not None:
|
|
251
|
+
status += f" (Epoch {current_epoch}/{total_epochs})"
|
|
252
|
+
except Exception:
|
|
253
|
+
pass
|
|
254
|
+
return {"status": status}
|
|
255
|
+
|
|
256
|
+
# Read GUI State when launched from experiment modal
|
|
257
|
+
state = self.gui._extract_state_from_env()
|
|
258
|
+
logger.debug(f"State: {state}")
|
|
259
|
+
gui_state_raw = state.get("guiState")
|
|
260
|
+
if gui_state_raw is not None:
|
|
261
|
+
logger.info("Loading GUI from state")
|
|
262
|
+
logger.debug(f"GUI State: {gui_state_raw}")
|
|
263
|
+
try:
|
|
264
|
+
self.gui.load_from_app_state(gui_state_raw)
|
|
265
|
+
logger.info("Successfully loaded GUI from state")
|
|
266
|
+
except Exception as e:
|
|
267
|
+
raise e
|
|
268
|
+
# ----------------------------------------- #
|
|
248
269
|
|
|
249
270
|
def _register_routes(self):
|
|
250
271
|
"""
|
|
@@ -284,6 +305,12 @@ class TrainApp:
|
|
|
284
305
|
|
|
285
306
|
# Properties
|
|
286
307
|
# General
|
|
308
|
+
@property
|
|
309
|
+
def auto_start(self) -> bool:
|
|
310
|
+
"""
|
|
311
|
+
If True, the training will start automatically after the GUI is loaded and train server is started.
|
|
312
|
+
"""
|
|
313
|
+
return self.gui._start_training
|
|
287
314
|
# ----------------------------------------- #
|
|
288
315
|
|
|
289
316
|
# Input Data
|
|
@@ -399,14 +426,14 @@ class TrainApp:
|
|
|
399
426
|
:rtype: str
|
|
400
427
|
"""
|
|
401
428
|
return self.gui.training_process.get_device()
|
|
402
|
-
|
|
429
|
+
|
|
403
430
|
@property
|
|
404
431
|
def base_checkpoint(self) -> str:
|
|
405
432
|
"""
|
|
406
433
|
Returns the name of the base checkpoint.
|
|
407
434
|
"""
|
|
408
435
|
return self.gui.model_selector.get_checkpoint_name()
|
|
409
|
-
|
|
436
|
+
|
|
410
437
|
@property
|
|
411
438
|
def base_checkpoint_link(self) -> str:
|
|
412
439
|
"""
|
|
@@ -596,13 +623,18 @@ class TrainApp:
|
|
|
596
623
|
if self.gui.classes_selector is not None:
|
|
597
624
|
if self.gui.classes_selector.is_convert_class_shapes_enabled():
|
|
598
625
|
self._convert_project_to_model_task()
|
|
626
|
+
|
|
599
627
|
# Step 4. Split Project
|
|
600
628
|
self._split_project()
|
|
601
629
|
# Step 5. Remove classes except selected
|
|
602
630
|
if self.sly_project.type == ProjectType.IMAGES.value:
|
|
603
631
|
self.sly_project.remove_classes_except(self.project_dir, self.classes, True)
|
|
604
632
|
self._read_project()
|
|
605
|
-
|
|
633
|
+
|
|
634
|
+
# Step 6. Create collections
|
|
635
|
+
self._create_collection_splits()
|
|
636
|
+
|
|
637
|
+
# Step 7. Download Model files
|
|
606
638
|
self._download_model()
|
|
607
639
|
|
|
608
640
|
def _finalize(self, experiment_info: dict) -> None:
|
|
@@ -702,11 +734,7 @@ class TrainApp:
|
|
|
702
734
|
self._upload_demo_files(remote_dir)
|
|
703
735
|
|
|
704
736
|
# Step 10. Generate training output
|
|
705
|
-
|
|
706
|
-
if need_generate_report:
|
|
707
|
-
output_file_info = self._generate_experiment_report(experiment_info, model_meta)
|
|
708
|
-
else: # output artifacts directory
|
|
709
|
-
output_file_info = session_link_file_info
|
|
737
|
+
output_file_info, experiment_info = self._generate_experiment_output(experiment_info, model_meta, session_link_file_info)
|
|
710
738
|
|
|
711
739
|
# Step 11. Set output widgets
|
|
712
740
|
self._set_text_status("reset")
|
|
@@ -851,7 +879,7 @@ class TrainApp:
|
|
|
851
879
|
"TensorRT": True
|
|
852
880
|
},
|
|
853
881
|
},
|
|
854
|
-
"experiment_name": "
|
|
882
|
+
"experiment_name": "My Experiment",
|
|
855
883
|
}
|
|
856
884
|
"""
|
|
857
885
|
self.gui.load_from_app_state(app_state)
|
|
@@ -1074,6 +1102,7 @@ class TrainApp:
|
|
|
1074
1102
|
project_id=self.project_id,
|
|
1075
1103
|
dest_dir=self.project_dir,
|
|
1076
1104
|
dataset_ids=[ds_info.id for ds_info in dataset_infos],
|
|
1105
|
+
save_image_info=True,
|
|
1077
1106
|
log_progress=True,
|
|
1078
1107
|
progress_cb=pbar.update,
|
|
1079
1108
|
)
|
|
@@ -1166,13 +1195,15 @@ class TrainApp:
|
|
|
1166
1195
|
self._train_val_split_file = None
|
|
1167
1196
|
self._train_split = []
|
|
1168
1197
|
self._val_split = []
|
|
1198
|
+
self._train_split_item_ids = set()
|
|
1199
|
+
self._val_split_item_ids = set()
|
|
1169
1200
|
return
|
|
1170
1201
|
|
|
1171
1202
|
# Load splits
|
|
1172
1203
|
self.gui.train_val_splits_selector.set_sly_project(self.sly_project)
|
|
1173
|
-
self._train_split, self._val_split = (
|
|
1174
|
-
|
|
1175
|
-
)
|
|
1204
|
+
self._train_split, self._val_split = self.gui.train_val_splits_selector.train_val_splits.get_splits()
|
|
1205
|
+
self._train_split_ids, self._val_split_ids = [], []
|
|
1206
|
+
self._train_split_item_ids, self._val_split_item_ids = set(), set()
|
|
1176
1207
|
|
|
1177
1208
|
# Prepare paths
|
|
1178
1209
|
project_split_path = join(self.work_dir, "splits")
|
|
@@ -1181,11 +1212,13 @@ class TrainApp:
|
|
|
1181
1212
|
"split_path": join(project_split_path, "train"),
|
|
1182
1213
|
"img_dir": join(project_split_path, "train", "img"),
|
|
1183
1214
|
"ann_dir": join(project_split_path, "train", "ann"),
|
|
1215
|
+
"img_info_dir": join(project_split_path, "train", "img_info"),
|
|
1184
1216
|
},
|
|
1185
1217
|
"val": {
|
|
1186
1218
|
"split_path": join(project_split_path, "val"),
|
|
1187
1219
|
"img_dir": join(project_split_path, "val", "img"),
|
|
1188
1220
|
"ann_dir": join(project_split_path, "val", "ann"),
|
|
1221
|
+
"img_info_dir": join(project_split_path, "val", "img_info"),
|
|
1189
1222
|
},
|
|
1190
1223
|
}
|
|
1191
1224
|
|
|
@@ -1203,7 +1236,7 @@ class TrainApp:
|
|
|
1203
1236
|
}
|
|
1204
1237
|
|
|
1205
1238
|
# Utility function to move files
|
|
1206
|
-
def move_files(split, paths, img_name_format, pbar):
|
|
1239
|
+
def move_files(split, split_name, paths, img_name_format, pbar):
|
|
1207
1240
|
"""
|
|
1208
1241
|
Move files to the appropriate directories.
|
|
1209
1242
|
"""
|
|
@@ -1212,6 +1245,19 @@ class TrainApp:
|
|
|
1212
1245
|
ann_name = f"{item_name}.json"
|
|
1213
1246
|
shutil.copy(item.img_path, join(paths["img_dir"], item_name))
|
|
1214
1247
|
shutil.copy(item.ann_path, join(paths["ann_dir"], ann_name))
|
|
1248
|
+
|
|
1249
|
+
# Move img_info
|
|
1250
|
+
img_info_name = f"{sly_fs.get_file_name_with_ext(item.img_path)}.json"
|
|
1251
|
+
img_info_path = join(dirname(dirname(item.img_path)), "img_info", img_info_name)
|
|
1252
|
+
# shutil.copy(img_info_path, join(paths["img_info_dir"], ann_name))
|
|
1253
|
+
|
|
1254
|
+
# Write split ids to img_info
|
|
1255
|
+
img_info = sly_json.load_json_file(img_info_path)
|
|
1256
|
+
if split_name == "train":
|
|
1257
|
+
self._train_split_item_ids.add(img_info["id"])
|
|
1258
|
+
else:
|
|
1259
|
+
self._val_split_item_ids.add(img_info["id"])
|
|
1260
|
+
|
|
1215
1261
|
pbar.update(1)
|
|
1216
1262
|
|
|
1217
1263
|
# Main split processing
|
|
@@ -1220,12 +1266,16 @@ class TrainApp:
|
|
|
1220
1266
|
) as main_pbar:
|
|
1221
1267
|
self.progress_bar_main.show()
|
|
1222
1268
|
for dataset in ["train", "val"]:
|
|
1223
|
-
|
|
1269
|
+
split_name = dataset
|
|
1270
|
+
if split_name == "train":
|
|
1271
|
+
split = self._train_split
|
|
1272
|
+
else:
|
|
1273
|
+
split = self._val_split
|
|
1224
1274
|
with self.progress_bar_secondary(
|
|
1225
1275
|
message=f"Preparing '{dataset}'", total=len(split)
|
|
1226
1276
|
) as second_pbar:
|
|
1227
1277
|
self.progress_bar_secondary.show()
|
|
1228
|
-
move_files(split, paths[dataset], image_name_formats[dataset], second_pbar)
|
|
1278
|
+
move_files(split, split_name, paths[dataset], image_name_formats[dataset], second_pbar)
|
|
1229
1279
|
main_pbar.update(1)
|
|
1230
1280
|
self.progress_bar_secondary.hide()
|
|
1231
1281
|
self.progress_bar_main.hide()
|
|
@@ -1245,10 +1295,7 @@ class TrainApp:
|
|
|
1245
1295
|
with self.progress_bar_main(message="Processing splits", total=2) as pbar:
|
|
1246
1296
|
self.progress_bar_main.show()
|
|
1247
1297
|
for dataset in ["train", "val"]:
|
|
1248
|
-
shutil.move(
|
|
1249
|
-
paths[dataset]["split_path"],
|
|
1250
|
-
train_ds_path if dataset == "train" else val_ds_path,
|
|
1251
|
-
)
|
|
1298
|
+
shutil.move(paths[dataset]["split_path"], train_ds_path if dataset == "train" else val_ds_path)
|
|
1252
1299
|
pbar.update(1)
|
|
1253
1300
|
self.progress_bar_main.hide()
|
|
1254
1301
|
|
|
@@ -1551,13 +1598,18 @@ class TrainApp:
|
|
|
1551
1598
|
project_id = self.project_id
|
|
1552
1599
|
|
|
1553
1600
|
dataset_infos = [dataset for _, dataset in self._api.dataset.tree(project_id)]
|
|
1601
|
+
id_to_info = {ds.id: ds for ds in dataset_infos}
|
|
1554
1602
|
ds_infos_dict = {}
|
|
1555
1603
|
for dataset in dataset_infos:
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1604
|
+
name_parts = [dataset.name]
|
|
1605
|
+
parent_id = dataset.parent_id
|
|
1606
|
+
while parent_id is not None:
|
|
1607
|
+
parent_ds = id_to_info.get(parent_id)
|
|
1608
|
+
if parent_ds is None:
|
|
1609
|
+
parent_ds = self._api.dataset.get_info_by_id(parent_id)
|
|
1610
|
+
name_parts.append(parent_ds.name)
|
|
1611
|
+
parent_id = parent_ds.parent_id
|
|
1612
|
+
dataset_name = "/".join(reversed(name_parts))
|
|
1561
1613
|
ds_infos_dict[dataset_name] = dataset
|
|
1562
1614
|
|
|
1563
1615
|
def get_image_infos_by_split(ds_infos_dict: dict, split: list):
|
|
@@ -1671,7 +1723,8 @@ class TrainApp:
|
|
|
1671
1723
|
try:
|
|
1672
1724
|
# pylint: disable=import-error
|
|
1673
1725
|
import torch
|
|
1674
|
-
|
|
1726
|
+
|
|
1727
|
+
state_dict = torch_load_safe(new_checkpoint_path)
|
|
1675
1728
|
state_dict["model_info"] = {
|
|
1676
1729
|
"task_id": self.task_id,
|
|
1677
1730
|
"model_name": experiment_info["model_name"],
|
|
@@ -1683,9 +1736,7 @@ class TrainApp:
|
|
|
1683
1736
|
state_dict["model_files"] = ckpt_files
|
|
1684
1737
|
torch.save(state_dict, new_checkpoint_path)
|
|
1685
1738
|
except Exception as e:
|
|
1686
|
-
logger.warning(
|
|
1687
|
-
f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}"
|
|
1688
|
-
)
|
|
1739
|
+
logger.warning(f"Error writing info to checkpoint: '{checkpoint_name}'. Error:{e}")
|
|
1689
1740
|
continue
|
|
1690
1741
|
|
|
1691
1742
|
new_checkpoint_paths.append(new_checkpoint_path)
|
|
@@ -1821,7 +1872,6 @@ class TrainApp:
|
|
|
1821
1872
|
:type export_weights: dict
|
|
1822
1873
|
"""
|
|
1823
1874
|
logger.debug("Updating experiment info")
|
|
1824
|
-
|
|
1825
1875
|
experiment_info = {
|
|
1826
1876
|
"experiment_name": self.gui.training_process.get_experiment_name(),
|
|
1827
1877
|
"framework_name": self.framework_name,
|
|
@@ -1830,6 +1880,7 @@ class TrainApp:
|
|
|
1830
1880
|
"base_checkpoint_link": self.base_checkpoint_link,
|
|
1831
1881
|
"task_type": experiment_info["task_type"],
|
|
1832
1882
|
"project_id": self.project_info.id,
|
|
1883
|
+
"project_version": self.project_info.version,
|
|
1833
1884
|
"task_id": self.task_id,
|
|
1834
1885
|
"model_files": experiment_info["model_files"],
|
|
1835
1886
|
"checkpoints": experiment_info["checkpoints"],
|
|
@@ -1847,6 +1898,8 @@ class TrainApp:
|
|
|
1847
1898
|
"logs": {"type": "tensorboard", "link": f"{remote_dir}logs/"},
|
|
1848
1899
|
"device": self.gui.training_process.get_device_name(),
|
|
1849
1900
|
"training_duration": self._training_duration,
|
|
1901
|
+
"train_collection_id": self._train_collection_id,
|
|
1902
|
+
"val_collection_id": self._val_collection_id,
|
|
1850
1903
|
}
|
|
1851
1904
|
|
|
1852
1905
|
if self._has_splits_selector:
|
|
@@ -1986,6 +2039,37 @@ class TrainApp:
|
|
|
1986
2039
|
|
|
1987
2040
|
self.progress_bar_main.hide()
|
|
1988
2041
|
|
|
2042
|
+
def _generate_experiment_output(self, experiment_info: dict, model_meta: ProjectMeta, session_link_file_info: FileInfo) -> tuple:
|
|
2043
|
+
"""
|
|
2044
|
+
Generates and uploads the experiment page to the output directory, if report generation is successful.
|
|
2045
|
+
Otherwise, artifacts directory link will be used for output.
|
|
2046
|
+
|
|
2047
|
+
:param experiment_info: Information about the experiment results.
|
|
2048
|
+
:type experiment_info: dict
|
|
2049
|
+
:param model_meta: Model meta with object classes.
|
|
2050
|
+
:type model_meta: ProjectMeta
|
|
2051
|
+
:param session_link_file_info: Artifacts directory link, used if report is not generated.
|
|
2052
|
+
:type session_link_file_info: FileInfo
|
|
2053
|
+
:return: Output file info and experiment info.
|
|
2054
|
+
:rtype: tuple
|
|
2055
|
+
"""
|
|
2056
|
+
need_generate_report = self._app_options.get("generate_report", False)
|
|
2057
|
+
if need_generate_report: # link to experiment page
|
|
2058
|
+
try:
|
|
2059
|
+
output_file_info = self._generate_experiment_report(experiment_info, model_meta)
|
|
2060
|
+
experiment_info["has_report"] = True
|
|
2061
|
+
experiment_info["experiment_report_id"] = output_file_info.id
|
|
2062
|
+
except Exception as e:
|
|
2063
|
+
logger.error(f"Error generating experiment report: {e}")
|
|
2064
|
+
output_file_info = session_link_file_info
|
|
2065
|
+
experiment_info["has_report"] = False
|
|
2066
|
+
experiment_info["experiment_report_id"] = None
|
|
2067
|
+
else: # link to artifacts directory
|
|
2068
|
+
output_file_info = session_link_file_info
|
|
2069
|
+
experiment_info["has_report"] = False
|
|
2070
|
+
experiment_info["experiment_report_id"] = None
|
|
2071
|
+
return output_file_info, experiment_info
|
|
2072
|
+
|
|
1989
2073
|
def _get_train_val_splits_for_app_state(self) -> Dict:
|
|
1990
2074
|
"""
|
|
1991
2075
|
Gets the train and val splits information for app_state.json.
|
|
@@ -2719,6 +2803,12 @@ class TrainApp:
|
|
|
2719
2803
|
train_logger.add_on_step_finished_callback(step_callback)
|
|
2720
2804
|
|
|
2721
2805
|
# ----------------------------------------- #
|
|
2806
|
+
def start_in_thread(self):
|
|
2807
|
+
def auto_train():
|
|
2808
|
+
import threading
|
|
2809
|
+
threading.Thread(target=self._wrapped_start_training, daemon=True).start()
|
|
2810
|
+
self._server.add_event_handler("startup", auto_train)
|
|
2811
|
+
|
|
2722
2812
|
def _wrapped_start_training(self):
|
|
2723
2813
|
"""
|
|
2724
2814
|
Wrapper function to wrap the training process.
|
|
@@ -3059,4 +3149,93 @@ class TrainApp:
|
|
|
3059
3149
|
# 4. Match splits with original project
|
|
3060
3150
|
gt_split_data = self._postprocess_splits(gt_project_info.id)
|
|
3061
3151
|
return gt_project_info.id, gt_split_data
|
|
3062
|
-
|
|
3152
|
+
|
|
3153
|
+
def _create_collection_splits(self):
|
|
3154
|
+
def _check_match(current_selected_collection_ids: List[int], all_split_collections: List[EntitiesCollectionInfo]):
|
|
3155
|
+
if len(current_selected_collection_ids) > 0:
|
|
3156
|
+
if len(current_selected_collection_ids) == 1:
|
|
3157
|
+
current_selected_collection_id = current_selected_collection_ids[0]
|
|
3158
|
+
for collection in all_split_collections:
|
|
3159
|
+
if collection.id == current_selected_collection_id:
|
|
3160
|
+
return True
|
|
3161
|
+
return False
|
|
3162
|
+
|
|
3163
|
+
# Case 1: Use existing collections for training. No need to create new collections
|
|
3164
|
+
split_method = self.gui.train_val_splits_selector.get_split_method()
|
|
3165
|
+
all_train_collections = self.gui.train_val_splits_selector.all_train_collections
|
|
3166
|
+
all_val_collections = self.gui.train_val_splits_selector.all_val_collections
|
|
3167
|
+
if split_method == "Based on collections":
|
|
3168
|
+
current_selected_train_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_train_collections_ids()
|
|
3169
|
+
train_match = _check_match(current_selected_train_collection_ids, all_train_collections)
|
|
3170
|
+
if train_match:
|
|
3171
|
+
current_selected_val_collection_ids = self.gui.train_val_splits_selector.train_val_splits.get_val_collections_ids()
|
|
3172
|
+
val_match = _check_match(current_selected_val_collection_ids, all_val_collections)
|
|
3173
|
+
if val_match:
|
|
3174
|
+
self._train_collection_id = current_selected_train_collection_ids[0]
|
|
3175
|
+
self._val_collection_id = current_selected_val_collection_ids[0]
|
|
3176
|
+
self._update_project_custom_data(self._train_collection_id, self._val_collection_id)
|
|
3177
|
+
return
|
|
3178
|
+
# ------------------------------------------------------------ #
|
|
3179
|
+
|
|
3180
|
+
# Case 2: Create new collections for selected train val splits. Need to create new collections
|
|
3181
|
+
item_type = self.project_info.type
|
|
3182
|
+
experiment_name = self.gui.training_process.get_experiment_name()
|
|
3183
|
+
|
|
3184
|
+
train_collection_idx = 1
|
|
3185
|
+
val_collection_idx = 1
|
|
3186
|
+
|
|
3187
|
+
def _extract_index_from_col_name(name: str, expected_prefix: str) -> Optional[int]:
|
|
3188
|
+
parts = name.split("_")
|
|
3189
|
+
if len(parts) == 2 and parts[0] == expected_prefix and parts[1].isdigit():
|
|
3190
|
+
return int(parts[1])
|
|
3191
|
+
return None
|
|
3192
|
+
|
|
3193
|
+
# Get train collection with max idx
|
|
3194
|
+
if len(all_train_collections) > 0:
|
|
3195
|
+
train_indices = [_extract_index_from_col_name(collection.name, "train") for collection in all_train_collections]
|
|
3196
|
+
train_indices = [idx for idx in train_indices if idx is not None]
|
|
3197
|
+
if len(train_indices) > 0:
|
|
3198
|
+
train_collection_idx = max(train_indices) + 1
|
|
3199
|
+
|
|
3200
|
+
# Get val collection with max idx
|
|
3201
|
+
if len(all_val_collections) > 0:
|
|
3202
|
+
val_indices = [_extract_index_from_col_name(collection.name, "val") for collection in all_val_collections]
|
|
3203
|
+
val_indices = [idx for idx in val_indices if idx is not None]
|
|
3204
|
+
if len(val_indices) > 0:
|
|
3205
|
+
val_collection_idx = max(val_indices) + 1
|
|
3206
|
+
# -------------------------------- #
|
|
3207
|
+
|
|
3208
|
+
# Create Train Collection
|
|
3209
|
+
train_img_ids = list(self._train_split_item_ids)
|
|
3210
|
+
train_collection_description = f"Collection with train {item_type} for experiment: {experiment_name}"
|
|
3211
|
+
train_collection = self._api.entities_collection.create(self.project_id, f"train_{train_collection_idx:03d}", train_collection_description)
|
|
3212
|
+
train_collection_id = getattr(train_collection, "id", None)
|
|
3213
|
+
if train_collection_id is None:
|
|
3214
|
+
raise AttributeError("Train EntitiesCollectionInfo object does not have 'id' attribute")
|
|
3215
|
+
self._api.entities_collection.add_items(train_collection_id, train_img_ids)
|
|
3216
|
+
self._train_collection_id = train_collection_id
|
|
3217
|
+
|
|
3218
|
+
# Create Val Collection
|
|
3219
|
+
val_img_ids = list(self._val_split_item_ids)
|
|
3220
|
+
val_collection_description = f"Collection with val {item_type} for experiment: {experiment_name}"
|
|
3221
|
+
val_collection = self._api.entities_collection.create(self.project_id, f"val_{val_collection_idx:03d}", val_collection_description)
|
|
3222
|
+
val_collection_id = getattr(val_collection, "id", None)
|
|
3223
|
+
if val_collection_id is None:
|
|
3224
|
+
raise AttributeError("Val EntitiesCollectionInfo object does not have 'id' attribute")
|
|
3225
|
+
self._api.entities_collection.add_items(val_collection_id, val_img_ids)
|
|
3226
|
+
self._val_collection_id = val_collection_id
|
|
3227
|
+
|
|
3228
|
+
# Update Project Custom Data
|
|
3229
|
+
self._update_project_custom_data(train_collection_id, val_collection_id)
|
|
3230
|
+
|
|
3231
|
+
def _update_project_custom_data(self, train_collection_id: int, val_collection_id: int):
|
|
3232
|
+
train_info = {
|
|
3233
|
+
"task_id": self.task_id,
|
|
3234
|
+
"framework_name": self.framework_name,
|
|
3235
|
+
"splits": {"train_collection": train_collection_id, "val_collection": val_collection_id}
|
|
3236
|
+
}
|
|
3237
|
+
custom_data = self._api.project.get_info_by_id(self.project_id).custom_data
|
|
3238
|
+
train_info_list = custom_data.get("train_info", [])
|
|
3239
|
+
train_info_list.append(train_info)
|
|
3240
|
+
custom_data.update({"train_info": train_info_list})
|
|
3241
|
+
self._api.project.update_custom_data(self.project_id, custom_data)
|
|
@@ -740,15 +740,19 @@ def download_pointcloud_episode_project(
|
|
|
740
740
|
if progress_cb is not None:
|
|
741
741
|
log_progress = False
|
|
742
742
|
|
|
743
|
-
|
|
743
|
+
filter_fn = lambda x: True
|
|
744
744
|
if dataset_ids is not None:
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
745
|
+
filter_fn = lambda ds: ds.id in dataset_ids
|
|
746
|
+
|
|
747
|
+
for parents, dataset in api.dataset.tree(project_id):
|
|
748
|
+
if not filter_fn(dataset):
|
|
749
|
+
continue
|
|
750
|
+
dataset_path = None
|
|
751
|
+
if parents:
|
|
752
|
+
dataset_path = "/datasets/".join(parents + [dataset.name])
|
|
753
|
+
dataset_fs: PointcloudEpisodeDataset = project_fs.create_dataset(
|
|
754
|
+
dataset.name, ds_path=dataset_path
|
|
755
|
+
)
|
|
752
756
|
pointclouds = api.pointcloud_episode.get_list(dataset.id)
|
|
753
757
|
|
|
754
758
|
# Download annotation to project_path/dataset_path/annotation.json
|
|
@@ -994,15 +994,19 @@ def download_pointcloud_project(
|
|
|
994
994
|
if progress_cb is not None:
|
|
995
995
|
log_progress = False
|
|
996
996
|
|
|
997
|
-
|
|
997
|
+
filter_fn = lambda ds: True
|
|
998
998
|
if dataset_ids is not None:
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
999
|
+
filter_fn = lambda ds: ds.id in dataset_ids
|
|
1000
|
+
|
|
1001
|
+
for parents, dataset in api.dataset.tree(project_id):
|
|
1002
|
+
if not filter_fn(dataset):
|
|
1003
|
+
continue
|
|
1004
|
+
dataset_path = None
|
|
1005
|
+
if parents:
|
|
1006
|
+
dataset_path = "/datasets/".join(parents + [dataset.name])
|
|
1007
|
+
dataset_fs: PointcloudDataset = project_fs.create_dataset(
|
|
1008
|
+
ds_name=dataset.name, ds_path=dataset_path
|
|
1009
|
+
)
|
|
1006
1010
|
pointclouds = api.pointcloud.get_list(dataset.id)
|
|
1007
1011
|
|
|
1008
1012
|
ds_progress = progress_cb
|