supervisely 6.73.253__py3-none-any.whl → 6.73.255__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/file_api.py +16 -5
- supervisely/api/task_api.py +4 -2
- supervisely/app/widgets/field/field.py +10 -7
- supervisely/app/widgets/grid_gallery_v2/grid_gallery_v2.py +3 -1
- supervisely/convert/image/sly/sly_image_converter.py +1 -1
- supervisely/nn/benchmark/base_benchmark.py +33 -35
- supervisely/nn/benchmark/base_evaluator.py +27 -1
- supervisely/nn/benchmark/base_visualizer.py +8 -11
- supervisely/nn/benchmark/comparison/base_visualizer.py +147 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/__init__.py +1 -1
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/avg_precision_by_class.py +5 -7
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +4 -6
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/{explore_predicttions.py → explore_predictions.py} +17 -17
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +7 -9
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +11 -22
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +3 -5
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +22 -20
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +12 -6
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +31 -76
- supervisely/nn/benchmark/comparison/model_comparison.py +112 -19
- supervisely/nn/benchmark/comparison/semantic_segmentation/__init__.py +0 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/text_templates.py +128 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/__init__.py +21 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/classwise_error_analysis.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/explore_predictions.py +141 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/frequently_confused.py +71 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/iou_eou.py +68 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/overview.py +223 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/renormalized_error_ou.py +57 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/vis_metrics/speedtest.py +314 -0
- supervisely/nn/benchmark/comparison/semantic_segmentation/visualizer.py +159 -0
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/evaluator.py +1 -1
- supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +3 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +1 -1
- supervisely/nn/benchmark/object_detection/visualizer.py +5 -10
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +12 -2
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +8 -9
- supervisely/nn/benchmark/semantic_segmentation/text_templates.py +2 -2
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +31 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -3
- supervisely/nn/benchmark/semantic_segmentation/visualizer.py +7 -6
- supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +3 -21
- supervisely/nn/benchmark/visualization/renderer.py +25 -10
- supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +1 -0
- supervisely/nn/inference/inference.py +1 -0
- supervisely/nn/training/gui/gui.py +32 -10
- supervisely/nn/training/gui/training_artifacts.py +145 -0
- supervisely/nn/training/gui/training_process.py +3 -19
- supervisely/nn/training/train_app.py +179 -70
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/METADATA +1 -1
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/RECORD +59 -47
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/vis_metric.py +0 -19
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/LICENSE +0 -0
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/WHEEL +0 -0
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.253.dist-info → supervisely-6.73.255.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from supervisely import Api
|
|
4
|
+
from supervisely.app.widgets import (
|
|
5
|
+
Card,
|
|
6
|
+
Container,
|
|
7
|
+
Empty,
|
|
8
|
+
Field,
|
|
9
|
+
Flexbox,
|
|
10
|
+
FolderThumbnail,
|
|
11
|
+
ReportThumbnail,
|
|
12
|
+
Text,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
PYTORCH_ICON = "https://img.icons8.com/?size=100&id=jH4BpkMnRrU5&format=png&color=000000"
|
|
16
|
+
ONNX_ICON = "https://artwork.lfaidata.foundation/projects/onnx/icon/color/onnx-icon-color.png"
|
|
17
|
+
TRT_ICON = "https://img.icons8.com/?size=100&id=yqf95864UzeQ&format=png&color=000000"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TrainingArtifacts:
|
|
21
|
+
title = "Training Artifacts"
|
|
22
|
+
description = "All outputs of the training process will appear here"
|
|
23
|
+
lock_message = "Artifacts will be available after training is completed"
|
|
24
|
+
|
|
25
|
+
def __init__(self, app_options: Dict[str, Any]):
|
|
26
|
+
self.display_widgets = []
|
|
27
|
+
self.success_message_text = (
|
|
28
|
+
"Training completed. Training artifacts were uploaded to Team Files. "
|
|
29
|
+
"You can find and open tensorboard logs in the artifacts folder via the "
|
|
30
|
+
"<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
|
|
31
|
+
)
|
|
32
|
+
self.app_options = app_options
|
|
33
|
+
|
|
34
|
+
# GUI Components
|
|
35
|
+
self.validator_text = Text("")
|
|
36
|
+
self.validator_text.hide()
|
|
37
|
+
self.display_widgets.extend([self.validator_text])
|
|
38
|
+
|
|
39
|
+
# Outputs
|
|
40
|
+
self.artifacts_thumbnail = FolderThumbnail()
|
|
41
|
+
self.artifacts_thumbnail.hide()
|
|
42
|
+
|
|
43
|
+
self.artifacts_field = Field(
|
|
44
|
+
title="Artifacts",
|
|
45
|
+
description="Contains all outputs of the training process",
|
|
46
|
+
content=self.artifacts_thumbnail,
|
|
47
|
+
)
|
|
48
|
+
self.artifacts_field.hide()
|
|
49
|
+
self.display_widgets.extend([self.artifacts_field])
|
|
50
|
+
|
|
51
|
+
# Optional Model Benchmark
|
|
52
|
+
if app_options.get("model_benchmark", False):
|
|
53
|
+
self.model_benchmark_report_thumbnail = ReportThumbnail()
|
|
54
|
+
self.model_benchmark_report_thumbnail.hide()
|
|
55
|
+
|
|
56
|
+
self.model_benchmark_fail_text = Text(
|
|
57
|
+
text="Model evaluation did not finish successfully. Please check the app logs for details.",
|
|
58
|
+
status="error",
|
|
59
|
+
)
|
|
60
|
+
self.model_benchmark_fail_text.hide()
|
|
61
|
+
|
|
62
|
+
self.model_benchmark_widgets = Container(
|
|
63
|
+
[self.model_benchmark_report_thumbnail, self.model_benchmark_fail_text]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.model_benchmark_report_field = Field(
|
|
67
|
+
title="Model Benchmark",
|
|
68
|
+
description="Evaluation report of the trained model",
|
|
69
|
+
content=self.model_benchmark_widgets,
|
|
70
|
+
)
|
|
71
|
+
self.model_benchmark_report_field.hide()
|
|
72
|
+
self.display_widgets.extend([self.model_benchmark_report_field])
|
|
73
|
+
# -------------------------------- #
|
|
74
|
+
|
|
75
|
+
# PyTorch, ONNX, TensorRT demo
|
|
76
|
+
self.inference_demo_field = []
|
|
77
|
+
model_demo = self.app_options.get("demo", None)
|
|
78
|
+
if model_demo is not None:
|
|
79
|
+
pytorch_demo_link = model_demo.get("pytorch", None)
|
|
80
|
+
if pytorch_demo_link is not None:
|
|
81
|
+
pytorch_icon = Field.Icon(image_url=PYTORCH_ICON, bg_color_rgb=[255, 255, 255])
|
|
82
|
+
self.pytorch_instruction = Field(
|
|
83
|
+
title="PyTorch",
|
|
84
|
+
description="Open file",
|
|
85
|
+
description_url=pytorch_demo_link,
|
|
86
|
+
icon=pytorch_icon,
|
|
87
|
+
content=Empty(),
|
|
88
|
+
)
|
|
89
|
+
self.pytorch_instruction.hide()
|
|
90
|
+
self.inference_demo_field.extend([self.pytorch_instruction])
|
|
91
|
+
|
|
92
|
+
onnx_demo_link = model_demo.get("onnx", None)
|
|
93
|
+
if onnx_demo_link is not None:
|
|
94
|
+
if self.app_options.get("export_onnx_supported", False):
|
|
95
|
+
onnx_icon = Field.Icon(image_url=ONNX_ICON, bg_color_rgb=[255, 255, 255])
|
|
96
|
+
self.onnx_instruction = Field(
|
|
97
|
+
title="ONNX",
|
|
98
|
+
description="Open file",
|
|
99
|
+
description_url=onnx_demo_link,
|
|
100
|
+
icon=onnx_icon,
|
|
101
|
+
content=Empty(),
|
|
102
|
+
)
|
|
103
|
+
self.onnx_instruction.hide()
|
|
104
|
+
self.inference_demo_field.extend([self.onnx_instruction])
|
|
105
|
+
|
|
106
|
+
trt_demo_link = model_demo.get("tensorrt", None)
|
|
107
|
+
if trt_demo_link is not None:
|
|
108
|
+
if self.app_options.get("export_tensorrt_supported", False):
|
|
109
|
+
trt_icon = Field.Icon(image_url=TRT_ICON, bg_color_rgb=[255, 255, 255])
|
|
110
|
+
self.trt_instruction = Field(
|
|
111
|
+
title="TensorRT",
|
|
112
|
+
description="Open file",
|
|
113
|
+
description_url=trt_demo_link,
|
|
114
|
+
icon=trt_icon,
|
|
115
|
+
content=Empty(),
|
|
116
|
+
)
|
|
117
|
+
self.trt_instruction.hide()
|
|
118
|
+
self.inference_demo_field.extend([self.trt_instruction])
|
|
119
|
+
|
|
120
|
+
demo_overview_link = model_demo.get("overview", None)
|
|
121
|
+
self.inference_demo_field = Field(
|
|
122
|
+
title="How to run inference",
|
|
123
|
+
description="Instructions on how to use your checkpoints outside of Supervisely Platform",
|
|
124
|
+
content=Flexbox(self.inference_demo_field),
|
|
125
|
+
title_url=demo_overview_link,
|
|
126
|
+
)
|
|
127
|
+
self.inference_demo_field.hide()
|
|
128
|
+
self.display_widgets.extend([self.inference_demo_field])
|
|
129
|
+
# -------------------------------- #
|
|
130
|
+
|
|
131
|
+
self.container = Container(self.display_widgets)
|
|
132
|
+
self.card = Card(
|
|
133
|
+
title=self.title,
|
|
134
|
+
description=self.description,
|
|
135
|
+
content=self.container,
|
|
136
|
+
lock_message=self.lock_message,
|
|
137
|
+
)
|
|
138
|
+
self.card.lock()
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def widgets_to_disable(self) -> list:
|
|
142
|
+
return []
|
|
143
|
+
|
|
144
|
+
def validate_step(self) -> bool:
|
|
145
|
+
return True
|
|
@@ -23,23 +23,18 @@ class TrainingProcess:
|
|
|
23
23
|
|
|
24
24
|
def __init__(self, app_options: Dict[str, Any]):
|
|
25
25
|
self.display_widgets = []
|
|
26
|
-
self.success_message_text = (
|
|
27
|
-
"Training completed. Training artifacts were uploaded to Team Files. "
|
|
28
|
-
"You can find and open tensorboard logs in the artifacts folder via the "
|
|
29
|
-
"<a href='https://ecosystem.supervisely.com/apps/tensorboard-logs-viewer' target='_blank'>Tensorboard</a> app."
|
|
30
|
-
)
|
|
31
26
|
self.app_options = app_options
|
|
32
27
|
|
|
33
28
|
# GUI Components
|
|
34
29
|
# Optional Select CUDA device
|
|
35
30
|
if self.app_options.get("device_selector", False):
|
|
36
31
|
self.select_device = SelectCudaDevice()
|
|
37
|
-
self.
|
|
32
|
+
self.select_device_field = Field(
|
|
38
33
|
title="Select CUDA device",
|
|
39
34
|
description="The device on which the model will be trained",
|
|
40
35
|
content=self.select_device,
|
|
41
36
|
)
|
|
42
|
-
self.display_widgets.extend([self.
|
|
37
|
+
self.display_widgets.extend([self.select_device_field])
|
|
43
38
|
# -------------------------------- #
|
|
44
39
|
|
|
45
40
|
self.experiment_name_input = Input("Enter experiment name")
|
|
@@ -63,26 +58,15 @@ class TrainingProcess:
|
|
|
63
58
|
self.validator_text = Text("")
|
|
64
59
|
self.validator_text.hide()
|
|
65
60
|
|
|
66
|
-
self.artifacts_thumbnail = FolderThumbnail()
|
|
67
|
-
self.artifacts_thumbnail.hide()
|
|
68
|
-
|
|
69
61
|
self.display_widgets.extend(
|
|
70
62
|
[
|
|
71
63
|
self.experiment_name_field,
|
|
72
64
|
button_container,
|
|
73
65
|
self.validator_text,
|
|
74
|
-
self.artifacts_thumbnail,
|
|
75
66
|
]
|
|
76
67
|
)
|
|
77
68
|
# -------------------------------- #
|
|
78
69
|
|
|
79
|
-
# Optional Model Benchmark
|
|
80
|
-
if app_options.get("model_benchmark", False):
|
|
81
|
-
self.model_benchmark_report_thumbnail = ReportThumbnail()
|
|
82
|
-
self.model_benchmark_report_thumbnail.hide()
|
|
83
|
-
self.display_widgets.extend([self.model_benchmark_report_thumbnail])
|
|
84
|
-
# -------------------------------- #
|
|
85
|
-
|
|
86
70
|
self.container = Container(self.display_widgets)
|
|
87
71
|
self.card = Card(
|
|
88
72
|
title=self.title,
|
|
@@ -96,7 +80,7 @@ class TrainingProcess:
|
|
|
96
80
|
def widgets_to_disable(self) -> list:
|
|
97
81
|
widgets = [self.experiment_name_input]
|
|
98
82
|
if self.app_options.get("device_selector", False):
|
|
99
|
-
widgets.
|
|
83
|
+
widgets.extend([self.select_device, self.select_device_field])
|
|
100
84
|
return widgets
|
|
101
85
|
|
|
102
86
|
def validate_step(self) -> bool:
|
|
@@ -114,18 +114,20 @@ class TrainApp:
|
|
|
114
114
|
self._tensorboard_port = 6006
|
|
115
115
|
|
|
116
116
|
if is_production():
|
|
117
|
+
self._app_name = sly_env.app_name()
|
|
117
118
|
self.task_id = sly_env.task_id()
|
|
118
119
|
else:
|
|
119
|
-
self.
|
|
120
|
+
self._app_name = sly_env.app_name(raise_not_found=False)
|
|
121
|
+
self.task_id = sly_env.task_id(raise_not_found=False)
|
|
122
|
+
if self.task_id is None:
|
|
123
|
+
self.task_id = "debug-session"
|
|
120
124
|
logger.info("TrainApp is running in debug mode")
|
|
121
125
|
|
|
122
126
|
self.framework_name = framework_name
|
|
123
127
|
self._team_id = sly_env.team_id()
|
|
124
128
|
self._workspace_id = sly_env.workspace_id()
|
|
125
|
-
self._app_name = sly_env.app_name(raise_not_found=False)
|
|
126
129
|
self._tensorboard_process = None
|
|
127
130
|
|
|
128
|
-
# TODO: read files
|
|
129
131
|
self._models = self._load_models(models)
|
|
130
132
|
self._hyperparameters = self._load_hyperparameters(hyperparameters)
|
|
131
133
|
self._app_options = self._load_app_options(app_options)
|
|
@@ -511,13 +513,25 @@ class TrainApp:
|
|
|
511
513
|
splits_data = self._postprocess_splits()
|
|
512
514
|
|
|
513
515
|
# Step 3. Upload artifacts
|
|
516
|
+
self._set_text_status("uploading")
|
|
514
517
|
remote_dir, file_info = self._upload_artifacts()
|
|
515
518
|
|
|
516
519
|
# Step 4. Run Model Benchmark
|
|
517
|
-
|
|
520
|
+
mb_eval_lnk_file_info, mb_eval_report, mb_eval_report_id, eval_metrics = (
|
|
521
|
+
None,
|
|
522
|
+
None,
|
|
523
|
+
None,
|
|
524
|
+
{},
|
|
525
|
+
)
|
|
518
526
|
if self.is_model_benchmark_enabled:
|
|
519
527
|
try:
|
|
520
|
-
|
|
528
|
+
self._set_text_status("benchmark")
|
|
529
|
+
(
|
|
530
|
+
mb_eval_lnk_file_info,
|
|
531
|
+
mb_eval_report,
|
|
532
|
+
mb_eval_report_id,
|
|
533
|
+
eval_metrics,
|
|
534
|
+
) = self._run_model_benchmark(
|
|
521
535
|
self.output_dir, remote_dir, experiment_info, splits_data
|
|
522
536
|
)
|
|
523
537
|
except Exception as e:
|
|
@@ -528,12 +542,12 @@ class TrainApp:
|
|
|
528
542
|
if self.gui.hyperparameters_selector.is_export_required():
|
|
529
543
|
try:
|
|
530
544
|
export_weights = self._export_weights(experiment_info)
|
|
531
|
-
self._set_progress_status("finalizing")
|
|
532
545
|
export_weights = self._upload_export_weights(export_weights, remote_dir)
|
|
533
546
|
except Exception as e:
|
|
534
547
|
logger.error(f"Export weights failed: {e}")
|
|
535
548
|
|
|
536
549
|
# Step 6. Generate and upload additional files
|
|
550
|
+
self._set_text_status("metadata")
|
|
537
551
|
self._generate_experiment_info(
|
|
538
552
|
remote_dir, experiment_info, eval_metrics, mb_eval_report_id, export_weights
|
|
539
553
|
)
|
|
@@ -543,16 +557,16 @@ class TrainApp:
|
|
|
543
557
|
self._generate_model_meta(remote_dir, experiment_info)
|
|
544
558
|
|
|
545
559
|
# Step 7. Set output widgets
|
|
546
|
-
self.
|
|
560
|
+
self._set_text_status("reset")
|
|
561
|
+
self._set_training_output(remote_dir, file_info, mb_eval_report)
|
|
562
|
+
self._set_ws_progress_status("completed")
|
|
547
563
|
|
|
548
564
|
# Step 8. Workflow output
|
|
549
565
|
if is_production():
|
|
550
|
-
self._workflow_output(remote_dir, file_info,
|
|
551
|
-
|
|
552
|
-
self._set_progress_status("completed")
|
|
566
|
+
self._workflow_output(remote_dir, file_info, mb_eval_lnk_file_info, mb_eval_report_id)
|
|
553
567
|
|
|
554
568
|
def register_inference_class(
|
|
555
|
-
self, inference_class: Inference, inference_settings: dict =
|
|
569
|
+
self, inference_class: Inference, inference_settings: dict = None
|
|
556
570
|
) -> None:
|
|
557
571
|
"""
|
|
558
572
|
Registers an inference class for the training application to do model benchmarking.
|
|
@@ -1582,30 +1596,83 @@ class TrainApp:
|
|
|
1582
1596
|
file_info = self._api.file.get_info_by_path(self._team_id, join(remote_dir, "open_app.lnk"))
|
|
1583
1597
|
return remote_dir, file_info
|
|
1584
1598
|
|
|
1585
|
-
def _set_training_output(
|
|
1599
|
+
def _set_training_output(
|
|
1600
|
+
self, remote_dir: str, file_info: FileInfo, mb_eval_report=None
|
|
1601
|
+
) -> None:
|
|
1586
1602
|
"""
|
|
1587
1603
|
Sets the training output in the GUI.
|
|
1588
1604
|
"""
|
|
1605
|
+
self.gui.set_next_step()
|
|
1589
1606
|
logger.info("All training artifacts uploaded successfully")
|
|
1590
1607
|
self.gui.training_process.start_button.loading = False
|
|
1591
1608
|
self.gui.training_process.start_button.disable()
|
|
1592
1609
|
self.gui.training_process.stop_button.disable()
|
|
1593
1610
|
# self.gui.training_logs.tensorboard_button.disable()
|
|
1594
1611
|
|
|
1612
|
+
# Set artifacts to GUI
|
|
1595
1613
|
set_directory(remote_dir)
|
|
1596
|
-
self.gui.
|
|
1597
|
-
self.gui.
|
|
1598
|
-
self.gui.
|
|
1599
|
-
|
|
1614
|
+
self.gui.training_artifacts.artifacts_thumbnail.set(file_info)
|
|
1615
|
+
self.gui.training_artifacts.artifacts_thumbnail.show()
|
|
1616
|
+
self.gui.training_artifacts.artifacts_field.show()
|
|
1617
|
+
# ---------------------------- #
|
|
1618
|
+
|
|
1619
|
+
# Set model benchmark to GUI
|
|
1620
|
+
if self._app_options.get("model_benchmark", False):
|
|
1621
|
+
if mb_eval_report is not None:
|
|
1622
|
+
self.gui.training_artifacts.model_benchmark_report_thumbnail.set(mb_eval_report)
|
|
1623
|
+
self.gui.training_artifacts.model_benchmark_report_thumbnail.show()
|
|
1624
|
+
self.gui.training_artifacts.model_benchmark_report_field.show()
|
|
1625
|
+
else:
|
|
1626
|
+
self.gui.training_artifacts.model_benchmark_fail_text.show()
|
|
1627
|
+
self.gui.training_artifacts.model_benchmark_report_field.show()
|
|
1628
|
+
# ---------------------------- #
|
|
1629
|
+
|
|
1630
|
+
# Set instruction to GUI
|
|
1631
|
+
demo_options = self._app_options.get("demo", {})
|
|
1632
|
+
if demo_options:
|
|
1633
|
+
# Show PyTorch demo if available
|
|
1634
|
+
pytorch_demo = demo_options.get("pytorch")
|
|
1635
|
+
if pytorch_demo:
|
|
1636
|
+
self.gui.training_artifacts.pytorch_instruction.show()
|
|
1637
|
+
|
|
1638
|
+
# Show ONNX demo if supported and available
|
|
1639
|
+
onnx_demo = demo_options.get("onnx")
|
|
1640
|
+
if (
|
|
1641
|
+
self._app_options.get("export_onnx_supported", False)
|
|
1642
|
+
and self.gui.hyperparameters_selector.get_export_onnx_checkbox_value()
|
|
1643
|
+
and onnx_demo
|
|
1644
|
+
):
|
|
1645
|
+
self.gui.training_artifacts.onnx_instruction.show()
|
|
1646
|
+
|
|
1647
|
+
# Show TensorRT demo if supported and available
|
|
1648
|
+
tensorrt_demo = demo_options.get("tensorrt")
|
|
1649
|
+
if (
|
|
1650
|
+
self._app_options.get("export_tensorrt_supported", False)
|
|
1651
|
+
and self.gui.hyperparameters_selector.get_export_tensorrt_checkbox_value()
|
|
1652
|
+
and tensorrt_demo
|
|
1653
|
+
):
|
|
1654
|
+
self.gui.training_artifacts.trt_instruction.show()
|
|
1655
|
+
|
|
1656
|
+
# Show the inference demo widget if overview or any demo is available
|
|
1657
|
+
demo_overview = self._app_options.get("overview", {})
|
|
1658
|
+
if demo_overview or any([pytorch_demo, onnx_demo, tensorrt_demo]):
|
|
1659
|
+
self.gui.training_artifacts.inference_demo_field.show()
|
|
1660
|
+
# ---------------------------- #
|
|
1661
|
+
|
|
1662
|
+
# Set status to completed and unlock
|
|
1663
|
+
self.gui.training_artifacts.validator_text.set(
|
|
1664
|
+
self.gui.training_artifacts.success_message_text, "success"
|
|
1600
1665
|
)
|
|
1666
|
+
self.gui.training_artifacts.validator_text.show()
|
|
1667
|
+
self.gui.training_artifacts.card.unlock()
|
|
1668
|
+
# ---------------------------- #
|
|
1601
1669
|
|
|
1602
1670
|
# Model Benchmark
|
|
1603
1671
|
def _get_eval_results_dir_name(self) -> str:
|
|
1604
1672
|
"""
|
|
1605
1673
|
Returns the evaluation results path.
|
|
1606
1674
|
"""
|
|
1607
|
-
|
|
1608
|
-
task_dir = f"{self.task_id}_{task_info['meta']['app']['name']}"
|
|
1675
|
+
task_dir = f"{self.task_id}_{self._app_name}"
|
|
1609
1676
|
eval_res_dir = (
|
|
1610
1677
|
f"/model-benchmark/{self.project_info.id}_{self.project_info.name}/{task_dir}/"
|
|
1611
1678
|
)
|
|
@@ -1633,13 +1700,13 @@ class TrainApp:
|
|
|
1633
1700
|
:return: Evaluation report, report ID and evaluation metrics.
|
|
1634
1701
|
:rtype: tuple
|
|
1635
1702
|
"""
|
|
1636
|
-
|
|
1703
|
+
lnk_file_info, report, report_id, eval_metrics = None, None, None, {}
|
|
1637
1704
|
if self._inference_class is None:
|
|
1638
|
-
logger.
|
|
1705
|
+
logger.warning(
|
|
1639
1706
|
"Inference class is not registered, model benchmark disabled. "
|
|
1640
1707
|
"Use 'register_inference_class' method to register inference class."
|
|
1641
1708
|
)
|
|
1642
|
-
return
|
|
1709
|
+
return lnk_file_info, report, report_id, eval_metrics
|
|
1643
1710
|
|
|
1644
1711
|
# Can't get task type from session. requires before session init
|
|
1645
1712
|
supported_task_types = [
|
|
@@ -1652,7 +1719,7 @@ class TrainApp:
|
|
|
1652
1719
|
f"Task type: '{task_type}' is not supported for Model Benchmark. "
|
|
1653
1720
|
f"Supported tasks: {', '.join(task_type)}"
|
|
1654
1721
|
)
|
|
1655
|
-
return
|
|
1722
|
+
return lnk_file_info, report, report_id, eval_metrics
|
|
1656
1723
|
|
|
1657
1724
|
logger.info("Running Model Benchmark evaluation")
|
|
1658
1725
|
try:
|
|
@@ -1661,14 +1728,6 @@ class TrainApp:
|
|
|
1661
1728
|
best_filename = sly_fs.get_file_name_with_ext(best_checkpoint)
|
|
1662
1729
|
remote_best_checkpoint = join(remote_checkpoints_dir, best_filename)
|
|
1663
1730
|
|
|
1664
|
-
config_path = experiment_info["model_files"].get("config")
|
|
1665
|
-
if config_path is not None:
|
|
1666
|
-
remote_config_path = join(
|
|
1667
|
-
remote_artifacts_dir, sly_fs.get_file_name_with_ext(config_path)
|
|
1668
|
-
)
|
|
1669
|
-
else:
|
|
1670
|
-
remote_config_path = None
|
|
1671
|
-
|
|
1672
1731
|
logger.info(f"Creating the report for the best model: {best_filename!r}")
|
|
1673
1732
|
self.gui.training_process.validator_text.set(
|
|
1674
1733
|
f"Creating evaluation report for the best model: {best_filename!r}",
|
|
@@ -1788,50 +1847,40 @@ class TrainApp:
|
|
|
1788
1847
|
|
|
1789
1848
|
# 6. Speed test
|
|
1790
1849
|
if self.gui.hyperparameters_selector.get_speedtest_checkbox_value() is True:
|
|
1850
|
+
self.progress_bar_secondary.show()
|
|
1791
1851
|
bm.run_speedtest(session, self.project_info.id)
|
|
1792
|
-
self.progress_bar_secondary.hide()
|
|
1852
|
+
self.progress_bar_secondary.hide()
|
|
1793
1853
|
bm.upload_speedtest_results(eval_res_dir + "/speedtest/")
|
|
1794
1854
|
|
|
1795
1855
|
# 7. Prepare visualizations, report and upload
|
|
1796
1856
|
bm.visualize()
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1857
|
+
_ = bm.upload_visualizations(eval_res_dir + "/visualizations/")
|
|
1858
|
+
lnk_file_info = bm.lnk
|
|
1859
|
+
report = bm.report
|
|
1860
|
+
report_id = bm.report.id
|
|
1800
1861
|
eval_metrics = bm.key_metrics
|
|
1801
1862
|
|
|
1802
1863
|
# 8. UI updates
|
|
1803
|
-
benchmark_report_template = self._api.file.get_info_by_path(
|
|
1804
|
-
self._team_id, remote_dir + "template.vue"
|
|
1805
|
-
)
|
|
1806
|
-
|
|
1807
|
-
self.gui.training_process.model_benchmark_report_thumbnail.set(
|
|
1808
|
-
benchmark_report_template
|
|
1809
|
-
)
|
|
1810
|
-
self.gui.training_process.model_benchmark_report_thumbnail.show()
|
|
1811
1864
|
self.progress_bar_main.hide()
|
|
1812
1865
|
self.progress_bar_secondary.hide()
|
|
1813
1866
|
logger.info("Model benchmark evaluation completed successfully")
|
|
1814
1867
|
logger.info(
|
|
1815
1868
|
f"Predictions project name: {bm.dt_project_info.name}. Workspace_id: {bm.dt_project_info.workspace_id}"
|
|
1816
1869
|
)
|
|
1817
|
-
logger.info(
|
|
1818
|
-
f"Differences project name: {bm.diff_project_info.name}. Workspace_id: {bm.diff_project_info.workspace_id}"
|
|
1819
|
-
)
|
|
1820
1870
|
except Exception as e:
|
|
1821
1871
|
logger.error(f"Model benchmark failed. {repr(e)}", exc_info=True)
|
|
1822
|
-
self.
|
|
1823
|
-
"Finalizing and uploading training artifacts...", "info"
|
|
1824
|
-
)
|
|
1872
|
+
self._set_text_status("finalizing")
|
|
1825
1873
|
self.progress_bar_main.hide()
|
|
1826
1874
|
self.progress_bar_secondary.hide()
|
|
1827
1875
|
try:
|
|
1828
1876
|
if bm.dt_project_info:
|
|
1829
1877
|
self._api.project.remove(bm.dt_project_info.id)
|
|
1830
|
-
|
|
1831
|
-
|
|
1878
|
+
diff_project_info = bm.get_diff_project_info()
|
|
1879
|
+
if diff_project_info:
|
|
1880
|
+
self._api.project.remove(diff_project_info.id)
|
|
1832
1881
|
except Exception as e2:
|
|
1833
|
-
return
|
|
1834
|
-
return
|
|
1882
|
+
return lnk_file_info, report, report_id, eval_metrics
|
|
1883
|
+
return lnk_file_info, report, report_id, eval_metrics
|
|
1835
1884
|
|
|
1836
1885
|
# ----------------------------------------- #
|
|
1837
1886
|
|
|
@@ -1932,7 +1981,8 @@ class TrainApp:
|
|
|
1932
1981
|
)
|
|
1933
1982
|
|
|
1934
1983
|
meta = WorkflowMeta(
|
|
1935
|
-
relation_settings=mb_relation_settings,
|
|
1984
|
+
relation_settings=mb_relation_settings,
|
|
1985
|
+
node_settings=node_settings,
|
|
1936
1986
|
)
|
|
1937
1987
|
self._api.app.workflow.add_output_file(model_benchmark_report, meta=meta)
|
|
1938
1988
|
else:
|
|
@@ -2072,11 +2122,12 @@ class TrainApp:
|
|
|
2072
2122
|
message = "Error occurred during training initialization. Please check the logs for more details."
|
|
2073
2123
|
self._show_error(message, e)
|
|
2074
2124
|
self._restore_train_widgets_state_on_error()
|
|
2075
|
-
self.
|
|
2125
|
+
self._set_ws_progress_status("reset")
|
|
2076
2126
|
return
|
|
2077
2127
|
|
|
2078
2128
|
try:
|
|
2079
|
-
self.
|
|
2129
|
+
self._set_text_status("preparing")
|
|
2130
|
+
self._set_ws_progress_status("preparing")
|
|
2080
2131
|
self._prepare()
|
|
2081
2132
|
except Exception as e:
|
|
2082
2133
|
message = (
|
|
@@ -2084,30 +2135,31 @@ class TrainApp:
|
|
|
2084
2135
|
)
|
|
2085
2136
|
self._show_error(message, e)
|
|
2086
2137
|
self._restore_train_widgets_state_on_error()
|
|
2087
|
-
self.
|
|
2138
|
+
self._set_ws_progress_status("reset")
|
|
2088
2139
|
return
|
|
2089
2140
|
|
|
2090
2141
|
try:
|
|
2091
|
-
self.
|
|
2142
|
+
self._set_text_status("training")
|
|
2092
2143
|
if self._app_options.get("train_logger", None) is None:
|
|
2093
|
-
self.
|
|
2144
|
+
self._set_ws_progress_status("training")
|
|
2094
2145
|
experiment_info = self._train_func()
|
|
2095
2146
|
except Exception as e:
|
|
2096
2147
|
message = "Error occurred during training. Please check the logs for more details."
|
|
2097
2148
|
self._show_error(message, e)
|
|
2098
2149
|
self._restore_train_widgets_state_on_error()
|
|
2099
|
-
self.
|
|
2150
|
+
self._set_ws_progress_status("reset")
|
|
2100
2151
|
return
|
|
2101
2152
|
|
|
2102
2153
|
try:
|
|
2103
|
-
self.
|
|
2154
|
+
self._set_text_status("finalizing")
|
|
2155
|
+
self._set_ws_progress_status("finalizing")
|
|
2104
2156
|
self._finalize(experiment_info)
|
|
2105
2157
|
self.gui.training_process.start_button.loading = False
|
|
2106
2158
|
except Exception as e:
|
|
2107
2159
|
message = "Error occurred during finalizing and uploading training artifacts . Please check the logs for more details."
|
|
2108
2160
|
self._show_error(message, e)
|
|
2109
2161
|
self._restore_train_widgets_state_on_error()
|
|
2110
|
-
self.
|
|
2162
|
+
self._set_ws_progress_status("reset")
|
|
2111
2163
|
return
|
|
2112
2164
|
|
|
2113
2165
|
def _show_error(self, message: str, e=None):
|
|
@@ -2121,12 +2173,18 @@ class TrainApp:
|
|
|
2121
2173
|
self._restore_train_widgets_state_on_error()
|
|
2122
2174
|
|
|
2123
2175
|
def _set_train_widgets_state_on_start(self):
|
|
2176
|
+
self.gui.training_artifacts.validator_text.hide()
|
|
2124
2177
|
self._validate_experiment_name()
|
|
2125
2178
|
self.gui.training_process.experiment_name_input.disable()
|
|
2126
2179
|
if self._app_options.get("device_selector", False):
|
|
2127
2180
|
self.gui.training_process.select_device._select.disable()
|
|
2128
2181
|
self.gui.training_process.select_device.disable()
|
|
2129
2182
|
|
|
2183
|
+
if self._app_options.get("model_benchmark", False):
|
|
2184
|
+
self.gui.training_artifacts.model_benchmark_report_thumbnail.hide()
|
|
2185
|
+
self.gui.training_artifacts.model_benchmark_fail_text.hide()
|
|
2186
|
+
self.gui.training_artifacts.model_benchmark_report_field.hide()
|
|
2187
|
+
|
|
2130
2188
|
self.gui.training_logs.card.unlock()
|
|
2131
2189
|
self.gui.stepper.set_active_step(7)
|
|
2132
2190
|
self.gui.training_process.validator_text.set("Training has been started...", "info")
|
|
@@ -2152,8 +2210,56 @@ class TrainApp:
|
|
|
2152
2210
|
raise ValueError(f"Experiment name contains invalid characters: {invalid_chars}")
|
|
2153
2211
|
return True
|
|
2154
2212
|
|
|
2155
|
-
def
|
|
2156
|
-
self,
|
|
2213
|
+
def _set_text_status(
|
|
2214
|
+
self,
|
|
2215
|
+
status: Literal[
|
|
2216
|
+
"reset",
|
|
2217
|
+
"completed",
|
|
2218
|
+
"training",
|
|
2219
|
+
"finalizing",
|
|
2220
|
+
"preparing",
|
|
2221
|
+
"uploading",
|
|
2222
|
+
"benchmark",
|
|
2223
|
+
"metadata",
|
|
2224
|
+
"export_onnx",
|
|
2225
|
+
"export_trt",
|
|
2226
|
+
],
|
|
2227
|
+
):
|
|
2228
|
+
|
|
2229
|
+
if status == "reset":
|
|
2230
|
+
self.gui.training_process.validator_text.set("", "text")
|
|
2231
|
+
elif status == "completed":
|
|
2232
|
+
self.gui.training_process.validator_text.set("Training completed", "success")
|
|
2233
|
+
elif status == "training":
|
|
2234
|
+
self.gui.training_process.validator_text.set("Training is in progress...", "info")
|
|
2235
|
+
elif status == "finalizing":
|
|
2236
|
+
self.gui.training_process.validator_text.set(
|
|
2237
|
+
"Finalizing and preparing training artifacts...", "info"
|
|
2238
|
+
)
|
|
2239
|
+
elif status == "preparing":
|
|
2240
|
+
self.gui.training_process.validator_text.set("Preparing data for training...", "info")
|
|
2241
|
+
elif status == "export_onnx":
|
|
2242
|
+
self.gui.training_process.validator_text.set(
|
|
2243
|
+
f"Converting to {RuntimeType.ONNXRUNTIME}", "info"
|
|
2244
|
+
)
|
|
2245
|
+
elif status == "export_trt":
|
|
2246
|
+
self.gui.training_process.validator_text.set(
|
|
2247
|
+
f"Converting to {RuntimeType.TENSORRT}", "info"
|
|
2248
|
+
)
|
|
2249
|
+
elif status == "uploading":
|
|
2250
|
+
self.gui.training_process.validator_text.set("Uploading training artifacts...", "info")
|
|
2251
|
+
elif status == "benchmark":
|
|
2252
|
+
self.gui.training_process.validator_text.set(
|
|
2253
|
+
"Running Model Benchmark evaluation...", "info"
|
|
2254
|
+
)
|
|
2255
|
+
elif status == "validating":
|
|
2256
|
+
self.gui.training_process.validator_text.set("Validating experiment...", "info")
|
|
2257
|
+
elif status == "metadata":
|
|
2258
|
+
self.gui.training_process.validator_text.set("Generating training metadata...", "info")
|
|
2259
|
+
|
|
2260
|
+
def _set_ws_progress_status(
|
|
2261
|
+
self,
|
|
2262
|
+
status: Literal["reset", "completed", "training", "finalizing", "preparing"],
|
|
2157
2263
|
):
|
|
2158
2264
|
message = ""
|
|
2159
2265
|
if status == "reset":
|
|
@@ -2180,9 +2286,7 @@ class TrainApp:
|
|
|
2180
2286
|
self.gui.hyperparameters_selector.get_export_onnx_checkbox_value() is True
|
|
2181
2287
|
and self._convert_onnx_func is not None
|
|
2182
2288
|
):
|
|
2183
|
-
self.
|
|
2184
|
-
f"Converting to {RuntimeType.ONNXRUNTIME}", "info"
|
|
2185
|
-
)
|
|
2289
|
+
self._set_text_status("export_onnx")
|
|
2186
2290
|
onnx_path = self._convert_onnx_func(experiment_info)
|
|
2187
2291
|
export_weights[RuntimeType.ONNXRUNTIME] = onnx_path
|
|
2188
2292
|
|
|
@@ -2190,9 +2294,7 @@ class TrainApp:
|
|
|
2190
2294
|
self.gui.hyperparameters_selector.get_export_tensorrt_checkbox_value() is True
|
|
2191
2295
|
and self._convert_tensorrt_func is not None
|
|
2192
2296
|
):
|
|
2193
|
-
self.
|
|
2194
|
-
f"Converting to {RuntimeType.TENSORRT}", "info"
|
|
2195
|
-
)
|
|
2297
|
+
self._set_text_status("export_trt")
|
|
2196
2298
|
tensorrt_path = self._convert_tensorrt_func(experiment_info)
|
|
2197
2299
|
export_weights[RuntimeType.TENSORRT] = tensorrt_path
|
|
2198
2300
|
return export_weights
|
|
@@ -2214,12 +2316,19 @@ class TrainApp:
|
|
|
2214
2316
|
unit="bytes",
|
|
2215
2317
|
unit_scale=True,
|
|
2216
2318
|
) as export_upload_secondary_pbar:
|
|
2319
|
+
self.progress_bar_secondary.show()
|
|
2217
2320
|
destination_path = join(remote_dir, self._export_dir_name, file_name)
|
|
2218
2321
|
self._api.file.upload(
|
|
2219
|
-
self._team_id,
|
|
2322
|
+
self._team_id,
|
|
2323
|
+
path,
|
|
2324
|
+
destination_path,
|
|
2325
|
+
export_upload_secondary_pbar,
|
|
2220
2326
|
)
|
|
2221
2327
|
export_upload_main_pbar.update(1)
|
|
2222
2328
|
|
|
2329
|
+
self.progress_bar_main.hide()
|
|
2330
|
+
self.progress_bar_secondary.hide()
|
|
2331
|
+
|
|
2223
2332
|
remote_export_weights = {
|
|
2224
2333
|
runtime: join(self._export_dir_name, sly_fs.get_file_name_with_ext(path))
|
|
2225
2334
|
for runtime, path in export_weights.items()
|