supervisely 6.73.357__py3-none-any.whl → 6.73.359__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/_utils.py +12 -0
- supervisely/api/annotation_api.py +3 -0
- supervisely/api/api.py +2 -2
- supervisely/api/app_api.py +27 -2
- supervisely/api/entity_annotation/tag_api.py +0 -1
- supervisely/api/nn/__init__.py +0 -0
- supervisely/api/nn/deploy_api.py +821 -0
- supervisely/api/nn/neural_network_api.py +248 -0
- supervisely/api/task_api.py +26 -467
- supervisely/app/fastapi/subapp.py +1 -0
- supervisely/nn/__init__.py +2 -1
- supervisely/nn/artifacts/artifacts.py +5 -5
- supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
- supervisely/nn/experiments.py +28 -5
- supervisely/nn/inference/cache.py +178 -114
- supervisely/nn/inference/gui/gui.py +18 -35
- supervisely/nn/inference/gui/serving_gui.py +3 -1
- supervisely/nn/inference/inference.py +1421 -1265
- supervisely/nn/inference/inference_request.py +412 -0
- supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
- supervisely/nn/inference/session.py +2 -2
- supervisely/nn/inference/tracking/base_tracking.py +45 -79
- supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
- supervisely/nn/inference/tracking/mask_tracking.py +274 -250
- supervisely/nn/inference/tracking/tracker_interface.py +23 -0
- supervisely/nn/inference/uploader.py +164 -0
- supervisely/nn/model/__init__.py +0 -0
- supervisely/nn/model/model_api.py +259 -0
- supervisely/nn/model/prediction.py +311 -0
- supervisely/nn/model/prediction_session.py +632 -0
- supervisely/nn/tracking/__init__.py +1 -0
- supervisely/nn/tracking/boxmot.py +114 -0
- supervisely/nn/tracking/tracking.py +24 -0
- supervisely/nn/training/train_app.py +61 -19
- supervisely/nn/utils.py +43 -3
- supervisely/task/progress.py +12 -2
- supervisely/video/video.py +107 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/METADATA +2 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/RECORD +43 -32
- supervisely/api/neural_network_api.py +0 -202
- {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/LICENSE +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/WHEEL +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,821 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import asdict
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import supervisely.io.env as env
|
|
9
|
+
from supervisely._utils import get_valid_kwargs
|
|
10
|
+
from supervisely.api.api import Api
|
|
11
|
+
from supervisely.io.fs import get_file_name_with_ext
|
|
12
|
+
from supervisely.nn.experiments import ExperimentInfo
|
|
13
|
+
from supervisely.sly_logger import logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_runtime(runtime: str):
|
|
17
|
+
from supervisely.nn.utils import RuntimeType
|
|
18
|
+
|
|
19
|
+
if runtime is None:
|
|
20
|
+
return None
|
|
21
|
+
aliases = {
|
|
22
|
+
str(RuntimeType.PYTORCH): RuntimeType.PYTORCH,
|
|
23
|
+
str(RuntimeType.ONNXRUNTIME): RuntimeType.ONNXRUNTIME,
|
|
24
|
+
str(RuntimeType.TENSORRT): RuntimeType.TENSORRT,
|
|
25
|
+
"pytorch": RuntimeType.PYTORCH,
|
|
26
|
+
"torch": RuntimeType.PYTORCH,
|
|
27
|
+
"pt": RuntimeType.PYTORCH,
|
|
28
|
+
"onnxruntime": RuntimeType.ONNXRUNTIME,
|
|
29
|
+
"onnx": RuntimeType.ONNXRUNTIME,
|
|
30
|
+
"tensorrt": RuntimeType.TENSORRT,
|
|
31
|
+
"trt": RuntimeType.TENSORRT,
|
|
32
|
+
"engine": RuntimeType.TENSORRT,
|
|
33
|
+
}
|
|
34
|
+
if runtime in aliases:
|
|
35
|
+
return aliases[runtime]
|
|
36
|
+
runtime = aliases.get(runtime.lower(), None)
|
|
37
|
+
if runtime is None:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Runtime '{runtime}' is not supported. Supported runtimes are: {', '.join(aliases.keys())}"
|
|
40
|
+
)
|
|
41
|
+
return runtime
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DeployApi:
|
|
45
|
+
""" """
|
|
46
|
+
|
|
47
|
+
def __init__(self, api: "Api"):
|
|
48
|
+
self._api = api
|
|
49
|
+
|
|
50
|
+
def load_pretrained_model(
|
|
51
|
+
self,
|
|
52
|
+
session_id: int,
|
|
53
|
+
model_name: str,
|
|
54
|
+
device: Optional[str] = None,
|
|
55
|
+
runtime: str = None,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Load a pretrained model in running serving App.
|
|
59
|
+
|
|
60
|
+
:param session_id: Task ID of the serving App.
|
|
61
|
+
:type session_id: int
|
|
62
|
+
:param model_name: Model name to deploy.
|
|
63
|
+
:type model_name: str
|
|
64
|
+
:param device: Device string. If not provided, will be chosen automatically.
|
|
65
|
+
:type device: Optional[str]
|
|
66
|
+
:param runtime: Runtime string, if not present will be defined automatically.
|
|
67
|
+
:type runtime: Optional[str]
|
|
68
|
+
"""
|
|
69
|
+
from supervisely.nn.utils import ModelSource
|
|
70
|
+
|
|
71
|
+
runtime = get_runtime(runtime)
|
|
72
|
+
deploy_params = {}
|
|
73
|
+
deploy_params["model_source"] = ModelSource.PRETRAINED
|
|
74
|
+
deploy_params["device"] = device
|
|
75
|
+
deploy_params["runtime"] = runtime
|
|
76
|
+
self._load_model_from_api(session_id, deploy_params, model_name=model_name)
|
|
77
|
+
|
|
78
|
+
def load_custom_model(
|
|
79
|
+
self,
|
|
80
|
+
session_id: int,
|
|
81
|
+
team_id: int,
|
|
82
|
+
artifacts_dir: str,
|
|
83
|
+
checkpoint_name: Optional[str] = None,
|
|
84
|
+
device: Optional[str] = None,
|
|
85
|
+
runtime: str = None,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Load a custom model in running serving App.
|
|
89
|
+
|
|
90
|
+
:param session_id: Task ID of the serving App.
|
|
91
|
+
:type session_id: int
|
|
92
|
+
:param team_id: Team ID in Supervisely.
|
|
93
|
+
:type team_id: int
|
|
94
|
+
:param artifacts_dir: Path to the artifacts directory in the team fies.
|
|
95
|
+
:type artifacts_dir: str
|
|
96
|
+
:param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
|
|
97
|
+
If not provided, checkpoint will be chosen automatically, depending on the app version.
|
|
98
|
+
:type checkpoint_name: Optional[str]
|
|
99
|
+
:param device: Device string. If not provided, will be chosen automatically.
|
|
100
|
+
:type device: Optional[str]
|
|
101
|
+
:param runtime: Runtime string, if not present will be defined automatically.
|
|
102
|
+
:type runtime: Optional[str]
|
|
103
|
+
"""
|
|
104
|
+
from supervisely.nn.utils import ModelSource
|
|
105
|
+
|
|
106
|
+
runtime = get_runtime(runtime)
|
|
107
|
+
|
|
108
|
+
# Train V1 logic (if artifacts_dir does not start with '/experiments')
|
|
109
|
+
if not artifacts_dir.startswith("/experiments"):
|
|
110
|
+
logger.debug("Deploying model from Train V1 artifacts")
|
|
111
|
+
_, _, deploy_params = self._deploy_params_v1(
|
|
112
|
+
team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=False
|
|
113
|
+
)
|
|
114
|
+
else: # Train V2 logic (when artifacts_dir starts with '/experiments')
|
|
115
|
+
logger.debug("Deploying model from Train V2 artifacts")
|
|
116
|
+
|
|
117
|
+
_, _, deploy_params = self._deploy_params_v2(
|
|
118
|
+
team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=False
|
|
119
|
+
)
|
|
120
|
+
deploy_params["model_source"] = ModelSource.CUSTOM
|
|
121
|
+
self._load_model_from_api(session_id, deploy_params)
|
|
122
|
+
|
|
123
|
+
def load_custom_model_from_experiment_info(
|
|
124
|
+
self,
|
|
125
|
+
session_id: int,
|
|
126
|
+
experiment_info: "ExperimentInfo",
|
|
127
|
+
checkpoint_name: Optional[str] = None,
|
|
128
|
+
device: Optional[str] = None,
|
|
129
|
+
runtime: str = None,
|
|
130
|
+
):
|
|
131
|
+
"""
|
|
132
|
+
Load a custom model in running serving App based on the training session.
|
|
133
|
+
|
|
134
|
+
:param session_id: Task ID of the serving App.
|
|
135
|
+
:type session_id: int
|
|
136
|
+
:param experiment_info: an :class:`ExperimentInfo` object.
|
|
137
|
+
:type experiment_info: ExperimentInfo
|
|
138
|
+
:param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
|
|
139
|
+
If not provided, checkpoint will be chosen automatically, depending on the app version.
|
|
140
|
+
:type checkpoint_name: Optional[str]
|
|
141
|
+
:param device: Device string. If not provided, will be chosen automatically.
|
|
142
|
+
:type device: Optional[str]
|
|
143
|
+
:param runtime: Runtime string, if not present will be defined automatically.
|
|
144
|
+
:type runtime: Optional[str]
|
|
145
|
+
"""
|
|
146
|
+
from supervisely.nn.utils import ModelSource
|
|
147
|
+
|
|
148
|
+
runtime = get_runtime(runtime)
|
|
149
|
+
if checkpoint_name is None:
|
|
150
|
+
checkpoint_name = experiment_info.best_checkpoint
|
|
151
|
+
deploy_params = {
|
|
152
|
+
"device": device,
|
|
153
|
+
"model_source": ModelSource.CUSTOM,
|
|
154
|
+
"model_files": {
|
|
155
|
+
"checkpoint": Path(
|
|
156
|
+
experiment_info.artifacts_dir, "checkpoints", checkpoint_name
|
|
157
|
+
).as_posix(),
|
|
158
|
+
"config": Path(
|
|
159
|
+
experiment_info.artifacts_dir, experiment_info.model_files["config"]
|
|
160
|
+
).as_posix(),
|
|
161
|
+
},
|
|
162
|
+
"model_info": experiment_info.to_json(),
|
|
163
|
+
"runtime": runtime,
|
|
164
|
+
}
|
|
165
|
+
self._load_model_from_api(session_id, deploy_params)
|
|
166
|
+
|
|
167
|
+
def _find_agent(self, team_id: int = None, public=True, gpu=True):
|
|
168
|
+
"""
|
|
169
|
+
Find an agent in Supervisely with most available memory.
|
|
170
|
+
|
|
171
|
+
:param team_id: Team ID. If not provided, will be taken from the current context.
|
|
172
|
+
:type team_id: Optional[int]
|
|
173
|
+
:param public: If True, can find a public agent.
|
|
174
|
+
:type public: bool
|
|
175
|
+
:param gpu: If True, find an agent with GPU.
|
|
176
|
+
:type gpu: bool
|
|
177
|
+
:return: Agent ID
|
|
178
|
+
:rtype: int
|
|
179
|
+
"""
|
|
180
|
+
if team_id is None:
|
|
181
|
+
team_id = env.team_id()
|
|
182
|
+
agents = self._api.agent.get_list_available(team_id, show_public=public, has_gpu=gpu)
|
|
183
|
+
if len(agents) == 0:
|
|
184
|
+
raise ValueError("No available agents found.")
|
|
185
|
+
agent_id_memory_map = {}
|
|
186
|
+
kubernetes_agents = []
|
|
187
|
+
for agent in agents:
|
|
188
|
+
if agent.type == "sly_agent":
|
|
189
|
+
# No multi-gpu support, always take the first one
|
|
190
|
+
agent_id_memory_map[agent.id] = agent.gpu_info["device_memory"][0]["available"]
|
|
191
|
+
elif agent.type == "kubernetes":
|
|
192
|
+
kubernetes_agents.append(agent.id)
|
|
193
|
+
if len(agent_id_memory_map) > 0:
|
|
194
|
+
return max(agent_id_memory_map, key=agent_id_memory_map.get)
|
|
195
|
+
if len(kubernetes_agents) > 0:
|
|
196
|
+
return kubernetes_agents[0]
|
|
197
|
+
|
|
198
|
+
def deploy_pretrained_model(
|
|
199
|
+
self,
|
|
200
|
+
framework: Union[str, int],
|
|
201
|
+
model_name: str,
|
|
202
|
+
device: Optional[str] = None,
|
|
203
|
+
runtime: str = None,
|
|
204
|
+
workspace_id: int = None,
|
|
205
|
+
agent_id: Optional[int] = None,
|
|
206
|
+
app: Union[str, int] = None,
|
|
207
|
+
**kwargs,
|
|
208
|
+
) -> Dict[str, Any]:
|
|
209
|
+
"""
|
|
210
|
+
Deploy a pretrained model.
|
|
211
|
+
|
|
212
|
+
:param framework: Framework name or Framework ID in Supervisely.
|
|
213
|
+
:type framework: Union[str, int]
|
|
214
|
+
:param model_name: Model name to deploy.
|
|
215
|
+
:type model_name: str
|
|
216
|
+
:param device: Device string. If not provided, will be chosen automatically.
|
|
217
|
+
:type device: Optional[str]
|
|
218
|
+
:param runtime: Runtime string, if not present will be defined automatically.
|
|
219
|
+
:type runtime: Optional[str]
|
|
220
|
+
:param workspace_id: Workspace ID where the app will be deployed. If not provided, will be taken from the current context.
|
|
221
|
+
:type workspace_id: Optional[int]
|
|
222
|
+
:param agent_id: Agent ID. If not provided, will be found automatically.
|
|
223
|
+
:type agent_id: Optional[int]
|
|
224
|
+
:param app: App name or App module ID in Supervisely.
|
|
225
|
+
:type app: Union[str, int]
|
|
226
|
+
:param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
|
|
227
|
+
:type kwargs: Dict[str, Any]
|
|
228
|
+
:return: Task Info
|
|
229
|
+
:rtype: Dict[str, Any]
|
|
230
|
+
:raises ValueError: if no serving apps found for the app name or multiple serving apps found for the app name.
|
|
231
|
+
"""
|
|
232
|
+
from supervisely.nn.artifacts import (
|
|
233
|
+
RITM,
|
|
234
|
+
RTDETR,
|
|
235
|
+
Detectron2,
|
|
236
|
+
MMClassification,
|
|
237
|
+
MMDetection,
|
|
238
|
+
MMDetection3,
|
|
239
|
+
MMSegmentation,
|
|
240
|
+
UNet,
|
|
241
|
+
YOLOv5,
|
|
242
|
+
YOLOv5v2,
|
|
243
|
+
YOLOv8,
|
|
244
|
+
)
|
|
245
|
+
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
246
|
+
from supervisely.nn.utils import ModelSource
|
|
247
|
+
|
|
248
|
+
workspace_info = self._api.workspace.get_info_by_id(workspace_id)
|
|
249
|
+
if workspace_info is None:
|
|
250
|
+
raise ValueError(f"Workspace with ID {workspace_id} not found")
|
|
251
|
+
team_id = workspace_info.team_id
|
|
252
|
+
|
|
253
|
+
# @TODO: Fix debug logs/ Fix code
|
|
254
|
+
# Skip HTTPS redirect check on API init: False. ENV: False. Checked servers: set()
|
|
255
|
+
frameworks_v1 = {
|
|
256
|
+
RITM(team_id).framework_name: RITM(team_id).serve_slug,
|
|
257
|
+
RTDETR(team_id).framework_name: RTDETR(team_id).serve_slug,
|
|
258
|
+
Detectron2(team_id).framework_name: Detectron2(team_id).serve_slug,
|
|
259
|
+
MMClassification(team_id).framework_name: MMClassification(team_id).serve_slug,
|
|
260
|
+
MMDetection(team_id).framework_name: MMDetection(team_id).serve_slug,
|
|
261
|
+
MMDetection3(team_id).framework_name: MMDetection3(team_id).serve_slug,
|
|
262
|
+
MMSegmentation(team_id).framework_name: MMSegmentation(team_id).serve_slug,
|
|
263
|
+
UNet(team_id).framework_name: UNet(team_id).serve_slug,
|
|
264
|
+
YOLOv5(team_id).framework_name: YOLOv5(team_id).serve_slug,
|
|
265
|
+
YOLOv5v2(team_id).framework_name: YOLOv5v2(team_id).serve_slug,
|
|
266
|
+
YOLOv8(team_id).framework_name: YOLOv8(team_id).serve_slug,
|
|
267
|
+
}
|
|
268
|
+
if framework in frameworks_v1:
|
|
269
|
+
slug = frameworks_v1[framework]
|
|
270
|
+
module_id = self.find_serving_app_by_slug(slug)
|
|
271
|
+
else:
|
|
272
|
+
module_id = None
|
|
273
|
+
if isinstance(app, int):
|
|
274
|
+
module_id = app
|
|
275
|
+
elif isinstance(app, str):
|
|
276
|
+
module_id = self._api.app.find_module_id_by_app_name(app)
|
|
277
|
+
else:
|
|
278
|
+
module_id = self.find_serving_app_by_framework(framework)["id"]
|
|
279
|
+
if module_id is None:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"Serving app for framework '{framework}' not found. Make sure that you used correct framework name."
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
runtime = get_runtime(runtime)
|
|
285
|
+
if agent_id is None:
|
|
286
|
+
agent_id = self._find_agent()
|
|
287
|
+
|
|
288
|
+
task_info = self._run_serve_app(agent_id, module_id, workspace_id=workspace_id, **kwargs)
|
|
289
|
+
self.load_pretrained_model(
|
|
290
|
+
task_info["id"], model_name=model_name, device=device, runtime=runtime
|
|
291
|
+
)
|
|
292
|
+
return task_info
|
|
293
|
+
|
|
294
|
+
def _find_team_by_path(self, path: str, team_id: int = None, raise_not_found=True):
|
|
295
|
+
if team_id is not None:
|
|
296
|
+
if self._api.file.exists(team_id, path) or self._api.file.dir_exists(
|
|
297
|
+
team_id, path, recursive=False
|
|
298
|
+
):
|
|
299
|
+
return team_id
|
|
300
|
+
elif raise_not_found:
|
|
301
|
+
raise ValueError(f"Checkpoint '{path}' not found in team provided team")
|
|
302
|
+
else:
|
|
303
|
+
return None
|
|
304
|
+
team_id = env.team_id(raise_not_found=False)
|
|
305
|
+
if team_id is not None:
|
|
306
|
+
if self._api.file.exists(team_id, path) or self._api.file.dir_exists(
|
|
307
|
+
team_id, path, recursive=False
|
|
308
|
+
):
|
|
309
|
+
return team_id
|
|
310
|
+
teams = self._api.team.get_list()
|
|
311
|
+
team_id = None
|
|
312
|
+
for team in teams:
|
|
313
|
+
if self._api.file.exists(team.id, path):
|
|
314
|
+
if team_id is not None:
|
|
315
|
+
raise ValueError("Multiple teams have the same checkpoint")
|
|
316
|
+
team_id = team.id
|
|
317
|
+
if team_id is None:
|
|
318
|
+
if raise_not_found:
|
|
319
|
+
raise ValueError("Checkpoint not found")
|
|
320
|
+
else:
|
|
321
|
+
return None
|
|
322
|
+
return team_id
|
|
323
|
+
|
|
324
|
+
def deploy_custom_model_by_checkpoint(
|
|
325
|
+
self,
|
|
326
|
+
checkpoint: str,
|
|
327
|
+
device: Optional[str] = None,
|
|
328
|
+
runtime: str = None,
|
|
329
|
+
timeout: int = 100,
|
|
330
|
+
team_id: int = None,
|
|
331
|
+
workspace_id: int = None,
|
|
332
|
+
agent_id: int = None,
|
|
333
|
+
**kwargs,
|
|
334
|
+
) -> Dict[str, Any]:
|
|
335
|
+
"""
|
|
336
|
+
Deploy a custom model based on the checkpoint path.
|
|
337
|
+
|
|
338
|
+
:param checkpoint: Path to the checkpoint in Team Files.
|
|
339
|
+
:type checkpoint: str
|
|
340
|
+
:param device: Device string. If not provided, will be chosen automatically.
|
|
341
|
+
:type device: Optional[str]
|
|
342
|
+
:param runtime: Runtime string, if not present will be defined automatically.
|
|
343
|
+
:type runtime: Optional[str]
|
|
344
|
+
:param timeout: Timeout in seconds (default is 100). The maximum time to wait for the serving app to be ready.
|
|
345
|
+
:type timeout: Optional[int]
|
|
346
|
+
:param team_id: Team ID where the artifacts are stored. If not provided, will be taken from the current context.
|
|
347
|
+
:type team_id: Optional[int]
|
|
348
|
+
:param workspace_id: Workspace ID where the app will be deployed. If not provided, will be taken from the current context.
|
|
349
|
+
:type workspace_id: Optional[int]
|
|
350
|
+
:param agent_id: Agent ID. If not provided, will be found automatically.
|
|
351
|
+
:type agent_id: Optional[int]
|
|
352
|
+
:param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
|
|
353
|
+
:type kwargs: Dict[str, Any]
|
|
354
|
+
:return: Task Info
|
|
355
|
+
:rtype: Dict[str, Any]
|
|
356
|
+
:raises ValueError: if validations fail.
|
|
357
|
+
"""
|
|
358
|
+
artifacts_dir, checkpoint_name = self._get_artifacts_dir_and_checkpoint_name(checkpoint)
|
|
359
|
+
return self.deploy_custom_model_by_artifacts_dir(
|
|
360
|
+
artifacts_dir=artifacts_dir,
|
|
361
|
+
checkpoint_name=checkpoint_name,
|
|
362
|
+
device=device,
|
|
363
|
+
runtime=runtime,
|
|
364
|
+
timeout=timeout,
|
|
365
|
+
team_id=team_id,
|
|
366
|
+
workspace_id=workspace_id,
|
|
367
|
+
agent_id=agent_id,
|
|
368
|
+
**kwargs,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def deploy_custom_model_by_artifacts_dir(
|
|
372
|
+
self,
|
|
373
|
+
artifacts_dir: str,
|
|
374
|
+
checkpoint_name: Optional[str] = None,
|
|
375
|
+
device: Optional[str] = None,
|
|
376
|
+
runtime: str = None,
|
|
377
|
+
timeout: int = 100,
|
|
378
|
+
team_id: int = None,
|
|
379
|
+
workspace_id: int = None,
|
|
380
|
+
agent_id: int = None,
|
|
381
|
+
**kwargs,
|
|
382
|
+
) -> Dict[str, Any]:
|
|
383
|
+
"""
|
|
384
|
+
Deploy a custom model based on the artifacts directory.
|
|
385
|
+
|
|
386
|
+
:param artifacts_dir: Path to the artifacts directory in the team fies.
|
|
387
|
+
:type artifacts_dir: str
|
|
388
|
+
:param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
|
|
389
|
+
If not provided, checkpoint will be chosen automatically, depending on the app version.
|
|
390
|
+
:type checkpoint_name: Optional[str]
|
|
391
|
+
:param device: Device string. If not provided, will be chosen automatically.
|
|
392
|
+
:type device: Optional[str]
|
|
393
|
+
:param runtime: Runtime string, if not present will be defined automatically.
|
|
394
|
+
:type runtime: Optional[str]
|
|
395
|
+
:param timeout: Timeout in seconds (default is 100). The maximum time to wait for the serving app to be ready.
|
|
396
|
+
:type timeout: Optional[int]
|
|
397
|
+
:param team_id: Team ID where the artifacts are stored. If not provided, will be taken from the current context.
|
|
398
|
+
:type team_id: Optional[int]
|
|
399
|
+
:param workspace_id: Workspace ID where the app will be deployed. If not provided, will be taken from the current context.
|
|
400
|
+
:type workspace_id: Optional[int]
|
|
401
|
+
:param agent_id: Agent ID. If not provided, will be found automatically.
|
|
402
|
+
:type agent_id: Optional[int]
|
|
403
|
+
:param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
|
|
404
|
+
:type kwargs: Dict[str, Any]
|
|
405
|
+
:return: Task Info
|
|
406
|
+
:rtype: Dict[str, Any]
|
|
407
|
+
:raises ValueError: if validations fail.
|
|
408
|
+
"""
|
|
409
|
+
from supervisely.nn.utils import ModelSource
|
|
410
|
+
|
|
411
|
+
if not isinstance(artifacts_dir, str) or not artifacts_dir.strip():
|
|
412
|
+
raise ValueError("artifacts_dir must be a non-empty string.")
|
|
413
|
+
|
|
414
|
+
runtime = get_runtime(runtime)
|
|
415
|
+
if team_id is None:
|
|
416
|
+
team_id = self._find_team_by_path(artifacts_dir, team_id=team_id)
|
|
417
|
+
logger.debug(
|
|
418
|
+
f"Starting custom model deployment. Team: {team_id}, Artifacts Dir: '{artifacts_dir}'"
|
|
419
|
+
)
|
|
420
|
+
if agent_id is None:
|
|
421
|
+
agent_id = self._find_agent()
|
|
422
|
+
|
|
423
|
+
# Train V1 logic (if artifacts_dir does not start with '/experiments')
|
|
424
|
+
if not artifacts_dir.startswith("/experiments"):
|
|
425
|
+
logger.debug("Deploying model from Train V1 artifacts")
|
|
426
|
+
module_id, serve_app_name, deploy_params = self._deploy_params_v1(
|
|
427
|
+
team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=True
|
|
428
|
+
)
|
|
429
|
+
else: # Train V2 logic (when artifacts_dir starts with '/experiments')
|
|
430
|
+
logger.debug("Deploying model from Train V2 artifacts")
|
|
431
|
+
|
|
432
|
+
module_id, serve_app_name, deploy_params = self._deploy_params_v2(
|
|
433
|
+
team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=True
|
|
434
|
+
)
|
|
435
|
+
deploy_params["model_source"] = ModelSource.CUSTOM
|
|
436
|
+
|
|
437
|
+
logger.info(
|
|
438
|
+
f"{serve_app_name} app deployment started. Checkpoint: '{checkpoint_name}'. Deploy params: '{deploy_params}'"
|
|
439
|
+
)
|
|
440
|
+
try:
|
|
441
|
+
task_info = self._run_serve_app(
|
|
442
|
+
agent_id, module_id, workspace_id=workspace_id, **kwargs
|
|
443
|
+
)
|
|
444
|
+
self._load_model_from_api(task_info["id"], deploy_params)
|
|
445
|
+
except Exception as e:
|
|
446
|
+
raise RuntimeError(f"Failed to run '{serve_app_name}': {e}") from e
|
|
447
|
+
return task_info
|
|
448
|
+
|
|
449
|
+
def deploy_custom_model_from_experiment_info(
|
|
450
|
+
self,
|
|
451
|
+
agent_id: int,
|
|
452
|
+
experiment_info: "ExperimentInfo",
|
|
453
|
+
checkpoint_name: Optional[str] = None,
|
|
454
|
+
device: Optional[str] = None,
|
|
455
|
+
runtime: str = None,
|
|
456
|
+
timeout: int = 100,
|
|
457
|
+
**kwargs,
|
|
458
|
+
) -> Dict[str, Any]:
|
|
459
|
+
"""
|
|
460
|
+
Deploy a custom model based on the training session.
|
|
461
|
+
|
|
462
|
+
:param experiment_info: an :class:`ExperimentInfo` object.
|
|
463
|
+
:type experiment_info: ExperimentInfo
|
|
464
|
+
:param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
|
|
465
|
+
If not provided, the best checkpoint will be chosen.
|
|
466
|
+
:type checkpoint_name: Optional[str]
|
|
467
|
+
:param device: Device string. If not provided, will be chosen automatically.
|
|
468
|
+
:type device: Optional[str]
|
|
469
|
+
:param timeout: Timeout in seconds (default is 100). The maximum time to wait for the serving app to be ready.
|
|
470
|
+
:type timeout: Optional[int]
|
|
471
|
+
:param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
|
|
472
|
+
:type kwargs: Dict[str, Any]
|
|
473
|
+
:return: Task Info
|
|
474
|
+
:rtype: Dict[str, Any]
|
|
475
|
+
:raises ValueError: if validations fail.
|
|
476
|
+
"""
|
|
477
|
+
task_id = experiment_info.task_id
|
|
478
|
+
train_task_info = self._api.task.get_info_by_id(task_id)
|
|
479
|
+
runtime = get_runtime(runtime)
|
|
480
|
+
|
|
481
|
+
logger.debug(f"Starting model deployment from experiment info. Task ID: '{task_id}'")
|
|
482
|
+
train_module_id = train_task_info["meta"]["app"]["moduleId"]
|
|
483
|
+
module = self.get_serving_app_by_train_app(module_id=train_module_id)
|
|
484
|
+
serve_app_name = module["name"]
|
|
485
|
+
module_id = module["id"]
|
|
486
|
+
logger.debug(f"Serving app detected: '{serve_app_name}'. Module ID: '{module_id}'")
|
|
487
|
+
|
|
488
|
+
if checkpoint_name is None:
|
|
489
|
+
checkpoint_name = experiment_info.best_checkpoint
|
|
490
|
+
|
|
491
|
+
# Task parameters
|
|
492
|
+
experiment_name = experiment_info.experiment_name
|
|
493
|
+
task_name = experiment_name + f" ({checkpoint_name})"
|
|
494
|
+
if "task_name" not in kwargs:
|
|
495
|
+
kwargs["task_name"] = task_name
|
|
496
|
+
|
|
497
|
+
description = f"""Serve from experiment
|
|
498
|
+
Experiment name: {experiment_name}
|
|
499
|
+
Evaluation report: {experiment_info.evaluation_report_link}
|
|
500
|
+
"""
|
|
501
|
+
while len(description) > 255:
|
|
502
|
+
description = description.rsplit("\n", 1)[0]
|
|
503
|
+
if "description" not in kwargs:
|
|
504
|
+
kwargs["description"] = description
|
|
505
|
+
|
|
506
|
+
logger.info(f"{serve_app_name} app deployment started. Checkpoint: '{checkpoint_name}'.")
|
|
507
|
+
try:
|
|
508
|
+
task_info = self._run_serve_app(agent_id, module_id, timeout=timeout, **kwargs)
|
|
509
|
+
self.load_custom_model_from_experiment_info(
|
|
510
|
+
task_info["id"], experiment_info, checkpoint_name, device, runtime
|
|
511
|
+
)
|
|
512
|
+
except Exception as e:
|
|
513
|
+
raise RuntimeError(f"Failed to run '{serve_app_name}': {e}") from e
|
|
514
|
+
return task_info
|
|
515
|
+
|
|
516
|
+
def start_serve_app(
|
|
517
|
+
self, agent_id: int, app_name=None, module_id=None, **kwargs
|
|
518
|
+
) -> Dict[str, Any]:
|
|
519
|
+
"""
|
|
520
|
+
Run a serving app. Either app_name or module_id must be provided.
|
|
521
|
+
|
|
522
|
+
:param app_name: App name in Supervisely.
|
|
523
|
+
:type app_name: Optional[str]
|
|
524
|
+
:param module_id: Module ID in Supervisely.
|
|
525
|
+
:type module_id: Optional[int]
|
|
526
|
+
:param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
|
|
527
|
+
:type kwargs: Dict[str, Any]
|
|
528
|
+
:return: Task Info
|
|
529
|
+
:rtype: Dict[str, Any]
|
|
530
|
+
"""
|
|
531
|
+
if app_name is None and module_id is None:
|
|
532
|
+
raise ValueError("Either app_name or module_id must be provided.")
|
|
533
|
+
if app_name is not None and module_id is not None:
|
|
534
|
+
raise ValueError("Only one of app_name or module_id must be provided.")
|
|
535
|
+
if module_id is None:
|
|
536
|
+
module_id = self._api.app.find_module_id_by_app_name(app_name)
|
|
537
|
+
self._run_serve_app(agent_id, module_id, **kwargs)
|
|
538
|
+
|
|
539
|
+
def _run_serve_app(
|
|
540
|
+
self, agent_id: int, module_id, workspace_id: int = None, timeout: int = 100, **kwargs
|
|
541
|
+
):
|
|
542
|
+
_attempt_delay_sec = 1
|
|
543
|
+
_attempts = timeout // _attempt_delay_sec
|
|
544
|
+
|
|
545
|
+
# @TODO: Run app in team?
|
|
546
|
+
if workspace_id is None:
|
|
547
|
+
workspace_id = env.workspace_id()
|
|
548
|
+
kwargs = get_valid_kwargs(
|
|
549
|
+
kwargs=kwargs,
|
|
550
|
+
func=self._api.task.start,
|
|
551
|
+
exclude=["self", "module_id", "workspace_id", "agent_id"],
|
|
552
|
+
)
|
|
553
|
+
task_info = self._api.task.start(
|
|
554
|
+
agent_id=agent_id, module_id=module_id, workspace_id=workspace_id, **kwargs
|
|
555
|
+
)
|
|
556
|
+
ready = self._api.app.wait_until_ready_for_api_calls(
|
|
557
|
+
task_info["id"], _attempts, _attempt_delay_sec
|
|
558
|
+
)
|
|
559
|
+
if not ready:
|
|
560
|
+
raise TimeoutError(
|
|
561
|
+
f"Task {task_info['id']} is not ready for API calls after {timeout} seconds."
|
|
562
|
+
)
|
|
563
|
+
return task_info
|
|
564
|
+
|
|
565
|
+
def _load_model_from_api(self, task_id, deploy_params, model_name: Optional[str] = None):
|
|
566
|
+
logger.info("Loading model")
|
|
567
|
+
self._api.task.send_request(
|
|
568
|
+
task_id,
|
|
569
|
+
"deploy_from_api",
|
|
570
|
+
data={"deploy_params": deploy_params, "model_name": model_name},
|
|
571
|
+
raise_error=True,
|
|
572
|
+
)
|
|
573
|
+
logger.info("Model loaded successfully")
|
|
574
|
+
|
|
575
|
+
def find_serving_app_by_framework(self, framework: str):
|
|
576
|
+
modules = self._api.app.get_list_ecosystem_modules(
|
|
577
|
+
categories=["serve", f"framework:{framework}"], categories_operation="and"
|
|
578
|
+
)
|
|
579
|
+
if len(modules) == 0:
|
|
580
|
+
return None
|
|
581
|
+
return modules[0]
|
|
582
|
+
|
|
583
|
+
def find_serving_app_by_slug(self, slug: str) -> int:
|
|
584
|
+
return self._api.app.get_ecosystem_module_id(slug)
|
|
585
|
+
|
|
586
|
+
def get_serving_app_by_train_app(self, app_name: Optional[str] = None, module_id: int = None):
|
|
587
|
+
if app_name is None and module_id is None:
|
|
588
|
+
raise ValueError("Either app_name or module_id must be provided.")
|
|
589
|
+
if app_name is not None:
|
|
590
|
+
module_id = self._api.app.find_module_id_by_app_name(app_name)
|
|
591
|
+
train_module_info = self._api.app.get_ecosystem_module_info(module_id)
|
|
592
|
+
train_app_config = train_module_info.config
|
|
593
|
+
categories = train_app_config["categories"]
|
|
594
|
+
framework = None
|
|
595
|
+
for category in categories:
|
|
596
|
+
if category.lower().startswith("framework:"):
|
|
597
|
+
framework = category.lstrip("framework:")
|
|
598
|
+
break
|
|
599
|
+
if framework is None:
|
|
600
|
+
raise ValueError(
|
|
601
|
+
"Unable to define serving app. Framework is not specified in the train app"
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
logger.debug(f"Detected framework: {framework}")
|
|
605
|
+
module = self.find_serving_app_by_framework(framework)
|
|
606
|
+
if module is None:
|
|
607
|
+
raise ValueError(f"No serving apps found for framework {framework}")
|
|
608
|
+
return module
|
|
609
|
+
|
|
610
|
+
def get_deploy_info(self, task_id: int) -> Dict[str, Any]:
|
|
611
|
+
"""
|
|
612
|
+
Get deploy info of a serving task.
|
|
613
|
+
|
|
614
|
+
:param task_id: Task ID of the serving App.
|
|
615
|
+
:type task_id: int
|
|
616
|
+
:return: Deploy Info
|
|
617
|
+
:rtype: Dict[str, Any]
|
|
618
|
+
"""
|
|
619
|
+
return self._api.task.send_request(task_id, "get_deploy_info", data={}, raise_error=True)
|
|
620
|
+
|
|
621
|
+
def _deploy_params_v1(
|
|
622
|
+
self,
|
|
623
|
+
team_id: int,
|
|
624
|
+
artifacts_dir: str,
|
|
625
|
+
checkpoint_name: str,
|
|
626
|
+
device: str,
|
|
627
|
+
runtime: str,
|
|
628
|
+
with_module: bool = True,
|
|
629
|
+
) -> Tuple[int, Dict[str, Any]]:
|
|
630
|
+
from supervisely.nn.artifacts import RITM, YOLOv5
|
|
631
|
+
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
632
|
+
from supervisely.nn.utils import ModelSource
|
|
633
|
+
|
|
634
|
+
framework_cls = self._get_framework_by_path(artifacts_dir)
|
|
635
|
+
if not framework_cls:
|
|
636
|
+
raise ValueError(f"Unsupported framework for artifacts_dir: '{artifacts_dir}'")
|
|
637
|
+
|
|
638
|
+
framework: BaseTrainArtifacts = framework_cls(team_id)
|
|
639
|
+
if framework_cls is RITM or framework_cls is YOLOv5:
|
|
640
|
+
raise ValueError(
|
|
641
|
+
f"{framework.framework_name} framework is not supported for deployment"
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
runtime = get_runtime(runtime)
|
|
645
|
+
logger.debug(f"Detected framework: '{framework.framework_name}'")
|
|
646
|
+
module_id = None
|
|
647
|
+
serve_app_name = None
|
|
648
|
+
if with_module:
|
|
649
|
+
module_id = self._api.app.get_ecosystem_module_id(framework.serve_slug)
|
|
650
|
+
serve_app_name = framework.serve_app_name
|
|
651
|
+
logger.debug(f"Module ID fetched:' {module_id}'. App name: '{serve_app_name}'")
|
|
652
|
+
|
|
653
|
+
train_info = framework.get_info_by_artifacts_dir(artifacts_dir.rstrip("/"))
|
|
654
|
+
if not hasattr(train_info, "checkpoints") or not train_info.checkpoints:
|
|
655
|
+
raise ValueError("No checkpoints found in train info.")
|
|
656
|
+
|
|
657
|
+
checkpoint = None
|
|
658
|
+
if checkpoint_name is not None:
|
|
659
|
+
for cp in train_info.checkpoints:
|
|
660
|
+
if cp.name == checkpoint_name:
|
|
661
|
+
checkpoint = cp
|
|
662
|
+
break
|
|
663
|
+
if checkpoint is None:
|
|
664
|
+
raise ValueError(f"Checkpoint '{checkpoint_name}' not found in train info.")
|
|
665
|
+
else:
|
|
666
|
+
logger.info("Checkpoint name not provided. Using the last checkpoint.")
|
|
667
|
+
checkpoint = train_info.checkpoints[-1]
|
|
668
|
+
|
|
669
|
+
checkpoint_name = checkpoint.name
|
|
670
|
+
deploy_params = {
|
|
671
|
+
"device": device,
|
|
672
|
+
"model_source": ModelSource.CUSTOM,
|
|
673
|
+
"task_type": train_info.task_type,
|
|
674
|
+
"checkpoint_name": checkpoint_name,
|
|
675
|
+
"checkpoint_url": checkpoint.path,
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
if getattr(train_info, "config_path", None) is not None:
|
|
679
|
+
deploy_params["config_url"] = train_info.config_path
|
|
680
|
+
|
|
681
|
+
if framework.require_runtime:
|
|
682
|
+
deploy_params["runtime"] = runtime
|
|
683
|
+
return module_id, serve_app_name, deploy_params
|
|
684
|
+
|
|
685
|
+
def _deploy_params_v2(
|
|
686
|
+
self,
|
|
687
|
+
team_id: int,
|
|
688
|
+
artifacts_dir: str,
|
|
689
|
+
checkpoint_name: str,
|
|
690
|
+
device: str,
|
|
691
|
+
runtime: str,
|
|
692
|
+
with_module: bool = True,
|
|
693
|
+
):
|
|
694
|
+
from supervisely.nn.experiments import get_experiment_info_by_artifacts_dir
|
|
695
|
+
from supervisely.nn.utils import ModelSource
|
|
696
|
+
|
|
697
|
+
experiment_info = get_experiment_info_by_artifacts_dir(self._api, team_id, artifacts_dir)
|
|
698
|
+
if not experiment_info:
|
|
699
|
+
raise ValueError(
|
|
700
|
+
f"Failed to retrieve experiment info for artifacts_dir: '{artifacts_dir}'"
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
runtime = get_runtime(runtime)
|
|
704
|
+
experiment_task_id = experiment_info.task_id
|
|
705
|
+
experiment_task_info = self._api.task.get_info_by_id(experiment_task_id)
|
|
706
|
+
if experiment_task_info is None:
|
|
707
|
+
raise ValueError(f"Task with ID '{experiment_task_id}' not found")
|
|
708
|
+
|
|
709
|
+
module_id = None
|
|
710
|
+
serve_app_name = None
|
|
711
|
+
if with_module:
|
|
712
|
+
train_module_id = experiment_task_info["meta"]["app"]["moduleId"]
|
|
713
|
+
module = self.get_serving_app_by_train_app(module_id=train_module_id)
|
|
714
|
+
serve_app_name = module["name"]
|
|
715
|
+
module_id = module["id"]
|
|
716
|
+
logger.debug(f"Serving app detected: '{serve_app_name}'. Module ID: '{module_id}'")
|
|
717
|
+
|
|
718
|
+
if len(experiment_info.checkpoints) == 0:
|
|
719
|
+
raise ValueError(f"No checkpoints found in: '{artifacts_dir}'.")
|
|
720
|
+
|
|
721
|
+
checkpoint = None
|
|
722
|
+
if checkpoint_name is not None:
|
|
723
|
+
for checkpoint_path in experiment_info.checkpoints:
|
|
724
|
+
if get_file_name_with_ext(checkpoint_path) == checkpoint_name:
|
|
725
|
+
checkpoint = get_file_name_with_ext(checkpoint_path)
|
|
726
|
+
break
|
|
727
|
+
if checkpoint is None:
|
|
728
|
+
raise ValueError(f"Provided checkpoint '{checkpoint_name}' not found")
|
|
729
|
+
else:
|
|
730
|
+
logger.info("Checkpoint name not provided. Using the best checkpoint.")
|
|
731
|
+
checkpoint = experiment_info.best_checkpoint
|
|
732
|
+
|
|
733
|
+
model_info_dict = asdict(experiment_info)
|
|
734
|
+
model_info_dict["artifacts_dir"] = artifacts_dir
|
|
735
|
+
checkpoint_name = checkpoint
|
|
736
|
+
deploy_params = {
|
|
737
|
+
"device": device,
|
|
738
|
+
"model_source": ModelSource.CUSTOM,
|
|
739
|
+
"model_files": {
|
|
740
|
+
"checkpoint": f"/{artifacts_dir.strip('/')}/checkpoints/{checkpoint_name}"
|
|
741
|
+
},
|
|
742
|
+
"model_info": model_info_dict,
|
|
743
|
+
"runtime": runtime,
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
config = experiment_info.model_files.get("config")
|
|
747
|
+
if config is not None:
|
|
748
|
+
deploy_params["model_files"]["config"] = f"{experiment_info.artifacts_dir}{config}"
|
|
749
|
+
logger.debug(f"Config file added: {experiment_info.artifacts_dir}{config}")
|
|
750
|
+
return module_id, serve_app_name, deploy_params
|
|
751
|
+
|
|
752
|
+
def wait(self, model_id, target: Literal["started", "deployed"] = "started", timeout=5 * 60):
|
|
753
|
+
t = time.monotonic()
|
|
754
|
+
method = "is_alive" if target == "started" else "is_ready"
|
|
755
|
+
while time.monotonic() - t < timeout:
|
|
756
|
+
self._api.task.send_request(model_id, "is_ready", {})
|
|
757
|
+
time.sleep(1)
|
|
758
|
+
|
|
759
|
+
def _get_artifacts_dir_and_checkpoint_name(self, model: str) -> Tuple[str, str]:
|
|
760
|
+
if not model.startswith("/"):
|
|
761
|
+
raise ValueError(f"Path must start with '/'")
|
|
762
|
+
|
|
763
|
+
if model.startswith("/experiments"):
|
|
764
|
+
try:
|
|
765
|
+
artifacts_dir, checkpoint_name = model.split("/checkpoints/")
|
|
766
|
+
return artifacts_dir, checkpoint_name
|
|
767
|
+
except:
|
|
768
|
+
raise ValueError(
|
|
769
|
+
"Bad format of checkpoint path. Expected format: '/artifacts_dir/checkpoints/checkpoint_name'"
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
framework_cls = self._get_framework_by_path(model)
|
|
773
|
+
if framework_cls is None:
|
|
774
|
+
raise ValueError(f"Unknown path: '{model}'")
|
|
775
|
+
|
|
776
|
+
team_id = env.team_id()
|
|
777
|
+
framework = framework_cls(team_id)
|
|
778
|
+
checkpoint_name = get_file_name_with_ext(model)
|
|
779
|
+
checkpoints_dir = model.replace(checkpoint_name, "")
|
|
780
|
+
if framework.weights_folder is not None:
|
|
781
|
+
artifacts_dir = checkpoints_dir.replace(framework.weights_folder, "")
|
|
782
|
+
else:
|
|
783
|
+
artifacts_dir = checkpoints_dir
|
|
784
|
+
return artifacts_dir, checkpoint_name
|
|
785
|
+
|
|
786
|
+
def _get_framework_by_path(self, path: str):
|
|
787
|
+
from supervisely.nn.artifacts import (
|
|
788
|
+
RITM,
|
|
789
|
+
RTDETR,
|
|
790
|
+
Detectron2,
|
|
791
|
+
MMClassification,
|
|
792
|
+
MMDetection,
|
|
793
|
+
MMDetection3,
|
|
794
|
+
MMSegmentation,
|
|
795
|
+
UNet,
|
|
796
|
+
YOLOv5,
|
|
797
|
+
YOLOv5v2,
|
|
798
|
+
YOLOv8,
|
|
799
|
+
)
|
|
800
|
+
from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
|
|
801
|
+
from supervisely.nn.utils import ModelSource
|
|
802
|
+
|
|
803
|
+
path_obj = Path(path)
|
|
804
|
+
if len(path_obj.parts) < 2:
|
|
805
|
+
raise ValueError(f"Incorrect checkpoint path: '{path}'")
|
|
806
|
+
parent = path_obj.parts[1]
|
|
807
|
+
frameworks = {
|
|
808
|
+
"/detectron2": Detectron2,
|
|
809
|
+
"/mmclassification": MMClassification,
|
|
810
|
+
"/mmdetection": MMDetection,
|
|
811
|
+
"/mmdetection-3": MMDetection3,
|
|
812
|
+
"/mmsegmentation": MMSegmentation,
|
|
813
|
+
"/RITM_training": RITM,
|
|
814
|
+
"/RT-DETR": RTDETR,
|
|
815
|
+
"/unet": UNet,
|
|
816
|
+
"/yolov5_train": YOLOv5,
|
|
817
|
+
"/yolov5_2.0_train": YOLOv5v2,
|
|
818
|
+
"/yolov8_train": YOLOv8,
|
|
819
|
+
}
|
|
820
|
+
if f"/{parent}" in frameworks:
|
|
821
|
+
return frameworks[f"/{parent}"]
|