supervisely 6.73.242__py3-none-any.whl → 6.73.244__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +1 -1
- supervisely/_utils.py +18 -0
- supervisely/app/widgets/__init__.py +1 -0
- supervisely/app/widgets/card/card.py +3 -0
- supervisely/app/widgets/classes_table/classes_table.py +15 -1
- supervisely/app/widgets/custom_models_selector/custom_models_selector.py +25 -7
- supervisely/app/widgets/custom_models_selector/template.html +1 -1
- supervisely/app/widgets/experiment_selector/__init__.py +0 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +500 -0
- supervisely/app/widgets/experiment_selector/style.css +27 -0
- supervisely/app/widgets/experiment_selector/template.html +82 -0
- supervisely/app/widgets/pretrained_models_selector/pretrained_models_selector.py +25 -3
- supervisely/app/widgets/random_splits_table/random_splits_table.py +41 -17
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +12 -5
- supervisely/app/widgets/train_val_splits/train_val_splits.py +99 -10
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/__init__.py +3 -1
- supervisely/nn/artifacts/artifacts.py +10 -0
- supervisely/nn/artifacts/detectron2.py +2 -0
- supervisely/nn/artifacts/hrda.py +3 -0
- supervisely/nn/artifacts/mmclassification.py +2 -0
- supervisely/nn/artifacts/mmdetection.py +6 -3
- supervisely/nn/artifacts/mmsegmentation.py +2 -0
- supervisely/nn/artifacts/ritm.py +3 -1
- supervisely/nn/artifacts/rtdetr.py +2 -0
- supervisely/nn/artifacts/unet.py +2 -0
- supervisely/nn/artifacts/yolov5.py +3 -0
- supervisely/nn/artifacts/yolov8.py +7 -1
- supervisely/nn/experiments.py +113 -0
- supervisely/nn/inference/gui/__init__.py +3 -1
- supervisely/nn/inference/gui/gui.py +31 -232
- supervisely/nn/inference/gui/serving_gui.py +223 -0
- supervisely/nn/inference/gui/serving_gui_template.py +240 -0
- supervisely/nn/inference/inference.py +225 -24
- supervisely/nn/training/__init__.py +0 -0
- supervisely/nn/training/gui/__init__.py +1 -0
- supervisely/nn/training/gui/classes_selector.py +100 -0
- supervisely/nn/training/gui/gui.py +539 -0
- supervisely/nn/training/gui/hyperparameters_selector.py +117 -0
- supervisely/nn/training/gui/input_selector.py +70 -0
- supervisely/nn/training/gui/model_selector.py +95 -0
- supervisely/nn/training/gui/train_val_splits_selector.py +200 -0
- supervisely/nn/training/gui/training_logs.py +93 -0
- supervisely/nn/training/gui/training_process.py +114 -0
- supervisely/nn/training/gui/utils.py +128 -0
- supervisely/nn/training/loggers/__init__.py +0 -0
- supervisely/nn/training/loggers/base_train_logger.py +58 -0
- supervisely/nn/training/loggers/tensorboard_logger.py +46 -0
- supervisely/nn/training/train_app.py +2038 -0
- supervisely/nn/utils.py +5 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/METADATA +3 -1
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/RECORD +57 -35
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/LICENSE +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/WHEEL +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.242.dist-info → supervisely-6.73.244.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
from typing import Callable, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import yaml
|
|
5
|
+
|
|
6
|
+
import supervisely.app.widgets as Widgets
|
|
7
|
+
from supervisely.sly_logger import logger
|
|
8
|
+
from supervisely.task.progress import Progress
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ServingGUI:
|
|
12
|
+
def __init__(self) -> None:
|
|
13
|
+
self._device_select = Widgets.SelectCudaDevice(include_cpu_option=True)
|
|
14
|
+
self._device_field = Widgets.Field(self._device_select, title="Device")
|
|
15
|
+
self._serve_button = Widgets.Button("SERVE")
|
|
16
|
+
self._success_label = Widgets.DoneLabel()
|
|
17
|
+
self._success_label.hide()
|
|
18
|
+
self._download_progress = Widgets.Progress("Downloading model...", hide_on_finish=True)
|
|
19
|
+
self._download_progress.hide()
|
|
20
|
+
self._change_model_button = Widgets.Button(
|
|
21
|
+
"STOP AND CHOOSE ANOTHER MODEL", button_type="danger"
|
|
22
|
+
)
|
|
23
|
+
self._change_model_button.hide()
|
|
24
|
+
|
|
25
|
+
self.serve_container = Widgets.Container(
|
|
26
|
+
[
|
|
27
|
+
self._device_field,
|
|
28
|
+
self._download_progress,
|
|
29
|
+
self._success_label,
|
|
30
|
+
self._serve_button,
|
|
31
|
+
self._change_model_button,
|
|
32
|
+
],
|
|
33
|
+
)
|
|
34
|
+
self.serve_model_card = Widgets.Card(
|
|
35
|
+
title="Serve Model",
|
|
36
|
+
description="Download and deploy the model on the selected device.",
|
|
37
|
+
content=self.serve_container,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self._model_inference_settings_widget = Widgets.Editor(
|
|
41
|
+
readonly=True, restore_default_button=False
|
|
42
|
+
)
|
|
43
|
+
self._model_inference_settings_container = Widgets.Field(
|
|
44
|
+
self._model_inference_settings_widget,
|
|
45
|
+
title="Inference settings",
|
|
46
|
+
description="Model allows user to configure the following parameters on prediction phase",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
self._model_info_widget = Widgets.ModelInfo()
|
|
50
|
+
self._model_info_widget_container = Widgets.Field(
|
|
51
|
+
self._model_info_widget,
|
|
52
|
+
title="Session Info",
|
|
53
|
+
description="Basic information about the deployed model",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self._model_classes_widget = Widgets.ClassesTable(selectable=False)
|
|
57
|
+
self._model_classes_plug = Widgets.Text("No classes provided")
|
|
58
|
+
self._model_classes_widget_container = Widgets.Field(
|
|
59
|
+
content=Widgets.Container([self._model_classes_widget, self._model_classes_plug]),
|
|
60
|
+
title="Model classes",
|
|
61
|
+
description="List of classes model predicts",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self._model_full_info = Widgets.Container(
|
|
65
|
+
[
|
|
66
|
+
self._model_info_widget_container,
|
|
67
|
+
self._model_inference_settings_container,
|
|
68
|
+
self._model_classes_widget_container,
|
|
69
|
+
]
|
|
70
|
+
)
|
|
71
|
+
self._model_full_info.hide()
|
|
72
|
+
self._before_deploy_msg = Widgets.Text("Deploy model to see the information.")
|
|
73
|
+
|
|
74
|
+
self._model_full_info_card = Widgets.Card(
|
|
75
|
+
title="Full model info",
|
|
76
|
+
description="Inference settings, session parameters and model classes",
|
|
77
|
+
collapsable=True,
|
|
78
|
+
content=Widgets.Container(
|
|
79
|
+
[
|
|
80
|
+
self._model_full_info,
|
|
81
|
+
self._before_deploy_msg,
|
|
82
|
+
]
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self._model_full_info_card.collapse()
|
|
87
|
+
self._additional_ui_content = []
|
|
88
|
+
self.get_ui = self.__add_content_and_model_info_to_default_ui(
|
|
89
|
+
self._model_full_info_card
|
|
90
|
+
) # pylint: disable=method-hidden
|
|
91
|
+
|
|
92
|
+
self.on_change_model_callbacks: List[Callable] = [ServingGUI._hide_info_after_change]
|
|
93
|
+
self.on_serve_callbacks: List[Callable] = []
|
|
94
|
+
|
|
95
|
+
@self.serve_button.click
|
|
96
|
+
def serve_model():
|
|
97
|
+
self.deploy_with_current_params()
|
|
98
|
+
|
|
99
|
+
@self._change_model_button.click
|
|
100
|
+
def change_model():
|
|
101
|
+
for cb in self.on_change_model_callbacks:
|
|
102
|
+
cb(self)
|
|
103
|
+
self.change_model()
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def serve_button(self) -> Widgets.Button:
|
|
107
|
+
return self._serve_button
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def download_progress(self) -> Widgets.Progress:
|
|
111
|
+
return self._download_progress
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def device(self) -> str:
|
|
115
|
+
return self._device_select.get_device()
|
|
116
|
+
|
|
117
|
+
def get_device(self) -> str:
|
|
118
|
+
return self._device_select.get_device()
|
|
119
|
+
|
|
120
|
+
def deploy_with_current_params(self):
|
|
121
|
+
for cb in self.on_serve_callbacks:
|
|
122
|
+
cb(self)
|
|
123
|
+
self.set_deployed()
|
|
124
|
+
|
|
125
|
+
def change_model(self):
|
|
126
|
+
self._success_label.text = ""
|
|
127
|
+
self._success_label.hide()
|
|
128
|
+
self._serve_button.show()
|
|
129
|
+
self._device_select._select.enable()
|
|
130
|
+
self._device_select.enable()
|
|
131
|
+
self._change_model_button.hide()
|
|
132
|
+
Progress("model deployment canceled", 1).iter_done_report()
|
|
133
|
+
|
|
134
|
+
def _hide_info_after_change(self):
|
|
135
|
+
self._model_full_info_card.collapse()
|
|
136
|
+
self._model_full_info.hide()
|
|
137
|
+
self._before_deploy_msg.show()
|
|
138
|
+
|
|
139
|
+
def set_deployed(self, device: str = None):
|
|
140
|
+
if device is not None:
|
|
141
|
+
self._device_select.set_device(device)
|
|
142
|
+
self._success_label.text = f"Model has been successfully loaded on {self._device_select.get_device().upper()} device"
|
|
143
|
+
self._success_label.show()
|
|
144
|
+
self._serve_button.hide()
|
|
145
|
+
self._device_select._select.disable()
|
|
146
|
+
self._device_select.disable()
|
|
147
|
+
self._change_model_button.show()
|
|
148
|
+
Progress("Model deployed", 1).iter_done_report()
|
|
149
|
+
|
|
150
|
+
def show_deployed_model_info(self, inference):
|
|
151
|
+
self.set_inference_settings(inference)
|
|
152
|
+
self.set_project_meta(inference)
|
|
153
|
+
self.set_model_info(inference)
|
|
154
|
+
self._before_deploy_msg.hide()
|
|
155
|
+
self._model_full_info.show()
|
|
156
|
+
self._model_full_info_card.uncollapse()
|
|
157
|
+
|
|
158
|
+
def set_inference_settings(self, inference):
|
|
159
|
+
if len(inference.custom_inference_settings_dict.keys()) == 0:
|
|
160
|
+
inference_settings_str = "# inference settings dict is empty"
|
|
161
|
+
else:
|
|
162
|
+
inference_settings_str = yaml.dump(inference.custom_inference_settings_dict)
|
|
163
|
+
self._model_inference_settings_widget.set_text(inference_settings_str, "yaml")
|
|
164
|
+
self._model_inference_settings_widget.show()
|
|
165
|
+
|
|
166
|
+
def set_project_meta(self, inference):
|
|
167
|
+
if self._get_classes_from_inference(inference) is None:
|
|
168
|
+
logger.warn("Skip loading project meta.")
|
|
169
|
+
self._model_classes_widget.hide()
|
|
170
|
+
self._model_classes_plug.show()
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
self._model_classes_widget.set_project_meta(inference.model_meta)
|
|
174
|
+
self._model_classes_plug.hide()
|
|
175
|
+
self._model_classes_widget.show()
|
|
176
|
+
|
|
177
|
+
def set_model_info(self, inference):
|
|
178
|
+
info = inference.get_human_readable_info(replace_none_with="Not provided")
|
|
179
|
+
self._model_info_widget.set_model_info(inference.task_id, info)
|
|
180
|
+
|
|
181
|
+
def _get_classes_from_inference(self, inference) -> Optional[List[str]]:
|
|
182
|
+
classes = None
|
|
183
|
+
try:
|
|
184
|
+
classes = inference.get_classes()
|
|
185
|
+
except NotImplementedError:
|
|
186
|
+
logger.warn(f"get_classes() function not implemented for {type(inference)} object.")
|
|
187
|
+
except AttributeError:
|
|
188
|
+
logger.warn("Probably, get_classes() function not working without model deploy.")
|
|
189
|
+
except Exception as exc:
|
|
190
|
+
logger.warn("Skip getting classes info due to exception")
|
|
191
|
+
logger.exception(exc)
|
|
192
|
+
|
|
193
|
+
if classes is None or len(classes) == 0:
|
|
194
|
+
logger.warn(f"get_classes() function return {classes}; skip classes processing.")
|
|
195
|
+
return None
|
|
196
|
+
return classes
|
|
197
|
+
|
|
198
|
+
def get_ui(self) -> Widgets.Widget: # pylint: disable=method-hidden
|
|
199
|
+
return Widgets.Container([self.serve_model_card])
|
|
200
|
+
|
|
201
|
+
def add_content_to_default_ui(
|
|
202
|
+
self, widgets: Union[Widgets.Widget, List[Widgets.Widget]]
|
|
203
|
+
) -> None:
|
|
204
|
+
if isinstance(widgets, List):
|
|
205
|
+
self._additional_ui_content.extend(widgets)
|
|
206
|
+
else:
|
|
207
|
+
self._additional_ui_content.append(widgets)
|
|
208
|
+
|
|
209
|
+
def __add_content_and_model_info_to_default_ui(
|
|
210
|
+
self,
|
|
211
|
+
model_info_widget: Widgets.Widget,
|
|
212
|
+
) -> Callable:
|
|
213
|
+
def decorator(get_ui):
|
|
214
|
+
@wraps(get_ui)
|
|
215
|
+
def wrapper(*args, **kwargs):
|
|
216
|
+
ui = get_ui(*args, **kwargs)
|
|
217
|
+
content = [ui, *self._additional_ui_content, model_info_widget]
|
|
218
|
+
ui_with_info = Widgets.Container(content)
|
|
219
|
+
return ui_with_info
|
|
220
|
+
|
|
221
|
+
return wrapper
|
|
222
|
+
|
|
223
|
+
return decorator(self.get_ui)
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from os.path import join
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import yaml
|
|
5
|
+
|
|
6
|
+
import supervisely.io.env as sly_env
|
|
7
|
+
import supervisely.io.fs as sly_fs
|
|
8
|
+
import supervisely.io.json as sly_json
|
|
9
|
+
from supervisely import Api
|
|
10
|
+
from supervisely.app.widgets import (
|
|
11
|
+
Card,
|
|
12
|
+
Container,
|
|
13
|
+
Field,
|
|
14
|
+
RadioTabs,
|
|
15
|
+
SelectString,
|
|
16
|
+
Widget,
|
|
17
|
+
)
|
|
18
|
+
from supervisely.app.widgets.experiment_selector.experiment_selector import (
|
|
19
|
+
ExperimentSelector,
|
|
20
|
+
)
|
|
21
|
+
from supervisely.app.widgets.pretrained_models_selector.pretrained_models_selector import (
|
|
22
|
+
PretrainedModelsSelector,
|
|
23
|
+
)
|
|
24
|
+
from supervisely.nn.experiments import get_experiment_infos
|
|
25
|
+
from supervisely.nn.inference.gui.serving_gui import ServingGUI
|
|
26
|
+
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ServingGUITemplate(ServingGUI):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
framework_name: str,
|
|
33
|
+
models: Optional[str] = None,
|
|
34
|
+
app_options: Optional[str] = None,
|
|
35
|
+
):
|
|
36
|
+
if not isinstance(framework_name, str):
|
|
37
|
+
raise ValueError("'framework_name' must be a string")
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.api = Api.from_env()
|
|
41
|
+
self.team_id = sly_env.team_id()
|
|
42
|
+
|
|
43
|
+
self.framework_name = framework_name
|
|
44
|
+
self.models = self._load_models(models) if models else []
|
|
45
|
+
self.app_options = self._load_app_options(app_options) if app_options else {}
|
|
46
|
+
|
|
47
|
+
base_widgets = self._initialize_layout()
|
|
48
|
+
extra_widgets = self._initialize_extra_widgets()
|
|
49
|
+
|
|
50
|
+
self.widgets = base_widgets + extra_widgets
|
|
51
|
+
self.card = self._get_card()
|
|
52
|
+
|
|
53
|
+
def _get_card(self) -> Card:
|
|
54
|
+
return Card(
|
|
55
|
+
title="Select Model",
|
|
56
|
+
description="Select the model to deploy and press the 'Serve' button.",
|
|
57
|
+
content=Container(widgets=self.widgets, gap=10),
|
|
58
|
+
overflow="unset",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def _initialize_layout(self) -> List[Widget]:
|
|
62
|
+
# Pretrained models
|
|
63
|
+
use_pretrained_models = self.app_options.get("pretrained_models", True)
|
|
64
|
+
use_custom_models = self.app_options.get("custom_models", True)
|
|
65
|
+
|
|
66
|
+
if not use_pretrained_models and not use_custom_models:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"At least one of 'pretrained_models' or 'custom_models' must be enabled."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if use_pretrained_models and self.models is not None:
|
|
72
|
+
self.pretrained_models_table = PretrainedModelsSelector(self.models)
|
|
73
|
+
else:
|
|
74
|
+
self.pretrained_models_table = None
|
|
75
|
+
|
|
76
|
+
# Custom models
|
|
77
|
+
if use_custom_models:
|
|
78
|
+
experiments = get_experiment_infos(self.api, self.team_id, self.framework_name)
|
|
79
|
+
self.experiment_selector = ExperimentSelector(self.team_id, experiments)
|
|
80
|
+
else:
|
|
81
|
+
self.experiment_selector = None
|
|
82
|
+
|
|
83
|
+
# Tabs
|
|
84
|
+
tabs = []
|
|
85
|
+
if self.pretrained_models_table is not None:
|
|
86
|
+
tabs.append(
|
|
87
|
+
(
|
|
88
|
+
ModelSource.PRETRAINED,
|
|
89
|
+
"Publicly available models",
|
|
90
|
+
self.pretrained_models_table,
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
if self.experiment_selector is not None:
|
|
94
|
+
tabs.append(
|
|
95
|
+
(
|
|
96
|
+
ModelSource.CUSTOM,
|
|
97
|
+
"Models trained in Supervisely",
|
|
98
|
+
self.experiment_selector,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
if tabs:
|
|
102
|
+
titles, descriptions, content = zip(*tabs)
|
|
103
|
+
self.model_source_tabs = RadioTabs(
|
|
104
|
+
titles=titles,
|
|
105
|
+
descriptions=descriptions,
|
|
106
|
+
contents=content,
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
self.model_source_tabs = None
|
|
110
|
+
|
|
111
|
+
# Runtime
|
|
112
|
+
default_runtime = RuntimeType.PYTORCH
|
|
113
|
+
available_runtimes = {
|
|
114
|
+
value.lower(): value
|
|
115
|
+
for name, value in vars(RuntimeType).items()
|
|
116
|
+
if not name.startswith("__") # exclude private attributes
|
|
117
|
+
}
|
|
118
|
+
supported_runtimes_input = self.app_options.get("supported_runtimes", [default_runtime])
|
|
119
|
+
supported_runtimes = [
|
|
120
|
+
available_runtimes[runtime.lower()]
|
|
121
|
+
for runtime in supported_runtimes_input
|
|
122
|
+
if runtime.lower() in available_runtimes
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
if len(supported_runtimes) > 1:
|
|
126
|
+
self.runtime_select = SelectString(supported_runtimes)
|
|
127
|
+
runtime_field = Field(self.runtime_select, "Runtime", "Select a runtime for inference.")
|
|
128
|
+
else:
|
|
129
|
+
self.runtime_select = None
|
|
130
|
+
runtime_field = None
|
|
131
|
+
|
|
132
|
+
# Layout
|
|
133
|
+
card_widgets = [self.model_source_tabs]
|
|
134
|
+
if runtime_field is not None:
|
|
135
|
+
card_widgets.append(runtime_field)
|
|
136
|
+
return card_widgets
|
|
137
|
+
|
|
138
|
+
def _initialize_extra_widgets(self) -> List[Widget]:
|
|
139
|
+
return []
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def model_source(self) -> str:
|
|
143
|
+
return self.model_source_tabs.get_active_tab()
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def model_info(self) -> Dict[str, Any]:
|
|
147
|
+
return self._get_selected_row()
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def model_name(self) -> Optional[str]:
|
|
151
|
+
if self.model_source == ModelSource.PRETRAINED:
|
|
152
|
+
model_meta = self.model_info.get("meta", {})
|
|
153
|
+
return model_meta.get("model_name")
|
|
154
|
+
else:
|
|
155
|
+
return self.model_info.get("model_name")
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def model_files(self) -> List[str]:
|
|
159
|
+
if self.model_source == ModelSource.PRETRAINED:
|
|
160
|
+
model_meta = self.model_info.get("meta", {})
|
|
161
|
+
return model_meta.get("model_files", {})
|
|
162
|
+
else:
|
|
163
|
+
return self.experiment_selector.get_model_files()
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def runtime(self) -> str:
|
|
167
|
+
if self.runtime_select is not None:
|
|
168
|
+
return self.runtime_select.get_value()
|
|
169
|
+
return RuntimeType.PYTORCH
|
|
170
|
+
|
|
171
|
+
def get_params_from_gui(self) -> Dict[str, Any]:
|
|
172
|
+
return {
|
|
173
|
+
"model_source": self.model_source,
|
|
174
|
+
"model_files": self.model_files,
|
|
175
|
+
"model_info": self.model_info,
|
|
176
|
+
"device": self.device,
|
|
177
|
+
"runtime": self.runtime,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Loaders
|
|
181
|
+
def _load_models(self, models: str) -> List[Dict[str, Any]]:
|
|
182
|
+
"""
|
|
183
|
+
Loads models from the provided file or list of model configurations.
|
|
184
|
+
"""
|
|
185
|
+
if isinstance(models, str):
|
|
186
|
+
if sly_fs.file_exists(models) and sly_fs.get_file_ext(models) == ".json":
|
|
187
|
+
models = sly_json.load_json_file(models)
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError("File not found or invalid file format.")
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"Invalid models file. Please provide a valid '.json' file with list of model configurations."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if not isinstance(models, list):
|
|
196
|
+
raise ValueError("models parameters must be a list of dicts")
|
|
197
|
+
for item in models:
|
|
198
|
+
if not isinstance(item, dict):
|
|
199
|
+
raise ValueError(f"Each item in models must be a dict.")
|
|
200
|
+
model_meta = item.get("meta")
|
|
201
|
+
if model_meta is None:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"Model metadata not found. Please update provided models parameter to include key 'meta'."
|
|
204
|
+
)
|
|
205
|
+
model_files = model_meta.get("model_files")
|
|
206
|
+
if model_files is None:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"Model files not found in model metadata. "
|
|
209
|
+
"Please update provided models oarameter to include key 'model_files' in 'meta' key."
|
|
210
|
+
)
|
|
211
|
+
return models
|
|
212
|
+
|
|
213
|
+
def _load_app_options(self, app_options: str = None) -> Dict[str, Any]:
|
|
214
|
+
"""
|
|
215
|
+
Loads the app_options parameter to ensure it is in the correct format.
|
|
216
|
+
"""
|
|
217
|
+
if app_options is None:
|
|
218
|
+
return {}
|
|
219
|
+
|
|
220
|
+
if isinstance(app_options, str):
|
|
221
|
+
if sly_fs.file_exists(app_options) and sly_fs.get_file_ext(app_options) in [
|
|
222
|
+
".yaml",
|
|
223
|
+
".yml",
|
|
224
|
+
]:
|
|
225
|
+
with open(app_options, "r") as file:
|
|
226
|
+
app_options = yaml.safe_load(file)
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
"Invalid app_options file provided. Please provide a valid '.yaml' or '.yml' file with app_options."
|
|
230
|
+
)
|
|
231
|
+
if not isinstance(app_options, dict):
|
|
232
|
+
raise ValueError("app_options must be a dict")
|
|
233
|
+
return app_options
|
|
234
|
+
|
|
235
|
+
def _get_selected_row(self) -> Dict[str, Any]:
|
|
236
|
+
if self.model_source == ModelSource.PRETRAINED and self.pretrained_models_table:
|
|
237
|
+
return self.pretrained_models_table.get_selected_row()
|
|
238
|
+
elif self.model_source == ModelSource.CUSTOM and self.experiment_selector:
|
|
239
|
+
return self.experiment_selector.get_selected_experiment_info()
|
|
240
|
+
return {}
|