supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +136 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/json_geometries_map.py +2 -0
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +9 -9
- supervisely/api/api.py +67 -43
- supervisely/api/app_api.py +72 -5
- supervisely/api/dataset_api.py +108 -33
- supervisely/api/entity_annotation/figure_api.py +113 -49
- supervisely/api/image_api.py +82 -0
- supervisely/api/module_api.py +10 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/pointcloud/pointcloud_api.py +38 -0
- supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
- supervisely/api/project_api.py +213 -6
- supervisely/api/task_api.py +11 -1
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +79 -1
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/api/volume/volume_api.py +38 -0
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +14 -6
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +88 -0
- supervisely/app/fastapi/subapp.py +175 -42
- supervisely/app/fastapi/templating.py +1 -1
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +11 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
- supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
- supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
- supervisely/app/widgets/fast_table/fast_table.py +713 -126
- supervisely/app/widgets/fast_table/script.js +492 -95
- supervisely/app/widgets/fast_table/style.css +54 -0
- supervisely/app/widgets/fast_table/template.html +45 -5
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +523 -0
- supervisely/app/widgets/heatmap/script.js +378 -0
- supervisely/app/widgets/heatmap/style.css +227 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/input_tag/input_tag.py +102 -15
- supervisely/app/widgets/input_tag_list/__init__.py +0 -0
- supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
- supervisely/app/widgets/input_tag_list/template.html +70 -0
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/app/widgets/transfer/style.css +3 -0
- supervisely/app/widgets/transfer/template.html +3 -1
- supervisely/app/widgets/transfer/transfer.py +48 -45
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
- supervisely/convert/video/video_converter.py +2 -2
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/io/env.py +161 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/cache.py +37 -17
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +953 -211
- supervisely/nn/inference/inference_request.py +15 -8
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
- supervisely/nn/inference/predict_app/gui/gui.py +915 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +399 -0
- supervisely/nn/inference/predict_app/predict_app.py +176 -0
- supervisely/nn/inference/session.py +47 -39
- supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +4 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/model_api.py +44 -22
- supervisely/nn/model/prediction.py +15 -1
- supervisely/nn/model/prediction_session.py +70 -14
- supervisely/nn/prediction_dto.py +7 -0
- supervisely/nn/tracker/__init__.py +6 -8
- supervisely/nn/tracker/base_tracker.py +54 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
- supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +273 -0
- supervisely/nn/tracker/calculate_metrics.py +264 -0
- supervisely/nn/tracker/utils.py +273 -0
- supervisely/nn/tracker/visualize.py +520 -0
- supervisely/nn/training/gui/gui.py +152 -49
- supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
- supervisely/nn/training/gui/training_artifacts.py +3 -1
- supervisely/nn/training/train_app.py +225 -46
- supervisely/project/pointcloud_episode_project.py +12 -8
- supervisely/project/pointcloud_project.py +12 -8
- supervisely/project/project.py +221 -75
- supervisely/template/experiment/experiment.html.jinja +105 -55
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- supervisely/versions.json +3 -1
- supervisely/video/sampling.py +42 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
- supervisely_lib/__init__.py +6 -1
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
|
@@ -6,20 +6,21 @@ training workflows in Supervisely.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import os
|
|
9
|
+
import json
|
|
9
10
|
from os import environ, getenv
|
|
10
11
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
11
12
|
|
|
12
|
-
from supervisely import logger
|
|
13
|
-
import supervisely.io.fs as sly_fs
|
|
14
13
|
import supervisely.io.env as sly_env
|
|
14
|
+
import supervisely.io.fs as sly_fs
|
|
15
15
|
import supervisely.io.json as sly_json
|
|
16
|
-
from supervisely import Api, ProjectMeta
|
|
16
|
+
from supervisely import Api, ProjectMeta, logger
|
|
17
17
|
from supervisely._utils import is_production
|
|
18
18
|
from supervisely.app.widgets import Button, Card, Stepper, Widget
|
|
19
19
|
from supervisely.geometry.bitmap import Bitmap
|
|
20
20
|
from supervisely.geometry.graph import GraphNodes
|
|
21
21
|
from supervisely.geometry.polygon import Polygon
|
|
22
22
|
from supervisely.geometry.rectangle import Rectangle
|
|
23
|
+
from supervisely.nn.experiments import ExperimentInfo
|
|
23
24
|
from supervisely.nn.task_type import TaskType
|
|
24
25
|
from supervisely.nn.training.gui.classes_selector import ClassesSelector
|
|
25
26
|
from supervisely.nn.training.gui.hyperparameters_selector import HyperparametersSelector
|
|
@@ -32,7 +33,6 @@ from supervisely.nn.training.gui.training_logs import TrainingLogs
|
|
|
32
33
|
from supervisely.nn.training.gui.training_process import TrainingProcess
|
|
33
34
|
from supervisely.nn.training.gui.utils import set_stepper_step, wrap_button_click
|
|
34
35
|
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
35
|
-
from supervisely.nn.experiments import ExperimentInfo
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class StepFlow:
|
|
@@ -254,6 +254,7 @@ class TrainGUI:
|
|
|
254
254
|
self.app_options = app_options
|
|
255
255
|
self.collapsable = self.app_options.get("collapsable", False)
|
|
256
256
|
self.need_convert_shapes = False
|
|
257
|
+
self._start_training = False
|
|
257
258
|
|
|
258
259
|
self.team_id = sly_env.team_id(raise_not_found=False)
|
|
259
260
|
self.workspace_id = sly_env.workspace_id(raise_not_found=False)
|
|
@@ -303,7 +304,9 @@ class TrainGUI:
|
|
|
303
304
|
# 3. Classes selector
|
|
304
305
|
self.classes_selector = None
|
|
305
306
|
if self.show_classes_selector:
|
|
306
|
-
self.classes_selector = ClassesSelector(
|
|
307
|
+
self.classes_selector = ClassesSelector(
|
|
308
|
+
self.project_id, [], self.model_selector, self.app_options
|
|
309
|
+
)
|
|
307
310
|
self.steps.append(self.classes_selector.card)
|
|
308
311
|
|
|
309
312
|
# 4. Tags selector
|
|
@@ -355,16 +358,19 @@ class TrainGUI:
|
|
|
355
358
|
experiment_name = "Enter experiment name"
|
|
356
359
|
else:
|
|
357
360
|
if self.task_id == -1:
|
|
358
|
-
experiment_name = f"
|
|
361
|
+
experiment_name = f"debug {self.project_info.name} {model_name}"
|
|
359
362
|
else:
|
|
360
|
-
experiment_name = f"{self.task_id}
|
|
363
|
+
experiment_name = f"{self.task_id} {self.project_info.name} {model_name}"
|
|
361
364
|
|
|
362
365
|
if experiment_name == self.training_process.get_experiment_name():
|
|
363
366
|
return
|
|
364
367
|
self.training_process.set_experiment_name(experiment_name)
|
|
365
368
|
|
|
366
369
|
def need_convert_class_shapes() -> bool:
|
|
367
|
-
if
|
|
370
|
+
if (
|
|
371
|
+
self.hyperparameters_selector.run_model_benchmark_checkbox is None
|
|
372
|
+
or not self.hyperparameters_selector.run_model_benchmark_checkbox.is_checked()
|
|
373
|
+
):
|
|
368
374
|
self.hyperparameters_selector.model_benchmark_auto_convert_warning.hide()
|
|
369
375
|
self.need_convert_shapes = False
|
|
370
376
|
return False
|
|
@@ -376,14 +382,22 @@ class TrainGUI:
|
|
|
376
382
|
|
|
377
383
|
# Exclude classes with no annotations to avoid unnecessary conversion
|
|
378
384
|
data = self.classes_selector.classes_table._table_data
|
|
379
|
-
empty_classes = {
|
|
385
|
+
empty_classes = {
|
|
386
|
+
r[0]["data"] for r in data if r[2]["data"] == 0 and r[3]["data"] == 0
|
|
387
|
+
}
|
|
380
388
|
need_conversion = bool(wrong_shapes - empty_classes)
|
|
381
389
|
else:
|
|
382
390
|
# Classes selector disabled – check entire project meta
|
|
383
391
|
if task_type == TaskType.OBJECT_DETECTION:
|
|
384
|
-
need_conversion = any(
|
|
392
|
+
need_conversion = any(
|
|
393
|
+
obj_cls.geometry_type != Rectangle
|
|
394
|
+
for obj_cls in self.project_meta.obj_classes
|
|
395
|
+
)
|
|
385
396
|
elif task_type in [TaskType.INSTANCE_SEGMENTATION, TaskType.SEMANTIC_SEGMENTATION]:
|
|
386
|
-
need_conversion = any(
|
|
397
|
+
need_conversion = any(
|
|
398
|
+
obj_cls.geometry_type == Polygon
|
|
399
|
+
for obj_cls in self.project_meta.obj_classes
|
|
400
|
+
)
|
|
387
401
|
else:
|
|
388
402
|
need_conversion = False
|
|
389
403
|
|
|
@@ -394,6 +408,7 @@ class TrainGUI:
|
|
|
394
408
|
|
|
395
409
|
self.need_convert_shapes = need_conversion
|
|
396
410
|
return need_conversion
|
|
411
|
+
|
|
397
412
|
# ------------------------------------------------- #
|
|
398
413
|
|
|
399
414
|
self.step_flow = StepFlow(self.stepper, self.app_options)
|
|
@@ -420,7 +435,7 @@ class TrainGUI:
|
|
|
420
435
|
self.model_selector.widgets_to_disable,
|
|
421
436
|
self.model_selector.validator_text,
|
|
422
437
|
self.model_selector.validate_step,
|
|
423
|
-
position=position
|
|
438
|
+
position=position,
|
|
424
439
|
).add_on_select_actions("model_selector", [set_experiment_name])
|
|
425
440
|
position += 1
|
|
426
441
|
|
|
@@ -517,7 +532,9 @@ class TrainGUI:
|
|
|
517
532
|
has_model_selector = self.show_model_selector and self.model_selector is not None
|
|
518
533
|
has_classes_selector = self.show_classes_selector and self.classes_selector is not None
|
|
519
534
|
has_tags_selector = self.show_tags_selector and self.tags_selector is not None
|
|
520
|
-
has_train_val_splits =
|
|
535
|
+
has_train_val_splits = (
|
|
536
|
+
self.show_train_val_splits_selector and self.train_val_splits_selector is not None
|
|
537
|
+
)
|
|
521
538
|
|
|
522
539
|
# Set step dependency chain
|
|
523
540
|
prev_step = "input_selector"
|
|
@@ -571,11 +588,13 @@ class TrainGUI:
|
|
|
571
588
|
@self.hyperparameters_selector.run_model_benchmark_checkbox.value_changed
|
|
572
589
|
def show_mb_speedtest(is_checked: bool):
|
|
573
590
|
self.hyperparameters_selector.toggle_mb_speedtest(is_checked)
|
|
591
|
+
|
|
574
592
|
# ------------------------------------------------- #
|
|
575
593
|
|
|
576
594
|
self.layout: Widget = self.stepper
|
|
577
595
|
|
|
578
596
|
# Run from experiment page
|
|
597
|
+
|
|
579
598
|
train_task_id = getenv("modal.state.trainTaskId", None)
|
|
580
599
|
if train_task_id is not None:
|
|
581
600
|
train_task_id = int(train_task_id)
|
|
@@ -584,7 +603,6 @@ class TrainGUI:
|
|
|
584
603
|
self._run_from_experiment(train_task_id, train_mode)
|
|
585
604
|
# ----------------------------------------- #
|
|
586
605
|
|
|
587
|
-
|
|
588
606
|
def set_next_step(self):
|
|
589
607
|
current_step = self.stepper.get_active_step()
|
|
590
608
|
self.stepper.set_active_step(current_step + 1)
|
|
@@ -605,6 +623,8 @@ class TrainGUI:
|
|
|
605
623
|
"""
|
|
606
624
|
if self.input_selector is not None:
|
|
607
625
|
self.input_selector.button.enable()
|
|
626
|
+
if self.model_selector is not None:
|
|
627
|
+
self.model_selector.button.enable()
|
|
608
628
|
if self.train_val_splits_selector is not None:
|
|
609
629
|
self.train_val_splits_selector.button.enable()
|
|
610
630
|
if self.classes_selector is not None:
|
|
@@ -622,6 +642,8 @@ class TrainGUI:
|
|
|
622
642
|
"""
|
|
623
643
|
if self.input_selector is not None:
|
|
624
644
|
self.input_selector.button.disable()
|
|
645
|
+
if self.model_selector is not None:
|
|
646
|
+
self.model_selector.button.disable()
|
|
625
647
|
if self.train_val_splits_selector is not None:
|
|
626
648
|
self.train_val_splits_selector.button.disable()
|
|
627
649
|
if self.classes_selector is not None:
|
|
@@ -775,7 +797,9 @@ class TrainGUI:
|
|
|
775
797
|
)
|
|
776
798
|
return app_state
|
|
777
799
|
|
|
778
|
-
def load_from_app_state(
|
|
800
|
+
def load_from_app_state(
|
|
801
|
+
self, app_state: Union[str, dict], click_cb: bool = True, validate_steps: bool = True
|
|
802
|
+
) -> None:
|
|
779
803
|
"""
|
|
780
804
|
Load the GUI state from app state dictionary or path to the state file.
|
|
781
805
|
|
|
@@ -820,26 +844,25 @@ class TrainGUI:
|
|
|
820
844
|
"TensorRT": True
|
|
821
845
|
},
|
|
822
846
|
},
|
|
823
|
-
"experiment_name": "
|
|
847
|
+
"experiment_name": "My Experiment",
|
|
848
|
+
"start_training": False,
|
|
824
849
|
}
|
|
825
850
|
"""
|
|
826
851
|
if isinstance(app_state, str):
|
|
827
|
-
|
|
828
|
-
|
|
852
|
+
if os.path.isfile(app_state):
|
|
853
|
+
app_state = sly_json.load_json_file(app_state)
|
|
854
|
+
else:
|
|
855
|
+
app_state = json.loads(app_state)
|
|
856
|
+
|
|
829
857
|
app_state = self.validate_app_state(app_state)
|
|
830
858
|
options = app_state.get("options", {})
|
|
831
|
-
|
|
832
|
-
# Set experiment name
|
|
833
|
-
experiment_name = app_state.get("experiment_name")
|
|
834
|
-
if experiment_name is not None:
|
|
835
|
-
self.training_process.set_experiment_name(experiment_name)
|
|
836
859
|
|
|
837
860
|
# Run init-steps and stop on validation failure
|
|
838
861
|
def _run_step(init_fn, settings) -> bool:
|
|
839
862
|
if not init_fn(settings, options, click_cb, validate_steps):
|
|
840
863
|
return False
|
|
841
864
|
return True
|
|
842
|
-
|
|
865
|
+
|
|
843
866
|
# GUI init steps
|
|
844
867
|
_steps = [
|
|
845
868
|
(self._init_input, app_state.get("input"), "Input project"),
|
|
@@ -856,12 +879,28 @@ class TrainGUI:
|
|
|
856
879
|
logger.warning(f"Step '{step_name}' {idx}/{len(_steps)} failed to validate")
|
|
857
880
|
return
|
|
858
881
|
if validate_steps:
|
|
859
|
-
logger.info(
|
|
882
|
+
logger.info(
|
|
883
|
+
f"Step '{step_name}' {idx}/{len(_steps)} has been validated successfully"
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
# Set experiment name
|
|
887
|
+
experiment_name = app_state.get("experiment_name")
|
|
888
|
+
if experiment_name is not None and experiment_name != "":
|
|
889
|
+
self.training_process.set_experiment_name(experiment_name)
|
|
890
|
+
|
|
860
891
|
if validate_steps:
|
|
861
892
|
logger.info(f"All steps have been validated successfully")
|
|
893
|
+
|
|
894
|
+
self._start_training = app_state.get("start_training", False)
|
|
862
895
|
# ------------------------------------------------------------------ #
|
|
863
896
|
|
|
864
|
-
def _init_input(
|
|
897
|
+
def _init_input(
|
|
898
|
+
self,
|
|
899
|
+
input_settings: Union[dict, None],
|
|
900
|
+
options: dict,
|
|
901
|
+
click_cb: bool = True,
|
|
902
|
+
validate: bool = True,
|
|
903
|
+
) -> bool:
|
|
865
904
|
"""
|
|
866
905
|
Initialize the input selector with the given settings.
|
|
867
906
|
|
|
@@ -885,7 +924,13 @@ class TrainGUI:
|
|
|
885
924
|
return is_valid
|
|
886
925
|
# ----------------------------------------- #
|
|
887
926
|
|
|
888
|
-
def _init_model(
|
|
927
|
+
def _init_model(
|
|
928
|
+
self,
|
|
929
|
+
model_settings: dict,
|
|
930
|
+
options: dict = None,
|
|
931
|
+
click_cb: bool = True,
|
|
932
|
+
validate: bool = True,
|
|
933
|
+
) -> bool:
|
|
889
934
|
"""
|
|
890
935
|
Initialize the model selector with the given settings.
|
|
891
936
|
|
|
@@ -909,14 +954,18 @@ class TrainGUI:
|
|
|
909
954
|
# Custom
|
|
910
955
|
elif model_settings["source"] == ModelSource.CUSTOM:
|
|
911
956
|
self.model_selector.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
912
|
-
self.model_selector.experiment_selector.
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
957
|
+
self.model_selector.experiment_selector.set_selected_row_by_task_id(
|
|
958
|
+
model_settings["task_id"]
|
|
959
|
+
)
|
|
960
|
+
experiment_info = self.model_selector.experiment_selector.get_selected_experiment_info()
|
|
961
|
+
if model_settings["checkpoint"] not in experiment_info.checkpoints:
|
|
962
|
+
if f"checkpoints/{model_settings['checkpoint']}" not in experiment_info.checkpoints:
|
|
963
|
+
raise ValueError(
|
|
964
|
+
f"Checkpoint '{model_settings['checkpoint']}' not found in selected task"
|
|
965
|
+
)
|
|
966
|
+
self.model_selector.experiment_selector.set_selected_checkpoint_by_name(
|
|
967
|
+
model_settings["checkpoint"]
|
|
968
|
+
)
|
|
920
969
|
|
|
921
970
|
is_valid = True
|
|
922
971
|
if validate:
|
|
@@ -926,8 +975,10 @@ class TrainGUI:
|
|
|
926
975
|
self.set_next_step()
|
|
927
976
|
return is_valid
|
|
928
977
|
# ----------------------------------------- #
|
|
929
|
-
|
|
930
|
-
def _init_classes(
|
|
978
|
+
|
|
979
|
+
def _init_classes(
|
|
980
|
+
self, classes_settings: list, options: dict, click_cb: bool = True, validate: bool = True
|
|
981
|
+
) -> bool:
|
|
931
982
|
"""
|
|
932
983
|
Initialize the classes selector with the given settings.
|
|
933
984
|
|
|
@@ -941,13 +992,20 @@ class TrainGUI:
|
|
|
941
992
|
:type validate: bool
|
|
942
993
|
"""
|
|
943
994
|
if self.classes_selector is None:
|
|
944
|
-
return True
|
|
995
|
+
return True # Selector disabled by app options
|
|
945
996
|
|
|
946
997
|
convert_class_shapes = options.get("convert_class_shapes", True)
|
|
947
998
|
if convert_class_shapes:
|
|
948
999
|
self.classes_selector.convert_class_shapes_checkbox.check()
|
|
949
1000
|
|
|
950
1001
|
# Set Classes
|
|
1002
|
+
if all(isinstance(c, int) for c in classes_settings):
|
|
1003
|
+
project_classes = []
|
|
1004
|
+
for obj_class in self.project_meta.obj_classes:
|
|
1005
|
+
if obj_class.sly_id in classes_settings:
|
|
1006
|
+
project_classes.append(obj_class.name)
|
|
1007
|
+
classes_settings = project_classes
|
|
1008
|
+
|
|
951
1009
|
self.classes_selector.set_classes(classes_settings)
|
|
952
1010
|
is_valid = True
|
|
953
1011
|
if validate:
|
|
@@ -958,7 +1016,9 @@ class TrainGUI:
|
|
|
958
1016
|
return is_valid
|
|
959
1017
|
# ----------------------------------------- #
|
|
960
1018
|
|
|
961
|
-
def _init_tags(
|
|
1019
|
+
def _init_tags(
|
|
1020
|
+
self, tags_settings: list, options: dict, click_cb: bool = True, validate: bool = True
|
|
1021
|
+
) -> bool:
|
|
962
1022
|
"""
|
|
963
1023
|
Initialize the tags selector with the given settings.
|
|
964
1024
|
|
|
@@ -972,7 +1032,7 @@ class TrainGUI:
|
|
|
972
1032
|
:type validate: bool
|
|
973
1033
|
"""
|
|
974
1034
|
if self.tags_selector is None:
|
|
975
|
-
return True
|
|
1035
|
+
return True # Selector disabled by app options
|
|
976
1036
|
|
|
977
1037
|
# Set Tags
|
|
978
1038
|
self.tags_selector.set_tags(tags_settings)
|
|
@@ -985,7 +1045,13 @@ class TrainGUI:
|
|
|
985
1045
|
return is_valid
|
|
986
1046
|
# ----------------------------------------- #
|
|
987
1047
|
|
|
988
|
-
def _init_train_val_splits(
|
|
1048
|
+
def _init_train_val_splits(
|
|
1049
|
+
self,
|
|
1050
|
+
train_val_splits_settings: dict,
|
|
1051
|
+
options: dict,
|
|
1052
|
+
click_cb: bool = True,
|
|
1053
|
+
validate: bool = True,
|
|
1054
|
+
) -> bool:
|
|
989
1055
|
"""
|
|
990
1056
|
Initialize the train/val splits selector with the given settings.
|
|
991
1057
|
|
|
@@ -999,7 +1065,7 @@ class TrainGUI:
|
|
|
999
1065
|
:type validate: bool
|
|
1000
1066
|
"""
|
|
1001
1067
|
if self.train_val_splits_selector is None:
|
|
1002
|
-
return True
|
|
1068
|
+
return True # Selector disabled by app options
|
|
1003
1069
|
|
|
1004
1070
|
if train_val_splits_settings == {}:
|
|
1005
1071
|
available_methods = self.app_options.get("train_val_splits_methods", [])
|
|
@@ -1059,8 +1125,14 @@ class TrainGUI:
|
|
|
1059
1125
|
self.train_val_splits_selector_cb()
|
|
1060
1126
|
self.set_next_step()
|
|
1061
1127
|
return is_valid
|
|
1062
|
-
|
|
1063
|
-
def _init_hyperparameters(
|
|
1128
|
+
|
|
1129
|
+
def _init_hyperparameters(
|
|
1130
|
+
self,
|
|
1131
|
+
hyperparameters_settings: dict,
|
|
1132
|
+
options: dict,
|
|
1133
|
+
click_cb: bool = True,
|
|
1134
|
+
validate: bool = True,
|
|
1135
|
+
) -> bool:
|
|
1064
1136
|
"""
|
|
1065
1137
|
Initialize the hyperparameters selector with the given settings.
|
|
1066
1138
|
|
|
@@ -1101,6 +1173,7 @@ class TrainGUI:
|
|
|
1101
1173
|
self.hyperparameters_selector_cb()
|
|
1102
1174
|
self.set_next_step()
|
|
1103
1175
|
return is_valid
|
|
1176
|
+
|
|
1104
1177
|
# ----------------------------------------- #
|
|
1105
1178
|
|
|
1106
1179
|
# Run from experiment page
|
|
@@ -1111,10 +1184,12 @@ class TrainGUI:
|
|
|
1111
1184
|
app_state = sly_json.load_json_file(local_app_state_path)
|
|
1112
1185
|
sly_fs.silent_remove(local_app_state_path)
|
|
1113
1186
|
return app_state
|
|
1114
|
-
|
|
1187
|
+
|
|
1115
1188
|
def _download_experiment_hparams(self, experiment_info: ExperimentInfo) -> dict:
|
|
1116
1189
|
local_hparams_path = f"./{experiment_info.hyperparameters}"
|
|
1117
|
-
remote_hparams_path = os.path.join(
|
|
1190
|
+
remote_hparams_path = os.path.join(
|
|
1191
|
+
experiment_info.artifacts_dir, experiment_info.hyperparameters
|
|
1192
|
+
)
|
|
1118
1193
|
self._api.file.download(self.team_id, remote_hparams_path, local_hparams_path)
|
|
1119
1194
|
with open(local_hparams_path, "r") as f:
|
|
1120
1195
|
hparams = f.read()
|
|
@@ -1129,11 +1204,14 @@ class TrainGUI:
|
|
|
1129
1204
|
model_settings = {
|
|
1130
1205
|
"source": ModelSource.CUSTOM,
|
|
1131
1206
|
"task_id": train_task_id,
|
|
1132
|
-
"checkpoint": experiment_info.best_checkpoint
|
|
1207
|
+
"checkpoint": experiment_info.best_checkpoint,
|
|
1133
1208
|
}
|
|
1134
1209
|
|
|
1135
1210
|
if experiment_state is not None:
|
|
1136
|
-
self.input_selector.validator_text.set(
|
|
1211
|
+
self.input_selector.validator_text.set(
|
|
1212
|
+
f"Training configuration is loaded from the experiment: {experiment_info.experiment_name}.",
|
|
1213
|
+
"success",
|
|
1214
|
+
)
|
|
1137
1215
|
self.input_selector.validator_text.show()
|
|
1138
1216
|
experiment_state = self._download_experiment_state(experiment_info)
|
|
1139
1217
|
if train_mode == "continue":
|
|
@@ -1142,7 +1220,7 @@ class TrainGUI:
|
|
|
1142
1220
|
else:
|
|
1143
1221
|
self.input_selector.validator_text.set(
|
|
1144
1222
|
f"Couldn't load full training configuration from the experiment: {experiment_info.experiment_name}. Only model and hyperparameters are loaded.",
|
|
1145
|
-
"warning"
|
|
1223
|
+
"warning",
|
|
1146
1224
|
)
|
|
1147
1225
|
self.input_selector.validator_text.show()
|
|
1148
1226
|
hparams = self._download_experiment_hparams(experiment_info)
|
|
@@ -1150,3 +1228,28 @@ class TrainGUI:
|
|
|
1150
1228
|
if train_mode == "continue":
|
|
1151
1229
|
self._init_model(model_settings, {}, click_cb=False, validate=False)
|
|
1152
1230
|
# ----------------------------------------- #
|
|
1231
|
+
|
|
1232
|
+
def _extract_state_from_env(self):
|
|
1233
|
+
import ast
|
|
1234
|
+
import os
|
|
1235
|
+
|
|
1236
|
+
base = "modal.state"
|
|
1237
|
+
state = {}
|
|
1238
|
+
for key, value in os.environ.items():
|
|
1239
|
+
state_part = state
|
|
1240
|
+
if key.startswith(base):
|
|
1241
|
+
key = key.replace(base + ".", "")
|
|
1242
|
+
parts = key.split(".")
|
|
1243
|
+
while len(parts) > 1:
|
|
1244
|
+
part = parts.pop(0)
|
|
1245
|
+
state_part.setdefault(part, {})
|
|
1246
|
+
state_part = state_part[part]
|
|
1247
|
+
part = parts.pop(0)
|
|
1248
|
+
if value and (value[0] == "[" or value.isdigit()):
|
|
1249
|
+
state_part[part] = ast.literal_eval(value)
|
|
1250
|
+
elif value in ["True", "true", "False", "false"]:
|
|
1251
|
+
state_part[part] = value in ["True", "true"]
|
|
1252
|
+
else:
|
|
1253
|
+
state_part[part] = value
|
|
1254
|
+
return state
|
|
1255
|
+
# ----------------------------------------- #
|
|
@@ -48,7 +48,7 @@ class HyperparametersSelector:
|
|
|
48
48
|
self.run_model_benchmark_checkbox = Checkbox(
|
|
49
49
|
content="Run Model Benchmark evaluation", checked=True
|
|
50
50
|
)
|
|
51
|
-
self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=
|
|
51
|
+
self.run_speedtest_checkbox = Checkbox(content="Run speed test", checked=False)
|
|
52
52
|
|
|
53
53
|
self.model_benchmark_field = Field(
|
|
54
54
|
title="Model Evaluation Benchmark",
|
|
@@ -26,6 +26,7 @@ class ModelSelector:
|
|
|
26
26
|
|
|
27
27
|
def __init__(self, api: Api, framework: str, models: list, app_options: dict = {}):
|
|
28
28
|
# Init widgets
|
|
29
|
+
self.api = api
|
|
29
30
|
self.pretrained_models_table = None
|
|
30
31
|
self.experiment_selector = None
|
|
31
32
|
self.model_source_tabs = None
|
|
@@ -50,7 +51,7 @@ class ModelSelector:
|
|
|
50
51
|
|
|
51
52
|
# GUI Components
|
|
52
53
|
self.pretrained_models_table = PretrainedModelsSelector(self.models)
|
|
53
|
-
experiment_infos = get_experiment_infos(api, self.team_id, framework)
|
|
54
|
+
experiment_infos = get_experiment_infos(self.api, self.team_id, framework)
|
|
54
55
|
if self.app_options.get("legacy_checkpoints", False):
|
|
55
56
|
try:
|
|
56
57
|
framework_cls = FrameworkMapper.get_framework_cls(framework, self.team_id)
|
|
@@ -59,7 +60,7 @@ class ModelSelector:
|
|
|
59
60
|
except:
|
|
60
61
|
logger.warning(f"Legacy checkpoints are not available for '{framework}'")
|
|
61
62
|
|
|
62
|
-
self.experiment_selector = ExperimentSelector(self.team_id, experiment_infos)
|
|
63
|
+
self.experiment_selector = ExperimentSelector(self.api, self.team_id, experiment_infos)
|
|
63
64
|
|
|
64
65
|
tab_titles = []
|
|
65
66
|
tab_descriptions = []
|
|
@@ -85,6 +86,7 @@ class ModelSelector:
|
|
|
85
86
|
self.validator_text = Text("")
|
|
86
87
|
self.validator_text.hide()
|
|
87
88
|
self.button = Button("Select")
|
|
89
|
+
|
|
88
90
|
self.display_widgets.extend([self.model_source_tabs, self.validator_text, self.button])
|
|
89
91
|
# -------------------------------- #
|
|
90
92
|
|
|
@@ -118,14 +120,14 @@ class ModelSelector:
|
|
|
118
120
|
model_name = _get_model_name(selected_row)
|
|
119
121
|
else:
|
|
120
122
|
selected_row = self.experiment_selector.get_selected_experiment_info()
|
|
121
|
-
model_name = selected_row.
|
|
123
|
+
model_name = selected_row.model_name
|
|
122
124
|
return model_name
|
|
123
125
|
|
|
124
126
|
def get_model_info(self) -> dict:
|
|
125
127
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
126
128
|
return self.pretrained_models_table.get_selected_row()
|
|
127
129
|
else:
|
|
128
|
-
return self.experiment_selector.get_selected_experiment_info()
|
|
130
|
+
return self.experiment_selector.get_selected_experiment_info().to_json()
|
|
129
131
|
|
|
130
132
|
def get_checkpoint_name(self) -> str:
|
|
131
133
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
@@ -146,7 +148,7 @@ class ModelSelector:
|
|
|
146
148
|
else:
|
|
147
149
|
checkpoint_name = self.experiment_selector.get_selected_checkpoint_name()
|
|
148
150
|
return checkpoint_name
|
|
149
|
-
|
|
151
|
+
|
|
150
152
|
def get_checkpoint_link(self) -> str:
|
|
151
153
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
152
154
|
selected_row = self.pretrained_models_table.get_selected_row()
|
|
@@ -182,4 +184,4 @@ class ModelSelector:
|
|
|
182
184
|
if self.get_model_source() == ModelSource.PRETRAINED:
|
|
183
185
|
return self.pretrained_models_table.get_selected_task_type()
|
|
184
186
|
else:
|
|
185
|
-
return self.experiment_selector.
|
|
187
|
+
return self.experiment_selector.get_selected_experiment_info().task_type
|