supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,750 @@
1
+ import datetime
2
+ import tempfile
3
+ import threading
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Literal
6
+
7
+ import pandas as pd
8
+
9
+ from supervisely._utils import logger
10
+ from supervisely.api.api import Api
11
+ from supervisely.api.app_api import ModuleInfo
12
+ from supervisely.app.widgets.agent_selector.agent_selector import AgentSelector
13
+ from supervisely.app.widgets.button.button import Button
14
+ from supervisely.app.widgets.container.container import Container
15
+ from supervisely.app.widgets.ecosystem_model_selector.ecosystem_model_selector import (
16
+ EcosystemModelSelector,
17
+ )
18
+ from supervisely.app.widgets.experiment_selector.experiment_selector import (
19
+ ExperimentSelector,
20
+ )
21
+ from supervisely.app.widgets.fast_table.fast_table import FastTable
22
+ from supervisely.app.widgets.field.field import Field
23
+ from supervisely.app.widgets.flexbox.flexbox import Flexbox
24
+ from supervisely.app.widgets.model_info.model_info import ModelInfo
25
+ from supervisely.app.widgets.radio_tabs.radio_tabs import RadioTabs
26
+ from supervisely.app.widgets.text.text import Text
27
+ from supervisely.app.widgets.widget import Widget
28
+ from supervisely.io import env
29
+ from supervisely.nn.experiments import ExperimentInfo, get_experiment_infos
30
+ from supervisely.nn.model.model_api import ModelAPI
31
+
32
+
33
+ class DeployModel(Widget):
34
+
35
+ class DeployMode:
36
+
37
+ def deploy(self, agent_id: int = None) -> ModelAPI:
38
+ raise NotImplementedError("This method should be implemented in subclasses.")
39
+
40
+ def get_deploy_parameters(self) -> Dict[str, Any]:
41
+ raise NotImplementedError("This method should be implemented in subclasses.")
42
+
43
+ def load_from_json(self, data: Dict[str, Any]) -> None:
44
+ raise NotImplementedError("This method should be implemented in subclasses.")
45
+
46
+ @property
47
+ def layout(self) -> Widget:
48
+ raise NotImplementedError("This property should be implemented in subclasses.")
49
+
50
+ class Connect(DeployMode):
51
+
52
+ class COLUMN:
53
+ SESSION_ID = "Session ID"
54
+ APP_NAME = "App Name"
55
+ FRAMEWORK = "Framework"
56
+ MODEL = "Model"
57
+
58
+ COLUMNS = [
59
+ str(COLUMN.SESSION_ID),
60
+ str(COLUMN.APP_NAME),
61
+ str(COLUMN.FRAMEWORK),
62
+ str(COLUMN.MODEL),
63
+ ]
64
+
65
+ def __init__(self, deploy_model: "DeployModel"):
66
+ self.api = deploy_model.api
67
+ self.team_id = deploy_model.team_id
68
+ self._cache = deploy_model._cache
69
+ self.deploy_model = deploy_model
70
+ self._layout = self._create_layout()
71
+ self._update_sessions()
72
+
73
+ def _create_layout(self) -> Container:
74
+ self.refresh_button = Button(
75
+ "",
76
+ icon="zmdi zmdi-refresh",
77
+ button_type="text",
78
+ )
79
+ self.sessions_table = FastTable(
80
+ columns=self.COLUMNS,
81
+ page_size=10,
82
+ is_radio=True,
83
+ header_left_content=self.refresh_button,
84
+ )
85
+
86
+ @self.refresh_button.click
87
+ def _refresh_button_clicked():
88
+ self._update_sessions()
89
+
90
+ return self.sessions_table
91
+
92
+ @property
93
+ def layout(self) -> FastTable:
94
+ return self._layout
95
+
96
+ def _data_from_session(self, session: Dict) -> Dict[str, Any]:
97
+ task_info = session["task_info"]
98
+ deploy_info = session["model_info"]
99
+ return {
100
+ self.COLUMN.SESSION_ID: task_info["id"],
101
+ self.COLUMN.APP_NAME: task_info["meta"]["app"]["name"],
102
+ self.COLUMN.FRAMEWORK: self.deploy_model._framework_from_task_info(task_info),
103
+ self.COLUMN.MODEL: deploy_info["model_name"],
104
+ }
105
+
106
+ def _update_sessions(self) -> None:
107
+ self.sessions_table.loading = True
108
+ try:
109
+ self.sessions_table.clear()
110
+ sessions = self.api.nn.list_deployed_models(team_id=self.team_id)
111
+ data = [self._data_from_session(session) for session in sessions]
112
+ df = pd.DataFrame.from_records(data=data, columns=self.COLUMNS)
113
+ self.sessions_table.read_pandas(df)
114
+ if len(data) == 0:
115
+ self.deploy_model.connect_button.disable()
116
+ else:
117
+ self.deploy_model.connect_button.enable()
118
+ except Exception as e:
119
+ logger.error(
120
+ f"Failed to load deployed models: {e}",
121
+ exc_info=True,
122
+ )
123
+ finally:
124
+ self.sessions_table.loading = False
125
+
126
+ def deploy(self, agent_id: int = None) -> ModelAPI:
127
+ deploy_parameters = self.get_deploy_parameters()
128
+ logger.info(f"Connecting to model with parameters:", extra=deploy_parameters)
129
+ session_id = deploy_parameters["session_id"]
130
+ model_api = self.api.nn.connect(task_id=session_id)
131
+ return model_api
132
+
133
+ def get_deploy_parameters(self) -> Dict[str, Any]:
134
+ selected_row = self.sessions_table.get_selected_row()
135
+ return {
136
+ "session_id": selected_row.row[self.COLUMNS.index(str(self.COLUMN.SESSION_ID))],
137
+ }
138
+
139
+ def load_from_json(self, data: Dict):
140
+ session_id = data["session_id"]
141
+ self._update_sessions()
142
+ self.sessions_table.select_row_by_value(str(self.COLUMN.SESSION_ID), session_id)
143
+
144
+ class Pretrained(DeployMode):
145
+ class COLUMN:
146
+ # TODO: columns are the same as in EcosystemModelSelector, make a common base class
147
+ FRAMEWORK = "Framework"
148
+ MODEL_NAME = "Model"
149
+ TASK_TYPE = "Task Type"
150
+ PARAMETERS = "Parameters (M)"
151
+ # TODO: support metrics for different tasks
152
+ MAP = "mAP"
153
+
154
+ COLUMNS = [
155
+ str(COLUMN.FRAMEWORK),
156
+ str(COLUMN.MODEL_NAME),
157
+ str(COLUMN.TASK_TYPE),
158
+ str(COLUMN.PARAMETERS),
159
+ str(COLUMN.MAP),
160
+ ]
161
+
162
+ def __init__(self, deploy_model: "DeployModel"):
163
+ self.api = deploy_model.api
164
+ self.team_id = deploy_model.team_id
165
+ self._cache = deploy_model._cache
166
+ self.deploy_model = deploy_model
167
+ self._model_api = None
168
+ self._last_selected_framework = None
169
+ self._layout = self._create_layout()
170
+
171
+ @property
172
+ def layout(self) -> FastTable:
173
+ return self._layout
174
+
175
+ def _create_layout(self) -> Container:
176
+ self.model_selector = EcosystemModelSelector(api=self.api)
177
+ return self.model_selector
178
+
179
+ def get_deploy_parameters(self) -> Dict[str, Any]:
180
+ selected_model = self.model_selector.get_selected()
181
+ return {
182
+ "framework": selected_model["framework"],
183
+ "model_name": selected_model["name"],
184
+ }
185
+
186
+ def load_from_json(self, data: Dict[str, Any]) -> None:
187
+ framework = data["framework"]
188
+ model_name = data["model_name"]
189
+ self.model_selector.select_framework_and_model_name(framework, model_name)
190
+
191
+ def deploy(self, agent_id: int = None) -> ModelAPI:
192
+ deploy_parameters = self.get_deploy_parameters()
193
+ logger.info(f"Deploying pretrained model with parameters:", extra=deploy_parameters)
194
+ framework = deploy_parameters["framework"]
195
+ model_name = deploy_parameters["model_name"]
196
+ model_api = self.api.nn.deploy(model=f"{framework}/{model_name}", agent_id=agent_id)
197
+ return model_api
198
+
199
+ class Custom(DeployMode):
200
+ def __init__(self, deploy_model: "DeployModel"):
201
+ self.api = deploy_model.api
202
+ self.team_id = deploy_model.team_id
203
+ self._cache = deploy_model._cache
204
+ self.deploy_model = deploy_model
205
+ self._model_api = None
206
+ self._layout = self._create_layout()
207
+
208
+ @property
209
+ def layout(self) -> ExperimentSelector:
210
+ return self._layout
211
+
212
+ def _create_layout(self) -> Container:
213
+ self.experiment_table = ExperimentSelector(
214
+ api=self.api,
215
+ team_id=self.team_id,
216
+ )
217
+
218
+ @self.experiment_table.checkpoint_changed
219
+ def _checkpoint_changed(row: ExperimentSelector.ModelRow, checkpoint_value: str):
220
+ print(f"Checkpoint changed for {row._experiment_info.task_id}: {checkpoint_value}")
221
+
222
+ threading.Thread(target=self.refresh_experiments, daemon=True).start()
223
+
224
+ return self.experiment_table
225
+
226
+ def refresh_experiments(self):
227
+ self.experiment_table.loading = True
228
+ frameworks = self.deploy_model.get_frameworks()
229
+ experiment_infos = []
230
+ for framework_name in frameworks:
231
+ experiment_infos.extend(
232
+ get_experiment_infos(self.api, self.team_id, framework_name=framework_name)
233
+ )
234
+
235
+ self.experiment_table.set_experiment_infos(experiment_infos)
236
+ self.experiment_table.loading = False
237
+
238
+ def get_deploy_parameters(self) -> Dict[str, Any]:
239
+ experiment_info = self.experiment_table.get_selected_experiment_info()
240
+ return {
241
+ "experiment_info": (experiment_info.to_json() if experiment_info else None),
242
+ }
243
+
244
+ def deploy(self, agent_id: int) -> ModelAPI:
245
+ deploy_parameters = self.get_deploy_parameters()
246
+ logger.info(f"Deploying custom model with parameters:", extra=deploy_parameters)
247
+ experiment_info = deploy_parameters["experiment_info"]
248
+ experiment_info = ExperimentInfo(**experiment_info) # pylint: disable=not-a-mapping
249
+ task_info = self.api.nn._deploy_api.deploy_custom_model_from_experiment_info(
250
+ agent_id=agent_id,
251
+ experiment_info=experiment_info,
252
+ log_level="debug",
253
+ )
254
+ model_api = ModelAPI(api=self.api, task_id=task_info["id"])
255
+ return model_api
256
+
257
+ def load_from_json(self, data: Dict):
258
+ if "experiment_info" in data:
259
+ experiment_info_json = data["experiment_info"]
260
+ experiment_info = ExperimentInfo(**experiment_info_json) # pylint: disable=not-a-mapping
261
+ self.experiment_table.set_selected_row_by_experiment_info(experiment_info)
262
+ elif "train_task_id" in data:
263
+ task_id = data["train_task_id"]
264
+ self.experiment_table.set_selected_row_by_task_id(task_id)
265
+ else:
266
+ raise ValueError("Invalid data format for loading custom model.")
267
+
268
+ class MODE:
269
+ CONNECT = "connect"
270
+ PRETRAINED = "pretrained"
271
+ CUSTOM = "custom"
272
+
273
+ MODES = [str(MODE.CONNECT), str(MODE.PRETRAINED), str(MODE.CUSTOM)]
274
+ MODE_TO_CLASS = {
275
+ str(MODE.CONNECT): Connect,
276
+ str(MODE.CUSTOM): Custom,
277
+ str(MODE.PRETRAINED): Pretrained,
278
+ }
279
+
280
+ def __init__(
281
+ self,
282
+ api: Api = None,
283
+ team_id: int = None,
284
+ modes: List[Literal["connect", "pretrained", "custom"]] = None,
285
+ widget_id: str = None,
286
+ ):
287
+ self.modes: Dict[str, DeployModel.DeployMode] = {}
288
+ if modes is None:
289
+ modes = self.MODES.copy()
290
+ self._validate_modes(modes)
291
+ if api is None:
292
+ api = Api()
293
+ self.api = api
294
+ if team_id is None:
295
+ team_id = env.team_id()
296
+ self.team_id = team_id
297
+ self._cache = {}
298
+
299
+ self.modes_labels = {
300
+ self.MODE.CONNECT: "Connect",
301
+ self.MODE.PRETRAINED: "Pretrained",
302
+ self.MODE.CUSTOM: "Custom",
303
+ }
304
+ self.modes_descriptions = {
305
+ self.MODE.CONNECT: "Connect to an already deployed model",
306
+ self.MODE.PRETRAINED: "Deploy a pretrained model from the ecosystem",
307
+ self.MODE.CUSTOM: "Deploy a custom model from your experiments",
308
+ }
309
+
310
+ # GUI
311
+ self.layout: Widget = None
312
+ self._init_gui(modes)
313
+
314
+ self.model_api: ModelAPI = None
315
+
316
+ super().__init__(widget_id=widget_id)
317
+
318
+ def _validate_modes(self, modes) -> None:
319
+ if len(modes) < 1 or len(modes) > len(self.MODES):
320
+ raise ValueError(
321
+ f"Modes must be a list containing 1 to {len(self.MODES)} of the following: {', '.join(self.MODES)}."
322
+ )
323
+ for mode in modes:
324
+ if mode not in self.MODES:
325
+ raise ValueError(f"Invalid mode '{mode}'. Valid modes are {', '.join(self.MODES)}.")
326
+
327
+ def get_modules(self) -> List[ModuleInfo]:
328
+ modules = self._cache.setdefault("modules", [])
329
+ if len(modules) > 0:
330
+ return modules
331
+ modules = self.api.app.get_list_ecosystem_modules(
332
+ categories=["serve", "images"], categories_operation="and"
333
+ )
334
+ modules = [
335
+ module
336
+ for module in modules
337
+ if any([cat for cat in module["config"]["categories"] if cat.startswith("framework:")])
338
+ ]
339
+ modules = [ModuleInfo.from_json(module) for module in modules]
340
+ self._cache["modules"] = modules
341
+ return modules
342
+
343
+ def get_frameworks(self) -> List[str]:
344
+ if len(self._cache.get("frameworks", [])) > 0:
345
+ return self._cache["frameworks"]
346
+
347
+ modules = self._cache.setdefault("modules", [])
348
+ if len(modules) == 0:
349
+ modules = self.get_modules()
350
+ frameworks = [cat for module in modules for cat in module.config.get("categories", [])]
351
+ frameworks = [
352
+ cat[len("framework:") :] for cat in frameworks if cat.startswith("framework:")
353
+ ]
354
+ self._cache["frameworks"] = frameworks
355
+ return frameworks
356
+
357
+ def _init_modes(self, modes: str) -> None:
358
+ for mode in modes:
359
+ self.modes[mode] = self.MODE_TO_CLASS[mode](self)
360
+
361
+ def _create_task_link(self, task_id: int) -> str:
362
+ return f"{self.api.server_address}/apps/sessions/{task_id}"
363
+
364
+ def _get_inference_settings_by_module(self, module: Dict) -> str:
365
+ config = module["config"]
366
+ inference_settings_path = config.get("files", {}).get("inference_settings", None)
367
+ if inference_settings_path is None:
368
+ raise ValueError(
369
+ f"No inference settings file found for framework app {module['meta']['app']['name']}."
370
+ )
371
+ save_path = tempfile.mktemp(suffix=".yaml")
372
+ self.api.app.download_git_file(
373
+ module_id=module["id"],
374
+ file_path=inference_settings_path,
375
+ save_path=save_path,
376
+ )
377
+ inference_settings = Path(save_path).read_text()
378
+ return inference_settings
379
+
380
+ def _get_inference_settings_for_framework(self, framework: str) -> str:
381
+ inference_settings_cache = self._cache.setdefault("inference_settings", {})
382
+ if framework not in inference_settings_cache:
383
+ module = self.api.nn._deploy_api.find_serving_app_by_framework(framework)
384
+ if module is None:
385
+ raise ValueError(f"No serving app found for framework {framework}.")
386
+ config = module["config"]
387
+ inference_settings_path = config.get("files", {}).get("inference_settings", None)
388
+ if inference_settings_path is None:
389
+ raise ValueError(f"No inference settings file found for framework {framework}.")
390
+ save_path = tempfile.mktemp(suffix=".yaml")
391
+ self.api.app.download_git_file(
392
+ module_id=module["id"],
393
+ file_path=inference_settings_path,
394
+ save_path=save_path,
395
+ )
396
+ inference_settings = Path(save_path).read_text()
397
+ inference_settings_cache[framework] = inference_settings
398
+ return inference_settings_cache[framework]
399
+
400
+ def _framework_from_task_info(self, task_info: Dict) -> str:
401
+ module_id = task_info["meta"]["app"]["moduleId"]
402
+ module = None
403
+ for m in self.get_modules():
404
+ if m.id == module_id:
405
+ module = m
406
+ if module is None:
407
+ module = self.api.app.get_ecosystem_module_info(module_id=module_id)
408
+ self._cache.setdefault("modules", []).append(module)
409
+ for cat in module.config["categories"]:
410
+ if cat.startswith("framework:"):
411
+ return cat[len("framework:") :]
412
+ return "unknown"
413
+
414
+ def _init_gui(self, modes: List[str]) -> None:
415
+ self.status = Text("Deploying model...", status="info")
416
+ self.session_text_1 = Text(
417
+ "",
418
+ "text",
419
+ )
420
+ self.session_text_2 = Text(
421
+ "",
422
+ "text",
423
+ font_size=13,
424
+ )
425
+ self.sesson_link = Container(
426
+ [
427
+ self.session_text_1,
428
+ self.session_text_2,
429
+ ],
430
+ gap=0,
431
+ style="padding-left: 10px",
432
+ )
433
+ self.status.hide()
434
+ self.sesson_link.hide()
435
+
436
+ self.select_agent = AgentSelector(self.team_id)
437
+ self.select_agent_field = Field(content=self.select_agent, title="Select Agent")
438
+
439
+ self._create_model_info_widget()
440
+
441
+ self.deploy_button = Button("Deploy", icon="zmdi zmdi-play")
442
+ self.connect_button = Button("Connect", icon="zmdi zmdi-link")
443
+ self.stop_button = Button("Stop", icon="zmdi zmdi-stop", button_type="danger")
444
+ self.stop_button.hide()
445
+ self.disconnect_button = Button("Disconnect", icon="zmdi zmdi-close", button_type="warning")
446
+ self.disconnect_button.hide()
447
+ self.deploy_stop_buttons = Flexbox(
448
+ widgets=[self.deploy_button, self.stop_button, self.disconnect_button],
449
+ gap=10,
450
+ )
451
+ self.connect_stop_buttons = Flexbox(
452
+ widgets=[self.connect_button, self.stop_button, self.disconnect_button],
453
+ gap=10,
454
+ )
455
+
456
+ self._init_modes(modes)
457
+ _labels = []
458
+ _descriptions = []
459
+ _contents = []
460
+ self.statuses_widgets = Container(
461
+ widgets=[
462
+ self.sesson_link,
463
+ self._model_info_container,
464
+ ],
465
+ gap=20,
466
+ )
467
+ self.statuses_widgets.hide()
468
+ for mode_name, mode in self.modes.items():
469
+ label = self.modes_labels[mode_name]
470
+ description = self.modes_descriptions[mode_name]
471
+ if mode_name == str(self.MODE.CONNECT):
472
+ widgets = [
473
+ mode.layout,
474
+ self.status,
475
+ self.statuses_widgets,
476
+ self.connect_stop_buttons,
477
+ ]
478
+ else:
479
+ widgets = [
480
+ mode.layout,
481
+ self.select_agent_field,
482
+ self.status,
483
+ self.statuses_widgets,
484
+ self.deploy_stop_buttons,
485
+ ]
486
+
487
+ content = Container(widgets=widgets, gap=20)
488
+ _labels.append(label)
489
+ _descriptions.append(description)
490
+ _contents.append(content)
491
+
492
+ self.tabs = RadioTabs(titles=_labels, descriptions=_descriptions, contents=_contents)
493
+ if len(self.modes) == 1:
494
+ self.layout = _contents[0]
495
+ else:
496
+ self.layout = self.tabs
497
+
498
+ @self.deploy_button.click
499
+ def _deploy_button_clicked():
500
+ self._deploy()
501
+
502
+ @self.stop_button.click
503
+ def _stop_button_clicked():
504
+ self.stop()
505
+
506
+ @self.connect_button.click
507
+ def _connect_button_clicked():
508
+ self._connect()
509
+
510
+ @self.disconnect_button.click
511
+ def _disconnect_button_clicked():
512
+ self.disconnect()
513
+
514
+ @self.tabs.value_changed
515
+ def _active_tab_changed(tab_name: str):
516
+ self.set_model_message_by_tab(tab_name)
517
+
518
+ def set_model_status(
519
+ self,
520
+ status: Literal["deployed", "stopped", "deploying", "connecting", "error", "hide"],
521
+ extra_text: str = None,
522
+ ) -> None:
523
+ if status == "hide":
524
+ self.status.hide()
525
+ return
526
+ status_args = {
527
+ "deployed": {"text": "Model deployed successfully!", "status": "success"},
528
+ "stopped": {"text": "Model stopped", "status": "info"},
529
+ "deploying": {"text": "Deploying model...", "status": "info"},
530
+ "connecting": {"text": "Connecting to model...", "status": "info"},
531
+ "connected": {"text": "Model connected successfully!", "status": "success"},
532
+ "error": {
533
+ "text": "Error occurred during model deployment.",
534
+ "status": "error",
535
+ },
536
+ }
537
+ args = status_args[status]
538
+ if extra_text:
539
+ args["text"] += f" {extra_text}"
540
+ self.status.set(**args)
541
+ self.status.show()
542
+
543
+ def set_session_info(self, task_info: Dict):
544
+ if task_info is None:
545
+ self.sesson_link.hide()
546
+ return
547
+ task_id = task_info["id"]
548
+ task_link = self._create_task_link(task_id)
549
+ task_date = task_info["startedAt"]
550
+ task_date = datetime.datetime.fromisoformat(task_date.replace("Z", "+00:00"))
551
+ task_date = task_date.strftime("%Y-%m-%d %H:%M:%S")
552
+ task_name = task_info["meta"]["app"]["name"]
553
+ self.session_text_1.text = f"<i class='zmdi zmdi-link' style='color: #7f858e'></i> <a href='{task_link}' target='_blank'>{task_name}: {task_id}</a>"
554
+ self.session_text_2.text = f"<span class='field-description text-muted' style='color: #7f858e'>{task_date} (UTC)</span>"
555
+ self.sesson_link.show()
556
+
557
+ def disable_modes(self) -> None:
558
+ for mode_name, mode in self.modes.items():
559
+ mode.layout.disable()
560
+ label = self.modes_labels[mode_name]
561
+ self.tabs.disable_tab(label)
562
+ self.select_agent.disable()
563
+
564
+ def enable_modes(self) -> None:
565
+ for mode_name, mode in self.modes.items():
566
+ mode.layout.enable()
567
+ label = self.modes_labels[mode_name]
568
+ self.tabs.enable_tab(label)
569
+ self.select_agent.enable()
570
+
571
+ def show_deploy_button(self) -> None:
572
+ self.stop_button.hide()
573
+ self.disconnect_button.hide()
574
+ self.connect_button.show()
575
+ self.deploy_button.show()
576
+
577
+ def show_stop(self) -> None:
578
+ self.connect_button.hide()
579
+ self.deploy_button.hide()
580
+ self.stop_button.show()
581
+ self.disconnect_button.show()
582
+
583
+ def _connect(self) -> None:
584
+ self.set_model_status("connecting")
585
+ self.set_session_info(None)
586
+ try:
587
+ self.disable_modes()
588
+ model_api = self.deploy()
589
+ task_info = self.api.task.get_info_by_id(model_api.task_id)
590
+ model_info = model_api.get_info()
591
+ model_name = model_info["model_name"]
592
+ framework = self._framework_from_task_info(task_info)
593
+ logger.info(
594
+ f"Model {framework}: {model_name} deployed with session ID {model_api.task_id}."
595
+ )
596
+ self.model_api = model_api
597
+ self.statuses_widgets.show()
598
+ self.set_model_status("connected")
599
+ self.set_session_info(task_info)
600
+ self.set_model_info(model_api.task_id)
601
+ self.show_stop()
602
+ except Exception as e:
603
+ logger.error(f"Failed to deploy model: {e}", exc_info=True)
604
+ self.set_model_status("error", str(e))
605
+ self.set_session_info(None)
606
+ self.enable_modes()
607
+ self.reset_model_info()
608
+ self.show_deploy_button()
609
+
610
+ def _deploy(self) -> None:
611
+ self.set_model_status("deploying")
612
+ self.set_session_info(None)
613
+ try:
614
+ self.disable_modes()
615
+ model_api = self.deploy()
616
+ task_info = self.api.task.get_info_by_id(model_api.task_id)
617
+ model_info = model_api.get_info()
618
+ model_name = model_info["model_name"]
619
+ framework = self._framework_from_task_info(task_info)
620
+ logger.info(
621
+ f"Model {framework}: {model_name} deployed with session ID {model_api.task_id}."
622
+ )
623
+ self.model_api = model_api
624
+ self.set_model_status("deployed")
625
+ self.set_session_info(task_info)
626
+ self.set_model_info(model_api.task_id)
627
+ self.show_stop()
628
+ self.statuses_widgets.show()
629
+ except Exception as e:
630
+ logger.error(f"Failed to deploy model: {e}", exc_info=True)
631
+ self.set_model_status("error", str(e))
632
+ self.set_session_info(None)
633
+ self.reset_model_info()
634
+ self.show_deploy_button()
635
+ self.statuses_widgets.hide()
636
+ self.enable_modes()
637
+ else:
638
+ if str(self.MODE.CONNECT) in self.modes:
639
+ self.modes[str(self.MODE.CONNECT)]._update_sessions()
640
+
641
+ def deploy(self) -> ModelAPI:
642
+ mode_label = self.tabs.get_active_tab()
643
+ mode = None
644
+ for mode, label in self.modes_labels.items():
645
+ if label == mode_label:
646
+ break
647
+ agent_id = self.select_agent.get_value()
648
+ self.model_api = self.modes[mode].deploy(agent_id=agent_id)
649
+ return self.model_api
650
+
651
+ def stop(self) -> None:
652
+ if self.model_api is None:
653
+ return
654
+ logger.info("Stopping model...")
655
+ self.model_api.shutdown()
656
+ self.model_api = None
657
+ self.set_model_status("stopped")
658
+ self.enable_modes()
659
+ self.reset_model_info()
660
+ self.show_deploy_button()
661
+ self.statuses_widgets.hide()
662
+ if str(self.MODE.CONNECT) in self.modes:
663
+ self.modes[str(self.MODE.CONNECT)]._update_sessions()
664
+
665
+ def disconnect(self) -> None:
666
+ if self.model_api is None:
667
+ return
668
+ self.model_api = None
669
+ self.set_model_status("hide")
670
+ self.set_session_info(None)
671
+ self.reset_model_info()
672
+ self.show_deploy_button()
673
+ self.statuses_widgets.hide()
674
+ self.enable_modes()
675
+
676
+ def load_from_json(self, data: Dict[str, Any]) -> None:
677
+ """
678
+ Load widget state from JSON data.
679
+ :param data: Dictionary with widget data.
680
+ """
681
+ if not data:
682
+ return
683
+ mode = data["mode"]
684
+ label = self.modes_labels[mode]
685
+ self.tabs.set_active_tab(label)
686
+ agent_id = data.get("agent_id", None)
687
+ if agent_id is not None:
688
+ self.select_agent.set_value(agent_id)
689
+ self.modes[mode].load_from_json(data)
690
+
691
+ def get_deploy_parameters(self) -> Dict[str, Any]:
692
+ mode_label = self.tabs.get_active_tab()
693
+ mode = None
694
+ for mode, label in self.modes_labels.items():
695
+ if label == mode_label:
696
+ break
697
+ agent_id = self.select_agent.get_value()
698
+ parameters = {"mode": mode, "agent_id": agent_id}
699
+ parameters.update(self.modes[mode].get_deploy_parameters())
700
+ return parameters
701
+
702
+ def get_json_data(self) -> Dict[str, Any]:
703
+ return {}
704
+
705
+ def get_json_state(self) -> Dict[str, Any]:
706
+ return {}
707
+
708
+ def to_html(self):
709
+ return self.layout.to_html()
710
+
711
+ # Model Info
712
+ def _create_model_info_widget(self):
713
+ self._model_info_widget = ModelInfo()
714
+ self._model_info_widget_field = Field(
715
+ self._model_info_widget,
716
+ title="Model Info",
717
+ description="Information about the deployed model",
718
+ )
719
+ self._model_info_message = Text("Connect to model to see the session information.")
720
+ self._model_info_container = Container(
721
+ [self._model_info_widget_field, self._model_info_message], gap=0
722
+ )
723
+ self._model_info_widget_field.hide()
724
+
725
+ self._model_info_container.hide()
726
+
727
+ def set_model_info(self, session_id):
728
+ self._model_info_widget.set_session_id(session_id)
729
+
730
+ self._model_info_message.hide()
731
+ self._model_info_widget_field.show()
732
+ self._model_info_container.show()
733
+
734
+ def reset_model_info(self):
735
+ self._model_info_container.hide()
736
+ self._model_info_widget_field.hide()
737
+ self._model_info_message.show()
738
+
739
+ def set_model_message_by_tab(self, tab_name: str):
740
+ if tab_name == self.modes_labels[str(self.MODE.CONNECT)]:
741
+ self._model_info_message.set(
742
+ "Connect to model to see the session information.", status="text"
743
+ )
744
+ else:
745
+ self._model_info_message.set(
746
+ "Deploy model to see the session information.", status="text"
747
+ )
748
+ self._model_info_widget_field.hide()
749
+
750
+ # ------------------------------------------------------------ #