supervisely 6.73.302__py3-none-any.whl → 6.73.304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/api/task_api.py +293 -5
- supervisely/nn/artifacts/artifacts.py +133 -62
- supervisely/nn/artifacts/detectron2.py +6 -0
- supervisely/nn/artifacts/hrda.py +4 -0
- supervisely/nn/artifacts/mmclassification.py +4 -0
- supervisely/nn/artifacts/mmdetection.py +9 -1
- supervisely/nn/artifacts/mmsegmentation.py +4 -0
- supervisely/nn/artifacts/ritm.py +4 -0
- supervisely/nn/artifacts/rtdetr.py +4 -0
- supervisely/nn/artifacts/unet.py +4 -0
- supervisely/nn/artifacts/yolov5.py +7 -0
- supervisely/nn/artifacts/yolov8.py +5 -1
- supervisely/nn/experiments.py +85 -2
- supervisely/nn/inference/inference.py +11 -4
- supervisely/nn/training/train_app.py +1 -1
- {supervisely-6.73.302.dist-info → supervisely-6.73.304.dist-info}/METADATA +1 -1
- {supervisely-6.73.302.dist-info → supervisely-6.73.304.dist-info}/RECORD +21 -21
- {supervisely-6.73.302.dist-info → supervisely-6.73.304.dist-info}/LICENSE +0 -0
- {supervisely-6.73.302.dist-info → supervisely-6.73.304.dist-info}/WHEEL +0 -0
- {supervisely-6.73.302.dist-info → supervisely-6.73.304.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.302.dist-info → supervisely-6.73.304.dist-info}/top_level.txt +0 -0
supervisely/api/task_api.py
CHANGED
|
@@ -5,6 +5,7 @@ import json
|
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
7
|
from collections import OrderedDict, defaultdict
|
|
8
|
+
from pathlib import Path
|
|
8
9
|
|
|
9
10
|
# docs
|
|
10
11
|
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union
|
|
@@ -12,6 +13,7 @@ from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Uni
|
|
|
12
13
|
from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
|
|
13
14
|
from tqdm import tqdm
|
|
14
15
|
|
|
16
|
+
from supervisely import logger
|
|
15
17
|
from supervisely._utils import batched, take_with_default
|
|
16
18
|
from supervisely.api.module_api import (
|
|
17
19
|
ApiField,
|
|
@@ -20,7 +22,12 @@ from supervisely.api.module_api import (
|
|
|
20
22
|
WaitingTimeExceeded,
|
|
21
23
|
)
|
|
22
24
|
from supervisely.collection.str_enum import StrEnum
|
|
23
|
-
from supervisely.io.fs import
|
|
25
|
+
from supervisely.io.fs import (
|
|
26
|
+
ensure_base_path,
|
|
27
|
+
get_file_hash,
|
|
28
|
+
get_file_name,
|
|
29
|
+
get_file_name_with_ext,
|
|
30
|
+
)
|
|
24
31
|
|
|
25
32
|
|
|
26
33
|
class TaskFinishedWithError(Exception):
|
|
@@ -301,7 +308,10 @@ class TaskApi(ModuleApiBase, ModuleWithStatus):
|
|
|
301
308
|
)
|
|
302
309
|
|
|
303
310
|
def upload_dtl_archive(
|
|
304
|
-
self,
|
|
311
|
+
self,
|
|
312
|
+
task_id: int,
|
|
313
|
+
archive_path: str,
|
|
314
|
+
progress_cb: Optional[Union[tqdm, Callable]] = None,
|
|
305
315
|
):
|
|
306
316
|
"""upload_dtl_archive"""
|
|
307
317
|
encoder = MultipartEncoder(
|
|
@@ -822,11 +832,19 @@ class TaskApi(ModuleApiBase, ModuleWithStatus):
|
|
|
822
832
|
return resp.json()
|
|
823
833
|
|
|
824
834
|
def set_output_report(
|
|
825
|
-
self,
|
|
835
|
+
self,
|
|
836
|
+
task_id: int,
|
|
837
|
+
file_id: int,
|
|
838
|
+
file_name: str,
|
|
839
|
+
description: Optional[str] = "Report",
|
|
826
840
|
) -> Dict:
|
|
827
841
|
"""set_output_report"""
|
|
828
842
|
return self._set_custom_output(
|
|
829
|
-
task_id,
|
|
843
|
+
task_id,
|
|
844
|
+
file_id,
|
|
845
|
+
file_name,
|
|
846
|
+
description=description,
|
|
847
|
+
icon="zmdi zmdi-receipt",
|
|
830
848
|
)
|
|
831
849
|
|
|
832
850
|
def _set_custom_output(
|
|
@@ -942,7 +960,11 @@ class TaskApi(ModuleApiBase, ModuleWithStatus):
|
|
|
942
960
|
)
|
|
943
961
|
|
|
944
962
|
def update_meta(
|
|
945
|
-
self,
|
|
963
|
+
self,
|
|
964
|
+
id: int,
|
|
965
|
+
data: dict,
|
|
966
|
+
agent_storage_folder: str = None,
|
|
967
|
+
relative_app_dir: str = None,
|
|
946
968
|
):
|
|
947
969
|
"""
|
|
948
970
|
Update given task metadata
|
|
@@ -1197,3 +1219,269 @@ class TaskApi(ModuleApiBase, ModuleWithStatus):
|
|
|
1197
1219
|
"tasks.output.set", {ApiField.TASK_ID: task_id, ApiField.OUTPUT: output}
|
|
1198
1220
|
)
|
|
1199
1221
|
return resp.json()
|
|
1222
|
+
|
|
1223
|
+
def deploy_model_from_api(self, task_id, deploy_params):
|
|
1224
|
+
self.send_request(
|
|
1225
|
+
task_id,
|
|
1226
|
+
"deploy_from_api",
|
|
1227
|
+
data={"deploy_params": deploy_params},
|
|
1228
|
+
raise_error=True,
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
def deploy_model_app(
|
|
1232
|
+
self,
|
|
1233
|
+
module_id: int,
|
|
1234
|
+
workspace_id: int,
|
|
1235
|
+
agent_id: Optional[int] = None,
|
|
1236
|
+
description: Optional[str] = "application description",
|
|
1237
|
+
params: Dict[str, Any] = None,
|
|
1238
|
+
log_level: Optional[Literal["info", "debug", "warning", "error"]] = "info",
|
|
1239
|
+
users_ids: Optional[List[int]] = None,
|
|
1240
|
+
app_version: Optional[str] = "",
|
|
1241
|
+
is_branch: Optional[bool] = False,
|
|
1242
|
+
task_name: Optional[str] = "pythonSpawned",
|
|
1243
|
+
restart_policy: Optional[Literal["never", "on_error"]] = "never",
|
|
1244
|
+
proxy_keep_url: Optional[bool] = False,
|
|
1245
|
+
redirect_requests: Optional[Dict[str, int]] = {},
|
|
1246
|
+
limit_by_workspace: bool = False,
|
|
1247
|
+
deploy_params: Dict[str, Any] = None,
|
|
1248
|
+
timeout: int = 100,
|
|
1249
|
+
):
|
|
1250
|
+
if deploy_params is None:
|
|
1251
|
+
deploy_params = {}
|
|
1252
|
+
task_info = self.start(
|
|
1253
|
+
agent_id=agent_id,
|
|
1254
|
+
workspace_id=workspace_id,
|
|
1255
|
+
module_id=module_id,
|
|
1256
|
+
description=description,
|
|
1257
|
+
params=params,
|
|
1258
|
+
log_level=log_level,
|
|
1259
|
+
users_ids=users_ids,
|
|
1260
|
+
app_version=app_version,
|
|
1261
|
+
is_branch=is_branch,
|
|
1262
|
+
task_name=task_name,
|
|
1263
|
+
restart_policy=restart_policy,
|
|
1264
|
+
proxy_keep_url=proxy_keep_url,
|
|
1265
|
+
redirect_requests=redirect_requests,
|
|
1266
|
+
limit_by_workspace=limit_by_workspace,
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
attempt_delay_sec = 10
|
|
1270
|
+
attempts = (timeout + attempt_delay_sec) // attempt_delay_sec
|
|
1271
|
+
ready = self._api.app.wait_until_ready_for_api_calls(
|
|
1272
|
+
task_info["id"], attempts, attempt_delay_sec
|
|
1273
|
+
)
|
|
1274
|
+
if not ready:
|
|
1275
|
+
raise TimeoutError(
|
|
1276
|
+
f"Task {task_info['id']} is not ready for API calls after {timeout} seconds."
|
|
1277
|
+
)
|
|
1278
|
+
logger.info("Deploying model from API")
|
|
1279
|
+
self.deploy_model_from_api(task_info["id"], deploy_params=deploy_params)
|
|
1280
|
+
return task_info
|
|
1281
|
+
|
|
1282
|
+
def deploy_custom_model(
|
|
1283
|
+
self,
|
|
1284
|
+
team_id: int,
|
|
1285
|
+
workspace_id: int,
|
|
1286
|
+
artifacts_dir: str,
|
|
1287
|
+
checkpoint_name: str = None,
|
|
1288
|
+
agent_id: int = None,
|
|
1289
|
+
device: str = "cuda",
|
|
1290
|
+
) -> int:
|
|
1291
|
+
"""
|
|
1292
|
+
Deploy a custom model based on the artifacts directory.
|
|
1293
|
+
|
|
1294
|
+
:param workspace_id: Workspace ID in Supervisely.
|
|
1295
|
+
:type workspace_id: int
|
|
1296
|
+
:param artifacts_dir: Path to the artifacts directory.
|
|
1297
|
+
:type artifacts_dir: str
|
|
1298
|
+
:param checkpoint_name: Checkpoint name (with extension) to deploy.
|
|
1299
|
+
:type checkpoint_name: Optional[str]
|
|
1300
|
+
:param agent_id: Agent ID in Supervisely.
|
|
1301
|
+
:type agent_id: Optional[int]
|
|
1302
|
+
:param device: Device string (default is "cuda").
|
|
1303
|
+
:type device: str
|
|
1304
|
+
:raises ValueError: if validations fail.
|
|
1305
|
+
"""
|
|
1306
|
+
from dataclasses import asdict
|
|
1307
|
+
|
|
1308
|
+
from supervisely.nn.artifacts import (
|
|
1309
|
+
RITM,
|
|
1310
|
+
RTDETR,
|
|
1311
|
+
Detectron2,
|
|
1312
|
+
MMClassification,
|
|
1313
|
+
MMDetection,
|
|
1314
|
+
MMDetection3,
|
|
1315
|
+
MMSegmentation,
|
|
1316
|
+
UNet,
|
|
1317
|
+
YOLOv5,
|
|
1318
|
+
YOLOv5v2,
|
|
1319
|
+
YOLOv8,
|
|
1320
|
+
)
|
|
1321
|
+
from supervisely.nn.experiments import get_experiment_info_by_artifacts_dir
|
|
1322
|
+
from supervisely.nn.utils import ModelSource, RuntimeType
|
|
1323
|
+
|
|
1324
|
+
if not isinstance(workspace_id, int) or workspace_id <= 0:
|
|
1325
|
+
raise ValueError(f"workspace_id must be a positive integer. Received: {workspace_id}")
|
|
1326
|
+
if not isinstance(artifacts_dir, str) or not artifacts_dir.strip():
|
|
1327
|
+
raise ValueError("artifacts_dir must be a non-empty string.")
|
|
1328
|
+
|
|
1329
|
+
workspace_info = self._api.workspace.get_info_by_id(workspace_id)
|
|
1330
|
+
if workspace_info is None:
|
|
1331
|
+
raise ValueError(f"Workspace with ID '{workspace_id}' not found.")
|
|
1332
|
+
|
|
1333
|
+
team_id = workspace_info.team_id
|
|
1334
|
+
logger.debug(
|
|
1335
|
+
f"Starting model deployment. Team: {team_id}, Workspace: {workspace_id}, Artifacts Dir: '{artifacts_dir}'"
|
|
1336
|
+
)
|
|
1337
|
+
|
|
1338
|
+
# Train V1 logic (if artifacts_dir does not start with '/experiments')
|
|
1339
|
+
if not artifacts_dir.startswith("/experiments"):
|
|
1340
|
+
logger.debug("Deploying model from Train V1 artifacts")
|
|
1341
|
+
frameworks = {
|
|
1342
|
+
"/detectron2": Detectron2,
|
|
1343
|
+
"/mmclassification": MMClassification,
|
|
1344
|
+
"/mmdetection": MMDetection,
|
|
1345
|
+
"/mmdetection-3": MMDetection3,
|
|
1346
|
+
"/mmsegmentation": MMSegmentation,
|
|
1347
|
+
"/RITM_training": RITM,
|
|
1348
|
+
"/RT-DETR": RTDETR,
|
|
1349
|
+
"/unet": UNet,
|
|
1350
|
+
"/yolov5_train": YOLOv5,
|
|
1351
|
+
"/yolov5_2.0_train": YOLOv5v2,
|
|
1352
|
+
"/yolov8_train": YOLOv8,
|
|
1353
|
+
}
|
|
1354
|
+
|
|
1355
|
+
framework_cls = next(
|
|
1356
|
+
(cls for prefix, cls in frameworks.items() if artifacts_dir.startswith(prefix)),
|
|
1357
|
+
None,
|
|
1358
|
+
)
|
|
1359
|
+
if not framework_cls:
|
|
1360
|
+
raise ValueError(f"Unsupported framework for artifacts_dir: '{artifacts_dir}'")
|
|
1361
|
+
|
|
1362
|
+
framework = framework_cls(team_id)
|
|
1363
|
+
if framework_cls is RITM or framework_cls is YOLOv5:
|
|
1364
|
+
raise ValueError(
|
|
1365
|
+
f"{framework.framework_name} framework is not supported for deployment"
|
|
1366
|
+
)
|
|
1367
|
+
|
|
1368
|
+
logger.debug(f"Detected framework: '{framework.framework_name}'")
|
|
1369
|
+
|
|
1370
|
+
module_id = self._api.app.get_ecosystem_module_id(framework.serve_slug)
|
|
1371
|
+
serve_app_name = framework.serve_app_name
|
|
1372
|
+
logger.debug(f"Module ID fetched:' {module_id}'. App name: '{serve_app_name}'")
|
|
1373
|
+
|
|
1374
|
+
train_info = framework.get_info_by_artifacts_dir(artifacts_dir.rstrip("/"))
|
|
1375
|
+
if not hasattr(train_info, "checkpoints") or not train_info.checkpoints:
|
|
1376
|
+
raise ValueError("No checkpoints found in train info.")
|
|
1377
|
+
|
|
1378
|
+
checkpoint = None
|
|
1379
|
+
if checkpoint_name is not None:
|
|
1380
|
+
for cp in train_info.checkpoints:
|
|
1381
|
+
if cp.name == checkpoint_name:
|
|
1382
|
+
checkpoint = cp
|
|
1383
|
+
break
|
|
1384
|
+
if checkpoint is None:
|
|
1385
|
+
raise ValueError(f"Checkpoint '{checkpoint_name}' not found in train info.")
|
|
1386
|
+
else:
|
|
1387
|
+
logger.debug("Checkpoint name not provided. Using the last checkpoint.")
|
|
1388
|
+
checkpoint = train_info.checkpoints[-1]
|
|
1389
|
+
|
|
1390
|
+
checkpoint_name = checkpoint.name
|
|
1391
|
+
deploy_params = {
|
|
1392
|
+
"device": device,
|
|
1393
|
+
"model_source": ModelSource.CUSTOM,
|
|
1394
|
+
"task_type": train_info.task_type,
|
|
1395
|
+
"checkpoint_name": checkpoint_name,
|
|
1396
|
+
"checkpoint_url": checkpoint.path,
|
|
1397
|
+
}
|
|
1398
|
+
|
|
1399
|
+
if getattr(train_info, "config_path", None) is not None:
|
|
1400
|
+
deploy_params["config_url"] = train_info.config_path
|
|
1401
|
+
|
|
1402
|
+
if framework.require_runtime:
|
|
1403
|
+
deploy_params["runtime"] = RuntimeType.PYTORCH
|
|
1404
|
+
|
|
1405
|
+
else: # Train V2 logic (when artifacts_dir starts with '/experiments')
|
|
1406
|
+
logger.debug("Deploying model from Train V2 artifacts")
|
|
1407
|
+
|
|
1408
|
+
def get_framework_from_artifacts_dir(artifacts_dir: str) -> str:
|
|
1409
|
+
clean_path = artifacts_dir.rstrip("/")
|
|
1410
|
+
parts = clean_path.split("/")
|
|
1411
|
+
if not parts or "_" not in parts[-1]:
|
|
1412
|
+
raise ValueError(f"Invalid artifacts_dir format: '{artifacts_dir}'")
|
|
1413
|
+
return parts[-1].split("_", 1)[1]
|
|
1414
|
+
|
|
1415
|
+
# TODO: temporary solution, need to add Serve App Name into config.json
|
|
1416
|
+
framework_name = get_framework_from_artifacts_dir(artifacts_dir)
|
|
1417
|
+
logger.debug(f"Detected framework: {framework_name}")
|
|
1418
|
+
|
|
1419
|
+
modules = self._api.app.get_list_all_pages(
|
|
1420
|
+
method="ecosystem.list",
|
|
1421
|
+
data={"filter": [], "search": framework_name, "categories": ["serve"]},
|
|
1422
|
+
convert_json_info_cb=lambda x: x,
|
|
1423
|
+
)
|
|
1424
|
+
if not modules:
|
|
1425
|
+
raise ValueError(f"No serve apps found for framework: '{framework_name}'")
|
|
1426
|
+
|
|
1427
|
+
module = modules[0]
|
|
1428
|
+
module_id = module["id"]
|
|
1429
|
+
serve_app_name = module["name"]
|
|
1430
|
+
logger.debug(f"Serving app delected: '{serve_app_name}'. Module ID: '{module_id}'")
|
|
1431
|
+
|
|
1432
|
+
experiment_info = get_experiment_info_by_artifacts_dir(
|
|
1433
|
+
self._api, team_id, artifacts_dir
|
|
1434
|
+
)
|
|
1435
|
+
if not experiment_info:
|
|
1436
|
+
raise ValueError(
|
|
1437
|
+
f"Failed to retrieve experiment info for artifacts_dir: '{artifacts_dir}'"
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
if len(experiment_info.checkpoints) == 0:
|
|
1441
|
+
raise ValueError(f"No checkpoints found in: '{artifacts_dir}'.")
|
|
1442
|
+
|
|
1443
|
+
checkpoint = None
|
|
1444
|
+
if checkpoint_name is not None:
|
|
1445
|
+
for checkpoint_path in experiment_info.checkpoints:
|
|
1446
|
+
if get_file_name_with_ext(checkpoint_path) == checkpoint_name:
|
|
1447
|
+
checkpoint = get_file_name_with_ext(checkpoint_path)
|
|
1448
|
+
break
|
|
1449
|
+
if checkpoint is None:
|
|
1450
|
+
raise ValueError(
|
|
1451
|
+
f"Provided checkpoint '{checkpoint_name}' not found. Using the best checkpoint."
|
|
1452
|
+
)
|
|
1453
|
+
else:
|
|
1454
|
+
logger.debug("Checkpoint name not provided. Using the best checkpoint.")
|
|
1455
|
+
checkpoint = experiment_info.best_checkpoint
|
|
1456
|
+
|
|
1457
|
+
checkpoint_name = checkpoint
|
|
1458
|
+
deploy_params = {
|
|
1459
|
+
"device": device,
|
|
1460
|
+
"model_source": ModelSource.CUSTOM,
|
|
1461
|
+
"model_files": {
|
|
1462
|
+
"checkpoint": f"{experiment_info.artifacts_dir}checkpoints/{checkpoint_name}"
|
|
1463
|
+
},
|
|
1464
|
+
"model_info": asdict(experiment_info),
|
|
1465
|
+
"runtime": RuntimeType.PYTORCH,
|
|
1466
|
+
}
|
|
1467
|
+
# TODO: add support for **kwargs
|
|
1468
|
+
|
|
1469
|
+
config = experiment_info.model_files.get("config")
|
|
1470
|
+
if config is not None:
|
|
1471
|
+
deploy_params["model_files"]["config"] = f"{experiment_info.artifacts_dir}{config}"
|
|
1472
|
+
logger.debug(f"Config file added: {experiment_info.artifacts_dir}{config}")
|
|
1473
|
+
|
|
1474
|
+
logger.info(
|
|
1475
|
+
f"{serve_app_name} app deployment started. Checkpoint: '{checkpoint_name}'. Deploy params: '{deploy_params}'"
|
|
1476
|
+
)
|
|
1477
|
+
task_info = self.deploy_model_app(
|
|
1478
|
+
module_id,
|
|
1479
|
+
workspace_id,
|
|
1480
|
+
agent_id,
|
|
1481
|
+
description=f"Deployed via deploy_custom_model",
|
|
1482
|
+
task_name=f"{serve_app_name} ({checkpoint_name})",
|
|
1483
|
+
deploy_params=deploy_params,
|
|
1484
|
+
)
|
|
1485
|
+
if task_info is None:
|
|
1486
|
+
raise RuntimeError(f"Failed to run '{serve_app_name}'.")
|
|
1487
|
+
return task_info["id"]
|
|
@@ -3,11 +3,12 @@ import string
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
6
|
+
from dataclasses import fields
|
|
6
7
|
from datetime import datetime
|
|
7
8
|
from json import JSONDecodeError
|
|
8
9
|
from os.path import dirname, join
|
|
9
10
|
from time import time
|
|
10
|
-
from typing import Any, Dict, List, Literal, NamedTuple
|
|
11
|
+
from typing import Any, Dict, List, Literal, NamedTuple, Union
|
|
11
12
|
|
|
12
13
|
import requests
|
|
13
14
|
|
|
@@ -55,6 +56,9 @@ class BaseTrainArtifacts:
|
|
|
55
56
|
self._metadata_file_name: str = "train_info.json"
|
|
56
57
|
|
|
57
58
|
self._app_name: str = None
|
|
59
|
+
self._slug = None
|
|
60
|
+
self._serve_app_name = None
|
|
61
|
+
self._serve_slug = None
|
|
58
62
|
self._framework_name: str = None
|
|
59
63
|
self._framework_folder: str = None
|
|
60
64
|
self._weights_folder: str = None
|
|
@@ -63,6 +67,7 @@ class BaseTrainArtifacts:
|
|
|
63
67
|
self._config_file: str = None
|
|
64
68
|
self._pattern: str = None
|
|
65
69
|
self._available_task_types: List[str] = []
|
|
70
|
+
self._require_runtime = False
|
|
66
71
|
|
|
67
72
|
@property
|
|
68
73
|
def team_id(self) -> int:
|
|
@@ -94,6 +99,36 @@ class BaseTrainArtifacts:
|
|
|
94
99
|
"""
|
|
95
100
|
return self._app_name
|
|
96
101
|
|
|
102
|
+
@property
|
|
103
|
+
def slug(self):
|
|
104
|
+
"""
|
|
105
|
+
Train app slug.
|
|
106
|
+
|
|
107
|
+
:return: Train app slug.
|
|
108
|
+
:rtype: str
|
|
109
|
+
"""
|
|
110
|
+
return self._slug
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def serve_app_name(self):
|
|
114
|
+
"""
|
|
115
|
+
Serve application name.
|
|
116
|
+
|
|
117
|
+
:return: The serve application name.
|
|
118
|
+
:rtype: str
|
|
119
|
+
"""
|
|
120
|
+
return self._serve_app_name
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def serve_slug(self):
|
|
124
|
+
"""
|
|
125
|
+
Serve app slug.
|
|
126
|
+
|
|
127
|
+
:return: Serve app slug.
|
|
128
|
+
:rtype: str
|
|
129
|
+
"""
|
|
130
|
+
return self._serve_slug
|
|
131
|
+
|
|
97
132
|
@property
|
|
98
133
|
def framework_name(self):
|
|
99
134
|
"""
|
|
@@ -164,6 +199,16 @@ class BaseTrainArtifacts:
|
|
|
164
199
|
"""
|
|
165
200
|
return self._pattern
|
|
166
201
|
|
|
202
|
+
@property
|
|
203
|
+
def require_runtime(self):
|
|
204
|
+
"""
|
|
205
|
+
Whether providing runtime is required for the framework.
|
|
206
|
+
|
|
207
|
+
:return: True if runtime is required, False otherwise.
|
|
208
|
+
:rtype: bool
|
|
209
|
+
"""
|
|
210
|
+
return self._require_runtime
|
|
211
|
+
|
|
167
212
|
def is_valid_artifacts_path(self, path):
|
|
168
213
|
"""
|
|
169
214
|
Check if the provided path is valid and follows specified session path pattern.
|
|
@@ -531,68 +576,68 @@ class BaseTrainArtifacts:
|
|
|
531
576
|
logger.debug(f"Listing time: '{format(end_time - start_time, '.6f')}' sec")
|
|
532
577
|
return train_infos
|
|
533
578
|
|
|
534
|
-
def
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
579
|
+
def convert_train_to_experiment_info(
|
|
580
|
+
self, train_info: TrainInfo
|
|
581
|
+
) -> Union[ExperimentInfo, None]:
|
|
582
|
+
try:
|
|
583
|
+
checkpoints = []
|
|
584
|
+
for chk in train_info.checkpoints:
|
|
585
|
+
if self.weights_folder:
|
|
586
|
+
checkpoints.append(join(self.weights_folder, chk.name))
|
|
587
|
+
else:
|
|
588
|
+
checkpoints.append(chk.name)
|
|
589
|
+
|
|
590
|
+
best_checkpoint = next(
|
|
591
|
+
(chk.name for chk in train_info.checkpoints if "best" in chk.name), None
|
|
592
|
+
)
|
|
593
|
+
if not best_checkpoint and checkpoints:
|
|
594
|
+
best_checkpoint = get_file_name_with_ext(checkpoints[-1])
|
|
595
|
+
|
|
596
|
+
task_info = self._api.task.get_info_by_id(train_info.task_id)
|
|
597
|
+
workspace_id = task_info["workspaceId"]
|
|
598
|
+
|
|
599
|
+
project = self._api.project.get_info_by_name(workspace_id, train_info.project_name)
|
|
600
|
+
project_id = project.id if project else None
|
|
601
|
+
|
|
602
|
+
model_files = {}
|
|
603
|
+
if train_info.config_path:
|
|
604
|
+
model_files["config"] = self.get_config_path(train_info.artifacts_folder).replace(
|
|
605
|
+
train_info.artifacts_folder, ""
|
|
549
606
|
)
|
|
550
|
-
if not best_checkpoint and checkpoints:
|
|
551
|
-
best_checkpoint = get_file_name_with_ext(checkpoints[-1])
|
|
552
|
-
|
|
553
|
-
task_info = api.task.get_info_by_id(train_info.task_id)
|
|
554
|
-
workspace_id = task_info["workspaceId"]
|
|
555
|
-
|
|
556
|
-
project = api.project.get_info_by_name(workspace_id, train_info.project_name)
|
|
557
|
-
project_id = project.id if project else None
|
|
558
|
-
|
|
559
|
-
model_files = {}
|
|
560
|
-
if train_info.config_path:
|
|
561
|
-
model_files["config"] = self.get_config_path(
|
|
562
|
-
train_info.artifacts_folder
|
|
563
|
-
).replace(train_info.artifacts_folder, "")
|
|
564
|
-
|
|
565
|
-
input_datetime = task_info["startedAt"]
|
|
566
|
-
parsed_datetime = datetime.strptime(input_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
|
567
|
-
date_time = parsed_datetime.strftime("%Y-%m-%d %H:%M:%S")
|
|
568
|
-
|
|
569
|
-
experiment_info_data = {
|
|
570
|
-
"experiment_name": f"Unknown {self.framework_name} experiment",
|
|
571
|
-
"framework_name": self.framework_name,
|
|
572
|
-
"model_name": f"Unknown {self.framework_name} model",
|
|
573
|
-
"task_type": train_info.task_type,
|
|
574
|
-
"project_id": project_id,
|
|
575
|
-
"task_id": train_info.task_id,
|
|
576
|
-
"model_files": model_files,
|
|
577
|
-
"checkpoints": checkpoints,
|
|
578
|
-
"best_checkpoint": best_checkpoint,
|
|
579
|
-
"artifacts_dir": train_info.artifacts_folder,
|
|
580
|
-
"datetime": date_time,
|
|
581
|
-
}
|
|
582
|
-
|
|
583
|
-
experiment_info_fields = {
|
|
584
|
-
field.name
|
|
585
|
-
for field in ExperimentInfo.__dataclass_fields__.values() # pylint: disable=no-member
|
|
586
|
-
}
|
|
587
|
-
for field in experiment_info_fields:
|
|
588
|
-
if field not in experiment_info_data:
|
|
589
|
-
experiment_info_data[field] = None
|
|
590
|
-
|
|
591
|
-
return ExperimentInfo(**experiment_info_data)
|
|
592
|
-
except Exception as e:
|
|
593
|
-
logger.debug(f"Failed to build experiment info: {e}")
|
|
594
|
-
return None
|
|
595
607
|
|
|
608
|
+
input_datetime = task_info["startedAt"]
|
|
609
|
+
parsed_datetime = datetime.strptime(input_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
|
610
|
+
date_time = parsed_datetime.strftime("%Y-%m-%d %H:%M:%S")
|
|
611
|
+
|
|
612
|
+
experiment_info_data = {
|
|
613
|
+
"experiment_name": f"Unknown {self.framework_name} experiment",
|
|
614
|
+
"framework_name": self.framework_name,
|
|
615
|
+
"model_name": f"Unknown {self.framework_name} model",
|
|
616
|
+
"task_type": train_info.task_type,
|
|
617
|
+
"project_id": project_id,
|
|
618
|
+
"task_id": train_info.task_id,
|
|
619
|
+
"model_files": model_files,
|
|
620
|
+
"checkpoints": checkpoints,
|
|
621
|
+
"best_checkpoint": best_checkpoint,
|
|
622
|
+
"artifacts_dir": train_info.artifacts_folder,
|
|
623
|
+
"datetime": date_time,
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
experiment_info_fields = {
|
|
627
|
+
field.name
|
|
628
|
+
for field in ExperimentInfo.__dataclass_fields__.values() # pylint: disable=no-member
|
|
629
|
+
}
|
|
630
|
+
for field in experiment_info_fields:
|
|
631
|
+
if field not in experiment_info_data:
|
|
632
|
+
experiment_info_data[field] = None
|
|
633
|
+
return ExperimentInfo(**experiment_info_data)
|
|
634
|
+
except Exception as e:
|
|
635
|
+
logger.debug(f"Failed to build experiment info: {e}")
|
|
636
|
+
return None
|
|
637
|
+
|
|
638
|
+
def get_list_experiment_info(
|
|
639
|
+
self, sort: Literal["desc", "asc"] = "desc"
|
|
640
|
+
) -> List[ExperimentInfo]:
|
|
596
641
|
train_infos = self.get_list(sort)
|
|
597
642
|
|
|
598
643
|
# Sync version
|
|
@@ -607,7 +652,7 @@ class BaseTrainArtifacts:
|
|
|
607
652
|
with ThreadPoolExecutor() as executor:
|
|
608
653
|
experiment_infos = list(
|
|
609
654
|
executor.map(
|
|
610
|
-
lambda t:
|
|
655
|
+
lambda t: self.convert_train_to_experiment_info(t),
|
|
611
656
|
train_infos,
|
|
612
657
|
)
|
|
613
658
|
)
|
|
@@ -621,3 +666,29 @@ class BaseTrainArtifacts:
|
|
|
621
666
|
:rtype: List[str]
|
|
622
667
|
"""
|
|
623
668
|
return self._available_task_types
|
|
669
|
+
|
|
670
|
+
def get_info_by_artifacts_dir(
|
|
671
|
+
self,
|
|
672
|
+
artifacts_dir: str,
|
|
673
|
+
return_type: Literal["train_info", "experiment_info"] = "train_info",
|
|
674
|
+
) -> Union[TrainInfo, ExperimentInfo, None]:
|
|
675
|
+
"""
|
|
676
|
+
Get training info by artifacts directory.
|
|
677
|
+
|
|
678
|
+
:param artifacts_dir: The artifacts directory.
|
|
679
|
+
:type artifacts_dir: str
|
|
680
|
+
:param return_type: The return type, either "train_info" or "experiment_info". Default is "experiment_info".
|
|
681
|
+
:type return_type: Literal["train_info", "experiment_info"]
|
|
682
|
+
:return: The training info.
|
|
683
|
+
:rtype: TrainInfo
|
|
684
|
+
"""
|
|
685
|
+
for train_info in self.get_list():
|
|
686
|
+
if train_info.artifacts_folder == artifacts_dir:
|
|
687
|
+
if return_type == "train_info":
|
|
688
|
+
return train_info
|
|
689
|
+
else:
|
|
690
|
+
return self.convert_train_to_experiment_info(train_info)
|
|
691
|
+
|
|
692
|
+
# load_custom_checkpoint
|
|
693
|
+
# inference
|
|
694
|
+
# fix docstrings :param: x -> :param x:
|
|
@@ -10,6 +10,11 @@ class Detectron2(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train Detectron2"
|
|
13
|
+
self._slug = "supervisely-ecosystem/detectron2/supervisely/train"
|
|
14
|
+
self._serve_app_name = "Serve Detectron2"
|
|
15
|
+
self._serve_slug = (
|
|
16
|
+
"supervisely-ecosystem/detectron2/supervisely/instance_segmentation/serve"
|
|
17
|
+
)
|
|
13
18
|
self._framework_name = "Detectron2"
|
|
14
19
|
self._framework_folder = "/detectron2"
|
|
15
20
|
self._weights_folder = "checkpoints"
|
|
@@ -19,6 +24,7 @@ class Detectron2(BaseTrainArtifacts):
|
|
|
19
24
|
self._config_file = "model_config.yaml"
|
|
20
25
|
self._pattern = re_compile(r"^/detectron2/\d+_[^/]+/?$")
|
|
21
26
|
self._available_task_types: List[str] = ["instance segmentation"]
|
|
27
|
+
self._require_runtime = False
|
|
22
28
|
|
|
23
29
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
24
30
|
parts = artifacts_folder.split("/")
|
supervisely/nn/artifacts/hrda.py
CHANGED
|
@@ -10,12 +10,16 @@ class HRDA(BaseTrainArtifacts):
|
|
|
10
10
|
raise NotImplementedError
|
|
11
11
|
# super().__init__(team_id)
|
|
12
12
|
# self._app_name = "Train HRDA"
|
|
13
|
+
# self._serve_app_name = None
|
|
14
|
+
# self._slug = None
|
|
15
|
+
# self._serve_slug = None
|
|
13
16
|
# self._framework_folder = "/HRDA"
|
|
14
17
|
# self._weights_folder = None
|
|
15
18
|
# self._task_type = "semantic segmentation"
|
|
16
19
|
# self._weights_ext = ".pth"
|
|
17
20
|
# self._config_file = "config.py"
|
|
18
21
|
# self._available_task_types: List[str] = ["semantic segmentation"]
|
|
22
|
+
# self._require_runtime = False
|
|
19
23
|
|
|
20
24
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
21
25
|
raise NotImplementedError
|
|
@@ -10,6 +10,9 @@ class MMClassification(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train MMClassification"
|
|
13
|
+
self._slug = "supervisely-ecosystem/mmclassification/supervisely/train"
|
|
14
|
+
self._serve_app_name = "Serve MMClassification"
|
|
15
|
+
self._serve_slug = "supervisely-ecosystem/mmclassification/supervisely/serve"
|
|
13
16
|
self._framework_name = "MMClassification"
|
|
14
17
|
self._framework_folder = "/mmclassification"
|
|
15
18
|
self._weights_folder = "checkpoints"
|
|
@@ -17,6 +20,7 @@ class MMClassification(BaseTrainArtifacts):
|
|
|
17
20
|
self._weights_ext = ".pth"
|
|
18
21
|
self._pattern = re_compile(r"^/mmclassification/\d+_[^/]+/?$")
|
|
19
22
|
self._available_task_types: List[str] = ["classification"]
|
|
23
|
+
self._require_runtime = False
|
|
20
24
|
|
|
21
25
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
22
26
|
parts = artifacts_folder.split("/")
|
|
@@ -13,6 +13,9 @@ class MMDetection(BaseTrainArtifacts):
|
|
|
13
13
|
super().__init__(team_id)
|
|
14
14
|
|
|
15
15
|
self._app_name = "Train MMDetection"
|
|
16
|
+
self._slug = "supervisely-ecosystem/mmdetection/train"
|
|
17
|
+
self._serve_app_name = "Serve MMDetection"
|
|
18
|
+
self._serve_slug = "supervisely-ecosystem/mmdetection/serve"
|
|
16
19
|
self._framework_name = "MMDetection"
|
|
17
20
|
self._framework_folder = "/mmdetection"
|
|
18
21
|
self._weights_folder = "checkpoints/data"
|
|
@@ -22,6 +25,7 @@ class MMDetection(BaseTrainArtifacts):
|
|
|
22
25
|
self._config_file = "config.py"
|
|
23
26
|
self._pattern = re_compile(r"^/mmdetection/\d+_[^/]+/?$")
|
|
24
27
|
self._available_task_types: List[str] = ["object detection", "instance segmentation"]
|
|
28
|
+
self._require_runtime = False
|
|
25
29
|
|
|
26
30
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
27
31
|
parts = artifacts_folder.split("/")
|
|
@@ -59,6 +63,9 @@ class MMDetection3(BaseTrainArtifacts):
|
|
|
59
63
|
super().__init__(team_id)
|
|
60
64
|
|
|
61
65
|
self._app_name = "Train MMDetection 3.0"
|
|
66
|
+
self._slug = "Serve MMDetection 3.0"
|
|
67
|
+
self._serve_app_name = "supervisely-ecosystem/train-mmdetection-v3"
|
|
68
|
+
self._serve_slug = "supervisely-ecosystem/serve-mmdetection-v3"
|
|
62
69
|
self._framework_name = "MMDetection 3.0"
|
|
63
70
|
self._framework_folder = "/mmdetection-3"
|
|
64
71
|
self._weights_folder = None
|
|
@@ -67,7 +74,8 @@ class MMDetection3(BaseTrainArtifacts):
|
|
|
67
74
|
self._config_file = "config.py"
|
|
68
75
|
self._pattern = re_compile(r"^/mmdetection-3/\d+_[^/]+/?$")
|
|
69
76
|
self._available_task_types: List[str] = ["object detection", "instance segmentation"]
|
|
70
|
-
|
|
77
|
+
self._require_runtime = False
|
|
78
|
+
|
|
71
79
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
72
80
|
parts = artifacts_folder.split("/")
|
|
73
81
|
if len(parts) < 3:
|
|
@@ -10,6 +10,9 @@ class MMSegmentation(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train MMSegmentation"
|
|
13
|
+
self._slug = "supervisely-ecosystem/mmsegmentation/train"
|
|
14
|
+
self._serve_app_name = "Serve MMSegmentation"
|
|
15
|
+
self._serve_slug = "supervisely-ecosystem/mmsegmentation/serve"
|
|
13
16
|
self._framework_name = "MMSegmentation"
|
|
14
17
|
self._framework_folder = "/mmsegmentation"
|
|
15
18
|
self._weights_folder = "checkpoints/data"
|
|
@@ -18,6 +21,7 @@ class MMSegmentation(BaseTrainArtifacts):
|
|
|
18
21
|
self._config_file = "config.py"
|
|
19
22
|
self._pattern = re_compile(r"^/mmsegmentation/\d+_[^/]+/?$")
|
|
20
23
|
self._available_task_types: List[str] = ["instance segmentation"]
|
|
24
|
+
self._require_runtime = False
|
|
21
25
|
|
|
22
26
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
23
27
|
return artifacts_folder.split("/")[2].split("_")[0]
|
supervisely/nn/artifacts/ritm.py
CHANGED
|
@@ -10,6 +10,9 @@ class RITM(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train RITM"
|
|
13
|
+
self._slug = "supervisely-ecosystem/ritm-training/supervisely/train"
|
|
14
|
+
self._serve_app_name = None
|
|
15
|
+
self._serve_slug = None
|
|
13
16
|
self._framework_name = "RITM"
|
|
14
17
|
self._framework_folder = "/RITM_training"
|
|
15
18
|
self._weights_folder = "checkpoints"
|
|
@@ -18,6 +21,7 @@ class RITM(BaseTrainArtifacts):
|
|
|
18
21
|
self._weights_ext = ".pth"
|
|
19
22
|
self._pattern = re_compile(r"^/RITM_training/\d+_[^/]+/?$")
|
|
20
23
|
self._available_task_types: List[str] = ["interactive segmentation"]
|
|
24
|
+
self._require_runtime = False
|
|
21
25
|
|
|
22
26
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
23
27
|
parts = artifacts_folder.split("/")
|
|
@@ -10,6 +10,9 @@ class RTDETR(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train RT-DETR"
|
|
13
|
+
self._slug = "supervisely-ecosystem/rt-detr/supervisely_integration/train"
|
|
14
|
+
self._serve_app_name = "Serve RT-DETR"
|
|
15
|
+
self._serve_slug = "supervisely-ecosystem/rt-detr/supervisely_integration/serve"
|
|
13
16
|
self._framework_name = "RT-DETR"
|
|
14
17
|
self._framework_folder = "/RT-DETR"
|
|
15
18
|
self._weights_folder = "weights"
|
|
@@ -18,6 +21,7 @@ class RTDETR(BaseTrainArtifacts):
|
|
|
18
21
|
self._config_file = "config.yml"
|
|
19
22
|
self._pattern = re_compile(r"^/RT-DETR/[^/]+/\d+/?$")
|
|
20
23
|
self._available_task_types: List[str] = ["object detection"]
|
|
24
|
+
self._require_runtime = False
|
|
21
25
|
|
|
22
26
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
23
27
|
return artifacts_folder.split("/")[-1]
|
supervisely/nn/artifacts/unet.py
CHANGED
|
@@ -10,6 +10,9 @@ class UNet(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train UNet"
|
|
13
|
+
self._slug = "supervisely-ecosystem/unet/supervisely/train"
|
|
14
|
+
self._serve_app_name = "Serve UNet"
|
|
15
|
+
self._serve_slug = "supervisely-ecosystem/unet/supervisely/serve"
|
|
13
16
|
self._framework_name = "UNet"
|
|
14
17
|
self._framework_folder = "/unet"
|
|
15
18
|
self._weights_folder = "checkpoints"
|
|
@@ -18,6 +21,7 @@ class UNet(BaseTrainArtifacts):
|
|
|
18
21
|
self._config_file = "train_args.json"
|
|
19
22
|
self._pattern = re_compile(r"^/unet/\d+_[^/]+/?$")
|
|
20
23
|
self._available_task_types: List[str] = ["semantic segmentation"]
|
|
24
|
+
self._require_runtime = False
|
|
21
25
|
|
|
22
26
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
23
27
|
parts = artifacts_folder.split("/")
|
|
@@ -10,6 +10,9 @@ class YOLOv5(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train YOLOv5"
|
|
13
|
+
self._slug = "supervisely-ecosystem/yolov5/supervisely/train"
|
|
14
|
+
self._serve_app_name = "Serve YOLOv5"
|
|
15
|
+
self._serve_slug = "supervisely-ecosystem/yolov5/supervisely/serve"
|
|
13
16
|
self._framework_name = "YOLOv5"
|
|
14
17
|
self._framework_folder = "/yolov5_train"
|
|
15
18
|
self._weights_folder = "weights"
|
|
@@ -18,6 +21,7 @@ class YOLOv5(BaseTrainArtifacts):
|
|
|
18
21
|
self._config_file = None
|
|
19
22
|
self._pattern = re_compile(r"^/yolov5_train/[^/]+/\d+/?$")
|
|
20
23
|
self._available_task_types: List[str] = ["object detection"]
|
|
24
|
+
self._require_runtime = False
|
|
21
25
|
|
|
22
26
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
23
27
|
return artifacts_folder.split("/")[-1]
|
|
@@ -40,6 +44,9 @@ class YOLOv5v2(YOLOv5):
|
|
|
40
44
|
super().__init__(team_id)
|
|
41
45
|
|
|
42
46
|
self._app_name = "Train YOLOv5 2.0"
|
|
47
|
+
self._slug = "supervisely-ecosystem/yolov5_2.0/train"
|
|
48
|
+
self._serve_app_name = "Serve YOLOv5 2.0"
|
|
49
|
+
self._serve_slug = "supervisely-ecosystem/yolov5_2.0/serve"
|
|
43
50
|
self._framework_name = "YOLOv5 2.0"
|
|
44
51
|
self._framework_folder = "/yolov5_2.0_train"
|
|
45
52
|
self._weights_folder = "weights"
|
|
@@ -10,7 +10,10 @@ class YOLOv8(BaseTrainArtifacts):
|
|
|
10
10
|
super().__init__(team_id)
|
|
11
11
|
|
|
12
12
|
self._app_name = "Train YOLOv8 | v9 | v10 | v11"
|
|
13
|
-
self.
|
|
13
|
+
self._slug = "supervisely-ecosystem/yolov8/train"
|
|
14
|
+
self._serve_app_name = "Serve YOLOv8 | v9 | v10 | v11"
|
|
15
|
+
self._serve_slug = "supervisely-ecosystem/yolov8/serve"
|
|
16
|
+
self._framework_name = "YOLOv8"
|
|
14
17
|
self._framework_folder = "/yolov8_train"
|
|
15
18
|
self._weights_folder = "weights"
|
|
16
19
|
self._task_type = None
|
|
@@ -24,6 +27,7 @@ class YOLOv8(BaseTrainArtifacts):
|
|
|
24
27
|
"instance segmentation",
|
|
25
28
|
"pose estimation",
|
|
26
29
|
]
|
|
30
|
+
self._require_runtime = True
|
|
27
31
|
|
|
28
32
|
def get_task_id(self, artifacts_folder: str) -> str:
|
|
29
33
|
parts = artifacts_folder.split("/")
|
supervisely/nn/experiments.py
CHANGED
|
@@ -2,13 +2,15 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
2
2
|
from dataclasses import dataclass, fields
|
|
3
3
|
from json import JSONDecodeError
|
|
4
4
|
from os.path import dirname, join
|
|
5
|
-
from typing import List, Optional
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
6
|
|
|
7
7
|
import requests
|
|
8
8
|
|
|
9
9
|
from supervisely import logger
|
|
10
10
|
from supervisely.api.api import Api, ApiField
|
|
11
11
|
|
|
12
|
+
EXPERIMENT_INFO_FILENAME = "experiment_info.json"
|
|
13
|
+
|
|
12
14
|
|
|
13
15
|
@dataclass
|
|
14
16
|
class ExperimentInfo:
|
|
@@ -78,7 +80,7 @@ def get_experiment_infos(api: Api, team_id: int, framework_name: str) -> List[Ex
|
|
|
78
80
|
|
|
79
81
|
api = sly.Api.from_env()
|
|
80
82
|
team_id = sly.env.team_id()
|
|
81
|
-
framework_name = "
|
|
83
|
+
framework_name = "RT-DETRv2"
|
|
82
84
|
experiment_infos = sly.nn.training.experiments.get_experiment_infos(api, team_id, framework_name)
|
|
83
85
|
"""
|
|
84
86
|
metadata_name = "experiment_info.json"
|
|
@@ -140,3 +142,84 @@ def get_experiment_infos(api: Api, team_id: int, framework_name: str) -> List[Ex
|
|
|
140
142
|
|
|
141
143
|
experiment_infos = [info for info in experiment_infos if info is not None]
|
|
142
144
|
return experiment_infos
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _fetch_experiment_data(api, team_id: int, experiment_path: str) -> Union[ExperimentInfo, None]:
|
|
148
|
+
"""
|
|
149
|
+
Fetch experiment data from the specified path in Supervisely Team Files
|
|
150
|
+
|
|
151
|
+
:param api: Supervisely API client
|
|
152
|
+
:type api: Api
|
|
153
|
+
:param team_id: Team ID
|
|
154
|
+
:type team_id: int
|
|
155
|
+
:param experiment_path: Path to the experiment data
|
|
156
|
+
:type experiment_path: str
|
|
157
|
+
:return: ExperimentInfo object
|
|
158
|
+
:rtype: Union[ExperimentInfo, None]
|
|
159
|
+
"""
|
|
160
|
+
try:
|
|
161
|
+
response = api.post(
|
|
162
|
+
"file-storage.download",
|
|
163
|
+
{ApiField.TEAM_ID: team_id, ApiField.PATH: experiment_path},
|
|
164
|
+
stream=True,
|
|
165
|
+
)
|
|
166
|
+
response.raise_for_status()
|
|
167
|
+
response_json = response.json()
|
|
168
|
+
required_fields = {
|
|
169
|
+
field.name for field in fields(ExperimentInfo) if field.default is not None
|
|
170
|
+
}
|
|
171
|
+
optional_fields = {field.name for field in fields(ExperimentInfo) if field.default is None}
|
|
172
|
+
|
|
173
|
+
missing_optional_fields = optional_fields - response_json.keys()
|
|
174
|
+
if missing_optional_fields:
|
|
175
|
+
logger.debug(
|
|
176
|
+
f"Missing optional fields: {missing_optional_fields} for '{experiment_path}'"
|
|
177
|
+
)
|
|
178
|
+
for field in missing_optional_fields:
|
|
179
|
+
response_json[field] = None
|
|
180
|
+
|
|
181
|
+
missing_required_fields = required_fields - response_json.keys()
|
|
182
|
+
if missing_required_fields:
|
|
183
|
+
logger.debug(
|
|
184
|
+
f"Missing required fields: {missing_required_fields} for '{experiment_path}'. Skipping."
|
|
185
|
+
)
|
|
186
|
+
return None
|
|
187
|
+
return ExperimentInfo(**{k: v for k, v in response_json.items() if k in required_fields})
|
|
188
|
+
except requests.exceptions.RequestException as e:
|
|
189
|
+
logger.debug(f"Request failed for '{experiment_path}': {e}")
|
|
190
|
+
except JSONDecodeError as e:
|
|
191
|
+
logger.debug(f"JSON decode failed for '{experiment_path}': {e}")
|
|
192
|
+
except TypeError as e:
|
|
193
|
+
logger.error(f"TypeError for '{experiment_path}': {e}")
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def get_experiment_info_by_artifacts_dir(
|
|
198
|
+
api: Api, team_id: int, artifacts_dir: str
|
|
199
|
+
) -> Union[ExperimentInfo, None]:
|
|
200
|
+
"""
|
|
201
|
+
Get experiment info by artifacts directory
|
|
202
|
+
|
|
203
|
+
:param api: Supervisely API client
|
|
204
|
+
:type api: Api
|
|
205
|
+
:param team_id: Team ID
|
|
206
|
+
:type team_id: int
|
|
207
|
+
:param artifacts_dir: Path to the directory with artifacts
|
|
208
|
+
:type artifacts_dir: str
|
|
209
|
+
:return: ExperimentInfo object
|
|
210
|
+
:rtype: Optional[ExperimentInfo]
|
|
211
|
+
:Usage example:
|
|
212
|
+
|
|
213
|
+
.. code-block:: python
|
|
214
|
+
|
|
215
|
+
import supervisely as sly
|
|
216
|
+
|
|
217
|
+
api = sly.Api.from_env()
|
|
218
|
+
team_id = sly.env.team_id()
|
|
219
|
+
artifacts_dir = "/experiments/27_Lemons (Rectangle)/265_RT-DETRv2/"
|
|
220
|
+
experiment_info = sly.nn.training.experiments.get_experiment_info_by_artifacts_dir(api, team_id, artifacts_dir)
|
|
221
|
+
"""
|
|
222
|
+
if not artifacts_dir.startswith("/experiments"):
|
|
223
|
+
raise ValueError("Artifacts directory should start with '/experiments'")
|
|
224
|
+
experiment_path = join(artifacts_dir, EXPERIMENT_INFO_FILENAME)
|
|
225
|
+
return _fetch_experiment_data(api, team_id, experiment_path)
|
|
@@ -670,16 +670,20 @@ class Inference:
|
|
|
670
670
|
self.update_gui(self._model_served)
|
|
671
671
|
self.gui.show_deployed_model_info(self)
|
|
672
672
|
|
|
673
|
-
def load_custom_checkpoint(
|
|
673
|
+
def load_custom_checkpoint(
|
|
674
|
+
self, model_files: dict, model_meta: dict, device: str = "cuda", **kwargs
|
|
675
|
+
):
|
|
674
676
|
"""
|
|
675
677
|
Loads local custom model checkpoint.
|
|
676
678
|
|
|
677
|
-
:param: model_files: dict with paths to model files
|
|
679
|
+
:param: model_files: dict with local paths to model files
|
|
678
680
|
:type: model_files: dict
|
|
679
681
|
:param: model_meta: dict with model meta
|
|
680
682
|
:type: model_meta: dict
|
|
681
683
|
:param: device: device to load model on
|
|
682
684
|
:type: device: str
|
|
685
|
+
:param: kwargs: additional parameters will be passed to load_model method.
|
|
686
|
+
:type: kwargs: dict
|
|
683
687
|
:return: None
|
|
684
688
|
:rtype: None
|
|
685
689
|
|
|
@@ -717,6 +721,9 @@ class Inference:
|
|
|
717
721
|
"device": device,
|
|
718
722
|
"runtime": RuntimeType.PYTORCH,
|
|
719
723
|
}
|
|
724
|
+
deploy_params.update(kwargs)
|
|
725
|
+
|
|
726
|
+
# TODO: add support for **kwargs (user arguments)
|
|
720
727
|
self._set_model_meta_custom_model({"model_meta": model_meta})
|
|
721
728
|
self._load_model(deploy_params)
|
|
722
729
|
|
|
@@ -1013,7 +1020,7 @@ class Inference:
|
|
|
1013
1020
|
self,
|
|
1014
1021
|
source: Union[str, int, np.ndarray, List[str], List[int], List[np.ndarray]],
|
|
1015
1022
|
settings: dict = None,
|
|
1016
|
-
) -> Union[Annotation, List[Annotation]
|
|
1023
|
+
) -> Union[Annotation, List[Annotation]]:
|
|
1017
1024
|
"""
|
|
1018
1025
|
Inference method for images. Provide image path or numpy array of image.
|
|
1019
1026
|
|
|
@@ -1022,7 +1029,7 @@ class Inference:
|
|
|
1022
1029
|
:param: settings: inference settings
|
|
1023
1030
|
:type: settings: dict
|
|
1024
1031
|
:return: annotation or list of annotations
|
|
1025
|
-
:rtype: Union[Annotation, List[Annotation]
|
|
1032
|
+
:rtype: Union[Annotation, List[Annotation]]
|
|
1026
1033
|
|
|
1027
1034
|
:Usage Example:
|
|
1028
1035
|
|
|
@@ -697,7 +697,7 @@ class TrainApp:
|
|
|
697
697
|
if model_files is None:
|
|
698
698
|
raise ValueError(
|
|
699
699
|
"Model files not found in model metadata. "
|
|
700
|
-
"Please update provided models
|
|
700
|
+
"Please update provided models parameter to include key 'model_files' in 'meta' key."
|
|
701
701
|
)
|
|
702
702
|
return models
|
|
703
703
|
|
|
@@ -42,7 +42,7 @@ supervisely/api/remote_storage_api.py,sha256=qTuPhPsstgEjRm1g-ZInddik8BNC_38YvBB
|
|
|
42
42
|
supervisely/api/report_api.py,sha256=Om7CGulUbQ4BuJ16eDtz7luLe0JQNqab-LoLpUXu7YE,7123
|
|
43
43
|
supervisely/api/role_api.py,sha256=aBL4mxtn08LDPXQuS153-lQFN6N2kcwiz8MbescZ8Gk,3044
|
|
44
44
|
supervisely/api/storage_api.py,sha256=FPGYf3Rn3LBoe38RBNdoiURs306oshzvKOEOQ56XAbs,13030
|
|
45
|
-
supervisely/api/task_api.py,sha256=
|
|
45
|
+
supervisely/api/task_api.py,sha256=tRhHEvPQQa68nCrENNlgSCruVYDWYLdvvLD-1MGcicw,53632
|
|
46
46
|
supervisely/api/team_api.py,sha256=bEoz3mrykvliLhKnzEy52vzdd_H8VBJCpxF-Bnek9Q8,19467
|
|
47
47
|
supervisely/api/user_api.py,sha256=4S97yIc6AMTZCa0N57lzETnpIE8CeqClvCb6kjUkgfc,24940
|
|
48
48
|
supervisely/api/video_annotation_tool_api.py,sha256=3A9-U8WJzrTShP_n9T8U01M9FzGYdeS51CCBTzUnooo,6686
|
|
@@ -730,23 +730,23 @@ supervisely/metric/pixel_accuracy.py,sha256=qjtxInOTkGDwPeLUnjBdzOrVRT3V6kGGOWjB
|
|
|
730
730
|
supervisely/metric/precision_recall_metric.py,sha256=4AQCkcB84mpYQS94yJ-wkG1LBuXlQf3X_tI9f67vtR8,3426
|
|
731
731
|
supervisely/metric/projects_applier.py,sha256=ORtgLQHYtNi4KYsSGaGPPWiZPexTJF9IWqX_RuLRxPk,3415
|
|
732
732
|
supervisely/nn/__init__.py,sha256=w2gZ6pCreaTYyhZV8PZrYctqmCu_7sbJ-WGegs7mouw,570
|
|
733
|
-
supervisely/nn/experiments.py,sha256=
|
|
733
|
+
supervisely/nn/experiments.py,sha256=0RFYp-LZTRny9tsyeVV58GKtPWngxTw54abfrk3052g,8742
|
|
734
734
|
supervisely/nn/prediction_dto.py,sha256=8QQE6h_feOf3bjWtyG_PoU8FIQrr4g8PoMOyoscmqJM,1697
|
|
735
735
|
supervisely/nn/task_type.py,sha256=UJvSJ4L3I08j_e6sU6Ptu7kS5p1H09rfhfoDUSZ2iys,522
|
|
736
736
|
supervisely/nn/utils.py,sha256=-Xjv5KLu8CTtyi7acqsIX1E0dDwKZPED4D6b4Z_Ln3k,1451
|
|
737
737
|
supervisely/nn/artifacts/__init__.py,sha256=m7KYTMzEJnoV9wcU_0xzgLuPz69Dqp9va0fP32tohV4,576
|
|
738
|
-
supervisely/nn/artifacts/artifacts.py,sha256=
|
|
739
|
-
supervisely/nn/artifacts/detectron2.py,sha256=
|
|
740
|
-
supervisely/nn/artifacts/hrda.py,sha256=
|
|
741
|
-
supervisely/nn/artifacts/mmclassification.py,sha256=
|
|
742
|
-
supervisely/nn/artifacts/mmdetection.py,sha256=
|
|
743
|
-
supervisely/nn/artifacts/mmsegmentation.py,sha256=
|
|
744
|
-
supervisely/nn/artifacts/ritm.py,sha256=
|
|
745
|
-
supervisely/nn/artifacts/rtdetr.py,sha256=
|
|
746
|
-
supervisely/nn/artifacts/unet.py,sha256=
|
|
738
|
+
supervisely/nn/artifacts/artifacts.py,sha256=Ol6Tt3CHGbGm_7rR3iClokOArUj6z4ky92YKozpewRM,22859
|
|
739
|
+
supervisely/nn/artifacts/detectron2.py,sha256=g2F47GS1LryWng1zMAXW5ZLnz0fcRuYAY3sX6LcuHUs,1961
|
|
740
|
+
supervisely/nn/artifacts/hrda.py,sha256=m671dTq10X6IboUL3xed9yuYzH-HP8KDjYAD8-a9CDI,1213
|
|
741
|
+
supervisely/nn/artifacts/mmclassification.py,sha256=Lm-M1OD8-Qjr8xzajrOh53sFGtL-DSI8uSAxwPJJQxQ,1787
|
|
742
|
+
supervisely/nn/artifacts/mmdetection.py,sha256=Pmulm3Ppw6uTUr_TNDCVYFgZXTNS2-69PDR--xTVzcI,5272
|
|
743
|
+
supervisely/nn/artifacts/mmsegmentation.py,sha256=9yls1PUwcTHkeNvdDw4ZclQZxbNW8DYVi5s_yFVvWA0,1561
|
|
744
|
+
supervisely/nn/artifacts/ritm.py,sha256=rnZ8-cWzqYf-cqukdVN0VJhsPG4gXX3KeRODaIZX2Q4,2152
|
|
745
|
+
supervisely/nn/artifacts/rtdetr.py,sha256=lFyWyGMH0jfaUGUIqFGQVSVnG-MkC7u3piAAXyoh99M,1486
|
|
746
|
+
supervisely/nn/artifacts/unet.py,sha256=GmZ947o5Mdys3rqQN8gC7dhGOl1Uo2zWZ_bH9b9FcVI,1810
|
|
747
747
|
supervisely/nn/artifacts/utils.py,sha256=C4EaMi95MAwtK5TOnhK4sQ1BWvgwYBxXyRStkhYrYv8,1356
|
|
748
|
-
supervisely/nn/artifacts/yolov5.py,sha256=
|
|
749
|
-
supervisely/nn/artifacts/yolov8.py,sha256=
|
|
748
|
+
supervisely/nn/artifacts/yolov5.py,sha256=AVRbUSY4gxRz-yXhuP80oDZbtNElblun2Ie6eTOYU5g,2134
|
|
749
|
+
supervisely/nn/artifacts/yolov8.py,sha256=wLefz1CTGghZ8x61L5Oa0pe9nOetZzt4YN8yB9v_Dh8,1667
|
|
750
750
|
supervisely/nn/benchmark/__init__.py,sha256=7jDezvavJFtO9mDeB2TqW8N4sD8TsHQBPpA9RESleIQ,610
|
|
751
751
|
supervisely/nn/benchmark/base_benchmark.py,sha256=SltD3T2Nbo3h0RLfPShWmGrpHJOzSmF7IK6G4unXm4c,26055
|
|
752
752
|
supervisely/nn/benchmark/base_evaluator.py,sha256=MJeZnMcWr_cbeJ2r0GJ4SWgjWX5w33Y3pYVR6kCIQMQ,5246
|
|
@@ -875,7 +875,7 @@ supervisely/nn/benchmark/visualization/widgets/table/__init__.py,sha256=47DEQpj8
|
|
|
875
875
|
supervisely/nn/benchmark/visualization/widgets/table/table.py,sha256=atmDnF1Af6qLQBUjLhK18RMDKAYlxnsuVHMSEa5a-e8,4319
|
|
876
876
|
supervisely/nn/inference/__init__.py,sha256=mtEci4Puu-fRXDnGn8RP47o97rv3VTE0hjbYO34Zwqg,1622
|
|
877
877
|
supervisely/nn/inference/cache.py,sha256=h-pP_7th0ana3oJ75sFfTbead3hdKUvYA8Iq2OXDx3I,31317
|
|
878
|
-
supervisely/nn/inference/inference.py,sha256
|
|
878
|
+
supervisely/nn/inference/inference.py,sha256=-RXgGv9QUI-hJk1fVDYZVNguLlvSHjCDtElwSIz21Ow,148340
|
|
879
879
|
supervisely/nn/inference/session.py,sha256=jmkkxbe2kH-lEgUU6Afh62jP68dxfhF5v6OGDfLU62E,35757
|
|
880
880
|
supervisely/nn/inference/video_inference.py,sha256=8Bshjr6rDyLay5Za8IB8Dr6FURMO2R_v7aELasO8pR4,5746
|
|
881
881
|
supervisely/nn/inference/gui/__init__.py,sha256=wCxd-lF5Zhcwsis-wScDA8n1Gk_1O00PKgDviUZ3F1U,221
|
|
@@ -972,7 +972,7 @@ supervisely/nn/tracker/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
|
|
|
972
972
|
supervisely/nn/tracker/utils/gmc.py,sha256=3JX8979H3NA-YHNaRQyj9Z-xb9qtyMittPEjGw8y2Jo,11557
|
|
973
973
|
supervisely/nn/tracker/utils/kalman_filter.py,sha256=eSFmCjM0mikHCAFvj-KCVzw-0Jxpoc3Cfc2NWEjJC1Q,17268
|
|
974
974
|
supervisely/nn/training/__init__.py,sha256=gY4PCykJ-42MWKsqb9kl-skemKa8yB6t_fb5kzqR66U,111
|
|
975
|
-
supervisely/nn/training/train_app.py,sha256=
|
|
975
|
+
supervisely/nn/training/train_app.py,sha256=oFK1lGNmFAWvSel7nxYd-YA54gA0XJXnaZze1B3pqbg,103947
|
|
976
976
|
supervisely/nn/training/gui/__init__.py,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
|
|
977
977
|
supervisely/nn/training/gui/classes_selector.py,sha256=8UgzA4aogOAr1s42smwEcDbgaBj_i0JLhjwlZ9bFdIA,3772
|
|
978
978
|
supervisely/nn/training/gui/gui.py,sha256=CnT_QhihrxdSHKybpI0pXhPLwCaXEana_qdn0DhXByg,25558
|
|
@@ -1074,9 +1074,9 @@ supervisely/worker_proto/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
|
|
|
1074
1074
|
supervisely/worker_proto/worker_api_pb2.py,sha256=VQfi5JRBHs2pFCK1snec3JECgGnua3Xjqw_-b3aFxuM,59142
|
|
1075
1075
|
supervisely/worker_proto/worker_api_pb2_grpc.py,sha256=3BwQXOaP9qpdi0Dt9EKG--Lm8KGN0C5AgmUfRv77_Jk,28940
|
|
1076
1076
|
supervisely_lib/__init__.py,sha256=7-3QnN8Zf0wj8NCr2oJmqoQWMKKPKTECvjH9pd2S5vY,159
|
|
1077
|
-
supervisely-6.73.
|
|
1078
|
-
supervisely-6.73.
|
|
1079
|
-
supervisely-6.73.
|
|
1080
|
-
supervisely-6.73.
|
|
1081
|
-
supervisely-6.73.
|
|
1082
|
-
supervisely-6.73.
|
|
1077
|
+
supervisely-6.73.304.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
1078
|
+
supervisely-6.73.304.dist-info/METADATA,sha256=5JZrJ5bK3MX8Ao9Aib0L9MaK9HipNOi0gGvq6P568Vc,33573
|
|
1079
|
+
supervisely-6.73.304.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
1080
|
+
supervisely-6.73.304.dist-info/entry_points.txt,sha256=U96-5Hxrp2ApRjnCoUiUhWMqijqh8zLR03sEhWtAcms,102
|
|
1081
|
+
supervisely-6.73.304.dist-info/top_level.txt,sha256=kcFVwb7SXtfqZifrZaSE3owHExX4gcNYe7Q2uoby084,28
|
|
1082
|
+
supervisely-6.73.304.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|