supervisely 6.73.357__py3-none-any.whl → 6.73.359__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. supervisely/_utils.py +12 -0
  2. supervisely/api/annotation_api.py +3 -0
  3. supervisely/api/api.py +2 -2
  4. supervisely/api/app_api.py +27 -2
  5. supervisely/api/entity_annotation/tag_api.py +0 -1
  6. supervisely/api/nn/__init__.py +0 -0
  7. supervisely/api/nn/deploy_api.py +821 -0
  8. supervisely/api/nn/neural_network_api.py +248 -0
  9. supervisely/api/task_api.py +26 -467
  10. supervisely/app/fastapi/subapp.py +1 -0
  11. supervisely/nn/__init__.py +2 -1
  12. supervisely/nn/artifacts/artifacts.py +5 -5
  13. supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
  14. supervisely/nn/experiments.py +28 -5
  15. supervisely/nn/inference/cache.py +178 -114
  16. supervisely/nn/inference/gui/gui.py +18 -35
  17. supervisely/nn/inference/gui/serving_gui.py +3 -1
  18. supervisely/nn/inference/inference.py +1421 -1265
  19. supervisely/nn/inference/inference_request.py +412 -0
  20. supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
  21. supervisely/nn/inference/session.py +2 -2
  22. supervisely/nn/inference/tracking/base_tracking.py +45 -79
  23. supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
  24. supervisely/nn/inference/tracking/mask_tracking.py +274 -250
  25. supervisely/nn/inference/tracking/tracker_interface.py +23 -0
  26. supervisely/nn/inference/uploader.py +164 -0
  27. supervisely/nn/model/__init__.py +0 -0
  28. supervisely/nn/model/model_api.py +259 -0
  29. supervisely/nn/model/prediction.py +311 -0
  30. supervisely/nn/model/prediction_session.py +632 -0
  31. supervisely/nn/tracking/__init__.py +1 -0
  32. supervisely/nn/tracking/boxmot.py +114 -0
  33. supervisely/nn/tracking/tracking.py +24 -0
  34. supervisely/nn/training/train_app.py +61 -19
  35. supervisely/nn/utils.py +43 -3
  36. supervisely/task/progress.py +12 -2
  37. supervisely/video/video.py +107 -1
  38. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/METADATA +2 -1
  39. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/RECORD +43 -32
  40. supervisely/api/neural_network_api.py +0 -202
  41. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/LICENSE +0 -0
  42. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/WHEEL +0 -0
  43. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/entry_points.txt +0 -0
  44. {supervisely-6.73.357.dist-info → supervisely-6.73.359.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,821 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import asdict
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Literal, Optional, Tuple, Union
7
+
8
+ import supervisely.io.env as env
9
+ from supervisely._utils import get_valid_kwargs
10
+ from supervisely.api.api import Api
11
+ from supervisely.io.fs import get_file_name_with_ext
12
+ from supervisely.nn.experiments import ExperimentInfo
13
+ from supervisely.sly_logger import logger
14
+
15
+
16
+ def get_runtime(runtime: str):
17
+ from supervisely.nn.utils import RuntimeType
18
+
19
+ if runtime is None:
20
+ return None
21
+ aliases = {
22
+ str(RuntimeType.PYTORCH): RuntimeType.PYTORCH,
23
+ str(RuntimeType.ONNXRUNTIME): RuntimeType.ONNXRUNTIME,
24
+ str(RuntimeType.TENSORRT): RuntimeType.TENSORRT,
25
+ "pytorch": RuntimeType.PYTORCH,
26
+ "torch": RuntimeType.PYTORCH,
27
+ "pt": RuntimeType.PYTORCH,
28
+ "onnxruntime": RuntimeType.ONNXRUNTIME,
29
+ "onnx": RuntimeType.ONNXRUNTIME,
30
+ "tensorrt": RuntimeType.TENSORRT,
31
+ "trt": RuntimeType.TENSORRT,
32
+ "engine": RuntimeType.TENSORRT,
33
+ }
34
+ if runtime in aliases:
35
+ return aliases[runtime]
36
+ runtime = aliases.get(runtime.lower(), None)
37
+ if runtime is None:
38
+ raise ValueError(
39
+ f"Runtime '{runtime}' is not supported. Supported runtimes are: {', '.join(aliases.keys())}"
40
+ )
41
+ return runtime
42
+
43
+
44
+ class DeployApi:
45
+ """ """
46
+
47
+ def __init__(self, api: "Api"):
48
+ self._api = api
49
+
50
+ def load_pretrained_model(
51
+ self,
52
+ session_id: int,
53
+ model_name: str,
54
+ device: Optional[str] = None,
55
+ runtime: str = None,
56
+ ):
57
+ """
58
+ Load a pretrained model in running serving App.
59
+
60
+ :param session_id: Task ID of the serving App.
61
+ :type session_id: int
62
+ :param model_name: Model name to deploy.
63
+ :type model_name: str
64
+ :param device: Device string. If not provided, will be chosen automatically.
65
+ :type device: Optional[str]
66
+ :param runtime: Runtime string, if not present will be defined automatically.
67
+ :type runtime: Optional[str]
68
+ """
69
+ from supervisely.nn.utils import ModelSource
70
+
71
+ runtime = get_runtime(runtime)
72
+ deploy_params = {}
73
+ deploy_params["model_source"] = ModelSource.PRETRAINED
74
+ deploy_params["device"] = device
75
+ deploy_params["runtime"] = runtime
76
+ self._load_model_from_api(session_id, deploy_params, model_name=model_name)
77
+
78
+ def load_custom_model(
79
+ self,
80
+ session_id: int,
81
+ team_id: int,
82
+ artifacts_dir: str,
83
+ checkpoint_name: Optional[str] = None,
84
+ device: Optional[str] = None,
85
+ runtime: str = None,
86
+ ):
87
+ """
88
+ Load a custom model in running serving App.
89
+
90
+ :param session_id: Task ID of the serving App.
91
+ :type session_id: int
92
+ :param team_id: Team ID in Supervisely.
93
+ :type team_id: int
94
+ :param artifacts_dir: Path to the artifacts directory in the team fies.
95
+ :type artifacts_dir: str
96
+ :param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
97
+ If not provided, checkpoint will be chosen automatically, depending on the app version.
98
+ :type checkpoint_name: Optional[str]
99
+ :param device: Device string. If not provided, will be chosen automatically.
100
+ :type device: Optional[str]
101
+ :param runtime: Runtime string, if not present will be defined automatically.
102
+ :type runtime: Optional[str]
103
+ """
104
+ from supervisely.nn.utils import ModelSource
105
+
106
+ runtime = get_runtime(runtime)
107
+
108
+ # Train V1 logic (if artifacts_dir does not start with '/experiments')
109
+ if not artifacts_dir.startswith("/experiments"):
110
+ logger.debug("Deploying model from Train V1 artifacts")
111
+ _, _, deploy_params = self._deploy_params_v1(
112
+ team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=False
113
+ )
114
+ else: # Train V2 logic (when artifacts_dir starts with '/experiments')
115
+ logger.debug("Deploying model from Train V2 artifacts")
116
+
117
+ _, _, deploy_params = self._deploy_params_v2(
118
+ team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=False
119
+ )
120
+ deploy_params["model_source"] = ModelSource.CUSTOM
121
+ self._load_model_from_api(session_id, deploy_params)
122
+
123
+ def load_custom_model_from_experiment_info(
124
+ self,
125
+ session_id: int,
126
+ experiment_info: "ExperimentInfo",
127
+ checkpoint_name: Optional[str] = None,
128
+ device: Optional[str] = None,
129
+ runtime: str = None,
130
+ ):
131
+ """
132
+ Load a custom model in running serving App based on the training session.
133
+
134
+ :param session_id: Task ID of the serving App.
135
+ :type session_id: int
136
+ :param experiment_info: an :class:`ExperimentInfo` object.
137
+ :type experiment_info: ExperimentInfo
138
+ :param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
139
+ If not provided, checkpoint will be chosen automatically, depending on the app version.
140
+ :type checkpoint_name: Optional[str]
141
+ :param device: Device string. If not provided, will be chosen automatically.
142
+ :type device: Optional[str]
143
+ :param runtime: Runtime string, if not present will be defined automatically.
144
+ :type runtime: Optional[str]
145
+ """
146
+ from supervisely.nn.utils import ModelSource
147
+
148
+ runtime = get_runtime(runtime)
149
+ if checkpoint_name is None:
150
+ checkpoint_name = experiment_info.best_checkpoint
151
+ deploy_params = {
152
+ "device": device,
153
+ "model_source": ModelSource.CUSTOM,
154
+ "model_files": {
155
+ "checkpoint": Path(
156
+ experiment_info.artifacts_dir, "checkpoints", checkpoint_name
157
+ ).as_posix(),
158
+ "config": Path(
159
+ experiment_info.artifacts_dir, experiment_info.model_files["config"]
160
+ ).as_posix(),
161
+ },
162
+ "model_info": experiment_info.to_json(),
163
+ "runtime": runtime,
164
+ }
165
+ self._load_model_from_api(session_id, deploy_params)
166
+
167
+ def _find_agent(self, team_id: int = None, public=True, gpu=True):
168
+ """
169
+ Find an agent in Supervisely with most available memory.
170
+
171
+ :param team_id: Team ID. If not provided, will be taken from the current context.
172
+ :type team_id: Optional[int]
173
+ :param public: If True, can find a public agent.
174
+ :type public: bool
175
+ :param gpu: If True, find an agent with GPU.
176
+ :type gpu: bool
177
+ :return: Agent ID
178
+ :rtype: int
179
+ """
180
+ if team_id is None:
181
+ team_id = env.team_id()
182
+ agents = self._api.agent.get_list_available(team_id, show_public=public, has_gpu=gpu)
183
+ if len(agents) == 0:
184
+ raise ValueError("No available agents found.")
185
+ agent_id_memory_map = {}
186
+ kubernetes_agents = []
187
+ for agent in agents:
188
+ if agent.type == "sly_agent":
189
+ # No multi-gpu support, always take the first one
190
+ agent_id_memory_map[agent.id] = agent.gpu_info["device_memory"][0]["available"]
191
+ elif agent.type == "kubernetes":
192
+ kubernetes_agents.append(agent.id)
193
+ if len(agent_id_memory_map) > 0:
194
+ return max(agent_id_memory_map, key=agent_id_memory_map.get)
195
+ if len(kubernetes_agents) > 0:
196
+ return kubernetes_agents[0]
197
+
198
+ def deploy_pretrained_model(
199
+ self,
200
+ framework: Union[str, int],
201
+ model_name: str,
202
+ device: Optional[str] = None,
203
+ runtime: str = None,
204
+ workspace_id: int = None,
205
+ agent_id: Optional[int] = None,
206
+ app: Union[str, int] = None,
207
+ **kwargs,
208
+ ) -> Dict[str, Any]:
209
+ """
210
+ Deploy a pretrained model.
211
+
212
+ :param framework: Framework name or Framework ID in Supervisely.
213
+ :type framework: Union[str, int]
214
+ :param model_name: Model name to deploy.
215
+ :type model_name: str
216
+ :param device: Device string. If not provided, will be chosen automatically.
217
+ :type device: Optional[str]
218
+ :param runtime: Runtime string, if not present will be defined automatically.
219
+ :type runtime: Optional[str]
220
+ :param workspace_id: Workspace ID where the app will be deployed. If not provided, will be taken from the current context.
221
+ :type workspace_id: Optional[int]
222
+ :param agent_id: Agent ID. If not provided, will be found automatically.
223
+ :type agent_id: Optional[int]
224
+ :param app: App name or App module ID in Supervisely.
225
+ :type app: Union[str, int]
226
+ :param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
227
+ :type kwargs: Dict[str, Any]
228
+ :return: Task Info
229
+ :rtype: Dict[str, Any]
230
+ :raises ValueError: if no serving apps found for the app name or multiple serving apps found for the app name.
231
+ """
232
+ from supervisely.nn.artifacts import (
233
+ RITM,
234
+ RTDETR,
235
+ Detectron2,
236
+ MMClassification,
237
+ MMDetection,
238
+ MMDetection3,
239
+ MMSegmentation,
240
+ UNet,
241
+ YOLOv5,
242
+ YOLOv5v2,
243
+ YOLOv8,
244
+ )
245
+ from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
246
+ from supervisely.nn.utils import ModelSource
247
+
248
+ workspace_info = self._api.workspace.get_info_by_id(workspace_id)
249
+ if workspace_info is None:
250
+ raise ValueError(f"Workspace with ID {workspace_id} not found")
251
+ team_id = workspace_info.team_id
252
+
253
+ # @TODO: Fix debug logs/ Fix code
254
+ # Skip HTTPS redirect check on API init: False. ENV: False. Checked servers: set()
255
+ frameworks_v1 = {
256
+ RITM(team_id).framework_name: RITM(team_id).serve_slug,
257
+ RTDETR(team_id).framework_name: RTDETR(team_id).serve_slug,
258
+ Detectron2(team_id).framework_name: Detectron2(team_id).serve_slug,
259
+ MMClassification(team_id).framework_name: MMClassification(team_id).serve_slug,
260
+ MMDetection(team_id).framework_name: MMDetection(team_id).serve_slug,
261
+ MMDetection3(team_id).framework_name: MMDetection3(team_id).serve_slug,
262
+ MMSegmentation(team_id).framework_name: MMSegmentation(team_id).serve_slug,
263
+ UNet(team_id).framework_name: UNet(team_id).serve_slug,
264
+ YOLOv5(team_id).framework_name: YOLOv5(team_id).serve_slug,
265
+ YOLOv5v2(team_id).framework_name: YOLOv5v2(team_id).serve_slug,
266
+ YOLOv8(team_id).framework_name: YOLOv8(team_id).serve_slug,
267
+ }
268
+ if framework in frameworks_v1:
269
+ slug = frameworks_v1[framework]
270
+ module_id = self.find_serving_app_by_slug(slug)
271
+ else:
272
+ module_id = None
273
+ if isinstance(app, int):
274
+ module_id = app
275
+ elif isinstance(app, str):
276
+ module_id = self._api.app.find_module_id_by_app_name(app)
277
+ else:
278
+ module_id = self.find_serving_app_by_framework(framework)["id"]
279
+ if module_id is None:
280
+ raise ValueError(
281
+ f"Serving app for framework '{framework}' not found. Make sure that you used correct framework name."
282
+ )
283
+
284
+ runtime = get_runtime(runtime)
285
+ if agent_id is None:
286
+ agent_id = self._find_agent()
287
+
288
+ task_info = self._run_serve_app(agent_id, module_id, workspace_id=workspace_id, **kwargs)
289
+ self.load_pretrained_model(
290
+ task_info["id"], model_name=model_name, device=device, runtime=runtime
291
+ )
292
+ return task_info
293
+
294
+ def _find_team_by_path(self, path: str, team_id: int = None, raise_not_found=True):
295
+ if team_id is not None:
296
+ if self._api.file.exists(team_id, path) or self._api.file.dir_exists(
297
+ team_id, path, recursive=False
298
+ ):
299
+ return team_id
300
+ elif raise_not_found:
301
+ raise ValueError(f"Checkpoint '{path}' not found in team provided team")
302
+ else:
303
+ return None
304
+ team_id = env.team_id(raise_not_found=False)
305
+ if team_id is not None:
306
+ if self._api.file.exists(team_id, path) or self._api.file.dir_exists(
307
+ team_id, path, recursive=False
308
+ ):
309
+ return team_id
310
+ teams = self._api.team.get_list()
311
+ team_id = None
312
+ for team in teams:
313
+ if self._api.file.exists(team.id, path):
314
+ if team_id is not None:
315
+ raise ValueError("Multiple teams have the same checkpoint")
316
+ team_id = team.id
317
+ if team_id is None:
318
+ if raise_not_found:
319
+ raise ValueError("Checkpoint not found")
320
+ else:
321
+ return None
322
+ return team_id
323
+
324
+ def deploy_custom_model_by_checkpoint(
325
+ self,
326
+ checkpoint: str,
327
+ device: Optional[str] = None,
328
+ runtime: str = None,
329
+ timeout: int = 100,
330
+ team_id: int = None,
331
+ workspace_id: int = None,
332
+ agent_id: int = None,
333
+ **kwargs,
334
+ ) -> Dict[str, Any]:
335
+ """
336
+ Deploy a custom model based on the checkpoint path.
337
+
338
+ :param checkpoint: Path to the checkpoint in Team Files.
339
+ :type checkpoint: str
340
+ :param device: Device string. If not provided, will be chosen automatically.
341
+ :type device: Optional[str]
342
+ :param runtime: Runtime string, if not present will be defined automatically.
343
+ :type runtime: Optional[str]
344
+ :param timeout: Timeout in seconds (default is 100). The maximum time to wait for the serving app to be ready.
345
+ :type timeout: Optional[int]
346
+ :param team_id: Team ID where the artifacts are stored. If not provided, will be taken from the current context.
347
+ :type team_id: Optional[int]
348
+ :param workspace_id: Workspace ID where the app will be deployed. If not provided, will be taken from the current context.
349
+ :type workspace_id: Optional[int]
350
+ :param agent_id: Agent ID. If not provided, will be found automatically.
351
+ :type agent_id: Optional[int]
352
+ :param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
353
+ :type kwargs: Dict[str, Any]
354
+ :return: Task Info
355
+ :rtype: Dict[str, Any]
356
+ :raises ValueError: if validations fail.
357
+ """
358
+ artifacts_dir, checkpoint_name = self._get_artifacts_dir_and_checkpoint_name(checkpoint)
359
+ return self.deploy_custom_model_by_artifacts_dir(
360
+ artifacts_dir=artifacts_dir,
361
+ checkpoint_name=checkpoint_name,
362
+ device=device,
363
+ runtime=runtime,
364
+ timeout=timeout,
365
+ team_id=team_id,
366
+ workspace_id=workspace_id,
367
+ agent_id=agent_id,
368
+ **kwargs,
369
+ )
370
+
371
+ def deploy_custom_model_by_artifacts_dir(
372
+ self,
373
+ artifacts_dir: str,
374
+ checkpoint_name: Optional[str] = None,
375
+ device: Optional[str] = None,
376
+ runtime: str = None,
377
+ timeout: int = 100,
378
+ team_id: int = None,
379
+ workspace_id: int = None,
380
+ agent_id: int = None,
381
+ **kwargs,
382
+ ) -> Dict[str, Any]:
383
+ """
384
+ Deploy a custom model based on the artifacts directory.
385
+
386
+ :param artifacts_dir: Path to the artifacts directory in the team fies.
387
+ :type artifacts_dir: str
388
+ :param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
389
+ If not provided, checkpoint will be chosen automatically, depending on the app version.
390
+ :type checkpoint_name: Optional[str]
391
+ :param device: Device string. If not provided, will be chosen automatically.
392
+ :type device: Optional[str]
393
+ :param runtime: Runtime string, if not present will be defined automatically.
394
+ :type runtime: Optional[str]
395
+ :param timeout: Timeout in seconds (default is 100). The maximum time to wait for the serving app to be ready.
396
+ :type timeout: Optional[int]
397
+ :param team_id: Team ID where the artifacts are stored. If not provided, will be taken from the current context.
398
+ :type team_id: Optional[int]
399
+ :param workspace_id: Workspace ID where the app will be deployed. If not provided, will be taken from the current context.
400
+ :type workspace_id: Optional[int]
401
+ :param agent_id: Agent ID. If not provided, will be found automatically.
402
+ :type agent_id: Optional[int]
403
+ :param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
404
+ :type kwargs: Dict[str, Any]
405
+ :return: Task Info
406
+ :rtype: Dict[str, Any]
407
+ :raises ValueError: if validations fail.
408
+ """
409
+ from supervisely.nn.utils import ModelSource
410
+
411
+ if not isinstance(artifacts_dir, str) or not artifacts_dir.strip():
412
+ raise ValueError("artifacts_dir must be a non-empty string.")
413
+
414
+ runtime = get_runtime(runtime)
415
+ if team_id is None:
416
+ team_id = self._find_team_by_path(artifacts_dir, team_id=team_id)
417
+ logger.debug(
418
+ f"Starting custom model deployment. Team: {team_id}, Artifacts Dir: '{artifacts_dir}'"
419
+ )
420
+ if agent_id is None:
421
+ agent_id = self._find_agent()
422
+
423
+ # Train V1 logic (if artifacts_dir does not start with '/experiments')
424
+ if not artifacts_dir.startswith("/experiments"):
425
+ logger.debug("Deploying model from Train V1 artifacts")
426
+ module_id, serve_app_name, deploy_params = self._deploy_params_v1(
427
+ team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=True
428
+ )
429
+ else: # Train V2 logic (when artifacts_dir starts with '/experiments')
430
+ logger.debug("Deploying model from Train V2 artifacts")
431
+
432
+ module_id, serve_app_name, deploy_params = self._deploy_params_v2(
433
+ team_id, artifacts_dir, checkpoint_name, device, runtime, with_module=True
434
+ )
435
+ deploy_params["model_source"] = ModelSource.CUSTOM
436
+
437
+ logger.info(
438
+ f"{serve_app_name} app deployment started. Checkpoint: '{checkpoint_name}'. Deploy params: '{deploy_params}'"
439
+ )
440
+ try:
441
+ task_info = self._run_serve_app(
442
+ agent_id, module_id, workspace_id=workspace_id, **kwargs
443
+ )
444
+ self._load_model_from_api(task_info["id"], deploy_params)
445
+ except Exception as e:
446
+ raise RuntimeError(f"Failed to run '{serve_app_name}': {e}") from e
447
+ return task_info
448
+
449
+ def deploy_custom_model_from_experiment_info(
450
+ self,
451
+ agent_id: int,
452
+ experiment_info: "ExperimentInfo",
453
+ checkpoint_name: Optional[str] = None,
454
+ device: Optional[str] = None,
455
+ runtime: str = None,
456
+ timeout: int = 100,
457
+ **kwargs,
458
+ ) -> Dict[str, Any]:
459
+ """
460
+ Deploy a custom model based on the training session.
461
+
462
+ :param experiment_info: an :class:`ExperimentInfo` object.
463
+ :type experiment_info: ExperimentInfo
464
+ :param checkpoint_name: Checkpoint name (with file extension) to deploy, e.g. "best.pt".
465
+ If not provided, the best checkpoint will be chosen.
466
+ :type checkpoint_name: Optional[str]
467
+ :param device: Device string. If not provided, will be chosen automatically.
468
+ :type device: Optional[str]
469
+ :param timeout: Timeout in seconds (default is 100). The maximum time to wait for the serving app to be ready.
470
+ :type timeout: Optional[int]
471
+ :param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
472
+ :type kwargs: Dict[str, Any]
473
+ :return: Task Info
474
+ :rtype: Dict[str, Any]
475
+ :raises ValueError: if validations fail.
476
+ """
477
+ task_id = experiment_info.task_id
478
+ train_task_info = self._api.task.get_info_by_id(task_id)
479
+ runtime = get_runtime(runtime)
480
+
481
+ logger.debug(f"Starting model deployment from experiment info. Task ID: '{task_id}'")
482
+ train_module_id = train_task_info["meta"]["app"]["moduleId"]
483
+ module = self.get_serving_app_by_train_app(module_id=train_module_id)
484
+ serve_app_name = module["name"]
485
+ module_id = module["id"]
486
+ logger.debug(f"Serving app detected: '{serve_app_name}'. Module ID: '{module_id}'")
487
+
488
+ if checkpoint_name is None:
489
+ checkpoint_name = experiment_info.best_checkpoint
490
+
491
+ # Task parameters
492
+ experiment_name = experiment_info.experiment_name
493
+ task_name = experiment_name + f" ({checkpoint_name})"
494
+ if "task_name" not in kwargs:
495
+ kwargs["task_name"] = task_name
496
+
497
+ description = f"""Serve from experiment
498
+ Experiment name: {experiment_name}
499
+ Evaluation report: {experiment_info.evaluation_report_link}
500
+ """
501
+ while len(description) > 255:
502
+ description = description.rsplit("\n", 1)[0]
503
+ if "description" not in kwargs:
504
+ kwargs["description"] = description
505
+
506
+ logger.info(f"{serve_app_name} app deployment started. Checkpoint: '{checkpoint_name}'.")
507
+ try:
508
+ task_info = self._run_serve_app(agent_id, module_id, timeout=timeout, **kwargs)
509
+ self.load_custom_model_from_experiment_info(
510
+ task_info["id"], experiment_info, checkpoint_name, device, runtime
511
+ )
512
+ except Exception as e:
513
+ raise RuntimeError(f"Failed to run '{serve_app_name}': {e}") from e
514
+ return task_info
515
+
516
+ def start_serve_app(
517
+ self, agent_id: int, app_name=None, module_id=None, **kwargs
518
+ ) -> Dict[str, Any]:
519
+ """
520
+ Run a serving app. Either app_name or module_id must be provided.
521
+
522
+ :param app_name: App name in Supervisely.
523
+ :type app_name: Optional[str]
524
+ :param module_id: Module ID in Supervisely.
525
+ :type module_id: Optional[int]
526
+ :param kwargs: Additional parameters to start the task. See Api.task.start() for more details.
527
+ :type kwargs: Dict[str, Any]
528
+ :return: Task Info
529
+ :rtype: Dict[str, Any]
530
+ """
531
+ if app_name is None and module_id is None:
532
+ raise ValueError("Either app_name or module_id must be provided.")
533
+ if app_name is not None and module_id is not None:
534
+ raise ValueError("Only one of app_name or module_id must be provided.")
535
+ if module_id is None:
536
+ module_id = self._api.app.find_module_id_by_app_name(app_name)
537
+ self._run_serve_app(agent_id, module_id, **kwargs)
538
+
539
+ def _run_serve_app(
540
+ self, agent_id: int, module_id, workspace_id: int = None, timeout: int = 100, **kwargs
541
+ ):
542
+ _attempt_delay_sec = 1
543
+ _attempts = timeout // _attempt_delay_sec
544
+
545
+ # @TODO: Run app in team?
546
+ if workspace_id is None:
547
+ workspace_id = env.workspace_id()
548
+ kwargs = get_valid_kwargs(
549
+ kwargs=kwargs,
550
+ func=self._api.task.start,
551
+ exclude=["self", "module_id", "workspace_id", "agent_id"],
552
+ )
553
+ task_info = self._api.task.start(
554
+ agent_id=agent_id, module_id=module_id, workspace_id=workspace_id, **kwargs
555
+ )
556
+ ready = self._api.app.wait_until_ready_for_api_calls(
557
+ task_info["id"], _attempts, _attempt_delay_sec
558
+ )
559
+ if not ready:
560
+ raise TimeoutError(
561
+ f"Task {task_info['id']} is not ready for API calls after {timeout} seconds."
562
+ )
563
+ return task_info
564
+
565
+ def _load_model_from_api(self, task_id, deploy_params, model_name: Optional[str] = None):
566
+ logger.info("Loading model")
567
+ self._api.task.send_request(
568
+ task_id,
569
+ "deploy_from_api",
570
+ data={"deploy_params": deploy_params, "model_name": model_name},
571
+ raise_error=True,
572
+ )
573
+ logger.info("Model loaded successfully")
574
+
575
+ def find_serving_app_by_framework(self, framework: str):
576
+ modules = self._api.app.get_list_ecosystem_modules(
577
+ categories=["serve", f"framework:{framework}"], categories_operation="and"
578
+ )
579
+ if len(modules) == 0:
580
+ return None
581
+ return modules[0]
582
+
583
+ def find_serving_app_by_slug(self, slug: str) -> int:
584
+ return self._api.app.get_ecosystem_module_id(slug)
585
+
586
+ def get_serving_app_by_train_app(self, app_name: Optional[str] = None, module_id: int = None):
587
+ if app_name is None and module_id is None:
588
+ raise ValueError("Either app_name or module_id must be provided.")
589
+ if app_name is not None:
590
+ module_id = self._api.app.find_module_id_by_app_name(app_name)
591
+ train_module_info = self._api.app.get_ecosystem_module_info(module_id)
592
+ train_app_config = train_module_info.config
593
+ categories = train_app_config["categories"]
594
+ framework = None
595
+ for category in categories:
596
+ if category.lower().startswith("framework:"):
597
+ framework = category.lstrip("framework:")
598
+ break
599
+ if framework is None:
600
+ raise ValueError(
601
+ "Unable to define serving app. Framework is not specified in the train app"
602
+ )
603
+
604
+ logger.debug(f"Detected framework: {framework}")
605
+ module = self.find_serving_app_by_framework(framework)
606
+ if module is None:
607
+ raise ValueError(f"No serving apps found for framework {framework}")
608
+ return module
609
+
610
+ def get_deploy_info(self, task_id: int) -> Dict[str, Any]:
611
+ """
612
+ Get deploy info of a serving task.
613
+
614
+ :param task_id: Task ID of the serving App.
615
+ :type task_id: int
616
+ :return: Deploy Info
617
+ :rtype: Dict[str, Any]
618
+ """
619
+ return self._api.task.send_request(task_id, "get_deploy_info", data={}, raise_error=True)
620
+
621
+ def _deploy_params_v1(
622
+ self,
623
+ team_id: int,
624
+ artifacts_dir: str,
625
+ checkpoint_name: str,
626
+ device: str,
627
+ runtime: str,
628
+ with_module: bool = True,
629
+ ) -> Tuple[int, Dict[str, Any]]:
630
+ from supervisely.nn.artifacts import RITM, YOLOv5
631
+ from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
632
+ from supervisely.nn.utils import ModelSource
633
+
634
+ framework_cls = self._get_framework_by_path(artifacts_dir)
635
+ if not framework_cls:
636
+ raise ValueError(f"Unsupported framework for artifacts_dir: '{artifacts_dir}'")
637
+
638
+ framework: BaseTrainArtifacts = framework_cls(team_id)
639
+ if framework_cls is RITM or framework_cls is YOLOv5:
640
+ raise ValueError(
641
+ f"{framework.framework_name} framework is not supported for deployment"
642
+ )
643
+
644
+ runtime = get_runtime(runtime)
645
+ logger.debug(f"Detected framework: '{framework.framework_name}'")
646
+ module_id = None
647
+ serve_app_name = None
648
+ if with_module:
649
+ module_id = self._api.app.get_ecosystem_module_id(framework.serve_slug)
650
+ serve_app_name = framework.serve_app_name
651
+ logger.debug(f"Module ID fetched:' {module_id}'. App name: '{serve_app_name}'")
652
+
653
+ train_info = framework.get_info_by_artifacts_dir(artifacts_dir.rstrip("/"))
654
+ if not hasattr(train_info, "checkpoints") or not train_info.checkpoints:
655
+ raise ValueError("No checkpoints found in train info.")
656
+
657
+ checkpoint = None
658
+ if checkpoint_name is not None:
659
+ for cp in train_info.checkpoints:
660
+ if cp.name == checkpoint_name:
661
+ checkpoint = cp
662
+ break
663
+ if checkpoint is None:
664
+ raise ValueError(f"Checkpoint '{checkpoint_name}' not found in train info.")
665
+ else:
666
+ logger.info("Checkpoint name not provided. Using the last checkpoint.")
667
+ checkpoint = train_info.checkpoints[-1]
668
+
669
+ checkpoint_name = checkpoint.name
670
+ deploy_params = {
671
+ "device": device,
672
+ "model_source": ModelSource.CUSTOM,
673
+ "task_type": train_info.task_type,
674
+ "checkpoint_name": checkpoint_name,
675
+ "checkpoint_url": checkpoint.path,
676
+ }
677
+
678
+ if getattr(train_info, "config_path", None) is not None:
679
+ deploy_params["config_url"] = train_info.config_path
680
+
681
+ if framework.require_runtime:
682
+ deploy_params["runtime"] = runtime
683
+ return module_id, serve_app_name, deploy_params
684
+
685
+ def _deploy_params_v2(
686
+ self,
687
+ team_id: int,
688
+ artifacts_dir: str,
689
+ checkpoint_name: str,
690
+ device: str,
691
+ runtime: str,
692
+ with_module: bool = True,
693
+ ):
694
+ from supervisely.nn.experiments import get_experiment_info_by_artifacts_dir
695
+ from supervisely.nn.utils import ModelSource
696
+
697
+ experiment_info = get_experiment_info_by_artifacts_dir(self._api, team_id, artifacts_dir)
698
+ if not experiment_info:
699
+ raise ValueError(
700
+ f"Failed to retrieve experiment info for artifacts_dir: '{artifacts_dir}'"
701
+ )
702
+
703
+ runtime = get_runtime(runtime)
704
+ experiment_task_id = experiment_info.task_id
705
+ experiment_task_info = self._api.task.get_info_by_id(experiment_task_id)
706
+ if experiment_task_info is None:
707
+ raise ValueError(f"Task with ID '{experiment_task_id}' not found")
708
+
709
+ module_id = None
710
+ serve_app_name = None
711
+ if with_module:
712
+ train_module_id = experiment_task_info["meta"]["app"]["moduleId"]
713
+ module = self.get_serving_app_by_train_app(module_id=train_module_id)
714
+ serve_app_name = module["name"]
715
+ module_id = module["id"]
716
+ logger.debug(f"Serving app detected: '{serve_app_name}'. Module ID: '{module_id}'")
717
+
718
+ if len(experiment_info.checkpoints) == 0:
719
+ raise ValueError(f"No checkpoints found in: '{artifacts_dir}'.")
720
+
721
+ checkpoint = None
722
+ if checkpoint_name is not None:
723
+ for checkpoint_path in experiment_info.checkpoints:
724
+ if get_file_name_with_ext(checkpoint_path) == checkpoint_name:
725
+ checkpoint = get_file_name_with_ext(checkpoint_path)
726
+ break
727
+ if checkpoint is None:
728
+ raise ValueError(f"Provided checkpoint '{checkpoint_name}' not found")
729
+ else:
730
+ logger.info("Checkpoint name not provided. Using the best checkpoint.")
731
+ checkpoint = experiment_info.best_checkpoint
732
+
733
+ model_info_dict = asdict(experiment_info)
734
+ model_info_dict["artifacts_dir"] = artifacts_dir
735
+ checkpoint_name = checkpoint
736
+ deploy_params = {
737
+ "device": device,
738
+ "model_source": ModelSource.CUSTOM,
739
+ "model_files": {
740
+ "checkpoint": f"/{artifacts_dir.strip('/')}/checkpoints/{checkpoint_name}"
741
+ },
742
+ "model_info": model_info_dict,
743
+ "runtime": runtime,
744
+ }
745
+
746
+ config = experiment_info.model_files.get("config")
747
+ if config is not None:
748
+ deploy_params["model_files"]["config"] = f"{experiment_info.artifacts_dir}{config}"
749
+ logger.debug(f"Config file added: {experiment_info.artifacts_dir}{config}")
750
+ return module_id, serve_app_name, deploy_params
751
+
752
+ def wait(self, model_id, target: Literal["started", "deployed"] = "started", timeout=5 * 60):
753
+ t = time.monotonic()
754
+ method = "is_alive" if target == "started" else "is_ready"
755
+ while time.monotonic() - t < timeout:
756
+ self._api.task.send_request(model_id, "is_ready", {})
757
+ time.sleep(1)
758
+
759
+ def _get_artifacts_dir_and_checkpoint_name(self, model: str) -> Tuple[str, str]:
760
+ if not model.startswith("/"):
761
+ raise ValueError(f"Path must start with '/'")
762
+
763
+ if model.startswith("/experiments"):
764
+ try:
765
+ artifacts_dir, checkpoint_name = model.split("/checkpoints/")
766
+ return artifacts_dir, checkpoint_name
767
+ except:
768
+ raise ValueError(
769
+ "Bad format of checkpoint path. Expected format: '/artifacts_dir/checkpoints/checkpoint_name'"
770
+ )
771
+
772
+ framework_cls = self._get_framework_by_path(model)
773
+ if framework_cls is None:
774
+ raise ValueError(f"Unknown path: '{model}'")
775
+
776
+ team_id = env.team_id()
777
+ framework = framework_cls(team_id)
778
+ checkpoint_name = get_file_name_with_ext(model)
779
+ checkpoints_dir = model.replace(checkpoint_name, "")
780
+ if framework.weights_folder is not None:
781
+ artifacts_dir = checkpoints_dir.replace(framework.weights_folder, "")
782
+ else:
783
+ artifacts_dir = checkpoints_dir
784
+ return artifacts_dir, checkpoint_name
785
+
786
+ def _get_framework_by_path(self, path: str):
787
+ from supervisely.nn.artifacts import (
788
+ RITM,
789
+ RTDETR,
790
+ Detectron2,
791
+ MMClassification,
792
+ MMDetection,
793
+ MMDetection3,
794
+ MMSegmentation,
795
+ UNet,
796
+ YOLOv5,
797
+ YOLOv5v2,
798
+ YOLOv8,
799
+ )
800
+ from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts
801
+ from supervisely.nn.utils import ModelSource
802
+
803
+ path_obj = Path(path)
804
+ if len(path_obj.parts) < 2:
805
+ raise ValueError(f"Incorrect checkpoint path: '{path}'")
806
+ parent = path_obj.parts[1]
807
+ frameworks = {
808
+ "/detectron2": Detectron2,
809
+ "/mmclassification": MMClassification,
810
+ "/mmdetection": MMDetection,
811
+ "/mmdetection-3": MMDetection3,
812
+ "/mmsegmentation": MMSegmentation,
813
+ "/RITM_training": RITM,
814
+ "/RT-DETR": RTDETR,
815
+ "/unet": UNet,
816
+ "/yolov5_train": YOLOv5,
817
+ "/yolov5_2.0_train": YOLOv5v2,
818
+ "/yolov8_train": YOLOv8,
819
+ }
820
+ if f"/{parent}" in frameworks:
821
+ return frameworks[f"/{parent}"]