supervisely 6.73.242__py3-none-any.whl → 6.73.244__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +18 -0
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/card/card.py +3 -0
- supervisely/app/widgets/classes_table/classes_table.py +15 -1
- supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
- supervisely/app/widgets/custom_models_selector/template.html +1 -1
- supervisely/app/widgets/experiment_selector/__init__.py +0 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
- supervisely/app/widgets/experiment_selector/style.css +27 -0
- supervisely/app/widgets/experiment_selector/template.html +82 -0
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
- supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
- supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/__init__.py +3 -1
- supervisely/nn/artifacts/artifacts.py +10 -0
- supervisely/nn/artifacts/detectron2.py +2 -0
- supervisely/nn/artifacts/hrda.py +3 -0
- supervisely/nn/artifacts/mmclassification.py +2 -0
- supervisely/nn/artifacts/mmdetection.py +6 -3
- supervisely/nn/artifacts/mmsegmentation.py +2 -0
- supervisely/nn/artifacts/ritm.py +3 -1
- supervisely/nn/artifacts/rtdetr.py +2 -0
- supervisely/nn/artifacts/unet.py +2 -0
- supervisely/nn/artifacts/yolov5.py +3 -0
- supervisely/nn/artifacts/yolov8.py +7 -1
- supervisely/nn/experiments.py +113 -0
- supervisely/nn/inference/gui/__init__.py +3 -1
- supervisely/nn/inference/gui/gui.py +31 -232
- supervisely/nn/inference/gui/serving_gui.py +223 -0
- supervisely/nn/inference/gui/serving_gui_template.py +240 -0
- supervisely/nn/inference/inference.py +225 -24
- supervisely/nn/training/__init__.py +0 -0
- supervisely/nn/training/gui/__init__.py +1 -0
- supervisely/nn/training/gui/classes_selector.py +100 -0
- supervisely/nn/training/gui/gui.py +539 -0
- supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
- supervisely/nn/training/gui/input_selector.py +70 -0
- supervisely/nn/training/gui/model_selector.py +95 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
- supervisely/nn/training/gui/training_logs.py +93 -0
- supervisely/nn/training/gui/training_process.py +114 -0
- supervisely/nn/training/gui/utils.py +128 -0
- supervisely/nn/training/loggers/__init__.py +0 -0
- supervisely/nn/training/loggers/base_train_logger.py +58 -0
- supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
- supervisely/nn/training/train_app.py +2038 -0
- supervisely/nn/utils.py +5 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/RECORD +57 -35
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
|
@@ -13,6 +13,7 @@ from dataclasses import asdict
|
|
|
13
13
|
from functools import partial, wraps
|
|
14
14
|
from queue import Queue
|
|
15
15
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
16
|
+
from urllib.request import urlopen
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import requests
|
|
@@ -25,10 +26,13 @@ import supervisely.app.development as sly_app_development
|
|
|
25
26
|
import supervisely.imaging.image as sly_image
|
|
26
27
|
import supervisely.io.env as env
|
|
27
28
|
import supervisely.io.fs as fs
|
|
29
|
+
import supervisely.io.fs as sly_fs
|
|
30
|
+
import supervisely.io.json as sly_json
|
|
28
31
|
import supervisely.nn.inference.gui as GUI
|
|
29
32
|
from supervisely import DatasetInfo, ProjectInfo, VideoAnnotation, batched
|
|
30
33
|
from supervisely._utils import (
|
|
31
34
|
add_callback,
|
|
35
|
+
get_filename_from_headers,
|
|
32
36
|
is_debug_with_sly_net,
|
|
33
37
|
is_production,
|
|
34
38
|
rand_str,
|
|
@@ -59,7 +63,13 @@ from supervisely.geometry.any_geometry import AnyGeometry
|
|
|
59
63
|
from supervisely.imaging.color import get_predefined_colors
|
|
60
64
|
from supervisely.nn.inference.cache import InferenceImageCache
|
|
61
65
|
from supervisely.nn.prediction_dto import Prediction
|
|
62
|
-
from supervisely.nn.utils import
|
|
66
|
+
from supervisely.nn.utils import (
|
|
67
|
+
CheckpointInfo,
|
|
68
|
+
DeployInfo,
|
|
69
|
+
ModelPrecision,
|
|
70
|
+
ModelSource,
|
|
71
|
+
RuntimeType,
|
|
72
|
+
)
|
|
63
73
|
from supervisely.project import ProjectType
|
|
64
74
|
from supervisely.project.download import download_to_cache, read_from_cached_project
|
|
65
75
|
from supervisely.project.project_meta import ProjectMeta
|
|
@@ -74,6 +84,12 @@ except ImportError:
|
|
|
74
84
|
|
|
75
85
|
|
|
76
86
|
class Inference:
|
|
87
|
+
FRAMEWORK_NAME: str = None
|
|
88
|
+
"""Name of framework to register models in Supervisely"""
|
|
89
|
+
MODELS: str = None
|
|
90
|
+
"""Path to file with list of models"""
|
|
91
|
+
APP_OPTIONS: str = None
|
|
92
|
+
"""Path to file with app options"""
|
|
77
93
|
DEFAULT_BATCH_SIZE = 16
|
|
78
94
|
|
|
79
95
|
def __init__(
|
|
@@ -85,6 +101,7 @@ class Inference:
|
|
|
85
101
|
sliding_window_mode: Optional[Literal["basic", "advanced", "none"]] = "basic",
|
|
86
102
|
use_gui: Optional[bool] = False,
|
|
87
103
|
multithread_inference: Optional[bool] = True,
|
|
104
|
+
use_serving_gui_template: Optional[bool] = False,
|
|
88
105
|
):
|
|
89
106
|
if model_dir is None:
|
|
90
107
|
model_dir = os.path.join(get_data_dir(), "models")
|
|
@@ -92,8 +109,10 @@ class Inference:
|
|
|
92
109
|
self.device: str = None
|
|
93
110
|
self.runtime: str = None
|
|
94
111
|
self.model_precision: str = None
|
|
112
|
+
self.model_source: str = None
|
|
95
113
|
self.checkpoint_info: CheckpointInfo = None
|
|
96
114
|
self.max_batch_size: int = None # set it only if a model has a limit on the batch size
|
|
115
|
+
self.classes: List[str] = None
|
|
97
116
|
self._model_dir = model_dir
|
|
98
117
|
self._model_served = False
|
|
99
118
|
self._deploy_params: dict = None
|
|
@@ -117,6 +136,7 @@ class Inference:
|
|
|
117
136
|
self._custom_inference_settings = custom_inference_settings
|
|
118
137
|
|
|
119
138
|
self._use_gui = use_gui
|
|
139
|
+
self._use_serving_gui_template = use_serving_gui_template
|
|
120
140
|
self._gui = None
|
|
121
141
|
|
|
122
142
|
self.load_on_device = LOAD_ON_DEVICE_DECORATOR(self.load_on_device)
|
|
@@ -124,20 +144,48 @@ class Inference:
|
|
|
124
144
|
|
|
125
145
|
self.load_model = LOAD_MODEL_DECORATOR(self.load_model)
|
|
126
146
|
|
|
127
|
-
if
|
|
147
|
+
if self._use_gui:
|
|
128
148
|
initialize_custom_gui_method = getattr(self, "initialize_custom_gui", None)
|
|
129
149
|
original_initialize_custom_gui_method = getattr(
|
|
130
150
|
Inference, "initialize_custom_gui", None
|
|
131
151
|
)
|
|
132
|
-
if
|
|
152
|
+
if self._use_serving_gui_template:
|
|
153
|
+
if self.FRAMEWORK_NAME is None:
|
|
154
|
+
raise ValueError("FRAMEWORK_NAME is not defined")
|
|
155
|
+
self._gui = GUI.ServingGUITemplate(
|
|
156
|
+
self.FRAMEWORK_NAME, self.MODELS, self.APP_OPTIONS
|
|
157
|
+
)
|
|
158
|
+
self._user_layout = self._gui.widgets
|
|
159
|
+
self._user_layout_card = self._gui.card
|
|
160
|
+
elif initialize_custom_gui_method.__func__ is not original_initialize_custom_gui_method:
|
|
133
161
|
self._gui = GUI.ServingGUI()
|
|
134
162
|
self._user_layout = self.initialize_custom_gui()
|
|
135
163
|
else:
|
|
136
|
-
self
|
|
164
|
+
initialize_custom_gui_method = getattr(self, "initialize_custom_gui", None)
|
|
165
|
+
original_initialize_custom_gui_method = getattr(
|
|
166
|
+
Inference, "initialize_custom_gui", None
|
|
167
|
+
)
|
|
168
|
+
if (
|
|
169
|
+
initialize_custom_gui_method.__func__
|
|
170
|
+
is not original_initialize_custom_gui_method
|
|
171
|
+
):
|
|
172
|
+
self._gui = GUI.ServingGUI()
|
|
173
|
+
self._user_layout = self.initialize_custom_gui()
|
|
174
|
+
else:
|
|
175
|
+
self.initialize_gui()
|
|
137
176
|
|
|
138
|
-
def on_serve_callback(
|
|
177
|
+
def on_serve_callback(
|
|
178
|
+
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
|
|
179
|
+
):
|
|
139
180
|
Progress("Deploying model ...", 1)
|
|
140
|
-
if isinstance(self.gui, GUI.
|
|
181
|
+
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
182
|
+
deploy_params = self.get_params_from_gui()
|
|
183
|
+
model_files = self._download_model_files(
|
|
184
|
+
deploy_params["model_source"], deploy_params["model_files"]
|
|
185
|
+
)
|
|
186
|
+
deploy_params["model_files"] = model_files
|
|
187
|
+
self._load_model_headless(**deploy_params)
|
|
188
|
+
elif isinstance(self.gui, GUI.ServingGUI):
|
|
141
189
|
deploy_params = self.get_params_from_gui()
|
|
142
190
|
self._load_model(deploy_params)
|
|
143
191
|
else: # GUI.InferenceGUI
|
|
@@ -146,9 +194,11 @@ class Inference:
|
|
|
146
194
|
self.load_on_device(self._model_dir, device)
|
|
147
195
|
gui.show_deployed_model_info(self)
|
|
148
196
|
|
|
149
|
-
def on_change_model_callback(
|
|
197
|
+
def on_change_model_callback(
|
|
198
|
+
gui: Union[GUI.InferenceGUI, GUI.ServingGUI, GUI.ServingGUITemplate]
|
|
199
|
+
):
|
|
150
200
|
self.shutdown_model()
|
|
151
|
-
if isinstance(self.gui, GUI.ServingGUI):
|
|
201
|
+
if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
|
|
152
202
|
self._api_request_model_layout.unlock()
|
|
153
203
|
self._api_request_model_layout.hide()
|
|
154
204
|
self.update_gui(self._model_served)
|
|
@@ -198,7 +248,7 @@ class Inference:
|
|
|
198
248
|
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
199
249
|
|
|
200
250
|
def update_gui(self, is_model_deployed: bool = True) -> None:
|
|
201
|
-
if isinstance(self.gui, GUI.ServingGUI):
|
|
251
|
+
if isinstance(self.gui, (GUI.ServingGUI, GUI.ServingGUITemplate)):
|
|
202
252
|
if is_model_deployed:
|
|
203
253
|
self._user_layout_card.lock()
|
|
204
254
|
else:
|
|
@@ -211,6 +261,8 @@ class Inference:
|
|
|
211
261
|
self._api_request_model_layout.show()
|
|
212
262
|
|
|
213
263
|
def get_params_from_gui(self) -> dict:
|
|
264
|
+
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
265
|
+
return self.gui.get_params_from_gui()
|
|
214
266
|
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
215
267
|
|
|
216
268
|
def initialize_gui(self) -> None:
|
|
@@ -237,6 +289,25 @@ class Inference:
|
|
|
237
289
|
)
|
|
238
290
|
|
|
239
291
|
def _initialize_app_layout(self):
|
|
292
|
+
self._api_request_model_info = Editor(
|
|
293
|
+
height_lines=12,
|
|
294
|
+
language_mode="json",
|
|
295
|
+
readonly=True,
|
|
296
|
+
restore_default_button=False,
|
|
297
|
+
auto_format=True,
|
|
298
|
+
)
|
|
299
|
+
self._api_request_model_layout = Card(
|
|
300
|
+
title="Model was deployed from API request with the following settings",
|
|
301
|
+
content=self._api_request_model_info,
|
|
302
|
+
)
|
|
303
|
+
self._api_request_model_layout.hide()
|
|
304
|
+
|
|
305
|
+
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
306
|
+
self._app_layout = Container(
|
|
307
|
+
[self._user_layout_card, self._api_request_model_layout, self.get_ui()], gap=5
|
|
308
|
+
)
|
|
309
|
+
return
|
|
310
|
+
|
|
240
311
|
if hasattr(self, "_user_layout"):
|
|
241
312
|
self._user_layout_card = Card(
|
|
242
313
|
title="Select Model",
|
|
@@ -251,20 +322,9 @@ class Inference:
|
|
|
251
322
|
content=self._gui,
|
|
252
323
|
lock_message="Model is deployed. To change the model, stop the serving first.",
|
|
253
324
|
)
|
|
254
|
-
|
|
255
|
-
height_lines=12,
|
|
256
|
-
language_mode="json",
|
|
257
|
-
readonly=True,
|
|
258
|
-
restore_default_button=False,
|
|
259
|
-
auto_format=True,
|
|
260
|
-
)
|
|
261
|
-
self._api_request_model_layout = Card(
|
|
262
|
-
title="Model was deployed from API request with the following settings",
|
|
263
|
-
content=self._api_request_model_info,
|
|
264
|
-
)
|
|
265
|
-
self._api_request_model_layout.hide()
|
|
325
|
+
|
|
266
326
|
self._app_layout = Container(
|
|
267
|
-
[self._user_layout_card, self._api_request_model_layout, self.get_ui()]
|
|
327
|
+
[self._user_layout_card, self._api_request_model_layout, self.get_ui()], gap=5
|
|
268
328
|
)
|
|
269
329
|
|
|
270
330
|
def support_custom_models(self) -> bool:
|
|
@@ -427,7 +487,74 @@ class Inference:
|
|
|
427
487
|
def load_model_meta(self, model_tab: str, local_weights_path: str):
|
|
428
488
|
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
429
489
|
|
|
490
|
+
def _download_model_files(self, model_source: str, model_files: List[str]) -> dict:
|
|
491
|
+
if model_source == ModelSource.PRETRAINED:
|
|
492
|
+
return self._download_pretrained_model(model_files)
|
|
493
|
+
elif model_source == ModelSource.CUSTOM:
|
|
494
|
+
return self._download_custom_model(model_files)
|
|
495
|
+
|
|
496
|
+
def _download_pretrained_model(self, model_files: dict):
|
|
497
|
+
"""
|
|
498
|
+
Downloads the pretrained model data.
|
|
499
|
+
"""
|
|
500
|
+
local_model_files = {}
|
|
501
|
+
|
|
502
|
+
for file in model_files:
|
|
503
|
+
file_url = model_files[file]
|
|
504
|
+
file_path = os.path.join(self.model_dir, file)
|
|
505
|
+
if file_url.startswith("http"):
|
|
506
|
+
with urlopen(file_url) as f:
|
|
507
|
+
file_size = f.length
|
|
508
|
+
file_name = get_filename_from_headers(file_url)
|
|
509
|
+
file_path = os.path.join(self.model_dir, file_name)
|
|
510
|
+
if file_name is None:
|
|
511
|
+
file_name = file
|
|
512
|
+
with self.gui.download_progress(
|
|
513
|
+
message=f"Downloading: '{file_name}'",
|
|
514
|
+
total=file_size,
|
|
515
|
+
unit="bytes",
|
|
516
|
+
unit_scale=True,
|
|
517
|
+
) as download_pbar:
|
|
518
|
+
self.gui.download_progress.show()
|
|
519
|
+
sly_fs.download(
|
|
520
|
+
url=file_url, save_path=file_path, progress=download_pbar.update
|
|
521
|
+
)
|
|
522
|
+
local_model_files[file] = file_path
|
|
523
|
+
else:
|
|
524
|
+
local_model_files[file] = file_url
|
|
525
|
+
self.gui.download_progress.hide()
|
|
526
|
+
return local_model_files
|
|
527
|
+
|
|
528
|
+
def _download_custom_model(self, model_files: dict):
|
|
529
|
+
"""
|
|
530
|
+
Downloads the custom model data.
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
team_id = env.team_id()
|
|
534
|
+
local_model_files = {}
|
|
535
|
+
|
|
536
|
+
for file in model_files:
|
|
537
|
+
file_url = model_files[file]
|
|
538
|
+
file_info = self.api.file.get_info_by_path(team_id, file_url)
|
|
539
|
+
file_size = file_info.sizeb
|
|
540
|
+
file_name = os.path.basename(file_url)
|
|
541
|
+
file_path = os.path.join(self.model_dir, file_name)
|
|
542
|
+
with self.gui.download_progress(
|
|
543
|
+
message=f"Downloading: '{file_name}'",
|
|
544
|
+
total=file_size,
|
|
545
|
+
unit="bytes",
|
|
546
|
+
unit_scale=True,
|
|
547
|
+
) as download_pbar:
|
|
548
|
+
self.gui.download_progress.show()
|
|
549
|
+
self.api.file.download(
|
|
550
|
+
team_id, file_url, file_path, progress_cb=download_pbar.update
|
|
551
|
+
)
|
|
552
|
+
local_model_files[file] = file_path
|
|
553
|
+
self.gui.download_progress.hide()
|
|
554
|
+
return local_model_files
|
|
555
|
+
|
|
430
556
|
def _load_model(self, deploy_params: dict):
|
|
557
|
+
self.model_source = deploy_params.get("model_source")
|
|
431
558
|
self.device = deploy_params.get("device")
|
|
432
559
|
self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
|
|
433
560
|
self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
|
|
@@ -439,6 +566,64 @@ class Inference:
|
|
|
439
566
|
self.update_gui(self._model_served)
|
|
440
567
|
self.gui.show_deployed_model_info(self)
|
|
441
568
|
|
|
569
|
+
def _load_model_headless(
|
|
570
|
+
self,
|
|
571
|
+
model_files: dict,
|
|
572
|
+
model_source: str,
|
|
573
|
+
model_info: dict,
|
|
574
|
+
device: str,
|
|
575
|
+
runtime: str,
|
|
576
|
+
**kwargs,
|
|
577
|
+
):
|
|
578
|
+
deploy_params = {
|
|
579
|
+
"model_files": model_files,
|
|
580
|
+
"model_source": model_source,
|
|
581
|
+
"model_info": model_info,
|
|
582
|
+
"device": device,
|
|
583
|
+
"runtime": runtime,
|
|
584
|
+
**kwargs,
|
|
585
|
+
}
|
|
586
|
+
if model_source == ModelSource.CUSTOM:
|
|
587
|
+
self._set_model_meta_custom_model(model_info)
|
|
588
|
+
self._set_checkpoint_info_custom_model(deploy_params)
|
|
589
|
+
self._load_model(deploy_params)
|
|
590
|
+
if self._model_meta is None:
|
|
591
|
+
self._set_model_meta_from_classes()
|
|
592
|
+
|
|
593
|
+
def _set_model_meta_custom_model(self, model_info: dict):
|
|
594
|
+
model_meta = model_info.get("model_meta")
|
|
595
|
+
if model_meta is None:
|
|
596
|
+
return
|
|
597
|
+
if isinstance(model_meta, dict):
|
|
598
|
+
self._model_meta = ProjectMeta.from_json(model_meta)
|
|
599
|
+
elif isinstance(model_meta, str):
|
|
600
|
+
remote_artifacts_dir = model_info["artifacts_dir"]
|
|
601
|
+
model_meta_url = os.path.join(remote_artifacts_dir, model_meta)
|
|
602
|
+
model_meta_path = self.download(model_meta_url)
|
|
603
|
+
model_meta = sly_json.load_json_file(model_meta_path)
|
|
604
|
+
self._model_meta = ProjectMeta.from_json(model_meta)
|
|
605
|
+
else:
|
|
606
|
+
raise ValueError(
|
|
607
|
+
"model_meta should be a dict or a name of '.json' file in experiment artifacts folder in Team Files"
|
|
608
|
+
)
|
|
609
|
+
self._get_confidence_tag_meta()
|
|
610
|
+
self.classes = [obj_class.name for obj_class in self._model_meta.obj_classes]
|
|
611
|
+
|
|
612
|
+
def _set_checkpoint_info_custom_model(self, deploy_params: dict):
|
|
613
|
+
model_info = deploy_params.get("model_info", {})
|
|
614
|
+
model_files = deploy_params.get("model_files", {})
|
|
615
|
+
if model_info:
|
|
616
|
+
checkpoint_name = os.path.basename(model_files.get("checkpoint"))
|
|
617
|
+
self.checkpoint_info = CheckpointInfo(
|
|
618
|
+
checkpoint_name=checkpoint_name,
|
|
619
|
+
model_name=model_info.get("model_name"),
|
|
620
|
+
architecture=model_info.get("framework_name"),
|
|
621
|
+
custom_checkpoint_path=os.path.join(
|
|
622
|
+
model_info.get("artifacts_dir"), checkpoint_name
|
|
623
|
+
),
|
|
624
|
+
model_source=ModelSource.CUSTOM,
|
|
625
|
+
)
|
|
626
|
+
|
|
442
627
|
def shutdown_model(self):
|
|
443
628
|
self._model_served = False
|
|
444
629
|
self.device = None
|
|
@@ -453,7 +638,7 @@ class Inference:
|
|
|
453
638
|
pass
|
|
454
639
|
|
|
455
640
|
def get_classes(self) -> List[str]:
|
|
456
|
-
|
|
641
|
+
return self.classes
|
|
457
642
|
|
|
458
643
|
def get_info(self) -> Dict[str, Any]:
|
|
459
644
|
num_classes = None
|
|
@@ -550,6 +735,14 @@ class Inference:
|
|
|
550
735
|
self._model_meta = ProjectMeta(classes)
|
|
551
736
|
self._get_confidence_tag_meta()
|
|
552
737
|
|
|
738
|
+
def _set_model_meta_from_classes(self):
|
|
739
|
+
classes = self.get_classes()
|
|
740
|
+
if not classes:
|
|
741
|
+
raise ValueError("Can't create model meta. Please, set the `self.classes` attribute.")
|
|
742
|
+
shape = self._get_obj_class_shape()
|
|
743
|
+
self._model_meta = ProjectMeta([ObjClass(name, shape) for name in classes])
|
|
744
|
+
self._get_confidence_tag_meta()
|
|
745
|
+
|
|
553
746
|
@property
|
|
554
747
|
def task_id(self) -> int:
|
|
555
748
|
return self._task_id
|
|
@@ -2420,7 +2613,15 @@ class Inference:
|
|
|
2420
2613
|
self.shutdown_model()
|
|
2421
2614
|
state = request.state.state
|
|
2422
2615
|
deploy_params = state["deploy_params"]
|
|
2423
|
-
self.
|
|
2616
|
+
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
2617
|
+
model_files = self._download_model_files(
|
|
2618
|
+
deploy_params["model_source"], deploy_params["model_files"]
|
|
2619
|
+
)
|
|
2620
|
+
deploy_params["model_files"] = model_files
|
|
2621
|
+
self._load_model_headless(**deploy_params)
|
|
2622
|
+
elif isinstance(self.gui, GUI.ServingGUI):
|
|
2623
|
+
self._load_model(deploy_params)
|
|
2624
|
+
|
|
2424
2625
|
self.set_params_to_gui(deploy_params)
|
|
2425
2626
|
# update to set correct device
|
|
2426
2627
|
device = deploy_params.get("device", "cpu")
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from supervisely._utils import abs_url, is_debug_with_sly_net, is_development
|
|
2
|
+
from supervisely.app.widgets import Button, Card, ClassesTable, Container, Text
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ClassesSelector:
|
|
6
|
+
title = "Classes Selector"
|
|
7
|
+
description = (
|
|
8
|
+
"Select classes that will be used for training. "
|
|
9
|
+
"Supported shapes are Bitmap, Polygon, Rectangle."
|
|
10
|
+
)
|
|
11
|
+
lock_message = "Select training and validation splits to unlock"
|
|
12
|
+
|
|
13
|
+
def __init__(self, project_id: int, classes: list, app_options: dict = {}):
|
|
14
|
+
self.classes_table = ClassesTable(project_id=project_id) # use dataset_ids
|
|
15
|
+
if len(classes) > 0:
|
|
16
|
+
self.classes_table.select_classes(classes) # from app options
|
|
17
|
+
else:
|
|
18
|
+
self.classes_table.select_all()
|
|
19
|
+
|
|
20
|
+
if is_development() or is_debug_with_sly_net():
|
|
21
|
+
qa_stats_link = abs_url(f"projects/{project_id}/stats/datasets")
|
|
22
|
+
else:
|
|
23
|
+
qa_stats_link = f"/projects/{project_id}/stats/datasets"
|
|
24
|
+
|
|
25
|
+
qa_stats_text = Text(
|
|
26
|
+
text=f"<i class='zmdi zmdi-chart-donut' style='color: #7f858e'></i> <a href='{qa_stats_link}' target='_blank'> <b> QA & Stats </b></a>"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
self.validator_text = Text("")
|
|
30
|
+
self.validator_text.hide()
|
|
31
|
+
self.button = Button("Select")
|
|
32
|
+
container = Container(
|
|
33
|
+
[
|
|
34
|
+
qa_stats_text,
|
|
35
|
+
self.classes_table,
|
|
36
|
+
self.validator_text,
|
|
37
|
+
self.button,
|
|
38
|
+
]
|
|
39
|
+
)
|
|
40
|
+
self.card = Card(
|
|
41
|
+
title=self.title,
|
|
42
|
+
description=self.description,
|
|
43
|
+
content=container,
|
|
44
|
+
lock_message=self.lock_message,
|
|
45
|
+
collapsable=app_options.get("collapsable", False),
|
|
46
|
+
)
|
|
47
|
+
self.card.lock()
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def widgets_to_disable(self) -> list:
|
|
51
|
+
return [self.classes_table]
|
|
52
|
+
|
|
53
|
+
def get_selected_classes(self) -> list:
|
|
54
|
+
return self.classes_table.get_selected_classes()
|
|
55
|
+
|
|
56
|
+
def set_classes(self, classes) -> None:
|
|
57
|
+
self.classes_table.select_classes(classes)
|
|
58
|
+
|
|
59
|
+
def select_all_classes(self) -> None:
|
|
60
|
+
self.classes_table.select_all()
|
|
61
|
+
|
|
62
|
+
def validate_step(self) -> bool:
|
|
63
|
+
self.validator_text.hide()
|
|
64
|
+
|
|
65
|
+
if len(self.classes_table.project_meta.obj_classes) == 0:
|
|
66
|
+
self.validator_text.set(text="Project has no classes", status="error")
|
|
67
|
+
self.validator_text.show()
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
selected_classes = self.classes_table.get_selected_classes()
|
|
71
|
+
table_data = self.classes_table._table_data
|
|
72
|
+
|
|
73
|
+
empty_classes = [
|
|
74
|
+
row[0]["data"]
|
|
75
|
+
for row in table_data
|
|
76
|
+
if row[0]["data"] in selected_classes and row[2]["data"] == 0 and row[3]["data"] == 0
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
n_classes = len(selected_classes)
|
|
80
|
+
if n_classes == 0:
|
|
81
|
+
self.validator_text.set(text="Please select at least one class", status="error")
|
|
82
|
+
else:
|
|
83
|
+
warning_text = ""
|
|
84
|
+
status = "success"
|
|
85
|
+
if empty_classes:
|
|
86
|
+
intersections = set(selected_classes).intersection(empty_classes)
|
|
87
|
+
if intersections:
|
|
88
|
+
warning_text = (
|
|
89
|
+
f". Selected class has no annotations: {', '.join(intersections)}"
|
|
90
|
+
if len(intersections) == 1
|
|
91
|
+
else f". Selected classes have no annotations: {', '.join(intersections)}"
|
|
92
|
+
)
|
|
93
|
+
status = "warning"
|
|
94
|
+
|
|
95
|
+
class_text = "class" if n_classes == 1 else "classes"
|
|
96
|
+
self.validator_text.set(
|
|
97
|
+
text=f"Selected {n_classes} {class_text}{warning_text}", status=status
|
|
98
|
+
)
|
|
99
|
+
self.validator_text.show()
|
|
100
|
+
return n_classes > 0
|