supervisely 6.73.420__py3-none-any.whl → 6.73.421__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. supervisely/api/api.py +10 -5
  2. supervisely/api/app_api.py +71 -4
  3. supervisely/api/module_api.py +4 -0
  4. supervisely/api/nn/deploy_api.py +15 -9
  5. supervisely/api/nn/ecosystem_models_api.py +201 -0
  6. supervisely/api/nn/neural_network_api.py +12 -3
  7. supervisely/api/project_api.py +35 -6
  8. supervisely/api/task_api.py +5 -1
  9. supervisely/app/widgets/__init__.py +8 -1
  10. supervisely/app/widgets/agent_selector/template.html +1 -0
  11. supervisely/app/widgets/deploy_model/__init__.py +0 -0
  12. supervisely/app/widgets/deploy_model/deploy_model.py +729 -0
  13. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  14. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  15. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  16. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  17. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +190 -0
  18. supervisely/app/widgets/experiment_selector/experiment_selector.py +447 -264
  19. supervisely/app/widgets/fast_table/fast_table.py +402 -74
  20. supervisely/app/widgets/fast_table/script.js +364 -96
  21. supervisely/app/widgets/fast_table/style.css +24 -0
  22. supervisely/app/widgets/fast_table/template.html +43 -3
  23. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  24. supervisely/app/widgets/select/select.py +6 -4
  25. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +18 -0
  26. supervisely/app/widgets/tabs/tabs.py +22 -6
  27. supervisely/app/widgets/tabs/template.html +5 -1
  28. supervisely/nn/artifacts/__init__.py +1 -1
  29. supervisely/nn/artifacts/artifacts.py +10 -2
  30. supervisely/nn/artifacts/detectron2.py +1 -0
  31. supervisely/nn/artifacts/hrda.py +1 -0
  32. supervisely/nn/artifacts/mmclassification.py +20 -0
  33. supervisely/nn/artifacts/mmdetection.py +5 -3
  34. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  35. supervisely/nn/artifacts/ritm.py +1 -0
  36. supervisely/nn/artifacts/rtdetr.py +1 -0
  37. supervisely/nn/artifacts/unet.py +1 -0
  38. supervisely/nn/artifacts/utils.py +3 -0
  39. supervisely/nn/artifacts/yolov5.py +2 -0
  40. supervisely/nn/artifacts/yolov8.py +1 -0
  41. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  42. supervisely/nn/experiments.py +9 -0
  43. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  44. supervisely/nn/inference/inference.py +160 -94
  45. supervisely/nn/inference/predict_app/__init__.py +0 -0
  46. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  47. supervisely/nn/inference/predict_app/gui/classes_selector.py +91 -0
  48. supervisely/nn/inference/predict_app/gui/gui.py +710 -0
  49. supervisely/nn/inference/predict_app/gui/input_selector.py +165 -0
  50. supervisely/nn/inference/predict_app/gui/model_selector.py +79 -0
  51. supervisely/nn/inference/predict_app/gui/output_selector.py +139 -0
  52. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  53. supervisely/nn/inference/predict_app/gui/settings_selector.py +184 -0
  54. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  55. supervisely/nn/inference/predict_app/gui/utils.py +282 -0
  56. supervisely/nn/inference/predict_app/predict_app.py +184 -0
  57. supervisely/nn/inference/uploader.py +9 -5
  58. supervisely/nn/model/prediction.py +2 -0
  59. supervisely/nn/model/prediction_session.py +20 -3
  60. supervisely/nn/training/gui/gui.py +131 -44
  61. supervisely/nn/training/gui/model_selector.py +8 -6
  62. supervisely/nn/training/gui/train_val_splits_selector.py +122 -70
  63. supervisely/nn/training/gui/training_artifacts.py +0 -5
  64. supervisely/nn/training/train_app.py +161 -44
  65. supervisely/template/experiment/experiment.html.jinja +74 -17
  66. supervisely/template/experiment/experiment_generator.py +258 -112
  67. supervisely/template/experiment/header.html.jinja +31 -13
  68. supervisely/template/experiment/sly-style.css +7 -2
  69. {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/METADATA +3 -1
  70. {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/RECORD +74 -56
  71. supervisely/app/widgets/experiment_selector/style.css +0 -27
  72. supervisely/app/widgets/experiment_selector/template.html +0 -61
  73. {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/LICENSE +0 -0
  74. {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/WHEEL +0 -0
  75. {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/entry_points.txt +0 -0
  76. {supervisely-6.73.420.dist-info → supervisely-6.73.421.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,729 @@
1
+ import datetime
2
+ import tempfile
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Literal
5
+
6
+ import pandas as pd
7
+ import yaml
8
+
9
+ from supervisely._utils import logger
10
+ from supervisely.api.api import Api
11
+ from supervisely.api.app_api import ModuleInfo
12
+ from supervisely.app.widgets.agent_selector.agent_selector import AgentSelector
13
+ from supervisely.app.widgets.button.button import Button
14
+ from supervisely.app.widgets.container.container import Container
15
+ from supervisely.app.widgets.card.card import Card
16
+ from supervisely.app.widgets.model_info.model_info import ModelInfo
17
+ from supervisely.app.widgets.ecosystem_model_selector.ecosystem_model_selector import (
18
+ EcosystemModelSelector,
19
+ )
20
+ from supervisely.app.widgets.experiment_selector.experiment_selector import (
21
+ ExperimentSelector,
22
+ )
23
+ from supervisely.app.widgets.fast_table.fast_table import FastTable
24
+ from supervisely.app.widgets.field.field import Field
25
+ from supervisely.app.widgets.flexbox.flexbox import Flexbox
26
+ from supervisely.app.widgets.tabs.tabs import Tabs
27
+ from supervisely.app.widgets.text.text import Text
28
+ from supervisely.app.widgets.widget import Widget
29
+ from supervisely.io import env
30
+ from supervisely.nn.experiments import ExperimentInfo, get_experiment_infos
31
+ from supervisely.nn.model.model_api import ModelAPI
32
+
33
+
34
+ class DeployModel(Widget):
35
+
36
+ class DeployMode:
37
+
38
+ def deploy(self, agent_id: int = None) -> ModelAPI:
39
+ raise NotImplementedError("This method should be implemented in subclasses.")
40
+
41
+ def get_deploy_parameters(self) -> Dict[str, Any]:
42
+ raise NotImplementedError("This method should be implemented in subclasses.")
43
+
44
+ def load_from_json(self, data: Dict[str, Any]) -> None:
45
+ raise NotImplementedError("This method should be implemented in subclasses.")
46
+
47
+ @property
48
+ def layout(self) -> Widget:
49
+ raise NotImplementedError("This property should be implemented in subclasses.")
50
+
51
+ class Connect(DeployMode):
52
+
53
+ class COLUMN:
54
+ SESSION_ID = "Session ID"
55
+ APP_NAME = "App Name"
56
+ FRAMEWORK = "Framework"
57
+ MODEL = "Model"
58
+
59
+ COLUMNS = [
60
+ str(COLUMN.SESSION_ID),
61
+ str(COLUMN.APP_NAME),
62
+ str(COLUMN.FRAMEWORK),
63
+ str(COLUMN.MODEL),
64
+ ]
65
+
66
+ def __init__(self, deploy_model: "DeployModel"):
67
+ self.api = deploy_model.api
68
+ self.team_id = deploy_model.team_id
69
+ self._cache = deploy_model._cache
70
+ self.deploy_model = deploy_model
71
+ self._layout = self._create_layout()
72
+ self._update_sessions()
73
+
74
+ def _create_layout(self) -> Container:
75
+ self.refresh_button = Button(
76
+ "",
77
+ icon="zmdi zmdi-refresh",
78
+ button_type="text",
79
+ )
80
+ self.sessions_table = FastTable(
81
+ columns=self.COLUMNS,
82
+ page_size=10,
83
+ is_radio=True,
84
+ header_left_content=self.refresh_button,
85
+ )
86
+
87
+ @self.refresh_button.click
88
+ def _refresh_button_clicked():
89
+ self._update_sessions()
90
+
91
+ return self.sessions_table
92
+
93
+ @property
94
+ def layout(self) -> FastTable:
95
+ return self._layout
96
+
97
+ def _data_from_session(self, session: Dict) -> Dict[str, Any]:
98
+ task_info = session["task_info"]
99
+ deploy_info = session["model_info"]
100
+ return {
101
+ self.COLUMN.SESSION_ID: task_info["id"],
102
+ self.COLUMN.APP_NAME: task_info["meta"]["app"]["name"],
103
+ self.COLUMN.FRAMEWORK: self.deploy_model._framework_from_task_info(task_info),
104
+ self.COLUMN.MODEL: deploy_info["model_name"],
105
+ }
106
+
107
+ def _update_sessions(self) -> None:
108
+ self.sessions_table.loading = True
109
+ try:
110
+ self.sessions_table.clear()
111
+ sessions = self.api.nn.list_deployed_models(team_id=self.team_id)
112
+ data = [self._data_from_session(session) for session in sessions]
113
+ df = pd.DataFrame.from_records(data=data, columns=self.COLUMNS)
114
+ self.sessions_table.read_pandas(df)
115
+ if len(data) == 0:
116
+ self.deploy_model.connect_button.disable()
117
+ else:
118
+ self.deploy_model.connect_button.enable()
119
+ except Exception as e:
120
+ logger.error(
121
+ f"Failed to load deployed models: {e}",
122
+ exc_info=True,
123
+ )
124
+ finally:
125
+ self.sessions_table.loading = False
126
+
127
+ def deploy(self, agent_id: int = None) -> ModelAPI:
128
+ deploy_parameters = self.get_deploy_parameters()
129
+ logger.info(f"Connecting to model with parameters:", extra=deploy_parameters)
130
+ session_id = deploy_parameters["session_id"]
131
+ model_api = self.api.nn.connect(task_id=session_id)
132
+ return model_api
133
+
134
+ def get_deploy_parameters(self) -> Dict[str, Any]:
135
+ selected_row = self.sessions_table.get_selected_row()
136
+ return {
137
+ "session_id": selected_row.row[self.COLUMNS.index(str(self.COLUMN.SESSION_ID))],
138
+ }
139
+
140
+ def load_from_json(self, data: Dict):
141
+ session_id = data["session_id"]
142
+ self._update_sessions()
143
+ self.sessions_table.select_row_by_value(str(self.COLUMN.SESSION_ID), session_id)
144
+
145
+ class Pretrained(DeployMode):
146
+ class COLUMN:
147
+ # TODO: columns are the same as in EcosystemModelSelector, make a common base class
148
+ FRAMEWORK = "Framework"
149
+ MODEL_NAME = "Model"
150
+ TASK_TYPE = "Task Type"
151
+ PARAMETERS = "Parameters (M)"
152
+ # TODO: support metrics for different tasks
153
+ MAP = "mAP"
154
+
155
+ COLUMNS = [
156
+ str(COLUMN.FRAMEWORK),
157
+ str(COLUMN.MODEL_NAME),
158
+ str(COLUMN.TASK_TYPE),
159
+ str(COLUMN.PARAMETERS),
160
+ str(COLUMN.MAP),
161
+ ]
162
+
163
+ def __init__(self, deploy_model: "DeployModel"):
164
+ self.api = deploy_model.api
165
+ self.team_id = deploy_model.team_id
166
+ self._cache = deploy_model._cache
167
+ self.deploy_model = deploy_model
168
+ self._model_api = None
169
+ self._last_selected_framework = None
170
+ self._layout = self._create_layout()
171
+
172
+ @property
173
+ def layout(self) -> FastTable:
174
+ return self._layout
175
+
176
+ def _create_layout(self) -> Container:
177
+ self.model_selector = EcosystemModelSelector(api=self.api)
178
+ return self.model_selector
179
+
180
+ def get_deploy_parameters(self) -> Dict[str, Any]:
181
+ selected_model = self.model_selector.get_selected()
182
+ return {
183
+ "framework": selected_model["framework"],
184
+ "model_name": selected_model["name"],
185
+ }
186
+
187
+ def load_from_json(self, data: Dict[str, Any]) -> None:
188
+ framework = data["framework"]
189
+ model_name = data["model_name"]
190
+ self.model_selector.select_framework_and_model_name(framework, model_name)
191
+
192
+ def deploy(self, agent_id: int = None) -> ModelAPI:
193
+ deploy_parameters = self.get_deploy_parameters()
194
+ logger.info(f"Deploying pretrained model with parameters:", extra=deploy_parameters)
195
+ framework = deploy_parameters["framework"]
196
+ model_name = deploy_parameters["model_name"]
197
+ model_api = self.api.nn.deploy(model=f"{framework}/{model_name}", agent_id=agent_id)
198
+ return model_api
199
+
200
+ class Custom(DeployMode):
201
+ def __init__(self, deploy_model: "DeployModel"):
202
+ self.api = deploy_model.api
203
+ self.team_id = deploy_model.team_id
204
+ self._cache = deploy_model._cache
205
+ self.deploy_model = deploy_model
206
+ self._model_api = None
207
+ self._layout = self._create_layout()
208
+
209
+ @property
210
+ def layout(self) -> ExperimentSelector:
211
+ return self._layout
212
+
213
+ def _create_layout(self) -> Container:
214
+ frameworks = self.deploy_model.get_frameworks()
215
+ experiment_infos = []
216
+ for framework_name in frameworks:
217
+ experiment_infos.extend(
218
+ get_experiment_infos(self.api, self.team_id, framework_name=framework_name)
219
+ )
220
+ self.experiment_table = ExperimentSelector(
221
+ experiment_infos=experiment_infos,
222
+ team_id=self.team_id,
223
+ api=self.api,
224
+ )
225
+
226
+ @self.experiment_table.checkpoint_changed
227
+ def _checkpoint_changed(row: ExperimentSelector.ModelRow, checkpoint_value: str):
228
+ print(f"Checkpoint changed for {row._experiment_info.task_id}: {checkpoint_value}")
229
+
230
+ return self.experiment_table
231
+
232
+ def get_deploy_parameters(self) -> Dict[str, Any]:
233
+ experiment_info = self.experiment_table.get_selected_experiment_info()
234
+ return {
235
+ "experiment_info": (experiment_info.to_json() if experiment_info else None),
236
+ }
237
+
238
+ def deploy(self, agent_id: int) -> ModelAPI:
239
+ deploy_parameters = self.get_deploy_parameters()
240
+ logger.info(f"Deploying custom model with parameters:", extra=deploy_parameters)
241
+ experiment_info = deploy_parameters["experiment_info"]
242
+ experiment_info = ExperimentInfo(**experiment_info) # pylint: disable=not-a-mapping
243
+ task_info = self.api.nn._deploy_api.deploy_custom_model_from_experiment_info(
244
+ agent_id=agent_id,
245
+ experiment_info=experiment_info,
246
+ log_level="debug",
247
+ )
248
+ model_api = ModelAPI(api=self.api, task_id=task_info["id"])
249
+ return model_api
250
+
251
+ def load_from_json(self, data: Dict):
252
+ if "experiment_info" in data:
253
+ experiment_info_json = data["experiment_info"]
254
+ experiment_info = ExperimentInfo(**experiment_info_json) # pylint: disable=not-a-mapping
255
+ self.experiment_table.set_selected_row_by_experiment_info(experiment_info)
256
+ elif "train_task_id" in data:
257
+ task_id = data["train_task_id"]
258
+ self.experiment_table.set_selected_row_by_task_id(task_id)
259
+ else:
260
+ raise ValueError("Invalid data format for loading custom model.")
261
+
262
+ class MODE:
263
+ CONNECT = "connect"
264
+ PRETRAINED = "pretrained"
265
+ CUSTOM = "custom"
266
+
267
+ MODES = [str(MODE.CONNECT), str(MODE.PRETRAINED), str(MODE.CUSTOM)]
268
+ MODE_TO_CLASS = {
269
+ str(MODE.CONNECT): Connect,
270
+ str(MODE.PRETRAINED): Pretrained,
271
+ str(MODE.CUSTOM): Custom,
272
+ }
273
+
274
+ def __init__(
275
+ self,
276
+ api: Api = None,
277
+ team_id: int = None,
278
+ modes: List[Literal["connect", "pretrained", "custom"]] = None,
279
+ widget_id: str = None,
280
+ ):
281
+ self.modes: Dict[str, DeployModel.DeployMode] = {}
282
+ if modes is None:
283
+ modes = self.MODES.copy()
284
+ self._validate_modes(modes)
285
+ if api is None:
286
+ api = Api()
287
+ self.api = api
288
+ if team_id is None:
289
+ team_id = env.team_id()
290
+ self.team_id = team_id
291
+ self._cache = {}
292
+
293
+ self.modes_labels = {
294
+ self.MODE.CONNECT: "Connect",
295
+ self.MODE.PRETRAINED: "Pretrained",
296
+ self.MODE.CUSTOM: "Custom",
297
+ }
298
+
299
+ # GUI
300
+ self.layout: Widget = None
301
+ self._init_gui(modes)
302
+
303
+ self.model_api: ModelAPI = None
304
+
305
+ super().__init__(widget_id=widget_id)
306
+
307
+ def _validate_modes(self, modes) -> None:
308
+ if len(modes) < 1 or len(modes) > len(self.MODES):
309
+ raise ValueError(
310
+ f"Modes must be a list containing 1 to {len(self.MODES)} of the following: {', '.join(self.MODES)}."
311
+ )
312
+ for mode in modes:
313
+ if mode not in self.MODES:
314
+ raise ValueError(f"Invalid mode '{mode}'. Valid modes are {', '.join(self.MODES)}.")
315
+
316
+ def get_modules(self) -> List[ModuleInfo]:
317
+ modules = self._cache.setdefault("modules", [])
318
+ if len(modules) > 0:
319
+ return modules
320
+ modules = self.api.app.get_list_ecosystem_modules(
321
+ categories=["serve", "images"], categories_operation="and"
322
+ )
323
+ modules = [
324
+ module
325
+ for module in modules
326
+ if any([cat for cat in module["config"]["categories"] if cat.startswith("framework:")])
327
+ ]
328
+ modules = [ModuleInfo.from_json(module) for module in modules]
329
+ self._cache["modules"] = modules
330
+ return modules
331
+
332
+ def get_frameworks(self) -> List[str]:
333
+ if len(self._cache.get("frameworks", [])) > 0:
334
+ return self._cache["frameworks"]
335
+
336
+ modules = self._cache.setdefault("modules", [])
337
+ if len(modules) == 0:
338
+ modules = self.get_modules()
339
+ frameworks = [cat for module in modules for cat in module.config.get("categories", [])]
340
+ frameworks = [
341
+ cat[len("framework:") :] for cat in frameworks if cat.startswith("framework:")
342
+ ]
343
+ self._cache["frameworks"] = frameworks
344
+ return frameworks
345
+
346
+ def _init_modes(self, modes: str) -> None:
347
+ for mode in modes:
348
+ self.modes[mode] = self.MODE_TO_CLASS[mode](self)
349
+
350
+ def _create_task_link(self, task_id: int) -> str:
351
+ return f"{self.api.server_address}/apps/sessions/{task_id}"
352
+
353
+ def _get_inference_settings_by_module(self, module: Dict) -> str:
354
+ config = module["config"]
355
+ inference_settings_path = config.get("files", {}).get("inference_settings", None)
356
+ if inference_settings_path is None:
357
+ raise ValueError(
358
+ f"No inference settings file found for framework app {module['meta']['app']['name']}."
359
+ )
360
+ save_path = tempfile.mktemp(suffix=".yaml")
361
+ self.api.app.download_git_file(
362
+ module_id=module["id"],
363
+ file_path=inference_settings_path,
364
+ save_path=save_path,
365
+ )
366
+ inference_settings = Path(save_path).read_text()
367
+ return inference_settings
368
+
369
+ def _get_inference_settings_for_framework(self, framework: str) -> str:
370
+ inference_settings_cache = self._cache.setdefault("inference_settings", {})
371
+ if framework not in inference_settings_cache:
372
+ module = self.api.nn._deploy_api.find_serving_app_by_framework(framework)
373
+ if module is None:
374
+ raise ValueError(f"No serving app found for framework {framework}.")
375
+ config = module["config"]
376
+ inference_settings_path = config.get("files", {}).get("inference_settings", None)
377
+ if inference_settings_path is None:
378
+ raise ValueError(f"No inference settings file found for framework {framework}.")
379
+ save_path = tempfile.mktemp(suffix=".yaml")
380
+ self.api.app.download_git_file(
381
+ module_id=module["id"],
382
+ file_path=inference_settings_path,
383
+ save_path=save_path,
384
+ )
385
+ inference_settings = Path(save_path).read_text()
386
+ inference_settings_cache[framework] = inference_settings
387
+ return inference_settings_cache[framework]
388
+
389
+ def _framework_from_task_info(self, task_info: Dict) -> str:
390
+ module_id = task_info["meta"]["app"]["moduleId"]
391
+ module = None
392
+ for m in self.get_modules():
393
+ if m.id == module_id:
394
+ module = m
395
+ if module is None:
396
+ module = self.api.app.get_ecosystem_module_info(module_id=module_id)
397
+ self._cache.setdefault("modules", []).append(module)
398
+ for cat in module.config["categories"]:
399
+ if cat.startswith("framework:"):
400
+ return cat[len("framework:") :]
401
+ return "unknown"
402
+
403
+ def _init_gui(self, modes: List[str]) -> None:
404
+ self.status = Text("Deploying model...", status="info")
405
+ self.session_text_1 = Text(
406
+ "",
407
+ "text",
408
+ )
409
+ self.session_text_2 = Text(
410
+ "",
411
+ "text",
412
+ font_size=13,
413
+ )
414
+ self.sesson_link = Container(
415
+ [
416
+ self.session_text_1,
417
+ self.session_text_2,
418
+ ],
419
+ gap=0,
420
+ style="padding-left: 10px",
421
+ )
422
+ self.status.hide()
423
+ self.sesson_link.hide()
424
+
425
+ self.select_agent = AgentSelector(self.team_id)
426
+ self.select_agent_field = Field(content=self.select_agent, title="Select Agent")
427
+
428
+ self._create_model_info_widget()
429
+
430
+ self.deploy_button = Button("Deploy", icon="zmdi zmdi-play")
431
+ self.connect_button = Button("Connect", icon="zmdi zmdi-link")
432
+ self.stop_button = Button("Stop", icon="zmdi zmdi-stop", button_type="danger")
433
+ self.stop_button.hide()
434
+ self.disconnect_button = Button("Disconnect", icon="zmdi zmdi-close", button_type="warning")
435
+ self.disconnect_button.hide()
436
+ self.deploy_stop_buttons = Flexbox(
437
+ widgets=[self.deploy_button, self.stop_button, self.disconnect_button],
438
+ gap=10,
439
+ )
440
+ self.connect_stop_buttons = Flexbox(
441
+ widgets=[self.connect_button, self.stop_button, self.disconnect_button],
442
+ gap=10,
443
+ )
444
+
445
+ self._init_modes(modes)
446
+ _labels = []
447
+ _contents = []
448
+ for mode_name, mode in self.modes.items():
449
+ label = self.modes_labels[mode_name]
450
+ if mode_name == str(self.MODE.CONNECT):
451
+ widgets = [
452
+ mode.layout,
453
+ self._model_info_card,
454
+ self.connect_stop_buttons,
455
+ self.status,
456
+ self.sesson_link,
457
+ ]
458
+ else:
459
+ widgets = [
460
+ mode.layout,
461
+ self._model_info_card,
462
+ self.select_agent_field,
463
+ self.deploy_stop_buttons,
464
+ self.status,
465
+ self.sesson_link,
466
+ ]
467
+ content = Container(widgets=widgets, gap=20)
468
+ _labels.append(label)
469
+ _contents.append(content)
470
+
471
+ self.tabs = Tabs(labels=_labels, contents=_contents)
472
+ if len(self.modes) == 1:
473
+ self.layout = _contents[0]
474
+ else:
475
+ self.layout = self.tabs
476
+
477
+ @self.deploy_button.click
478
+ def _deploy_button_clicked():
479
+ self._deploy()
480
+
481
+ @self.stop_button.click
482
+ def _stop_button_clicked():
483
+ self.stop()
484
+
485
+ @self.connect_button.click
486
+ def _connect_button_clicked():
487
+ self._connect()
488
+
489
+ @self.disconnect_button.click
490
+ def _disconnect_button_clicked():
491
+ self.disconnect()
492
+
493
+ @self.tabs.click
494
+ def _active_tab_changed(tab_name: str):
495
+ self.set_model_message_by_tab(tab_name)
496
+
497
+ def set_model_status(
498
+ self,
499
+ status: Literal["deployed", "stopped", "deploying", "connecting", "error", "hide"],
500
+ extra_text: str = None,
501
+ ) -> None:
502
+ if status == "hide":
503
+ self.status.hide()
504
+ return
505
+ status_args = {
506
+ "deployed": {"text": "Model deployed successfully!", "status": "success"},
507
+ "stopped": {"text": "Model stopped", "status": "info"},
508
+ "deploying": {"text": "Deploying model...", "status": "info"},
509
+ "connecting": {"text": "Connecting to model...", "status": "info"},
510
+ "connected": {"text": "Model connected successfully!", "status": "success"},
511
+ "error": {
512
+ "text": "Error occurred during model deployment.",
513
+ "status": "error",
514
+ },
515
+ }
516
+ args = status_args[status]
517
+ if extra_text:
518
+ args["text"] += f" {extra_text}"
519
+ self.status.set(**args)
520
+ self.status.show()
521
+
522
+ def set_session_info(self, task_info: Dict):
523
+ if task_info is None:
524
+ self.sesson_link.hide()
525
+ return
526
+ task_id = task_info["id"]
527
+ task_link = self._create_task_link(task_id)
528
+ task_date = task_info["startedAt"]
529
+ task_date = datetime.datetime.fromisoformat(task_date.replace("Z", "+00:00"))
530
+ task_date = task_date.strftime("%Y-%m-%d %H:%M:%S")
531
+ task_name = task_info["meta"]["app"]["name"]
532
+ self.session_text_1.text = f"<i class='zmdi zmdi-link' style='color: #7f858e'></i> <a href='{task_link}' target='_blank'>{task_name}: {task_id}</a>"
533
+ self.session_text_2.text = f"<span class='field-description text-muted' style='color: #7f858e'>{task_date} (UTC)</span>"
534
+ self.sesson_link.show()
535
+
536
+ def disable_modes(self) -> None:
537
+ for mode_name, mode in self.modes.items():
538
+ mode.layout.disable()
539
+ label = self.modes_labels[mode_name]
540
+ self.tabs.disable_tab(label)
541
+ self.select_agent.disable()
542
+
543
+ def enable_modes(self) -> None:
544
+ for mode_name, mode in self.modes.items():
545
+ mode.layout.enable()
546
+ label = self.modes_labels[mode_name]
547
+ self.tabs.enable_tab(label)
548
+ self.select_agent.enable()
549
+
550
+ def show_deploy_button(self) -> None:
551
+ self.stop_button.hide()
552
+ self.disconnect_button.hide()
553
+ self.connect_button.show()
554
+ self.deploy_button.show()
555
+
556
+ def show_stop(self) -> None:
557
+ self.connect_button.hide()
558
+ self.deploy_button.hide()
559
+ self.stop_button.show()
560
+ self.disconnect_button.show()
561
+
562
+ def _connect(self) -> None:
563
+ self.set_model_status("connecting")
564
+ self.set_session_info(None)
565
+ try:
566
+ self.disable_modes()
567
+ model_api = self.deploy()
568
+ task_info = self.api.task.get_info_by_id(model_api.task_id)
569
+ model_info = model_api.get_info()
570
+ model_name = model_info["model_name"]
571
+ framework = self._framework_from_task_info(task_info)
572
+ logger.info(
573
+ f"Model {framework}: {model_name} deployed with session ID {model_api.task_id}."
574
+ )
575
+ self.model_api = model_api
576
+ self.set_model_status("connected")
577
+ self.set_session_info(task_info)
578
+ self.set_model_info(model_api.task_id)
579
+ self.show_stop()
580
+ except Exception as e:
581
+ logger.error(f"Failed to deploy model: {e}", exc_info=True)
582
+ self.set_model_status("error", str(e))
583
+ self.set_session_info(None)
584
+ self.enable_modes()
585
+ self.reset_model_info()
586
+ self.show_deploy_button()
587
+
588
+ def _deploy(self) -> None:
589
+ self.set_model_status("deploying")
590
+ self.set_session_info(None)
591
+ try:
592
+ self.disable_modes()
593
+ model_api = self.deploy()
594
+ task_info = self.api.task.get_info_by_id(model_api.task_id)
595
+ model_info = model_api.get_info()
596
+ model_name = model_info["model_name"]
597
+ framework = self._framework_from_task_info(task_info)
598
+ logger.info(
599
+ f"Model {framework}: {model_name} deployed with session ID {model_api.task_id}."
600
+ )
601
+ self.model_api = model_api
602
+ self.set_model_status("deployed")
603
+ self.set_session_info(task_info)
604
+ self.set_model_info(model_api.task_id)
605
+ self.show_stop()
606
+ except Exception as e:
607
+ logger.error(f"Failed to deploy model: {e}", exc_info=True)
608
+ self.set_model_status("error", str(e))
609
+ self.set_session_info(None)
610
+ self.reset_model_info()
611
+ self.show_deploy_button()
612
+ self.enable_modes()
613
+ else:
614
+ if str(self.MODE.CONNECT) in self.modes:
615
+ self.modes[str(self.MODE.CONNECT)]._update_sessions()
616
+
617
+ def deploy(self) -> ModelAPI:
618
+ mode_label = self.tabs.get_active_tab()
619
+ mode = None
620
+ for mode, label in self.modes_labels.items():
621
+ if label == mode_label:
622
+ break
623
+ agent_id = self.select_agent.get_value()
624
+ self.model_api = self.modes[mode].deploy(agent_id=agent_id)
625
+ return self.model_api
626
+
627
+ def stop(self) -> None:
628
+ if self.model_api is None:
629
+ return
630
+ logger.info("Stopping model...")
631
+ self.model_api.shutdown()
632
+ self.model_api = None
633
+ self.set_model_status("stopped")
634
+ self.enable_modes()
635
+ self.reset_model_info()
636
+ self.show_deploy_button()
637
+ if str(self.MODE.CONNECT) in self.modes:
638
+ self.modes[str(self.MODE.CONNECT)]._update_sessions()
639
+
640
+ def disconnect(self) -> None:
641
+ if self.model_api is None:
642
+ return
643
+ self.model_api = None
644
+ self.set_model_status("hide")
645
+ self.set_session_info(None)
646
+ self.reset_model_info()
647
+ self.show_deploy_button()
648
+ self.enable_modes()
649
+
650
+ def load_from_json(self, data: Dict[str, Any]) -> None:
651
+ """
652
+ Load widget state from JSON data.
653
+ :param data: Dictionary with widget data.
654
+ """
655
+ if not data:
656
+ return
657
+ mode = data["mode"]
658
+ label = self.modes_labels[mode]
659
+ self.tabs.set_active_tab(label)
660
+ agent_id = data.get("agent_id", None)
661
+ if agent_id is not None:
662
+ self.select_agent.set_value(agent_id)
663
+ self.modes[mode].load_from_json(data)
664
+
665
+ def get_deploy_parameters(self) -> Dict[str, Any]:
666
+ mode_label = self.tabs.get_active_tab()
667
+ mode = None
668
+ for mode, label in self.modes_labels.items():
669
+ if label == mode_label:
670
+ break
671
+ agent_id = self.select_agent.get_value()
672
+ parameters = {"mode": mode, "agent_id": agent_id}
673
+ parameters.update(self.modes[mode].get_deploy_parameters())
674
+ return parameters
675
+
676
+ def get_json_data(self) -> Dict[str, Any]:
677
+ return {}
678
+
679
+ def get_json_state(self) -> Dict[str, Any]:
680
+ return {}
681
+
682
+ def to_html(self):
683
+ return self.layout.to_html()
684
+
685
+ # Model Info
686
+ def _create_model_info_widget(self):
687
+ self._model_info_widget = ModelInfo()
688
+ self._model_info_widget_field = Field(
689
+ self._model_info_widget,
690
+ title="Model Info",
691
+ description="Information about the deployed model",
692
+ )
693
+
694
+ self._model_info_container = Container([self._model_info_widget_field])
695
+ self._model_info_container.hide()
696
+ self._model_info_message = Text("Connect to model to see the session information.")
697
+
698
+ self._model_info_card = Card(
699
+ title="Session Info",
700
+ description="Model parameters and classes",
701
+ collapsable=True,
702
+ content=Container([self._model_info_container, self._model_info_message]),
703
+ )
704
+ self._model_info_card.collapse()
705
+
706
+ def set_model_info(self, session_id):
707
+ self._model_info_widget.set_model_info(session_id)
708
+
709
+ self._model_info_message.hide()
710
+ self._model_info_container.show()
711
+ self._model_info_card.uncollapse()
712
+
713
+ def reset_model_info(self):
714
+ self._model_info_card.collapse()
715
+ self._model_info_container.hide()
716
+ self._model_info_message.show()
717
+
718
+ def set_model_message_by_tab(self, tab_name: str):
719
+ if tab_name == self.modes_labels[str(self.MODE.CONNECT)]:
720
+ self._model_info_message.set(
721
+ "Connect to model to see the session information.", status="text"
722
+ )
723
+ else:
724
+ self._model_info_message.set(
725
+ "Deploy model to see the session information.", status="text"
726
+ )
727
+ self._model_info_card.collapse()
728
+
729
+ # ------------------------------------------------------------ #