supervisely 6.73.420__py3-none-any.whl → 6.73.421__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/api.py +10 -5
- supervisely/api/app_api.py +71 -4
- supervisely/api/module_api.py +4 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/project_api.py +35 -6
- supervisely/api/task_api.py +5 -1
- supervisely/app/widgets/__init__.py +8 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/deploy_model/__init__.py +0 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
- supervisely/app/widgets/fast_table/fast_table.py +402 -74
- supervisely/app/widgets/fast_table/script.js +364 -96
- supervisely/app/widgets/fast_table/style.css +24 -0
- supervisely/app/widgets/fast_table/template.html +43 -3
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +160 -94
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
- supervisely/nn/inference/predict_app/gui/gui.py +710 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +282 -0
- supervisely/nn/inference/predict_app/predict_app.py +184 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/prediction.py +2 -0
- supervisely/nn/model/prediction_session.py +20 -3
- supervisely/nn/training/gui/gui.py +131 -44
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
- supervisely/nn/training/gui/training_artifacts.py +0 -5
- supervisely/nn/training/train_app.py +161 -44
- supervisely/template/experiment/experiment.html.jinja +74 -17
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/RECORD +74 -56
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import tempfile
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Literal
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from supervisely._utils import logger
|
|
10
|
+
from supervisely.api.api import Api
|
|
11
|
+
from supervisely.api.app_api import ModuleInfo
|
|
12
|
+
from supervisely.app.widgets.agent_selector.agent_selector import AgentSelector
|
|
13
|
+
from supervisely.app.widgets.button.button import Button
|
|
14
|
+
from supervisely.app.widgets.container.container import Container
|
|
15
|
+
from supervisely.app.widgets.card.card import Card
|
|
16
|
+
from supervisely.app.widgets.model_info.model_info import ModelInfo
|
|
17
|
+
from supervisely.app.widgets.ecosystem_model_selector.ecosystem_model_selector import (
|
|
18
|
+
EcosystemModelSelector,
|
|
19
|
+
)
|
|
20
|
+
from supervisely.app.widgets.experiment_selector.experiment_selector import (
|
|
21
|
+
ExperimentSelector,
|
|
22
|
+
)
|
|
23
|
+
from supervisely.app.widgets.fast_table.fast_table import FastTable
|
|
24
|
+
from supervisely.app.widgets.field.field import Field
|
|
25
|
+
from supervisely.app.widgets.flexbox.flexbox import Flexbox
|
|
26
|
+
from supervisely.app.widgets.tabs.tabs import Tabs
|
|
27
|
+
from supervisely.app.widgets.text.text import Text
|
|
28
|
+
from supervisely.app.widgets.widget import Widget
|
|
29
|
+
from supervisely.io import env
|
|
30
|
+
from supervisely.nn.experiments import ExperimentInfo, get_experiment_infos
|
|
31
|
+
from supervisely.nn.model.model_api import ModelAPI
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DeployModel(Widget):
|
|
35
|
+
|
|
36
|
+
class DeployMode:
|
|
37
|
+
|
|
38
|
+
def deploy(self, agent_id: int = None) -> ModelAPI:
|
|
39
|
+
raise NotImplementedError("This method should be implemented in subclasses.")
|
|
40
|
+
|
|
41
|
+
def get_deploy_parameters(self) -> Dict[str, Any]:
|
|
42
|
+
raise NotImplementedError("This method should be implemented in subclasses.")
|
|
43
|
+
|
|
44
|
+
def load_from_json(self, data: Dict[str, Any]) -> None:
|
|
45
|
+
raise NotImplementedError("This method should be implemented in subclasses.")
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def layout(self) -> Widget:
|
|
49
|
+
raise NotImplementedError("This property should be implemented in subclasses.")
|
|
50
|
+
|
|
51
|
+
class Connect(DeployMode):
|
|
52
|
+
|
|
53
|
+
class COLUMN:
|
|
54
|
+
SESSION_ID = "Session ID"
|
|
55
|
+
APP_NAME = "App Name"
|
|
56
|
+
FRAMEWORK = "Framework"
|
|
57
|
+
MODEL = "Model"
|
|
58
|
+
|
|
59
|
+
COLUMNS = [
|
|
60
|
+
str(COLUMN.SESSION_ID),
|
|
61
|
+
str(COLUMN.APP_NAME),
|
|
62
|
+
str(COLUMN.FRAMEWORK),
|
|
63
|
+
str(COLUMN.MODEL),
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
def __init__(self, deploy_model: "DeployModel"):
|
|
67
|
+
self.api = deploy_model.api
|
|
68
|
+
self.team_id = deploy_model.team_id
|
|
69
|
+
self._cache = deploy_model._cache
|
|
70
|
+
self.deploy_model = deploy_model
|
|
71
|
+
self._layout = self._create_layout()
|
|
72
|
+
self._update_sessions()
|
|
73
|
+
|
|
74
|
+
def _create_layout(self) -> Container:
|
|
75
|
+
self.refresh_button = Button(
|
|
76
|
+
"",
|
|
77
|
+
icon="zmdi zmdi-refresh",
|
|
78
|
+
button_type="text",
|
|
79
|
+
)
|
|
80
|
+
self.sessions_table = FastTable(
|
|
81
|
+
columns=self.COLUMNS,
|
|
82
|
+
page_size=10,
|
|
83
|
+
is_radio=True,
|
|
84
|
+
header_left_content=self.refresh_button,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@self.refresh_button.click
|
|
88
|
+
def _refresh_button_clicked():
|
|
89
|
+
self._update_sessions()
|
|
90
|
+
|
|
91
|
+
return self.sessions_table
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def layout(self) -> FastTable:
|
|
95
|
+
return self._layout
|
|
96
|
+
|
|
97
|
+
def _data_from_session(self, session: Dict) -> Dict[str, Any]:
|
|
98
|
+
task_info = session["task_info"]
|
|
99
|
+
deploy_info = session["model_info"]
|
|
100
|
+
return {
|
|
101
|
+
self.COLUMN.SESSION_ID: task_info["id"],
|
|
102
|
+
self.COLUMN.APP_NAME: task_info["meta"]["app"]["name"],
|
|
103
|
+
self.COLUMN.FRAMEWORK: self.deploy_model._framework_from_task_info(task_info),
|
|
104
|
+
self.COLUMN.MODEL: deploy_info["model_name"],
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
def _update_sessions(self) -> None:
|
|
108
|
+
self.sessions_table.loading = True
|
|
109
|
+
try:
|
|
110
|
+
self.sessions_table.clear()
|
|
111
|
+
sessions = self.api.nn.list_deployed_models(team_id=self.team_id)
|
|
112
|
+
data = [self._data_from_session(session) for session in sessions]
|
|
113
|
+
df = pd.DataFrame.from_records(data=data, columns=self.COLUMNS)
|
|
114
|
+
self.sessions_table.read_pandas(df)
|
|
115
|
+
if len(data) == 0:
|
|
116
|
+
self.deploy_model.connect_button.disable()
|
|
117
|
+
else:
|
|
118
|
+
self.deploy_model.connect_button.enable()
|
|
119
|
+
except Exception as e:
|
|
120
|
+
logger.error(
|
|
121
|
+
f"Failed to load deployed models: {e}",
|
|
122
|
+
exc_info=True,
|
|
123
|
+
)
|
|
124
|
+
finally:
|
|
125
|
+
self.sessions_table.loading = False
|
|
126
|
+
|
|
127
|
+
def deploy(self, agent_id: int = None) -> ModelAPI:
|
|
128
|
+
deploy_parameters = self.get_deploy_parameters()
|
|
129
|
+
logger.info(f"Connecting to model with parameters:", extra=deploy_parameters)
|
|
130
|
+
session_id = deploy_parameters["session_id"]
|
|
131
|
+
model_api = self.api.nn.connect(task_id=session_id)
|
|
132
|
+
return model_api
|
|
133
|
+
|
|
134
|
+
def get_deploy_parameters(self) -> Dict[str, Any]:
|
|
135
|
+
selected_row = self.sessions_table.get_selected_row()
|
|
136
|
+
return {
|
|
137
|
+
"session_id": selected_row.row[self.COLUMNS.index(str(self.COLUMN.SESSION_ID))],
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
def load_from_json(self, data: Dict):
|
|
141
|
+
session_id = data["session_id"]
|
|
142
|
+
self._update_sessions()
|
|
143
|
+
self.sessions_table.select_row_by_value(str(self.COLUMN.SESSION_ID), session_id)
|
|
144
|
+
|
|
145
|
+
class Pretrained(DeployMode):
|
|
146
|
+
class COLUMN:
|
|
147
|
+
# TODO: columns are the same as in EcosystemModelSelector, make a common base class
|
|
148
|
+
FRAMEWORK = "Framework"
|
|
149
|
+
MODEL_NAME = "Model"
|
|
150
|
+
TASK_TYPE = "Task Type"
|
|
151
|
+
PARAMETERS = "Parameters (M)"
|
|
152
|
+
# TODO: support metrics for different tasks
|
|
153
|
+
MAP = "mAP"
|
|
154
|
+
|
|
155
|
+
COLUMNS = [
|
|
156
|
+
str(COLUMN.FRAMEWORK),
|
|
157
|
+
str(COLUMN.MODEL_NAME),
|
|
158
|
+
str(COLUMN.TASK_TYPE),
|
|
159
|
+
str(COLUMN.PARAMETERS),
|
|
160
|
+
str(COLUMN.MAP),
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
def __init__(self, deploy_model: "DeployModel"):
|
|
164
|
+
self.api = deploy_model.api
|
|
165
|
+
self.team_id = deploy_model.team_id
|
|
166
|
+
self._cache = deploy_model._cache
|
|
167
|
+
self.deploy_model = deploy_model
|
|
168
|
+
self._model_api = None
|
|
169
|
+
self._last_selected_framework = None
|
|
170
|
+
self._layout = self._create_layout()
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def layout(self) -> FastTable:
|
|
174
|
+
return self._layout
|
|
175
|
+
|
|
176
|
+
def _create_layout(self) -> Container:
|
|
177
|
+
self.model_selector = EcosystemModelSelector(api=self.api)
|
|
178
|
+
return self.model_selector
|
|
179
|
+
|
|
180
|
+
def get_deploy_parameters(self) -> Dict[str, Any]:
|
|
181
|
+
selected_model = self.model_selector.get_selected()
|
|
182
|
+
return {
|
|
183
|
+
"framework": selected_model["framework"],
|
|
184
|
+
"model_name": selected_model["name"],
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
def load_from_json(self, data: Dict[str, Any]) -> None:
|
|
188
|
+
framework = data["framework"]
|
|
189
|
+
model_name = data["model_name"]
|
|
190
|
+
self.model_selector.select_framework_and_model_name(framework, model_name)
|
|
191
|
+
|
|
192
|
+
def deploy(self, agent_id: int = None) -> ModelAPI:
|
|
193
|
+
deploy_parameters = self.get_deploy_parameters()
|
|
194
|
+
logger.info(f"Deploying pretrained model with parameters:", extra=deploy_parameters)
|
|
195
|
+
framework = deploy_parameters["framework"]
|
|
196
|
+
model_name = deploy_parameters["model_name"]
|
|
197
|
+
model_api = self.api.nn.deploy(model=f"{framework}/{model_name}", agent_id=agent_id)
|
|
198
|
+
return model_api
|
|
199
|
+
|
|
200
|
+
class Custom(DeployMode):
|
|
201
|
+
def __init__(self, deploy_model: "DeployModel"):
|
|
202
|
+
self.api = deploy_model.api
|
|
203
|
+
self.team_id = deploy_model.team_id
|
|
204
|
+
self._cache = deploy_model._cache
|
|
205
|
+
self.deploy_model = deploy_model
|
|
206
|
+
self._model_api = None
|
|
207
|
+
self._layout = self._create_layout()
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def layout(self) -> ExperimentSelector:
|
|
211
|
+
return self._layout
|
|
212
|
+
|
|
213
|
+
def _create_layout(self) -> Container:
|
|
214
|
+
frameworks = self.deploy_model.get_frameworks()
|
|
215
|
+
experiment_infos = []
|
|
216
|
+
for framework_name in frameworks:
|
|
217
|
+
experiment_infos.extend(
|
|
218
|
+
get_experiment_infos(self.api, self.team_id, framework_name=framework_name)
|
|
219
|
+
)
|
|
220
|
+
self.experiment_table = ExperimentSelector(
|
|
221
|
+
experiment_infos=experiment_infos,
|
|
222
|
+
team_id=self.team_id,
|
|
223
|
+
api=self.api,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
@self.experiment_table.checkpoint_changed
|
|
227
|
+
def _checkpoint_changed(row: ExperimentSelector.ModelRow, checkpoint_value: str):
|
|
228
|
+
print(f"Checkpoint changed for {row._experiment_info.task_id}: {checkpoint_value}")
|
|
229
|
+
|
|
230
|
+
return self.experiment_table
|
|
231
|
+
|
|
232
|
+
def get_deploy_parameters(self) -> Dict[str, Any]:
|
|
233
|
+
experiment_info = self.experiment_table.get_selected_experiment_info()
|
|
234
|
+
return {
|
|
235
|
+
"experiment_info": (experiment_info.to_json() if experiment_info else None),
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
def deploy(self, agent_id: int) -> ModelAPI:
|
|
239
|
+
deploy_parameters = self.get_deploy_parameters()
|
|
240
|
+
logger.info(f"Deploying custom model with parameters:", extra=deploy_parameters)
|
|
241
|
+
experiment_info = deploy_parameters["experiment_info"]
|
|
242
|
+
experiment_info = ExperimentInfo(**experiment_info) # pylint: disable=not-a-mapping
|
|
243
|
+
task_info = self.api.nn._deploy_api.deploy_custom_model_from_experiment_info(
|
|
244
|
+
agent_id=agent_id,
|
|
245
|
+
experiment_info=experiment_info,
|
|
246
|
+
log_level="debug",
|
|
247
|
+
)
|
|
248
|
+
model_api = ModelAPI(api=self.api, task_id=task_info["id"])
|
|
249
|
+
return model_api
|
|
250
|
+
|
|
251
|
+
def load_from_json(self, data: Dict):
|
|
252
|
+
if "experiment_info" in data:
|
|
253
|
+
experiment_info_json = data["experiment_info"]
|
|
254
|
+
experiment_info = ExperimentInfo(**experiment_info_json) # pylint: disable=not-a-mapping
|
|
255
|
+
self.experiment_table.set_selected_row_by_experiment_info(experiment_info)
|
|
256
|
+
elif "train_task_id" in data:
|
|
257
|
+
task_id = data["train_task_id"]
|
|
258
|
+
self.experiment_table.set_selected_row_by_task_id(task_id)
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError("Invalid data format for loading custom model.")
|
|
261
|
+
|
|
262
|
+
class MODE:
|
|
263
|
+
CONNECT = "connect"
|
|
264
|
+
PRETRAINED = "pretrained"
|
|
265
|
+
CUSTOM = "custom"
|
|
266
|
+
|
|
267
|
+
MODES = [str(MODE.CONNECT), str(MODE.PRETRAINED), str(MODE.CUSTOM)]
|
|
268
|
+
MODE_TO_CLASS = {
|
|
269
|
+
str(MODE.CONNECT): Connect,
|
|
270
|
+
str(MODE.PRETRAINED): Pretrained,
|
|
271
|
+
str(MODE.CUSTOM): Custom,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
api: Api = None,
|
|
277
|
+
team_id: int = None,
|
|
278
|
+
modes: List[Literal["connect", "pretrained", "custom"]] = None,
|
|
279
|
+
widget_id: str = None,
|
|
280
|
+
):
|
|
281
|
+
self.modes: Dict[str, DeployModel.DeployMode] = {}
|
|
282
|
+
if modes is None:
|
|
283
|
+
modes = self.MODES.copy()
|
|
284
|
+
self._validate_modes(modes)
|
|
285
|
+
if api is None:
|
|
286
|
+
api = Api()
|
|
287
|
+
self.api = api
|
|
288
|
+
if team_id is None:
|
|
289
|
+
team_id = env.team_id()
|
|
290
|
+
self.team_id = team_id
|
|
291
|
+
self._cache = {}
|
|
292
|
+
|
|
293
|
+
self.modes_labels = {
|
|
294
|
+
self.MODE.CONNECT: "Connect",
|
|
295
|
+
self.MODE.PRETRAINED: "Pretrained",
|
|
296
|
+
self.MODE.CUSTOM: "Custom",
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
# GUI
|
|
300
|
+
self.layout: Widget = None
|
|
301
|
+
self._init_gui(modes)
|
|
302
|
+
|
|
303
|
+
self.model_api: ModelAPI = None
|
|
304
|
+
|
|
305
|
+
super().__init__(widget_id=widget_id)
|
|
306
|
+
|
|
307
|
+
def _validate_modes(self, modes) -> None:
|
|
308
|
+
if len(modes) < 1 or len(modes) > len(self.MODES):
|
|
309
|
+
raise ValueError(
|
|
310
|
+
f"Modes must be a list containing 1 to {len(self.MODES)} of the following: {', '.join(self.MODES)}."
|
|
311
|
+
)
|
|
312
|
+
for mode in modes:
|
|
313
|
+
if mode not in self.MODES:
|
|
314
|
+
raise ValueError(f"Invalid mode '{mode}'. Valid modes are {', '.join(self.MODES)}.")
|
|
315
|
+
|
|
316
|
+
def get_modules(self) -> List[ModuleInfo]:
|
|
317
|
+
modules = self._cache.setdefault("modules", [])
|
|
318
|
+
if len(modules) > 0:
|
|
319
|
+
return modules
|
|
320
|
+
modules = self.api.app.get_list_ecosystem_modules(
|
|
321
|
+
categories=["serve", "images"], categories_operation="and"
|
|
322
|
+
)
|
|
323
|
+
modules = [
|
|
324
|
+
module
|
|
325
|
+
for module in modules
|
|
326
|
+
if any([cat for cat in module["config"]["categories"] if cat.startswith("framework:")])
|
|
327
|
+
]
|
|
328
|
+
modules = [ModuleInfo.from_json(module) for module in modules]
|
|
329
|
+
self._cache["modules"] = modules
|
|
330
|
+
return modules
|
|
331
|
+
|
|
332
|
+
def get_frameworks(self) -> List[str]:
|
|
333
|
+
if len(self._cache.get("frameworks", [])) > 0:
|
|
334
|
+
return self._cache["frameworks"]
|
|
335
|
+
|
|
336
|
+
modules = self._cache.setdefault("modules", [])
|
|
337
|
+
if len(modules) == 0:
|
|
338
|
+
modules = self.get_modules()
|
|
339
|
+
frameworks = [cat for module in modules for cat in module.config.get("categories", [])]
|
|
340
|
+
frameworks = [
|
|
341
|
+
cat[len("framework:") :] for cat in frameworks if cat.startswith("framework:")
|
|
342
|
+
]
|
|
343
|
+
self._cache["frameworks"] = frameworks
|
|
344
|
+
return frameworks
|
|
345
|
+
|
|
346
|
+
def _init_modes(self, modes: str) -> None:
|
|
347
|
+
for mode in modes:
|
|
348
|
+
self.modes[mode] = self.MODE_TO_CLASS[mode](self)
|
|
349
|
+
|
|
350
|
+
def _create_task_link(self, task_id: int) -> str:
|
|
351
|
+
return f"{self.api.server_address}/apps/sessions/{task_id}"
|
|
352
|
+
|
|
353
|
+
def _get_inference_settings_by_module(self, module: Dict) -> str:
|
|
354
|
+
config = module["config"]
|
|
355
|
+
inference_settings_path = config.get("files", {}).get("inference_settings", None)
|
|
356
|
+
if inference_settings_path is None:
|
|
357
|
+
raise ValueError(
|
|
358
|
+
f"No inference settings file found for framework app {module['meta']['app']['name']}."
|
|
359
|
+
)
|
|
360
|
+
save_path = tempfile.mktemp(suffix=".yaml")
|
|
361
|
+
self.api.app.download_git_file(
|
|
362
|
+
module_id=module["id"],
|
|
363
|
+
file_path=inference_settings_path,
|
|
364
|
+
save_path=save_path,
|
|
365
|
+
)
|
|
366
|
+
inference_settings = Path(save_path).read_text()
|
|
367
|
+
return inference_settings
|
|
368
|
+
|
|
369
|
+
def _get_inference_settings_for_framework(self, framework: str) -> str:
|
|
370
|
+
inference_settings_cache = self._cache.setdefault("inference_settings", {})
|
|
371
|
+
if framework not in inference_settings_cache:
|
|
372
|
+
module = self.api.nn._deploy_api.find_serving_app_by_framework(framework)
|
|
373
|
+
if module is None:
|
|
374
|
+
raise ValueError(f"No serving app found for framework {framework}.")
|
|
375
|
+
config = module["config"]
|
|
376
|
+
inference_settings_path = config.get("files", {}).get("inference_settings", None)
|
|
377
|
+
if inference_settings_path is None:
|
|
378
|
+
raise ValueError(f"No inference settings file found for framework {framework}.")
|
|
379
|
+
save_path = tempfile.mktemp(suffix=".yaml")
|
|
380
|
+
self.api.app.download_git_file(
|
|
381
|
+
module_id=module["id"],
|
|
382
|
+
file_path=inference_settings_path,
|
|
383
|
+
save_path=save_path,
|
|
384
|
+
)
|
|
385
|
+
inference_settings = Path(save_path).read_text()
|
|
386
|
+
inference_settings_cache[framework] = inference_settings
|
|
387
|
+
return inference_settings_cache[framework]
|
|
388
|
+
|
|
389
|
+
def _framework_from_task_info(self, task_info: Dict) -> str:
|
|
390
|
+
module_id = task_info["meta"]["app"]["moduleId"]
|
|
391
|
+
module = None
|
|
392
|
+
for m in self.get_modules():
|
|
393
|
+
if m.id == module_id:
|
|
394
|
+
module = m
|
|
395
|
+
if module is None:
|
|
396
|
+
module = self.api.app.get_ecosystem_module_info(module_id=module_id)
|
|
397
|
+
self._cache.setdefault("modules", []).append(module)
|
|
398
|
+
for cat in module.config["categories"]:
|
|
399
|
+
if cat.startswith("framework:"):
|
|
400
|
+
return cat[len("framework:") :]
|
|
401
|
+
return "unknown"
|
|
402
|
+
|
|
403
|
+
def _init_gui(self, modes: List[str]) -> None:
|
|
404
|
+
self.status = Text("Deploying model...", status="info")
|
|
405
|
+
self.session_text_1 = Text(
|
|
406
|
+
"",
|
|
407
|
+
"text",
|
|
408
|
+
)
|
|
409
|
+
self.session_text_2 = Text(
|
|
410
|
+
"",
|
|
411
|
+
"text",
|
|
412
|
+
font_size=13,
|
|
413
|
+
)
|
|
414
|
+
self.sesson_link = Container(
|
|
415
|
+
[
|
|
416
|
+
self.session_text_1,
|
|
417
|
+
self.session_text_2,
|
|
418
|
+
],
|
|
419
|
+
gap=0,
|
|
420
|
+
style="padding-left: 10px",
|
|
421
|
+
)
|
|
422
|
+
self.status.hide()
|
|
423
|
+
self.sesson_link.hide()
|
|
424
|
+
|
|
425
|
+
self.select_agent = AgentSelector(self.team_id)
|
|
426
|
+
self.select_agent_field = Field(content=self.select_agent, title="Select Agent")
|
|
427
|
+
|
|
428
|
+
self._create_model_info_widget()
|
|
429
|
+
|
|
430
|
+
self.deploy_button = Button("Deploy", icon="zmdi zmdi-play")
|
|
431
|
+
self.connect_button = Button("Connect", icon="zmdi zmdi-link")
|
|
432
|
+
self.stop_button = Button("Stop", icon="zmdi zmdi-stop", button_type="danger")
|
|
433
|
+
self.stop_button.hide()
|
|
434
|
+
self.disconnect_button = Button("Disconnect", icon="zmdi zmdi-close", button_type="warning")
|
|
435
|
+
self.disconnect_button.hide()
|
|
436
|
+
self.deploy_stop_buttons = Flexbox(
|
|
437
|
+
widgets=[self.deploy_button, self.stop_button, self.disconnect_button],
|
|
438
|
+
gap=10,
|
|
439
|
+
)
|
|
440
|
+
self.connect_stop_buttons = Flexbox(
|
|
441
|
+
widgets=[self.connect_button, self.stop_button, self.disconnect_button],
|
|
442
|
+
gap=10,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
self._init_modes(modes)
|
|
446
|
+
_labels = []
|
|
447
|
+
_contents = []
|
|
448
|
+
for mode_name, mode in self.modes.items():
|
|
449
|
+
label = self.modes_labels[mode_name]
|
|
450
|
+
if mode_name == str(self.MODE.CONNECT):
|
|
451
|
+
widgets = [
|
|
452
|
+
mode.layout,
|
|
453
|
+
self._model_info_card,
|
|
454
|
+
self.connect_stop_buttons,
|
|
455
|
+
self.status,
|
|
456
|
+
self.sesson_link,
|
|
457
|
+
]
|
|
458
|
+
else:
|
|
459
|
+
widgets = [
|
|
460
|
+
mode.layout,
|
|
461
|
+
self._model_info_card,
|
|
462
|
+
self.select_agent_field,
|
|
463
|
+
self.deploy_stop_buttons,
|
|
464
|
+
self.status,
|
|
465
|
+
self.sesson_link,
|
|
466
|
+
]
|
|
467
|
+
content = Container(widgets=widgets, gap=20)
|
|
468
|
+
_labels.append(label)
|
|
469
|
+
_contents.append(content)
|
|
470
|
+
|
|
471
|
+
self.tabs = Tabs(labels=_labels, contents=_contents)
|
|
472
|
+
if len(self.modes) == 1:
|
|
473
|
+
self.layout = _contents[0]
|
|
474
|
+
else:
|
|
475
|
+
self.layout = self.tabs
|
|
476
|
+
|
|
477
|
+
@self.deploy_button.click
|
|
478
|
+
def _deploy_button_clicked():
|
|
479
|
+
self._deploy()
|
|
480
|
+
|
|
481
|
+
@self.stop_button.click
|
|
482
|
+
def _stop_button_clicked():
|
|
483
|
+
self.stop()
|
|
484
|
+
|
|
485
|
+
@self.connect_button.click
|
|
486
|
+
def _connect_button_clicked():
|
|
487
|
+
self._connect()
|
|
488
|
+
|
|
489
|
+
@self.disconnect_button.click
|
|
490
|
+
def _disconnect_button_clicked():
|
|
491
|
+
self.disconnect()
|
|
492
|
+
|
|
493
|
+
@self.tabs.click
|
|
494
|
+
def _active_tab_changed(tab_name: str):
|
|
495
|
+
self.set_model_message_by_tab(tab_name)
|
|
496
|
+
|
|
497
|
+
def set_model_status(
|
|
498
|
+
self,
|
|
499
|
+
status: Literal["deployed", "stopped", "deploying", "connecting", "error", "hide"],
|
|
500
|
+
extra_text: str = None,
|
|
501
|
+
) -> None:
|
|
502
|
+
if status == "hide":
|
|
503
|
+
self.status.hide()
|
|
504
|
+
return
|
|
505
|
+
status_args = {
|
|
506
|
+
"deployed": {"text": "Model deployed successfully!", "status": "success"},
|
|
507
|
+
"stopped": {"text": "Model stopped", "status": "info"},
|
|
508
|
+
"deploying": {"text": "Deploying model...", "status": "info"},
|
|
509
|
+
"connecting": {"text": "Connecting to model...", "status": "info"},
|
|
510
|
+
"connected": {"text": "Model connected successfully!", "status": "success"},
|
|
511
|
+
"error": {
|
|
512
|
+
"text": "Error occurred during model deployment.",
|
|
513
|
+
"status": "error",
|
|
514
|
+
},
|
|
515
|
+
}
|
|
516
|
+
args = status_args[status]
|
|
517
|
+
if extra_text:
|
|
518
|
+
args["text"] += f" {extra_text}"
|
|
519
|
+
self.status.set(**args)
|
|
520
|
+
self.status.show()
|
|
521
|
+
|
|
522
|
+
def set_session_info(self, task_info: Dict):
|
|
523
|
+
if task_info is None:
|
|
524
|
+
self.sesson_link.hide()
|
|
525
|
+
return
|
|
526
|
+
task_id = task_info["id"]
|
|
527
|
+
task_link = self._create_task_link(task_id)
|
|
528
|
+
task_date = task_info["startedAt"]
|
|
529
|
+
task_date = datetime.datetime.fromisoformat(task_date.replace("Z", "+00:00"))
|
|
530
|
+
task_date = task_date.strftime("%Y-%m-%d %H:%M:%S")
|
|
531
|
+
task_name = task_info["meta"]["app"]["name"]
|
|
532
|
+
self.session_text_1.text = f"<i class='zmdi zmdi-link' style='color: #7f858e'></i> <a href='{task_link}' target='_blank'>{task_name}: {task_id}</a>"
|
|
533
|
+
self.session_text_2.text = f"<span class='field-description text-muted' style='color: #7f858e'>{task_date} (UTC)</span>"
|
|
534
|
+
self.sesson_link.show()
|
|
535
|
+
|
|
536
|
+
def disable_modes(self) -> None:
|
|
537
|
+
for mode_name, mode in self.modes.items():
|
|
538
|
+
mode.layout.disable()
|
|
539
|
+
label = self.modes_labels[mode_name]
|
|
540
|
+
self.tabs.disable_tab(label)
|
|
541
|
+
self.select_agent.disable()
|
|
542
|
+
|
|
543
|
+
def enable_modes(self) -> None:
|
|
544
|
+
for mode_name, mode in self.modes.items():
|
|
545
|
+
mode.layout.enable()
|
|
546
|
+
label = self.modes_labels[mode_name]
|
|
547
|
+
self.tabs.enable_tab(label)
|
|
548
|
+
self.select_agent.enable()
|
|
549
|
+
|
|
550
|
+
def show_deploy_button(self) -> None:
|
|
551
|
+
self.stop_button.hide()
|
|
552
|
+
self.disconnect_button.hide()
|
|
553
|
+
self.connect_button.show()
|
|
554
|
+
self.deploy_button.show()
|
|
555
|
+
|
|
556
|
+
def show_stop(self) -> None:
|
|
557
|
+
self.connect_button.hide()
|
|
558
|
+
self.deploy_button.hide()
|
|
559
|
+
self.stop_button.show()
|
|
560
|
+
self.disconnect_button.show()
|
|
561
|
+
|
|
562
|
+
def _connect(self) -> None:
|
|
563
|
+
self.set_model_status("connecting")
|
|
564
|
+
self.set_session_info(None)
|
|
565
|
+
try:
|
|
566
|
+
self.disable_modes()
|
|
567
|
+
model_api = self.deploy()
|
|
568
|
+
task_info = self.api.task.get_info_by_id(model_api.task_id)
|
|
569
|
+
model_info = model_api.get_info()
|
|
570
|
+
model_name = model_info["model_name"]
|
|
571
|
+
framework = self._framework_from_task_info(task_info)
|
|
572
|
+
logger.info(
|
|
573
|
+
f"Model {framework}: {model_name} deployed with session ID {model_api.task_id}."
|
|
574
|
+
)
|
|
575
|
+
self.model_api = model_api
|
|
576
|
+
self.set_model_status("connected")
|
|
577
|
+
self.set_session_info(task_info)
|
|
578
|
+
self.set_model_info(model_api.task_id)
|
|
579
|
+
self.show_stop()
|
|
580
|
+
except Exception as e:
|
|
581
|
+
logger.error(f"Failed to deploy model: {e}", exc_info=True)
|
|
582
|
+
self.set_model_status("error", str(e))
|
|
583
|
+
self.set_session_info(None)
|
|
584
|
+
self.enable_modes()
|
|
585
|
+
self.reset_model_info()
|
|
586
|
+
self.show_deploy_button()
|
|
587
|
+
|
|
588
|
+
def _deploy(self) -> None:
|
|
589
|
+
self.set_model_status("deploying")
|
|
590
|
+
self.set_session_info(None)
|
|
591
|
+
try:
|
|
592
|
+
self.disable_modes()
|
|
593
|
+
model_api = self.deploy()
|
|
594
|
+
task_info = self.api.task.get_info_by_id(model_api.task_id)
|
|
595
|
+
model_info = model_api.get_info()
|
|
596
|
+
model_name = model_info["model_name"]
|
|
597
|
+
framework = self._framework_from_task_info(task_info)
|
|
598
|
+
logger.info(
|
|
599
|
+
f"Model {framework}: {model_name} deployed with session ID {model_api.task_id}."
|
|
600
|
+
)
|
|
601
|
+
self.model_api = model_api
|
|
602
|
+
self.set_model_status("deployed")
|
|
603
|
+
self.set_session_info(task_info)
|
|
604
|
+
self.set_model_info(model_api.task_id)
|
|
605
|
+
self.show_stop()
|
|
606
|
+
except Exception as e:
|
|
607
|
+
logger.error(f"Failed to deploy model: {e}", exc_info=True)
|
|
608
|
+
self.set_model_status("error", str(e))
|
|
609
|
+
self.set_session_info(None)
|
|
610
|
+
self.reset_model_info()
|
|
611
|
+
self.show_deploy_button()
|
|
612
|
+
self.enable_modes()
|
|
613
|
+
else:
|
|
614
|
+
if str(self.MODE.CONNECT) in self.modes:
|
|
615
|
+
self.modes[str(self.MODE.CONNECT)]._update_sessions()
|
|
616
|
+
|
|
617
|
+
def deploy(self) -> ModelAPI:
|
|
618
|
+
mode_label = self.tabs.get_active_tab()
|
|
619
|
+
mode = None
|
|
620
|
+
for mode, label in self.modes_labels.items():
|
|
621
|
+
if label == mode_label:
|
|
622
|
+
break
|
|
623
|
+
agent_id = self.select_agent.get_value()
|
|
624
|
+
self.model_api = self.modes[mode].deploy(agent_id=agent_id)
|
|
625
|
+
return self.model_api
|
|
626
|
+
|
|
627
|
+
def stop(self) -> None:
|
|
628
|
+
if self.model_api is None:
|
|
629
|
+
return
|
|
630
|
+
logger.info("Stopping model...")
|
|
631
|
+
self.model_api.shutdown()
|
|
632
|
+
self.model_api = None
|
|
633
|
+
self.set_model_status("stopped")
|
|
634
|
+
self.enable_modes()
|
|
635
|
+
self.reset_model_info()
|
|
636
|
+
self.show_deploy_button()
|
|
637
|
+
if str(self.MODE.CONNECT) in self.modes:
|
|
638
|
+
self.modes[str(self.MODE.CONNECT)]._update_sessions()
|
|
639
|
+
|
|
640
|
+
def disconnect(self) -> None:
|
|
641
|
+
if self.model_api is None:
|
|
642
|
+
return
|
|
643
|
+
self.model_api = None
|
|
644
|
+
self.set_model_status("hide")
|
|
645
|
+
self.set_session_info(None)
|
|
646
|
+
self.reset_model_info()
|
|
647
|
+
self.show_deploy_button()
|
|
648
|
+
self.enable_modes()
|
|
649
|
+
|
|
650
|
+
def load_from_json(self, data: Dict[str, Any]) -> None:
|
|
651
|
+
"""
|
|
652
|
+
Load widget state from JSON data.
|
|
653
|
+
:param data: Dictionary with widget data.
|
|
654
|
+
"""
|
|
655
|
+
if not data:
|
|
656
|
+
return
|
|
657
|
+
mode = data["mode"]
|
|
658
|
+
label = self.modes_labels[mode]
|
|
659
|
+
self.tabs.set_active_tab(label)
|
|
660
|
+
agent_id = data.get("agent_id", None)
|
|
661
|
+
if agent_id is not None:
|
|
662
|
+
self.select_agent.set_value(agent_id)
|
|
663
|
+
self.modes[mode].load_from_json(data)
|
|
664
|
+
|
|
665
|
+
def get_deploy_parameters(self) -> Dict[str, Any]:
|
|
666
|
+
mode_label = self.tabs.get_active_tab()
|
|
667
|
+
mode = None
|
|
668
|
+
for mode, label in self.modes_labels.items():
|
|
669
|
+
if label == mode_label:
|
|
670
|
+
break
|
|
671
|
+
agent_id = self.select_agent.get_value()
|
|
672
|
+
parameters = {"mode": mode, "agent_id": agent_id}
|
|
673
|
+
parameters.update(self.modes[mode].get_deploy_parameters())
|
|
674
|
+
return parameters
|
|
675
|
+
|
|
676
|
+
def get_json_data(self) -> Dict[str, Any]:
|
|
677
|
+
return {}
|
|
678
|
+
|
|
679
|
+
def get_json_state(self) -> Dict[str, Any]:
|
|
680
|
+
return {}
|
|
681
|
+
|
|
682
|
+
def to_html(self):
|
|
683
|
+
return self.layout.to_html()
|
|
684
|
+
|
|
685
|
+
# Model Info
|
|
686
|
+
def _create_model_info_widget(self):
|
|
687
|
+
self._model_info_widget = ModelInfo()
|
|
688
|
+
self._model_info_widget_field = Field(
|
|
689
|
+
self._model_info_widget,
|
|
690
|
+
title="Model Info",
|
|
691
|
+
description="Information about the deployed model",
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
self._model_info_container = Container([self._model_info_widget_field])
|
|
695
|
+
self._model_info_container.hide()
|
|
696
|
+
self._model_info_message = Text("Connect to model to see the session information.")
|
|
697
|
+
|
|
698
|
+
self._model_info_card = Card(
|
|
699
|
+
title="Session Info",
|
|
700
|
+
description="Model parameters and classes",
|
|
701
|
+
collapsable=True,
|
|
702
|
+
content=Container([self._model_info_container, self._model_info_message]),
|
|
703
|
+
)
|
|
704
|
+
self._model_info_card.collapse()
|
|
705
|
+
|
|
706
|
+
def set_model_info(self, session_id):
|
|
707
|
+
self._model_info_widget.set_model_info(session_id)
|
|
708
|
+
|
|
709
|
+
self._model_info_message.hide()
|
|
710
|
+
self._model_info_container.show()
|
|
711
|
+
self._model_info_card.uncollapse()
|
|
712
|
+
|
|
713
|
+
def reset_model_info(self):
|
|
714
|
+
self._model_info_card.collapse()
|
|
715
|
+
self._model_info_container.hide()
|
|
716
|
+
self._model_info_message.show()
|
|
717
|
+
|
|
718
|
+
def set_model_message_by_tab(self, tab_name: str):
|
|
719
|
+
if tab_name == self.modes_labels[str(self.MODE.CONNECT)]:
|
|
720
|
+
self._model_info_message.set(
|
|
721
|
+
"Connect to model to see the session information.", status="text"
|
|
722
|
+
)
|
|
723
|
+
else:
|
|
724
|
+
self._model_info_message.set(
|
|
725
|
+
"Deploy model to see the session information.", status="text"
|
|
726
|
+
)
|
|
727
|
+
self._model_info_card.collapse()
|
|
728
|
+
|
|
729
|
+
# ------------------------------------------------------------ #
|