supervisely 6.73.356__py3-none-any.whl → 6.73.358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/_utils.py +12 -0
- supervisely/api/annotation_api.py +3 -0
- supervisely/api/api.py +2 -2
- supervisely/api/app_api.py +27 -2
- supervisely/api/entity_annotation/tag_api.py +0 -1
- supervisely/api/labeling_job_api.py +4 -1
- supervisely/api/nn/__init__.py +0 -0
- supervisely/api/nn/deploy_api.py +821 -0
- supervisely/api/nn/neural_network_api.py +248 -0
- supervisely/api/task_api.py +26 -467
- supervisely/app/fastapi/subapp.py +1 -0
- supervisely/nn/__init__.py +2 -1
- supervisely/nn/artifacts/artifacts.py +5 -5
- supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
- supervisely/nn/experiments.py +28 -5
- supervisely/nn/inference/cache.py +178 -114
- supervisely/nn/inference/gui/gui.py +18 -35
- supervisely/nn/inference/gui/serving_gui.py +3 -1
- supervisely/nn/inference/inference.py +1421 -1265
- supervisely/nn/inference/inference_request.py +412 -0
- supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
- supervisely/nn/inference/session.py +2 -2
- supervisely/nn/inference/tracking/base_tracking.py +45 -79
- supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
- supervisely/nn/inference/tracking/mask_tracking.py +274 -250
- supervisely/nn/inference/tracking/tracker_interface.py +23 -0
- supervisely/nn/inference/uploader.py +164 -0
- supervisely/nn/model/__init__.py +0 -0
- supervisely/nn/model/model_api.py +259 -0
- supervisely/nn/model/prediction.py +311 -0
- supervisely/nn/model/prediction_session.py +632 -0
- supervisely/nn/tracking/__init__.py +1 -0
- supervisely/nn/tracking/boxmot.py +114 -0
- supervisely/nn/tracking/tracking.py +24 -0
- supervisely/nn/training/train_app.py +61 -19
- supervisely/nn/utils.py +43 -3
- supervisely/task/progress.py +12 -2
- supervisely/video/video.py +107 -1
- supervisely/volume_annotation/volume_figure.py +8 -2
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/METADATA +2 -1
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/RECORD +45 -34
- supervisely/api/neural_network_api.py +0 -202
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/LICENSE +0 -0
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/WHEEL +0 -0
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/top_level.txt +0 -0
|
@@ -1,21 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import argparse
|
|
2
4
|
import asyncio
|
|
3
5
|
import inspect
|
|
4
6
|
import json
|
|
5
7
|
import os
|
|
6
8
|
import re
|
|
9
|
+
import shutil
|
|
7
10
|
import subprocess
|
|
8
|
-
import
|
|
11
|
+
import tempfile
|
|
9
12
|
import threading
|
|
10
13
|
import time
|
|
11
|
-
import uuid
|
|
12
14
|
from collections import OrderedDict, defaultdict
|
|
13
15
|
from concurrent.futures import ThreadPoolExecutor
|
|
14
16
|
from dataclasses import asdict, dataclass
|
|
15
17
|
from functools import partial, wraps
|
|
16
18
|
from pathlib import Path
|
|
17
|
-
from
|
|
18
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
19
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
19
20
|
from urllib.request import urlopen
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
@@ -25,6 +26,7 @@ import yaml
|
|
|
25
26
|
from fastapi import Form, HTTPException, Request, Response, UploadFile, status
|
|
26
27
|
from fastapi.responses import JSONResponse
|
|
27
28
|
from requests.structures import CaseInsensitiveDict
|
|
29
|
+
from tqdm import tqdm
|
|
28
30
|
|
|
29
31
|
import supervisely.app.development as sly_app_development
|
|
30
32
|
import supervisely.imaging.image as sly_image
|
|
@@ -32,7 +34,7 @@ import supervisely.io.env as sly_env
|
|
|
32
34
|
import supervisely.io.fs as sly_fs
|
|
33
35
|
import supervisely.io.json as sly_json
|
|
34
36
|
import supervisely.nn.inference.gui as GUI
|
|
35
|
-
from supervisely import DatasetInfo,
|
|
37
|
+
from supervisely import DatasetInfo, batched
|
|
36
38
|
from supervisely._utils import (
|
|
37
39
|
add_callback,
|
|
38
40
|
get_filename_from_headers,
|
|
@@ -49,8 +51,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
|
|
|
49
51
|
from supervisely.api.api import Api, ApiField
|
|
50
52
|
from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
|
|
51
53
|
from supervisely.api.image_api import ImageInfo
|
|
52
|
-
from supervisely.app.content import
|
|
53
|
-
from supervisely.app.exceptions import DialogWindowError
|
|
54
|
+
from supervisely.app.content import get_data_dir
|
|
54
55
|
from supervisely.app.fastapi.subapp import (
|
|
55
56
|
Application,
|
|
56
57
|
call_on_autostart,
|
|
@@ -68,7 +69,13 @@ from supervisely.geometry.any_geometry import AnyGeometry
|
|
|
68
69
|
from supervisely.imaging.color import get_predefined_colors
|
|
69
70
|
from supervisely.io.fs import list_files
|
|
70
71
|
from supervisely.nn.inference.cache import InferenceImageCache
|
|
71
|
-
from supervisely.nn.
|
|
72
|
+
from supervisely.nn.inference.inference_request import (
|
|
73
|
+
InferenceRequest,
|
|
74
|
+
InferenceRequestsManager,
|
|
75
|
+
)
|
|
76
|
+
from supervisely.nn.inference.uploader import Uploader
|
|
77
|
+
from supervisely.nn.model.model_api import Prediction
|
|
78
|
+
from supervisely.nn.prediction_dto import Prediction as PredictionDTO
|
|
72
79
|
from supervisely.nn.utils import (
|
|
73
80
|
CheckpointInfo,
|
|
74
81
|
DeployInfo,
|
|
@@ -76,13 +83,15 @@ from supervisely.nn.utils import (
|
|
|
76
83
|
ModelSource,
|
|
77
84
|
RuntimeType,
|
|
78
85
|
_get_model_name,
|
|
86
|
+
get_gpu_usage,
|
|
87
|
+
get_ram_usage,
|
|
79
88
|
)
|
|
80
89
|
from supervisely.project import ProjectType
|
|
81
90
|
from supervisely.project.download import download_to_cache, read_from_cached_project
|
|
82
91
|
from supervisely.project.project_meta import ProjectMeta
|
|
83
92
|
from supervisely.sly_logger import logger
|
|
84
93
|
from supervisely.task.progress import Progress
|
|
85
|
-
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS
|
|
94
|
+
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
|
|
86
95
|
|
|
87
96
|
try:
|
|
88
97
|
from typing import Literal
|
|
@@ -283,6 +292,8 @@ class Inference:
|
|
|
283
292
|
log_progress=True,
|
|
284
293
|
)
|
|
285
294
|
|
|
295
|
+
self.inference_requests_manager = InferenceRequestsManager(executor=self._executor)
|
|
296
|
+
|
|
286
297
|
def get_batch_size(self):
|
|
287
298
|
if self.max_batch_size is not None:
|
|
288
299
|
return min(self.DEFAULT_BATCH_SIZE, self.max_batch_size)
|
|
@@ -595,10 +606,55 @@ class Inference:
|
|
|
595
606
|
def _checkpoints_cache_dir(self):
|
|
596
607
|
return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
|
|
597
608
|
|
|
609
|
+
def _build_deploy_params_from_api(self, model_name: str, deploy_params: dict = None) -> dict:
|
|
610
|
+
if deploy_params is None:
|
|
611
|
+
deploy_params = {}
|
|
612
|
+
selected_model = None
|
|
613
|
+
for model in self.pretrained_models:
|
|
614
|
+
if model["meta"]["model_name"].lower() == model_name.lower():
|
|
615
|
+
selected_model = model
|
|
616
|
+
break
|
|
617
|
+
if selected_model is None:
|
|
618
|
+
raise ValueError(f"Model {model_name} not found in models.json of serving app")
|
|
619
|
+
deploy_params["model_files"] = selected_model["meta"]["model_files"]
|
|
620
|
+
deploy_params["model_info"] = selected_model
|
|
621
|
+
return deploy_params
|
|
622
|
+
|
|
623
|
+
def _build_legacy_deploy_params_from_api(self, model_name: str) -> dict:
|
|
624
|
+
selected_model = None
|
|
625
|
+
if hasattr(self, "pretrained_models_table"):
|
|
626
|
+
selected_model = self.pretrained_models_table.get_by_model_name(model_name)
|
|
627
|
+
if selected_model is None:
|
|
628
|
+
# @TODO: Improve error message
|
|
629
|
+
raise ValueError("This app doesn't support new deploy api")
|
|
630
|
+
|
|
631
|
+
self.pretrained_models_table.set_by_model_name(model_name)
|
|
632
|
+
deploy_params = self.pretrained_models_table.get_selected_model_params()
|
|
633
|
+
return deploy_params
|
|
634
|
+
|
|
635
|
+
# @TODO: method name should be better?
|
|
636
|
+
def _set_common_deploy_params(self, deploy_params: dict) -> dict:
|
|
637
|
+
load_model_params = inspect.signature(self.load_model).parameters
|
|
638
|
+
has_runtime_param = "runtime" in load_model_params
|
|
639
|
+
|
|
640
|
+
if has_runtime_param:
|
|
641
|
+
if deploy_params.get("runtime", None) is None:
|
|
642
|
+
deploy_params["runtime"] = RuntimeType.PYTORCH
|
|
643
|
+
if deploy_params.get("device", None) is None:
|
|
644
|
+
deploy_params["device"] = "cuda:0" if get_gpu_count() > 0 else "cpu"
|
|
645
|
+
return deploy_params
|
|
646
|
+
|
|
598
647
|
def _download_model_files(self, deploy_params: dict, log_progress: bool = True) -> dict:
|
|
599
|
-
if deploy_params["
|
|
600
|
-
|
|
601
|
-
|
|
648
|
+
if deploy_params["model_source"] == ModelSource.PRETRAINED:
|
|
649
|
+
headless = self.gui is None
|
|
650
|
+
return self._download_pretrained_model(
|
|
651
|
+
deploy_params["model_files"], log_progress, headless
|
|
652
|
+
)
|
|
653
|
+
elif deploy_params["model_source"] == ModelSource.CUSTOM:
|
|
654
|
+
if deploy_params["runtime"] != RuntimeType.PYTORCH:
|
|
655
|
+
export = deploy_params["model_info"].get("export", {})
|
|
656
|
+
if export is None:
|
|
657
|
+
export = {}
|
|
602
658
|
export_model = export.get(deploy_params["runtime"], None)
|
|
603
659
|
if export_model is not None:
|
|
604
660
|
if sly_fs.get_file_name(export_model) == sly_fs.get_file_name(
|
|
@@ -608,13 +664,11 @@ class Inference:
|
|
|
608
664
|
deploy_params["model_info"]["artifacts_dir"] + export_model
|
|
609
665
|
)
|
|
610
666
|
logger.info(f"Found model checkpoint for '{deploy_params['runtime']}'")
|
|
611
|
-
|
|
612
|
-
if deploy_params["model_source"] == ModelSource.PRETRAINED:
|
|
613
|
-
return self._download_pretrained_model(deploy_params["model_files"], log_progress)
|
|
614
|
-
elif deploy_params["model_source"] == ModelSource.CUSTOM:
|
|
615
667
|
return self._download_custom_model(deploy_params["model_files"], log_progress)
|
|
616
668
|
|
|
617
|
-
def _download_pretrained_model(
|
|
669
|
+
def _download_pretrained_model(
|
|
670
|
+
self, model_files: dict, log_progress: bool = True, headless: bool = False
|
|
671
|
+
):
|
|
618
672
|
"""
|
|
619
673
|
Downloads the pretrained model data.
|
|
620
674
|
"""
|
|
@@ -642,26 +696,39 @@ class Inference:
|
|
|
642
696
|
continue
|
|
643
697
|
|
|
644
698
|
if log_progress:
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
699
|
+
if not headless:
|
|
700
|
+
with self.gui.download_progress(
|
|
701
|
+
message=f"Downloading: '{file_name}'",
|
|
702
|
+
total=file_size,
|
|
703
|
+
unit="bytes",
|
|
704
|
+
unit_scale=True,
|
|
705
|
+
) as download_pbar:
|
|
706
|
+
self.gui.download_progress.show()
|
|
707
|
+
sly_fs.download(
|
|
708
|
+
url=file_url,
|
|
709
|
+
save_path=file_path,
|
|
710
|
+
progress=download_pbar.update,
|
|
711
|
+
)
|
|
712
|
+
else:
|
|
713
|
+
with tqdm(
|
|
714
|
+
total=file_size,
|
|
715
|
+
unit="bytes",
|
|
716
|
+
unit_scale=True,
|
|
717
|
+
) as download_pbar:
|
|
718
|
+
logger.info(f"Downloading: '{file_name}'")
|
|
719
|
+
sly_fs.download(
|
|
720
|
+
url=file_url, save_path=file_path, progress=download_pbar.update
|
|
721
|
+
)
|
|
657
722
|
else:
|
|
723
|
+
logger.info(f"Downloading: '{file_name}'")
|
|
658
724
|
sly_fs.download(url=file_url, save_path=file_path)
|
|
659
725
|
local_model_files[file] = file_path
|
|
660
726
|
else:
|
|
661
727
|
local_model_files[file] = file_url
|
|
662
728
|
|
|
663
729
|
if log_progress:
|
|
664
|
-
self.gui
|
|
730
|
+
if self.gui is not None:
|
|
731
|
+
self.gui.download_progress.hide()
|
|
665
732
|
return local_model_files
|
|
666
733
|
|
|
667
734
|
def _download_custom_model(self, model_files: dict, log_progress: bool = True):
|
|
@@ -732,7 +799,7 @@ class Inference:
|
|
|
732
799
|
self.gui.show_deployed_model_info(self)
|
|
733
800
|
|
|
734
801
|
def load_custom_checkpoint(
|
|
735
|
-
self, model_files: dict, model_meta: dict, device: str =
|
|
802
|
+
self, model_files: dict, model_meta: dict, device: Optional[str] = None, **kwargs
|
|
736
803
|
):
|
|
737
804
|
"""
|
|
738
805
|
Loads local custom model checkpoint.
|
|
@@ -886,7 +953,8 @@ class Inference:
|
|
|
886
953
|
classes = None
|
|
887
954
|
try:
|
|
888
955
|
classes = self.get_classes()
|
|
889
|
-
|
|
956
|
+
if classes is not None:
|
|
957
|
+
num_classes = len(classes)
|
|
890
958
|
except NotImplementedError:
|
|
891
959
|
logger.warn(f"get_classes() function not implemented for {type(self)} object.")
|
|
892
960
|
except AttributeError:
|
|
@@ -1002,13 +1070,13 @@ class Inference:
|
|
|
1002
1070
|
self._model_meta = self._model_meta.add_tag_meta(tag_meta)
|
|
1003
1071
|
return tag_meta
|
|
1004
1072
|
|
|
1005
|
-
def _create_label(self, dto:
|
|
1073
|
+
def _create_label(self, dto: PredictionDTO) -> Label:
|
|
1006
1074
|
raise NotImplementedError("Have to be implemented in child class")
|
|
1007
1075
|
|
|
1008
1076
|
def _predictions_to_annotation(
|
|
1009
1077
|
self,
|
|
1010
1078
|
image_path: Union[str, np.ndarray],
|
|
1011
|
-
predictions: List[
|
|
1079
|
+
predictions: List[PredictionDTO],
|
|
1012
1080
|
classes_whitelist: Optional[List[str]] = None,
|
|
1013
1081
|
) -> Annotation:
|
|
1014
1082
|
labels = []
|
|
@@ -1067,6 +1135,15 @@ class Inference:
|
|
|
1067
1135
|
logger.error(f"Error in {func.__name__} function: {e}", exc_info=True)
|
|
1068
1136
|
raise e
|
|
1069
1137
|
|
|
1138
|
+
def api_from_request(self, request) -> Api:
|
|
1139
|
+
"""
|
|
1140
|
+
Get API from request. If not found, use self.api.
|
|
1141
|
+
"""
|
|
1142
|
+
api = request.state.api
|
|
1143
|
+
if api is None:
|
|
1144
|
+
api = self.api
|
|
1145
|
+
return api
|
|
1146
|
+
|
|
1070
1147
|
def _inference_auto(
|
|
1071
1148
|
self,
|
|
1072
1149
|
source: List[Union[str, np.ndarray]],
|
|
@@ -1117,10 +1194,12 @@ class Inference:
|
|
|
1117
1194
|
settings = self._get_inference_settings({})
|
|
1118
1195
|
|
|
1119
1196
|
if isinstance(source[0], int):
|
|
1120
|
-
|
|
1121
|
-
self.api, {"batch_ids": source, "settings": settings}
|
|
1197
|
+
results = self.inference_requests_manager.run(
|
|
1198
|
+
self._inference_image_ids, self.api, {"batch_ids": source, "settings": settings}
|
|
1122
1199
|
)
|
|
1123
|
-
anns = [
|
|
1200
|
+
anns = [
|
|
1201
|
+
Annotation.from_json(result["annotation"], self.model_meta) for result in results
|
|
1202
|
+
]
|
|
1124
1203
|
else:
|
|
1125
1204
|
anns, _ = self._inference_auto(source, settings)
|
|
1126
1205
|
if not input_is_list:
|
|
@@ -1240,17 +1319,17 @@ class Inference:
|
|
|
1240
1319
|
return anns, benchmark
|
|
1241
1320
|
|
|
1242
1321
|
# pylint: disable=method-hidden
|
|
1243
|
-
def predict(self, image_path: str, settings: Dict[str, Any]) -> List[
|
|
1322
|
+
def predict(self, image_path: str, settings: Dict[str, Any]) -> List[PredictionDTO]:
|
|
1244
1323
|
raise NotImplementedError("Have to be implemented in child class")
|
|
1245
1324
|
|
|
1246
|
-
def predict_raw(self, image_path: str, settings: Dict[str, Any]) -> List[
|
|
1325
|
+
def predict_raw(self, image_path: str, settings: Dict[str, Any]) -> List[PredictionDTO]:
|
|
1247
1326
|
raise NotImplementedError(
|
|
1248
1327
|
"Have to be implemented in child class If sliding_window_mode is 'advanced'."
|
|
1249
1328
|
)
|
|
1250
1329
|
|
|
1251
1330
|
def predict_batch(
|
|
1252
1331
|
self, images_np: List[np.ndarray], settings: Dict[str, Any]
|
|
1253
|
-
) -> List[List[
|
|
1332
|
+
) -> List[List[PredictionDTO]]:
|
|
1254
1333
|
"""Predict batch of images. `images_np` is a list of numpy arrays in RGB format
|
|
1255
1334
|
|
|
1256
1335
|
If this method is not overridden in a subclass, the following fallback logic works:
|
|
@@ -1267,7 +1346,7 @@ class Inference:
|
|
|
1267
1346
|
|
|
1268
1347
|
def predict_batch_raw(
|
|
1269
1348
|
self, images_np: List[np.ndarray], settings: Dict[str, Any]
|
|
1270
|
-
) -> List[List[
|
|
1349
|
+
) -> List[List[PredictionDTO]]:
|
|
1271
1350
|
"""Predict batch of images. `source` is a list of numpy arrays in RGB format"""
|
|
1272
1351
|
raise NotImplementedError(
|
|
1273
1352
|
"Have to be implemented in child class If sliding_window_mode is 'advanced'."
|
|
@@ -1275,7 +1354,7 @@ class Inference:
|
|
|
1275
1354
|
|
|
1276
1355
|
def predict_benchmark(
|
|
1277
1356
|
self, images_np: List[np.ndarray], settings: dict
|
|
1278
|
-
) -> Tuple[List[List[
|
|
1357
|
+
) -> Tuple[List[List[PredictionDTO]], dict]:
|
|
1279
1358
|
"""
|
|
1280
1359
|
Inference a batch of images with speedtest benchmarking.
|
|
1281
1360
|
|
|
@@ -1318,15 +1397,24 @@ class Inference:
|
|
|
1318
1397
|
)
|
|
1319
1398
|
return is_predict_batch_overridden or is_predict_benchmark_overridden
|
|
1320
1399
|
|
|
1400
|
+
def set_conf_auto(self, conf: float, inference_settings: dict):
|
|
1401
|
+
conf_names = ["conf", "confidence", "confidence_threshold", "confidence_thresh"]
|
|
1402
|
+
for name in conf_names:
|
|
1403
|
+
if name in inference_settings:
|
|
1404
|
+
inference_settings[name] = conf
|
|
1405
|
+
return inference_settings
|
|
1406
|
+
|
|
1321
1407
|
# pylint: enable=method-hidden
|
|
1322
1408
|
def _get_inference_settings(self, state: dict):
|
|
1323
|
-
settings = state.get("settings"
|
|
1409
|
+
settings = state.get("settings")
|
|
1324
1410
|
if settings is None:
|
|
1325
1411
|
settings = {}
|
|
1326
1412
|
if "rectangle" in state.keys():
|
|
1327
1413
|
settings["rectangle"] = state["rectangle"]
|
|
1414
|
+
conf = settings.get("conf", None)
|
|
1415
|
+
if conf is not None:
|
|
1416
|
+
settings = self.set_conf_auto(conf, settings)
|
|
1328
1417
|
settings["sliding_window_mode"] = self.sliding_window_mode
|
|
1329
|
-
|
|
1330
1418
|
for key, value in self.custom_inference_settings_dict.items():
|
|
1331
1419
|
if key not in settings:
|
|
1332
1420
|
logger.debug(
|
|
@@ -1335,13 +1423,19 @@ class Inference:
|
|
|
1335
1423
|
settings[key] = value
|
|
1336
1424
|
return settings
|
|
1337
1425
|
|
|
1426
|
+
def _get_batch_size_from_state(self, state: dict):
|
|
1427
|
+
batch_size = state.get("batch_size", None)
|
|
1428
|
+
if batch_size is None:
|
|
1429
|
+
batch_size = self.get_batch_size()
|
|
1430
|
+
return batch_size
|
|
1431
|
+
|
|
1338
1432
|
@property
|
|
1339
1433
|
def app(self) -> Application:
|
|
1340
1434
|
return self._app
|
|
1341
1435
|
|
|
1342
1436
|
def visualize(
|
|
1343
1437
|
self,
|
|
1344
|
-
predictions: List[
|
|
1438
|
+
predictions: List[PredictionDTO],
|
|
1345
1439
|
image_path: str,
|
|
1346
1440
|
vis_path: str,
|
|
1347
1441
|
thickness: Optional[int] = None,
|
|
@@ -1358,194 +1452,79 @@ class Inference:
|
|
|
1358
1452
|
|
|
1359
1453
|
def _format_output(
|
|
1360
1454
|
self,
|
|
1361
|
-
|
|
1362
|
-
slides_data: List[dict] = None,
|
|
1455
|
+
predictions: List[Prediction],
|
|
1363
1456
|
) -> List[dict]:
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
image_np = sly_image.read_bytes(file.file.read())
|
|
1373
|
-
logger.debug("Inference settings:", extra=settings)
|
|
1374
|
-
logger.debug("Image info:", extra={"w": image_np.shape[1], "h": image_np.shape[0]})
|
|
1375
|
-
anns, slides_data = self._inference_auto(
|
|
1376
|
-
[image_np],
|
|
1377
|
-
settings=settings,
|
|
1378
|
-
)
|
|
1379
|
-
results = self._format_output(anns, slides_data)
|
|
1380
|
-
return results[0]
|
|
1457
|
+
output = [
|
|
1458
|
+
{
|
|
1459
|
+
**pred.to_json(),
|
|
1460
|
+
"data": pred.extra_data.get("slides_data", {}),
|
|
1461
|
+
}
|
|
1462
|
+
for pred in predictions
|
|
1463
|
+
]
|
|
1464
|
+
return output
|
|
1381
1465
|
|
|
1382
|
-
def
|
|
1466
|
+
def _inference_images(
|
|
1467
|
+
self,
|
|
1468
|
+
images: Iterable[Union[np.ndarray, str]],
|
|
1469
|
+
state: dict,
|
|
1470
|
+
inference_request: InferenceRequest,
|
|
1471
|
+
):
|
|
1383
1472
|
logger.debug("Inferring batch...", extra={"state": state})
|
|
1384
1473
|
settings = self._get_inference_settings(state)
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
)
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
settings = self._get_inference_settings(state)
|
|
1395
|
-
ids = state["batch_ids"]
|
|
1396
|
-
infos = api.image.get_info_by_id_batch(ids)
|
|
1397
|
-
datasets = defaultdict(list)
|
|
1398
|
-
for info in infos:
|
|
1399
|
-
datasets[info.dataset_id].append(info.id)
|
|
1400
|
-
results = []
|
|
1401
|
-
for dataset_id, ids in datasets.items():
|
|
1402
|
-
images_np = api.image.download_nps(dataset_id, ids)
|
|
1474
|
+
logger.debug("Inference settings:", extra={"inference_settings": settings})
|
|
1475
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
1476
|
+
|
|
1477
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(images))
|
|
1478
|
+
for batch in batched_iter(images, batch_size=batch_size):
|
|
1479
|
+
batch = [
|
|
1480
|
+
self.cache.get_image_path(image) if isinstance(image, str) else image
|
|
1481
|
+
for image in batch
|
|
1482
|
+
]
|
|
1403
1483
|
anns, slides_data = self._inference_auto(
|
|
1404
|
-
|
|
1484
|
+
batch,
|
|
1405
1485
|
settings=settings,
|
|
1406
1486
|
)
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
)
|
|
1423
|
-
|
|
1424
|
-
inference_request = {}
|
|
1425
|
-
if async_inference_request_uuid is not None:
|
|
1426
|
-
try:
|
|
1427
|
-
inference_request = self._inference_requests[async_inference_request_uuid]
|
|
1428
|
-
except Exception as ex:
|
|
1429
|
-
import traceback
|
|
1430
|
-
|
|
1431
|
-
logger.error(traceback.format_exc())
|
|
1432
|
-
raise RuntimeError(
|
|
1433
|
-
f"async_inference_request_uuid {async_inference_request_uuid} was given, "
|
|
1434
|
-
f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
|
|
1435
|
-
)
|
|
1436
|
-
|
|
1437
|
-
anns, slides_data = self._inference_auto(
|
|
1438
|
-
[image_np],
|
|
1439
|
-
settings=settings,
|
|
1440
|
-
)
|
|
1441
|
-
ann = anns[0]
|
|
1442
|
-
|
|
1443
|
-
if upload:
|
|
1444
|
-
ds_info = api.dataset.get_info_by_id(image_info.dataset_id, raise_error=True)
|
|
1445
|
-
output_project_id = ds_info.project_id
|
|
1446
|
-
output_project_meta = self.cache.get_project_meta(api, output_project_id)
|
|
1447
|
-
logger.debug("Merging project meta...")
|
|
1448
|
-
|
|
1449
|
-
output_project_meta, ann, meta_changed = update_meta_and_ann(output_project_meta, ann)
|
|
1450
|
-
if meta_changed:
|
|
1451
|
-
output_project_meta = api.project.update_meta(
|
|
1452
|
-
output_project_id, output_project_meta
|
|
1453
|
-
)
|
|
1454
|
-
self.cache.set_project_meta(output_project_id, output_project_meta)
|
|
1455
|
-
|
|
1456
|
-
ann = self._exclude_duplicated_predictions(
|
|
1457
|
-
api, anns, settings, ds_info.id, [image_id], output_project_meta
|
|
1458
|
-
)[0]
|
|
1459
|
-
|
|
1460
|
-
logger.debug(
|
|
1461
|
-
"Uploading annotation...",
|
|
1462
|
-
extra={
|
|
1463
|
-
"image_id": image_id,
|
|
1464
|
-
"dataset_id": ds_info.id,
|
|
1465
|
-
"project_id": output_project_id,
|
|
1466
|
-
},
|
|
1467
|
-
)
|
|
1468
|
-
api.annotation.upload_ann(image_id, ann)
|
|
1469
|
-
else:
|
|
1470
|
-
ann = self._exclude_duplicated_predictions(
|
|
1471
|
-
api, anns, settings, image_info.dataset_id, [image_id]
|
|
1472
|
-
)[0]
|
|
1473
|
-
|
|
1474
|
-
result = self._format_output(anns, slides_data)[0]
|
|
1475
|
-
if async_inference_request_uuid is not None and ann is not None:
|
|
1476
|
-
inference_request["result"] = result
|
|
1477
|
-
return result
|
|
1478
|
-
|
|
1479
|
-
def _inference_image_url(self, api: Api, state: dict):
|
|
1480
|
-
logger.debug("Inferring image_url...", extra={"state": state})
|
|
1481
|
-
settings = self._get_inference_settings(state)
|
|
1482
|
-
image_url = state["image_url"]
|
|
1483
|
-
ext = sly_fs.get_file_ext(image_url)
|
|
1484
|
-
if ext == "":
|
|
1485
|
-
ext = ".jpg"
|
|
1486
|
-
image_path = os.path.join(get_data_dir(), rand_str(15) + ext)
|
|
1487
|
-
sly_fs.download(image_url, image_path)
|
|
1488
|
-
logger.debug("Inference settings:", extra=settings)
|
|
1489
|
-
logger.debug(f"Downloaded path: {image_path}")
|
|
1490
|
-
anns, slides_data = self._inference_auto(
|
|
1491
|
-
[image_path],
|
|
1492
|
-
settings=settings,
|
|
1493
|
-
)
|
|
1494
|
-
sly_fs.silent_remove(image_path)
|
|
1495
|
-
return self._format_output(anns, slides_data)[0]
|
|
1496
|
-
|
|
1497
|
-
def _inference_video_id(self, api: Api, state: dict, async_inference_request_uuid: str = None):
|
|
1498
|
-
from supervisely.nn.inference.video_inference import InferenceVideoInterface
|
|
1499
|
-
|
|
1500
|
-
logger.debug("Inferring video_id...", extra={"state": state})
|
|
1501
|
-
video_info = api.video.get_info_by_id(state["videoId"])
|
|
1502
|
-
n_frames = state.get("framesCount", video_info.frames_count)
|
|
1487
|
+
predictions = [Prediction(ann, model_meta=self.model_meta) for ann in anns]
|
|
1488
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
1489
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
1490
|
+
batch_results = self._format_output(predictions)
|
|
1491
|
+
inference_request.add_results(batch_results)
|
|
1492
|
+
inference_request.done(len(batch_results))
|
|
1493
|
+
|
|
1494
|
+
def _inference_video(
|
|
1495
|
+
self,
|
|
1496
|
+
path: str,
|
|
1497
|
+
state: Dict,
|
|
1498
|
+
inference_request: InferenceRequest,
|
|
1499
|
+
):
|
|
1500
|
+
logger.debug("Inferring video...", extra={"path": path, "state": state})
|
|
1501
|
+
inference_settings = self._get_inference_settings(state)
|
|
1502
|
+
logger.debug(f"Inference settings:", extra=inference_settings)
|
|
1503
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
1503
1504
|
start_frame_index = state.get("startFrameIndex", 0)
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
),
|
|
1513
|
-
)
|
|
1505
|
+
step = state.get("stride", None)
|
|
1506
|
+
if step is None:
|
|
1507
|
+
step = state.get("step", None)
|
|
1508
|
+
if step is None:
|
|
1509
|
+
step = 1
|
|
1510
|
+
end_frame_index = state.get("endFrameIndex", None)
|
|
1511
|
+
duration = state.get("duration", None)
|
|
1512
|
+
frames_count = state.get("framesCount", None)
|
|
1514
1513
|
tracking = state.get("tracker", None)
|
|
1514
|
+
direction = state.get("direction", "forward")
|
|
1515
|
+
direction = 1 if direction == "forward" else -1
|
|
1515
1516
|
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
)
|
|
1528
|
-
sly_progress: Progress = inference_request["progress"]
|
|
1529
|
-
|
|
1530
|
-
sly_progress.total = n_frames
|
|
1531
|
-
inference_request["preparing_progress"]["total"] = n_frames
|
|
1532
|
-
preparing_progress = inference_request["preparing_progress"]
|
|
1533
|
-
|
|
1534
|
-
# progress
|
|
1535
|
-
preparing_progress["status"] = "download_video"
|
|
1536
|
-
preparing_progress["current"] = 0
|
|
1537
|
-
preparing_progress["total"] = int(video_info.file_meta["size"])
|
|
1538
|
-
|
|
1539
|
-
def _progress_cb(chunk_size):
|
|
1540
|
-
preparing_progress["current"] += chunk_size
|
|
1541
|
-
|
|
1542
|
-
self.cache.download_video(api, video_info.id, return_images=False, progress_cb=_progress_cb)
|
|
1543
|
-
preparing_progress["status"] = "inference"
|
|
1544
|
-
|
|
1545
|
-
settings = self._get_inference_settings(state)
|
|
1546
|
-
logger.debug(f"Inference settings:", extra=settings)
|
|
1547
|
-
|
|
1548
|
-
logger.debug(f"Total frames to infer: {n_frames}")
|
|
1517
|
+
frames_reader = VideoFrameReader(path)
|
|
1518
|
+
video_height, video_witdth = frames_reader.frame_size()
|
|
1519
|
+
if frames_count is not None:
|
|
1520
|
+
n_frames = frames_count
|
|
1521
|
+
elif end_frame_index is not None:
|
|
1522
|
+
n_frames = end_frame_index - start_frame_index
|
|
1523
|
+
elif duration is not None:
|
|
1524
|
+
fps = frames_reader.fps()
|
|
1525
|
+
n_frames = int(duration * fps)
|
|
1526
|
+
else:
|
|
1527
|
+
n_frames = frames_reader.frames_count()
|
|
1549
1528
|
|
|
1550
1529
|
if tracking == "bot":
|
|
1551
1530
|
from supervisely.nn.tracker import BoTTracker
|
|
@@ -1557,444 +1536,374 @@ class Inference:
|
|
|
1557
1536
|
tracker = DeepSortTracker(state)
|
|
1558
1537
|
else:
|
|
1559
1538
|
if tracking is not None:
|
|
1560
|
-
logger.
|
|
1539
|
+
logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
|
|
1561
1540
|
tracker = None
|
|
1562
1541
|
|
|
1542
|
+
progress_total = (n_frames + step - 1) // step
|
|
1543
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
1544
|
+
|
|
1563
1545
|
results = []
|
|
1564
|
-
batch_size = state.get("batch_size", None)
|
|
1565
|
-
if batch_size is None:
|
|
1566
|
-
batch_size = self.get_batch_size()
|
|
1567
1546
|
tracks_data = {}
|
|
1568
|
-
direction = 1 if direction == "forward" else -1
|
|
1569
1547
|
for batch in batched(
|
|
1570
|
-
range(start_frame_index, start_frame_index + direction * n_frames, direction),
|
|
1548
|
+
range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
|
|
1571
1549
|
batch_size,
|
|
1572
1550
|
):
|
|
1573
|
-
if (
|
|
1574
|
-
async_inference_request_uuid is not None
|
|
1575
|
-
and inference_request["cancel_inference"] is True
|
|
1576
|
-
):
|
|
1551
|
+
if inference_request.is_stopped():
|
|
1577
1552
|
logger.debug(
|
|
1578
|
-
f"Cancelling inference
|
|
1579
|
-
extra={"inference_request_uuid":
|
|
1553
|
+
f"Cancelling inference...",
|
|
1554
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
1580
1555
|
)
|
|
1581
1556
|
results = []
|
|
1582
1557
|
break
|
|
1583
1558
|
logger.debug(
|
|
1584
1559
|
f"Inferring frames {batch[0]}-{batch[-1]}:",
|
|
1585
1560
|
)
|
|
1586
|
-
frames =
|
|
1561
|
+
frames = frames_reader.read_frames(batch)
|
|
1587
1562
|
anns, slides_data = self._inference_auto(
|
|
1588
1563
|
source=frames,
|
|
1589
|
-
settings=
|
|
1564
|
+
settings=inference_settings,
|
|
1590
1565
|
)
|
|
1566
|
+
predictions = [
|
|
1567
|
+
Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
|
|
1568
|
+
for ann, frame_index in zip(anns, batch)
|
|
1569
|
+
]
|
|
1570
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
1571
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
1572
|
+
batch_results = self._format_output(predictions)
|
|
1591
1573
|
if tracker is not None:
|
|
1592
1574
|
for frame_index, frame, ann in zip(batch, frames, anns):
|
|
1593
1575
|
tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
if async_inference_request_uuid is not None:
|
|
1597
|
-
sly_progress.iters_done(len(batch))
|
|
1598
|
-
inference_request["pending_results"].extend(batch_results)
|
|
1576
|
+
inference_request.add_results(batch_results)
|
|
1577
|
+
inference_request.done(len(batch_results))
|
|
1599
1578
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
1600
1579
|
video_ann_json = None
|
|
1601
1580
|
if tracker is not None:
|
|
1581
|
+
inference_request.set_stage("Postprocess...", 0, 1)
|
|
1602
1582
|
video_ann_json = tracker.get_annotation(
|
|
1603
|
-
tracks_data, (
|
|
1583
|
+
tracks_data, (video_height, video_witdth), n_frames
|
|
1604
1584
|
).to_json()
|
|
1585
|
+
inference_request.done()
|
|
1605
1586
|
result = {"ann": results, "video_ann": video_ann_json}
|
|
1606
|
-
|
|
1607
|
-
inference_request["result"] = result.copy()
|
|
1608
|
-
return result
|
|
1587
|
+
inference_request.final_result = result.copy()
|
|
1609
1588
|
|
|
1610
|
-
def
|
|
1589
|
+
def _inference_image_ids(
|
|
1611
1590
|
self,
|
|
1612
1591
|
api: Api,
|
|
1613
1592
|
state: dict,
|
|
1614
|
-
|
|
1615
|
-
async_inference_request_uuid: str = None,
|
|
1593
|
+
inference_request: InferenceRequest,
|
|
1616
1594
|
):
|
|
1617
1595
|
"""Inference images by ids.
|
|
1618
1596
|
If "output_project_id" in state, upload images and annotations to the output project.
|
|
1619
1597
|
If "output_project_id" equal to source project id, upload annotations to the source project.
|
|
1620
1598
|
If "output_project_id" is None, write annotations to inference request object.
|
|
1621
1599
|
"""
|
|
1622
|
-
logger.debug("Inferring
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1626
|
-
|
|
1627
|
-
|
|
1600
|
+
logger.debug("Inferring batch_ids", extra={"state": state})
|
|
1601
|
+
inference_settings = self._get_inference_settings(state)
|
|
1602
|
+
logger.debug("Inference settings:", extra={"inference_settings": inference_settings})
|
|
1603
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
1604
|
+
image_ids = get_value_for_keys(
|
|
1605
|
+
state, ["batch_ids", "image_ids", "images_ids", "imageIds", "image_id", "imageId"]
|
|
1606
|
+
)
|
|
1607
|
+
if image_ids is None:
|
|
1608
|
+
raise ValueError("Image ids are not provided")
|
|
1609
|
+
if not isinstance(image_ids, list):
|
|
1610
|
+
image_ids = [image_ids]
|
|
1611
|
+
upload_mode = state.get("upload_mode", None)
|
|
1612
|
+
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
1613
|
+
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
1614
|
+
iou_merge_threshold = 0.7
|
|
1615
|
+
|
|
1616
|
+
images_infos = api.image.get_info_by_id_batch(image_ids)
|
|
1628
1617
|
images_infos_dict = {im_info.id: im_info for im_info in images_infos}
|
|
1618
|
+
inference_request.context.setdefault("image_info", {}).update(images_infos_dict)
|
|
1619
|
+
|
|
1629
1620
|
dataset_infos_dict = {
|
|
1630
1621
|
ds_id: api.dataset.get_info_by_id(ds_id)
|
|
1631
1622
|
for ds_id in set([im_info.dataset_id for im_info in images_infos])
|
|
1632
1623
|
}
|
|
1624
|
+
inference_request.context.setdefault("dataset_info", {}).update(dataset_infos_dict)
|
|
1633
1625
|
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
def _download_images(images_ids):
|
|
1649
|
-
with ThreadPoolExecutor(max(8, min(batch_size, 64))) as executor:
|
|
1650
|
-
for image_id in images_ids:
|
|
1651
|
-
executor.submit(
|
|
1652
|
-
self.cache.download_image,
|
|
1653
|
-
api,
|
|
1654
|
-
image_id,
|
|
1655
|
-
)
|
|
1656
|
-
|
|
1657
|
-
# start downloading in parallel
|
|
1658
|
-
threading.Thread(target=_download_images, args=[images_ids], daemon=True).start()
|
|
1659
|
-
|
|
1660
|
-
output_project_metas_dict = {}
|
|
1661
|
-
|
|
1662
|
-
def _upload_results_to_source(results: List[Dict]):
|
|
1663
|
-
nonlocal output_project_metas_dict
|
|
1664
|
-
for result in results:
|
|
1665
|
-
image_id = result["image_id"]
|
|
1666
|
-
image_info: ImageInfo = images_infos_dict[image_id]
|
|
1667
|
-
dataset_info: DatasetInfo = dataset_infos_dict[image_info.dataset_id]
|
|
1668
|
-
project_id = dataset_info.project_id
|
|
1669
|
-
ann = Annotation.from_json(result["annotation"], self.model_meta)
|
|
1670
|
-
output_project_meta = output_project_metas_dict.get(project_id, None)
|
|
1671
|
-
if output_project_meta is None:
|
|
1672
|
-
output_project_meta = ProjectMeta.from_json(
|
|
1673
|
-
api.project.get_meta(output_project_id)
|
|
1674
|
-
)
|
|
1675
|
-
output_project_meta, ann, meta_changed = update_meta_and_ann(
|
|
1676
|
-
output_project_meta, ann
|
|
1677
|
-
)
|
|
1678
|
-
output_project_metas_dict[project_id] = output_project_meta
|
|
1679
|
-
if meta_changed:
|
|
1680
|
-
output_project_meta = api.project.update_meta(project_id, output_project_meta)
|
|
1681
|
-
ann = update_classes(api, ann, output_project_meta, project_id)
|
|
1682
|
-
api.annotation.append_labels(image_id, ann.labels)
|
|
1683
|
-
if async_inference_request_uuid is not None:
|
|
1684
|
-
sly_progress.iters_done(1)
|
|
1685
|
-
inference_request["pending_results"].append(
|
|
1686
|
-
{
|
|
1687
|
-
"annotation": None, # to less response size
|
|
1688
|
-
"data": None, # to less response size
|
|
1689
|
-
"image_id": image_id,
|
|
1690
|
-
"image_name": result["image_name"],
|
|
1691
|
-
"dataset_id": result["dataset_id"],
|
|
1692
|
-
}
|
|
1693
|
-
)
|
|
1694
|
-
|
|
1695
|
-
def _add_results_to_request(results: List[Dict]):
|
|
1696
|
-
if async_inference_request_uuid is None:
|
|
1697
|
-
return
|
|
1698
|
-
inference_request["pending_results"].extend(results)
|
|
1699
|
-
sly_progress.iters_done(len(results))
|
|
1700
|
-
|
|
1701
|
-
new_dataset_id = {}
|
|
1702
|
-
|
|
1703
|
-
def _get_or_create_new_dataset(output_project_id, src_dataset_id):
|
|
1704
|
-
"""Copy dataset in output project if not exists and return its id"""
|
|
1705
|
-
if src_dataset_id in new_dataset_id:
|
|
1706
|
-
return new_dataset_id[src_dataset_id]
|
|
1707
|
-
dataset_info = api.dataset.get_info_by_id(src_dataset_id)
|
|
1708
|
-
|
|
1709
|
-
def _create_parent_recursively(output_project_id, src_parent_id):
|
|
1710
|
-
"""Create parent datasets recursively and return the ID of the top-level parent"""
|
|
1711
|
-
if src_parent_id in new_dataset_id:
|
|
1712
|
-
return new_dataset_id[src_parent_id]
|
|
1713
|
-
src_parent_info = dataset_infos_dict.get(src_parent_id)
|
|
1714
|
-
if src_parent_info is None:
|
|
1715
|
-
src_parent_info = api.dataset.get_info_by_id(src_parent_id)
|
|
1716
|
-
if src_parent_info.parent_id is not None:
|
|
1717
|
-
parent_id = _create_parent_recursively(
|
|
1718
|
-
output_project_id, src_parent_info.parent_id
|
|
1719
|
-
)
|
|
1720
|
-
else:
|
|
1721
|
-
parent_id = None
|
|
1722
|
-
dst_parent = api.dataset.create(
|
|
1723
|
-
output_project_id,
|
|
1724
|
-
src_parent_info.name,
|
|
1725
|
-
change_name_if_conflict=True,
|
|
1726
|
-
parent_id=parent_id,
|
|
1727
|
-
)
|
|
1728
|
-
new_dataset_id[src_parent_info.id] = dst_parent.id
|
|
1729
|
-
return dst_parent.id
|
|
1730
|
-
|
|
1731
|
-
parent_id = None
|
|
1732
|
-
if dataset_info.parent_id is not None:
|
|
1733
|
-
parent_id = _create_parent_recursively(output_project_id, dataset_info.parent_id)
|
|
1734
|
-
|
|
1735
|
-
output_dataset_id = api.dataset.create(
|
|
1736
|
-
output_project_id,
|
|
1737
|
-
dataset_info.name,
|
|
1626
|
+
output_project_id = state.get("output_project_id", None)
|
|
1627
|
+
output_dataset_id = None
|
|
1628
|
+
inference_request.context.setdefault("project_meta", {})
|
|
1629
|
+
if output_project_id is not None:
|
|
1630
|
+
if upload_mode is None:
|
|
1631
|
+
upload_mode = "append"
|
|
1632
|
+
if output_project_id is None and upload_mode == "create":
|
|
1633
|
+
image_info = images_infos[0]
|
|
1634
|
+
dataset_info = dataset_infos_dict[image_info.dataset_id]
|
|
1635
|
+
output_project_info = api.project.create(
|
|
1636
|
+
dataset_info.workspace_id,
|
|
1637
|
+
name=f"Predictions from task #{self.task_id}",
|
|
1638
|
+
description=f"Auto created project from inference request {inference_request.uuid}",
|
|
1738
1639
|
change_name_if_conflict=True,
|
|
1739
|
-
parent_id=parent_id,
|
|
1740
|
-
).id
|
|
1741
|
-
new_dataset_id[src_dataset_id] = output_dataset_id
|
|
1742
|
-
return output_dataset_id
|
|
1743
|
-
|
|
1744
|
-
def _copy_images_to_dst(
|
|
1745
|
-
src_dataset_id, dst_dataset_id, image_infos, dst_names
|
|
1746
|
-
) -> List[ImageInfo]:
|
|
1747
|
-
return api.image.copy_batch_optimized(
|
|
1748
|
-
src_dataset_id,
|
|
1749
|
-
image_infos,
|
|
1750
|
-
dst_dataset_id,
|
|
1751
|
-
dst_names=dst_names,
|
|
1752
|
-
with_annotations=False,
|
|
1753
|
-
skip_validation=True,
|
|
1754
1640
|
)
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
image_infos = _copy_images_to_dst(
|
|
1765
|
-
src_dataset_id, dataset_id, src_image_infos, image_names
|
|
1641
|
+
output_project_id = output_project_info.id
|
|
1642
|
+
inference_request.context.setdefault("project_info", {})[
|
|
1643
|
+
output_project_id
|
|
1644
|
+
] = output_project_info
|
|
1645
|
+
output_dataset_info = api.dataset.create(
|
|
1646
|
+
output_project_id,
|
|
1647
|
+
"Predictions",
|
|
1648
|
+
description=f"Auto created dataset from inference request {inference_request.uuid}",
|
|
1649
|
+
change_name_if_conflict=True,
|
|
1766
1650
|
)
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
|
|
1651
|
+
output_dataset_id = output_dataset_info.id
|
|
1652
|
+
inference_request.context.setdefault("dataset_info", {})[
|
|
1653
|
+
output_dataset_id
|
|
1654
|
+
] = output_dataset_info
|
|
1655
|
+
|
|
1656
|
+
# start download to cache in background
|
|
1657
|
+
dataset_image_infos: Dict[int, List[ImageInfo]] = defaultdict(list)
|
|
1658
|
+
for image_info in images_infos:
|
|
1659
|
+
dataset_image_infos[image_info.dataset_id].append(image_info)
|
|
1660
|
+
for dataset_id, ds_image_infos in dataset_image_infos.items():
|
|
1661
|
+
self.cache.run_cache_task_manually(
|
|
1662
|
+
api, [info.id for info in ds_image_infos], dataset_id=dataset_id
|
|
1776
1663
|
)
|
|
1777
|
-
meta_changed = False
|
|
1778
|
-
anns = []
|
|
1779
|
-
for result in results:
|
|
1780
|
-
ann = Annotation.from_json(result["annotation"], self.model_meta)
|
|
1781
|
-
output_project_meta = output_project_metas_dict.get(output_project_id, None)
|
|
1782
|
-
if output_project_meta is None:
|
|
1783
|
-
output_project_meta = ProjectMeta.from_json(
|
|
1784
|
-
api.project.get_meta(output_project_id)
|
|
1785
|
-
)
|
|
1786
|
-
output_project_meta, ann, c = update_meta_and_ann(output_project_meta, ann)
|
|
1787
|
-
output_project_metas_dict[output_project_id] = output_project_meta
|
|
1788
|
-
meta_changed = meta_changed or c
|
|
1789
|
-
anns.append(ann)
|
|
1790
|
-
if meta_changed:
|
|
1791
|
-
api.project.update_meta(output_project_id, output_project_meta)
|
|
1792
|
-
|
|
1793
|
-
# upload in batches to update progress with each batch
|
|
1794
|
-
# api.annotation.upload_anns() uploads in same batches anyways
|
|
1795
|
-
for batch in batched(list(zip(anns, results, image_infos))):
|
|
1796
|
-
batch_anns, batch_results, batch_image_infos = zip(*batch)
|
|
1797
|
-
api.annotation.upload_anns(
|
|
1798
|
-
img_ids=[info.id for info in batch_image_infos],
|
|
1799
|
-
anns=batch_anns,
|
|
1800
|
-
)
|
|
1801
|
-
if async_inference_request_uuid is not None:
|
|
1802
|
-
sly_progress.iters_done(len(batch_results))
|
|
1803
|
-
inference_request["pending_results"].extend(
|
|
1804
|
-
[{**result, "annotation": None, "data": None} for result in batch_results]
|
|
1805
|
-
)
|
|
1806
1664
|
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
if output_project_id is None:
|
|
1819
|
-
upload_f = _add_results_to_request
|
|
1820
|
-
else:
|
|
1821
|
-
upload_f = upload_results_to_source_or_other
|
|
1822
|
-
|
|
1823
|
-
def _upload_loop(q: Queue, stop_event: threading.Event, api: Api, upload_f: Callable):
|
|
1824
|
-
try:
|
|
1825
|
-
while True:
|
|
1826
|
-
items = []
|
|
1827
|
-
while not q.empty():
|
|
1828
|
-
items.append(q.get_nowait())
|
|
1829
|
-
if len(items) > 0:
|
|
1830
|
-
ds_batches = {}
|
|
1831
|
-
for batch in items:
|
|
1832
|
-
if len(batch) == 0:
|
|
1833
|
-
continue
|
|
1834
|
-
for each in batch:
|
|
1835
|
-
ds_batches.setdefault(each["dataset_id"], []).append(each)
|
|
1836
|
-
for _, joined_batch in ds_batches.items():
|
|
1837
|
-
upload_f(joined_batch)
|
|
1838
|
-
continue
|
|
1839
|
-
if stop_event.is_set():
|
|
1840
|
-
self._on_inference_end(None, async_inference_request_uuid)
|
|
1841
|
-
return
|
|
1842
|
-
time.sleep(1)
|
|
1843
|
-
except Exception as e:
|
|
1844
|
-
api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
|
|
1845
|
-
raise
|
|
1846
|
-
|
|
1847
|
-
upload_queue = Queue()
|
|
1848
|
-
stop_upload_event = threading.Event()
|
|
1849
|
-
upload_thread = threading.Thread(
|
|
1850
|
-
target=_upload_loop,
|
|
1851
|
-
args=[upload_queue, stop_upload_event, api, upload_f],
|
|
1852
|
-
daemon=True,
|
|
1665
|
+
_upload_predictions = partial(
|
|
1666
|
+
self.upload_predictions,
|
|
1667
|
+
api=api,
|
|
1668
|
+
upload_mode=upload_mode,
|
|
1669
|
+
context=inference_request.context,
|
|
1670
|
+
dst_dataset_id=output_dataset_id,
|
|
1671
|
+
dst_project_id=output_project_id,
|
|
1672
|
+
progress_cb=inference_request.done,
|
|
1673
|
+
iou_merge_threshold=iou_merge_threshold,
|
|
1674
|
+
inference_request=inference_request,
|
|
1853
1675
|
)
|
|
1854
|
-
upload_thread.start()
|
|
1855
1676
|
|
|
1856
|
-
|
|
1857
|
-
|
|
1677
|
+
_add_results_to_request = partial(
|
|
1678
|
+
self.add_results_to_request, inference_request=inference_request
|
|
1679
|
+
)
|
|
1858
1680
|
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1681
|
+
if upload_mode is None:
|
|
1682
|
+
upload_f = _add_results_to_request
|
|
1683
|
+
else:
|
|
1684
|
+
upload_f = _upload_predictions
|
|
1685
|
+
|
|
1686
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(image_ids))
|
|
1687
|
+
with Uploader(upload_f, logger=logger) as uploader:
|
|
1688
|
+
for image_ids_batch in batched(image_ids, batch_size=batch_size):
|
|
1689
|
+
if uploader.has_exception():
|
|
1690
|
+
exception = uploader.exception()
|
|
1691
|
+
raise RuntimeError(f"Error in upload loop: {exception}") from exception
|
|
1692
|
+
if inference_request.is_stopped():
|
|
1869
1693
|
logger.debug(
|
|
1870
1694
|
f"Cancelling inference project...",
|
|
1871
|
-
extra={"inference_request_uuid":
|
|
1695
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
1872
1696
|
)
|
|
1873
|
-
results = []
|
|
1874
|
-
stop = True
|
|
1875
1697
|
break
|
|
1876
1698
|
|
|
1877
1699
|
images_nps = [self.cache.download_image(api, img_id) for img_id in image_ids_batch]
|
|
1878
1700
|
anns, slides_data = self._inference_auto(
|
|
1879
1701
|
source=images_nps,
|
|
1880
|
-
settings=
|
|
1702
|
+
settings=inference_settings,
|
|
1881
1703
|
)
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
batch_results.append(
|
|
1895
|
-
{
|
|
1896
|
-
"annotation": ann.to_json(),
|
|
1897
|
-
"data": slides_data[i],
|
|
1898
|
-
"image_id": image_info.id,
|
|
1899
|
-
"image_name": image_info.name,
|
|
1900
|
-
"dataset_id": image_info.dataset_id,
|
|
1901
|
-
}
|
|
1704
|
+
|
|
1705
|
+
batch_predictions = []
|
|
1706
|
+
for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
|
|
1707
|
+
image_info: ImageInfo = images_infos_dict[image_id]
|
|
1708
|
+
dataset_info = dataset_infos_dict[image_info.dataset_id]
|
|
1709
|
+
prediction = Prediction(
|
|
1710
|
+
ann,
|
|
1711
|
+
model_meta=self.model_meta,
|
|
1712
|
+
name=image_info.name,
|
|
1713
|
+
image_id=image_info.id,
|
|
1714
|
+
dataset_id=image_info.dataset_id,
|
|
1715
|
+
project_id=dataset_info.project_id,
|
|
1902
1716
|
)
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
except Exception:
|
|
1906
|
-
stop_upload_event.set()
|
|
1907
|
-
upload_thread.join()
|
|
1908
|
-
raise
|
|
1909
|
-
if async_inference_request_uuid is not None and len(results) > 0:
|
|
1910
|
-
inference_request["result"] = {"ann": results}
|
|
1911
|
-
stop_upload_event.set()
|
|
1912
|
-
upload_thread.join()
|
|
1913
|
-
return results
|
|
1717
|
+
prediction.extra_data["slides_data"] = this_slides_data
|
|
1718
|
+
batch_predictions.append(prediction)
|
|
1914
1719
|
|
|
1915
|
-
|
|
1720
|
+
uploader.put(batch_predictions)
|
|
1721
|
+
|
|
1722
|
+
def _inference_video_id(
|
|
1916
1723
|
self,
|
|
1917
1724
|
api: Api,
|
|
1918
1725
|
state: dict,
|
|
1919
|
-
|
|
1920
|
-
async_inference_request_uuid: str = None,
|
|
1726
|
+
inference_request: InferenceRequest,
|
|
1921
1727
|
):
|
|
1728
|
+
logger.debug("Inferring video_id...", extra={"state": state})
|
|
1729
|
+
inference_settings = self._get_inference_settings(state)
|
|
1730
|
+
logger.debug(f"Inference settings:", extra=inference_settings)
|
|
1731
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
1732
|
+
video_id = state["videoId"]
|
|
1733
|
+
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
1734
|
+
if video_id is None:
|
|
1735
|
+
raise ValueError("Video id is not provided")
|
|
1736
|
+
video_info = api.video.get_info_by_id(video_id)
|
|
1737
|
+
start_frame_index = get_value_for_keys(
|
|
1738
|
+
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
1739
|
+
)
|
|
1740
|
+
if start_frame_index is None:
|
|
1741
|
+
start_frame_index = 0
|
|
1742
|
+
step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
|
|
1743
|
+
if step is None:
|
|
1744
|
+
step = 1
|
|
1745
|
+
end_frame_index = get_value_for_keys(
|
|
1746
|
+
state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
|
|
1747
|
+
)
|
|
1748
|
+
duration = state.get("duration", None)
|
|
1749
|
+
frames_count = get_value_for_keys(
|
|
1750
|
+
state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
|
|
1751
|
+
)
|
|
1752
|
+
tracking = state.get("tracker", None)
|
|
1753
|
+
direction = state.get("direction", "forward")
|
|
1754
|
+
direction = 1 if direction == "forward" else -1
|
|
1755
|
+
|
|
1756
|
+
if frames_count is not None:
|
|
1757
|
+
n_frames = frames_count
|
|
1758
|
+
elif end_frame_index is not None:
|
|
1759
|
+
n_frames = end_frame_index - start_frame_index
|
|
1760
|
+
elif duration is not None:
|
|
1761
|
+
fps = video_info.frames_count / video_info.duration
|
|
1762
|
+
n_frames = int(duration * fps)
|
|
1763
|
+
else:
|
|
1764
|
+
n_frames = video_info.frames_count
|
|
1765
|
+
|
|
1766
|
+
if tracking == "bot":
|
|
1767
|
+
from supervisely.nn.tracker import BoTTracker
|
|
1768
|
+
|
|
1769
|
+
tracker = BoTTracker(state)
|
|
1770
|
+
elif tracking == "deepsort":
|
|
1771
|
+
from supervisely.nn.tracker import DeepSortTracker
|
|
1772
|
+
|
|
1773
|
+
tracker = DeepSortTracker(state)
|
|
1774
|
+
else:
|
|
1775
|
+
if tracking is not None:
|
|
1776
|
+
logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
|
|
1777
|
+
tracker = None
|
|
1778
|
+
logger.debug(
|
|
1779
|
+
f"Video info:",
|
|
1780
|
+
extra=dict(
|
|
1781
|
+
w=video_info.frame_width,
|
|
1782
|
+
h=video_info.frame_height,
|
|
1783
|
+
start_frame_index=start_frame_index,
|
|
1784
|
+
n_frames=n_frames,
|
|
1785
|
+
),
|
|
1786
|
+
)
|
|
1787
|
+
|
|
1788
|
+
# start downloading video in background
|
|
1789
|
+
self.cache.run_cache_task_manually(api, None, video_id=video_id)
|
|
1790
|
+
|
|
1791
|
+
progress_total = (n_frames + step - 1) // step
|
|
1792
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
1793
|
+
|
|
1794
|
+
tracks_data = {}
|
|
1795
|
+
for batch in batched(
|
|
1796
|
+
range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
|
|
1797
|
+
batch_size,
|
|
1798
|
+
):
|
|
1799
|
+
if inference_request.is_stopped():
|
|
1800
|
+
logger.debug(
|
|
1801
|
+
f"Cancelling inference video...",
|
|
1802
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
1803
|
+
)
|
|
1804
|
+
break
|
|
1805
|
+
logger.debug(
|
|
1806
|
+
f"Inferring frames {batch[0]}-{batch[-1]}:",
|
|
1807
|
+
)
|
|
1808
|
+
frames = self.cache.download_frames(api, video_info.id, batch, redownload_video=True)
|
|
1809
|
+
anns, slides_data = self._inference_auto(
|
|
1810
|
+
source=frames,
|
|
1811
|
+
settings=inference_settings,
|
|
1812
|
+
)
|
|
1813
|
+
predictions = [
|
|
1814
|
+
Prediction(
|
|
1815
|
+
ann,
|
|
1816
|
+
model_meta=self.model_meta,
|
|
1817
|
+
frame_index=frame_index,
|
|
1818
|
+
video_id=video_info.id,
|
|
1819
|
+
dataset_id=video_info.dataset_id,
|
|
1820
|
+
project_id=video_info.project_id,
|
|
1821
|
+
)
|
|
1822
|
+
for ann, frame_index in zip(anns, batch)
|
|
1823
|
+
]
|
|
1824
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
1825
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
1826
|
+
batch_results = self._format_output(predictions)
|
|
1827
|
+
if tracker is not None:
|
|
1828
|
+
for frame_index, frame, ann in zip(batch, frames, anns):
|
|
1829
|
+
tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
|
|
1830
|
+
inference_request.add_results(batch_results)
|
|
1831
|
+
inference_request.done(len(batch_results))
|
|
1832
|
+
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
1833
|
+
video_ann_json = None
|
|
1834
|
+
if tracker is not None:
|
|
1835
|
+
inference_request.set_stage("Postprocess...", 0, 1)
|
|
1836
|
+
video_ann_json = tracker.get_annotation(
|
|
1837
|
+
tracks_data, (video_info.frame_height, video_info.frame_width), n_frames
|
|
1838
|
+
).to_json()
|
|
1839
|
+
inference_request.done()
|
|
1840
|
+
inference_request.final_result = {"video_ann": video_ann_json}
|
|
1841
|
+
|
|
1842
|
+
def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
1922
1843
|
"""Inference project images.
|
|
1923
1844
|
If "output_project_id" in state, upload images and annotations to the output project.
|
|
1924
1845
|
If "output_project_id" equal to source project id, upload annotations to the source project.
|
|
1925
1846
|
If "output_project_id" is None, write annotations to inference request object.
|
|
1926
1847
|
"""
|
|
1927
1848
|
logger.debug("Inferring project...", extra={"state": state})
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1849
|
+
inference_settings = self._get_inference_settings(state)
|
|
1850
|
+
logger.debug("Inference settings:", extra={"inference_settings": inference_settings})
|
|
1851
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
1852
|
+
project_id = get_value_for_keys(state, keys=["projectId", "project_id"])
|
|
1853
|
+
if project_id is None:
|
|
1854
|
+
raise ValueError("Project id is not provided")
|
|
1855
|
+
project_info = api.project.get_info_by_id(project_id)
|
|
1856
|
+
if project_info.type != str(ProjectType.IMAGES):
|
|
1857
|
+
raise ValueError("Only images projects are supported.")
|
|
1858
|
+
upload_mode = state.get("upload_mode", None)
|
|
1859
|
+
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
1860
|
+
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
1861
|
+
iou_merge_threshold = 0.7
|
|
1931
1862
|
cache_project_on_model = state.get("cache_project_on_model", False)
|
|
1932
|
-
batch_size = state.get("batch_size", None)
|
|
1933
|
-
if batch_size is None:
|
|
1934
|
-
batch_size = self.get_batch_size()
|
|
1935
1863
|
|
|
1864
|
+
project_info = api.project.get_info_by_id(project_id)
|
|
1865
|
+
inference_request.context.setdefault("project_info", {})[project_id] = project_info
|
|
1866
|
+
dataset_ids = state.get("dataset_ids", None)
|
|
1867
|
+
if dataset_ids is None:
|
|
1868
|
+
dataset_ids = state.get("datasetIds", None)
|
|
1936
1869
|
datasets_infos = api.dataset.get_list(project_info.id, recursive=True)
|
|
1870
|
+
inference_request.context.setdefault("dataset_info", {}).update(
|
|
1871
|
+
{ds_info.id: ds_info for ds_info in datasets_infos}
|
|
1872
|
+
)
|
|
1937
1873
|
if dataset_ids is not None:
|
|
1938
1874
|
datasets_infos = [ds_info for ds_info in datasets_infos if ds_info.id in dataset_ids]
|
|
1939
1875
|
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
preparing_progress["current"] = 0
|
|
1944
|
-
preparing_progress["total"] = len(datasets_infos)
|
|
1945
|
-
progress_cb = None
|
|
1946
|
-
if async_inference_request_uuid is not None:
|
|
1947
|
-
try:
|
|
1948
|
-
inference_request = self._inference_requests[async_inference_request_uuid]
|
|
1949
|
-
except Exception as ex:
|
|
1950
|
-
import traceback
|
|
1951
|
-
|
|
1952
|
-
logger.error(traceback.format_exc())
|
|
1953
|
-
raise RuntimeError(
|
|
1954
|
-
f"async_inference_request_uuid {async_inference_request_uuid} was given, "
|
|
1955
|
-
f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
|
|
1956
|
-
)
|
|
1957
|
-
sly_progress: Progress = inference_request["progress"]
|
|
1958
|
-
sly_progress.total = sum([ds_info.items_count for ds_info in datasets_infos])
|
|
1959
|
-
|
|
1960
|
-
inference_request["preparing_progress"]["total"] = len(datasets_infos)
|
|
1961
|
-
preparing_progress = inference_request["preparing_progress"]
|
|
1962
|
-
|
|
1963
|
-
if cache_project_on_model:
|
|
1964
|
-
progress_cb = sly_progress.iters_done
|
|
1965
|
-
preparing_progress["total"] = sly_progress.total
|
|
1966
|
-
preparing_progress["status"] = "download_project"
|
|
1876
|
+
preparing_progress_total = sum([ds_info.items_count for ds_info in datasets_infos])
|
|
1877
|
+
inference_progress_total = preparing_progress_total
|
|
1878
|
+
inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, preparing_progress_total)
|
|
1967
1879
|
|
|
1968
1880
|
output_project_id = state.get("output_project_id", None)
|
|
1969
|
-
|
|
1881
|
+
inference_request.context.setdefault("project_meta", {})
|
|
1970
1882
|
if output_project_id is not None:
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
output_project_id, output_project_meta
|
|
1985
|
-
)
|
|
1883
|
+
if upload_mode is None:
|
|
1884
|
+
upload_mode = "append"
|
|
1885
|
+
if output_project_id is None and upload_mode == "create":
|
|
1886
|
+
output_project_info = api.project.create(
|
|
1887
|
+
project_info.workspace_id,
|
|
1888
|
+
name=f"Predictions from task #{self.task_id}",
|
|
1889
|
+
description=f"Auto created project from inference request {inference_request.uuid}",
|
|
1890
|
+
change_name_if_conflict=True,
|
|
1891
|
+
)
|
|
1892
|
+
output_project_id = output_project_info.id
|
|
1893
|
+
inference_request.context.setdefault("project_info", {})[
|
|
1894
|
+
output_project_id
|
|
1895
|
+
] = output_project_info
|
|
1986
1896
|
|
|
1987
1897
|
if cache_project_on_model:
|
|
1988
|
-
download_to_cache(
|
|
1898
|
+
download_to_cache(
|
|
1899
|
+
api, project_info.id, datasets_infos, progress_cb=inference_request.done
|
|
1900
|
+
)
|
|
1989
1901
|
|
|
1990
1902
|
images_infos_dict = {}
|
|
1991
1903
|
for dataset_info in datasets_infos:
|
|
1992
1904
|
images_infos_dict[dataset_info.id] = api.image.get_list(dataset_info.id)
|
|
1993
1905
|
if not cache_project_on_model:
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
preparing_progress["status"] = "inference"
|
|
1997
|
-
preparing_progress["current"] = 0
|
|
1906
|
+
inference_request.done(dataset_info.items_count)
|
|
1998
1907
|
|
|
1999
1908
|
def _download_images(datasets_infos: List[DatasetInfo]):
|
|
2000
1909
|
for dataset_info in datasets_infos:
|
|
@@ -2011,166 +1920,41 @@ class Inference:
|
|
|
2011
1920
|
# start downloading in parallel
|
|
2012
1921
|
threading.Thread(target=_download_images, args=[datasets_infos], daemon=True).start()
|
|
2013
1922
|
|
|
2014
|
-
|
|
2015
|
-
|
|
2016
|
-
|
|
2017
|
-
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
project_info.id, output_project_meta
|
|
2025
|
-
)
|
|
2026
|
-
ann = update_classes(api, ann, output_project_meta, output_project_id)
|
|
2027
|
-
api.annotation.append_labels(image_id, ann.labels)
|
|
2028
|
-
if async_inference_request_uuid is not None:
|
|
2029
|
-
sly_progress.iters_done(1)
|
|
2030
|
-
inference_request["pending_results"].append(
|
|
2031
|
-
{
|
|
2032
|
-
"annotation": None, # to less response size
|
|
2033
|
-
"data": None, # to less response size
|
|
2034
|
-
"image_id": image_id,
|
|
2035
|
-
"image_name": result["image_name"],
|
|
2036
|
-
"dataset_id": result["dataset_id"],
|
|
2037
|
-
}
|
|
2038
|
-
)
|
|
2039
|
-
|
|
2040
|
-
new_dataset_id = {}
|
|
2041
|
-
|
|
2042
|
-
def _get_or_create_new_dataset(output_project_id, src_dataset_id):
|
|
2043
|
-
"""Copy dataset in output project if not exists and return its id"""
|
|
2044
|
-
if src_dataset_id in new_dataset_id:
|
|
2045
|
-
return new_dataset_id[src_dataset_id]
|
|
2046
|
-
dataset_info = api.dataset.get_info_by_id(src_dataset_id)
|
|
2047
|
-
if dataset_info.parent_id is None:
|
|
2048
|
-
output_dataset_id = api.dataset.copy(
|
|
2049
|
-
output_project_id,
|
|
2050
|
-
src_dataset_id,
|
|
2051
|
-
dataset_info.name,
|
|
2052
|
-
change_name_if_conflict=True,
|
|
2053
|
-
).id
|
|
2054
|
-
else:
|
|
2055
|
-
parent_dataset_id = _get_or_create_new_dataset(
|
|
2056
|
-
output_project_id, dataset_info.parent_id
|
|
2057
|
-
)
|
|
2058
|
-
output_dataset_info = api.dataset.create(
|
|
2059
|
-
output_project_id, dataset_info.name, parent_id=parent_dataset_id
|
|
2060
|
-
)
|
|
2061
|
-
api.image.copy_batch_optimized(
|
|
2062
|
-
dataset_info.id,
|
|
2063
|
-
images_infos_dict[dataset_info.id],
|
|
2064
|
-
output_dataset_info.id,
|
|
2065
|
-
with_annotations=False,
|
|
2066
|
-
)
|
|
2067
|
-
output_dataset_id = output_dataset_info.id
|
|
2068
|
-
new_dataset_id[src_dataset_id] = output_dataset_id
|
|
2069
|
-
return output_dataset_id
|
|
2070
|
-
|
|
2071
|
-
def _upload_results_to_other(results: List[Dict]):
|
|
2072
|
-
nonlocal output_project_meta
|
|
2073
|
-
if len(results) == 0:
|
|
2074
|
-
return
|
|
2075
|
-
src_dataset_id = results[0]["dataset_id"]
|
|
2076
|
-
dataset_id = _get_or_create_new_dataset(output_project_id, src_dataset_id)
|
|
2077
|
-
image_names = [result["image_name"] for result in results]
|
|
2078
|
-
image_infos = api.image.get_list(
|
|
2079
|
-
dataset_id,
|
|
2080
|
-
filters=[{"field": "name", "operator": "in", "value": image_names}],
|
|
2081
|
-
)
|
|
2082
|
-
meta_changed = False
|
|
2083
|
-
anns = []
|
|
2084
|
-
for result in results:
|
|
2085
|
-
ann = Annotation.from_json(result["annotation"], self.model_meta)
|
|
2086
|
-
output_project_meta, ann, c = update_meta_and_ann(output_project_meta, ann)
|
|
2087
|
-
meta_changed = meta_changed or c
|
|
2088
|
-
anns.append(ann)
|
|
2089
|
-
if meta_changed:
|
|
2090
|
-
api.project.update_meta(output_project_id, output_project_meta)
|
|
2091
|
-
|
|
2092
|
-
# upload in batches to update progress with each batch
|
|
2093
|
-
# api.annotation.upload_anns() uploads in same batches anyways
|
|
2094
|
-
for batch in batched(list(zip(anns, results, image_infos))):
|
|
2095
|
-
batch_anns, batch_results, batch_image_infos = zip(*batch)
|
|
2096
|
-
api.annotation.upload_anns(
|
|
2097
|
-
img_ids=[info.id for info in batch_image_infos],
|
|
2098
|
-
anns=batch_anns,
|
|
2099
|
-
)
|
|
2100
|
-
if async_inference_request_uuid is not None:
|
|
2101
|
-
sly_progress.iters_done(len(batch_results))
|
|
2102
|
-
inference_request["pending_results"].extend(
|
|
2103
|
-
[{**result, "annotation": None, "data": None} for result in batch_results]
|
|
2104
|
-
)
|
|
2105
|
-
|
|
2106
|
-
def _add_results_to_request(results: List[Dict]):
|
|
2107
|
-
if async_inference_request_uuid is None:
|
|
2108
|
-
return
|
|
2109
|
-
inference_request["pending_results"].extend(results)
|
|
2110
|
-
sly_progress.iters_done(len(results))
|
|
1923
|
+
_upload_predictions = partial(
|
|
1924
|
+
self.upload_predictions,
|
|
1925
|
+
api=api,
|
|
1926
|
+
upload_mode=upload_mode,
|
|
1927
|
+
context=inference_request.context,
|
|
1928
|
+
dst_project_id=output_project_id,
|
|
1929
|
+
progress_cb=inference_request.done,
|
|
1930
|
+
iou_merge_threshold=iou_merge_threshold,
|
|
1931
|
+
inference_request=inference_request,
|
|
1932
|
+
)
|
|
2111
1933
|
|
|
2112
|
-
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
items = []
|
|
2116
|
-
while not q.empty():
|
|
2117
|
-
items.append(q.get_nowait())
|
|
2118
|
-
if len(items) > 0:
|
|
2119
|
-
ds_batches = {}
|
|
2120
|
-
for batch in items:
|
|
2121
|
-
if len(batch) == 0:
|
|
2122
|
-
continue
|
|
2123
|
-
ds_batches.setdefault(batch[0].get("dataset_id"), []).extend(batch)
|
|
2124
|
-
for _, joined_batch in ds_batches.items():
|
|
2125
|
-
upload_f(joined_batch)
|
|
2126
|
-
continue
|
|
2127
|
-
if stop_event.is_set():
|
|
2128
|
-
self._on_inference_end(None, async_inference_request_uuid)
|
|
2129
|
-
return
|
|
2130
|
-
time.sleep(1)
|
|
2131
|
-
except Exception as e:
|
|
2132
|
-
api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
|
|
2133
|
-
raise
|
|
1934
|
+
_add_results_to_request = partial(
|
|
1935
|
+
self.add_results_to_request, inference_request=inference_request
|
|
1936
|
+
)
|
|
2134
1937
|
|
|
2135
|
-
if
|
|
1938
|
+
if upload_mode is None:
|
|
2136
1939
|
upload_f = _add_results_to_request
|
|
2137
|
-
elif output_project_id != project_info.id:
|
|
2138
|
-
upload_f = _upload_results_to_other
|
|
2139
1940
|
else:
|
|
2140
|
-
upload_f =
|
|
2141
|
-
|
|
2142
|
-
upload_queue = Queue()
|
|
2143
|
-
stop_upload_event = threading.Event()
|
|
2144
|
-
upload_thread = threading.Thread(
|
|
2145
|
-
target=_upload_loop,
|
|
2146
|
-
args=[upload_queue, stop_upload_event, api, upload_f],
|
|
2147
|
-
daemon=True,
|
|
2148
|
-
)
|
|
2149
|
-
upload_thread.start()
|
|
1941
|
+
upload_f = _upload_predictions
|
|
2150
1942
|
|
|
2151
|
-
|
|
2152
|
-
|
|
2153
|
-
results = []
|
|
2154
|
-
data_to_return = {}
|
|
2155
|
-
stop = False
|
|
2156
|
-
try:
|
|
1943
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, inference_progress_total)
|
|
1944
|
+
with Uploader(upload_f, logger=logger) as uploader:
|
|
2157
1945
|
for dataset_info in datasets_infos:
|
|
2158
|
-
if stop:
|
|
2159
|
-
break
|
|
2160
1946
|
for images_infos_batch in batched(
|
|
2161
1947
|
images_infos_dict[dataset_info.id], batch_size=batch_size
|
|
2162
1948
|
):
|
|
2163
|
-
if (
|
|
2164
|
-
async_inference_request_uuid is not None
|
|
2165
|
-
and inference_request["cancel_inference"] is True
|
|
2166
|
-
):
|
|
1949
|
+
if inference_request.is_stopped():
|
|
2167
1950
|
logger.debug(
|
|
2168
1951
|
f"Cancelling inference project...",
|
|
2169
|
-
extra={"inference_request_uuid":
|
|
1952
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2170
1953
|
)
|
|
2171
|
-
|
|
2172
|
-
|
|
2173
|
-
|
|
1954
|
+
return
|
|
1955
|
+
if uploader.has_exception():
|
|
1956
|
+
exception = uploader.exception
|
|
1957
|
+
raise RuntimeError(f"Error in upload loop: {exception}") from exception
|
|
2174
1958
|
if cache_project_on_model:
|
|
2175
1959
|
images_paths, _ = zip(
|
|
2176
1960
|
*read_from_cached_project(
|
|
@@ -2189,52 +1973,36 @@ class Inference:
|
|
|
2189
1973
|
)
|
|
2190
1974
|
anns, slides_data = self._inference_auto(
|
|
2191
1975
|
source=images_nps,
|
|
2192
|
-
settings=
|
|
1976
|
+
settings=inference_settings,
|
|
2193
1977
|
)
|
|
2194
|
-
|
|
2195
|
-
|
|
2196
|
-
|
|
2197
|
-
|
|
1978
|
+
predictions = [
|
|
1979
|
+
Prediction(
|
|
1980
|
+
ann,
|
|
1981
|
+
model_meta=self.model_meta,
|
|
1982
|
+
image_id=image_info.id,
|
|
1983
|
+
name=image_info.name,
|
|
1984
|
+
dataset_id=dataset_info.id,
|
|
1985
|
+
project_id=dataset_info.project_id,
|
|
1986
|
+
image_name=image_info.name,
|
|
2198
1987
|
)
|
|
2199
|
-
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
output_project_meta,
|
|
2206
|
-
)
|
|
2207
|
-
batch_results = []
|
|
2208
|
-
for i, ann in enumerate(anns):
|
|
2209
|
-
batch_results.append(
|
|
2210
|
-
{
|
|
2211
|
-
"annotation": ann.to_json(),
|
|
2212
|
-
"data": slides_data[i],
|
|
2213
|
-
"image_id": images_infos_batch[i].id,
|
|
2214
|
-
"image_name": images_infos_batch[i].name,
|
|
2215
|
-
"dataset_id": dataset_info.id,
|
|
2216
|
-
}
|
|
2217
|
-
)
|
|
2218
|
-
results.extend(batch_results)
|
|
2219
|
-
upload_queue.put(batch_results)
|
|
2220
|
-
except Exception:
|
|
2221
|
-
stop_upload_event.set()
|
|
2222
|
-
upload_thread.join()
|
|
2223
|
-
raise
|
|
2224
|
-
if async_inference_request_uuid is not None and len(results) > 0:
|
|
2225
|
-
inference_request["result"] = {"ann": results}
|
|
2226
|
-
stop_upload_event.set()
|
|
2227
|
-
upload_thread.join()
|
|
2228
|
-
return results
|
|
1988
|
+
for ann, image_info in zip(anns, images_infos_batch)
|
|
1989
|
+
]
|
|
1990
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
1991
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
1992
|
+
|
|
1993
|
+
uploader.put(predictions)
|
|
2229
1994
|
|
|
2230
1995
|
def _run_speedtest(
|
|
2231
1996
|
self,
|
|
2232
1997
|
api: Api,
|
|
2233
1998
|
state: dict,
|
|
2234
|
-
|
|
1999
|
+
inference_request: InferenceRequest,
|
|
2235
2000
|
):
|
|
2236
2001
|
"""Run speedtest on project images."""
|
|
2237
2002
|
logger.debug("Running speedtest...", extra={"state": state})
|
|
2003
|
+
settings = self._get_inference_settings(state)
|
|
2004
|
+
logger.debug(f"Inference settings:", extra=settings)
|
|
2005
|
+
|
|
2238
2006
|
project_id = state["projectId"]
|
|
2239
2007
|
batch_size = state["batch_size"]
|
|
2240
2008
|
num_iterations = state["num_iterations"]
|
|
@@ -2252,49 +2020,22 @@ class Inference:
|
|
|
2252
2020
|
if dataset_id in datasets_infos_dict
|
|
2253
2021
|
]
|
|
2254
2022
|
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
import traceback
|
|
2262
|
-
|
|
2263
|
-
logger.error(traceback.format_exc())
|
|
2264
|
-
raise RuntimeError(
|
|
2265
|
-
f"async_inference_request_uuid {async_inference_request_uuid} was given, "
|
|
2266
|
-
f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
|
|
2267
|
-
)
|
|
2268
|
-
sly_progress: Progress = inference_request["progress"]
|
|
2269
|
-
sly_progress.total = num_iterations
|
|
2270
|
-
sly_progress.current = 0
|
|
2271
|
-
|
|
2272
|
-
preparing_progress = inference_request["preparing_progress"]
|
|
2023
|
+
preparing_progress_total = len(datasets_infos)
|
|
2024
|
+
if cache_project_on_model:
|
|
2025
|
+
preparing_progress_total += sum(
|
|
2026
|
+
dataset_info.items_count for dataset_info in datasets_infos
|
|
2027
|
+
)
|
|
2028
|
+
inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, preparing_progress_total)
|
|
2273
2029
|
|
|
2274
|
-
preparing_progress["current"] = 0
|
|
2275
|
-
preparing_progress["total"] = len(datasets_infos)
|
|
2276
|
-
preparing_progress["status"] = "download_info"
|
|
2277
2030
|
images_infos_dict = {}
|
|
2278
2031
|
for dataset_info in datasets_infos:
|
|
2279
2032
|
images_infos_dict[dataset_info.id] = api.image.get_list(dataset_info.id)
|
|
2280
|
-
|
|
2281
|
-
preparing_progress["current"] += 1
|
|
2033
|
+
inference_request.done()
|
|
2282
2034
|
|
|
2283
2035
|
if cache_project_on_model:
|
|
2036
|
+
download_to_cache(api, project_id, datasets_infos, progress_cb=inference_request.done)
|
|
2284
2037
|
|
|
2285
|
-
|
|
2286
|
-
preparing_progress["current"] += count
|
|
2287
|
-
|
|
2288
|
-
preparing_progress["current"] = 0
|
|
2289
|
-
preparing_progress["total"] = sum(
|
|
2290
|
-
dataset_info.items_count for dataset_info in datasets_infos
|
|
2291
|
-
)
|
|
2292
|
-
preparing_progress["status"] = "download_project"
|
|
2293
|
-
download_to_cache(api, project_id, datasets_infos, progress_cb=_progress_cb)
|
|
2294
|
-
|
|
2295
|
-
preparing_progress["status"] = "warmup"
|
|
2296
|
-
preparing_progress["current"] = 0
|
|
2297
|
-
preparing_progress["total"] = num_warmup
|
|
2038
|
+
inference_request.set_stage("warmup", 0, num_warmup)
|
|
2298
2039
|
|
|
2299
2040
|
images_infos: List[ImageInfo] = [
|
|
2300
2041
|
image_info for infos in images_infos_dict.values() for image_info in infos
|
|
@@ -2313,44 +2054,9 @@ class Inference:
|
|
|
2313
2054
|
# start downloading in parallel
|
|
2314
2055
|
threading.Thread(target=_download_images, daemon=True).start()
|
|
2315
2056
|
|
|
2316
|
-
def
|
|
2317
|
-
|
|
2318
|
-
|
|
2319
|
-
inference_request["pending_results"].append(results)
|
|
2320
|
-
sly_progress.iters_done(1)
|
|
2321
|
-
|
|
2322
|
-
def _upload_loop(q: Queue, stop_event: threading.Event, api: Api, upload_f: Callable):
|
|
2323
|
-
try:
|
|
2324
|
-
while True:
|
|
2325
|
-
items = []
|
|
2326
|
-
while not q.empty():
|
|
2327
|
-
items.append(q.get_nowait())
|
|
2328
|
-
if len(items) > 0:
|
|
2329
|
-
for batch in items:
|
|
2330
|
-
upload_f(batch)
|
|
2331
|
-
continue
|
|
2332
|
-
if stop_event.is_set():
|
|
2333
|
-
self._on_inference_end(None, async_inference_request_uuid)
|
|
2334
|
-
return
|
|
2335
|
-
time.sleep(1)
|
|
2336
|
-
except Exception as e:
|
|
2337
|
-
api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
|
|
2338
|
-
raise
|
|
2339
|
-
|
|
2340
|
-
upload_f = _add_results_to_request
|
|
2341
|
-
|
|
2342
|
-
upload_queue = Queue()
|
|
2343
|
-
stop_upload_event = threading.Event()
|
|
2344
|
-
threading.Thread(
|
|
2345
|
-
target=_upload_loop,
|
|
2346
|
-
args=[upload_queue, stop_upload_event, api, upload_f],
|
|
2347
|
-
daemon=True,
|
|
2348
|
-
).start()
|
|
2349
|
-
|
|
2350
|
-
settings = self._get_inference_settings(state)
|
|
2351
|
-
logger.debug(f"Inference settings:", extra=settings)
|
|
2352
|
-
results = []
|
|
2353
|
-
stop = False
|
|
2057
|
+
def upload_f(benchmarks: List):
|
|
2058
|
+
inference_request.add_results(benchmarks)
|
|
2059
|
+
inference_request.done(len(benchmarks))
|
|
2354
2060
|
|
|
2355
2061
|
def image_batch_generator(batch_size):
|
|
2356
2062
|
logger.debug(
|
|
@@ -2366,23 +2072,20 @@ class Inference:
|
|
|
2366
2072
|
batch = []
|
|
2367
2073
|
|
|
2368
2074
|
batch_generator = image_batch_generator(batch_size)
|
|
2369
|
-
|
|
2075
|
+
|
|
2076
|
+
with Uploader(upload_f=upload_f, logger=logger) as uploader:
|
|
2370
2077
|
for i in range(num_iterations + num_warmup):
|
|
2371
|
-
if
|
|
2372
|
-
break
|
|
2373
|
-
if (
|
|
2374
|
-
async_inference_request_uuid is not None
|
|
2375
|
-
and inference_request["cancel_inference"] is True
|
|
2376
|
-
):
|
|
2078
|
+
if inference_request.is_stopped():
|
|
2377
2079
|
logger.debug(
|
|
2378
2080
|
f"Cancelling inference project...",
|
|
2379
|
-
extra={"inference_request_uuid":
|
|
2081
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2380
2082
|
)
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
2083
|
+
return
|
|
2084
|
+
if uploader.has_exception():
|
|
2085
|
+
exception = uploader.exception
|
|
2086
|
+
raise RuntimeError(f"Error in upload loop: {exception}") from exception
|
|
2384
2087
|
if i == num_warmup:
|
|
2385
|
-
|
|
2088
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, num_iterations)
|
|
2386
2089
|
|
|
2387
2090
|
images_infos_batch: List[ImageInfo] = next(batch_generator)
|
|
2388
2091
|
|
|
@@ -2429,35 +2132,9 @@ class Inference:
|
|
|
2429
2132
|
)
|
|
2430
2133
|
# Collect results if warmup is done
|
|
2431
2134
|
if i >= num_warmup:
|
|
2432
|
-
|
|
2433
|
-
upload_queue.put(benchmark)
|
|
2135
|
+
uploader.put([benchmark])
|
|
2434
2136
|
else:
|
|
2435
|
-
|
|
2436
|
-
except Exception:
|
|
2437
|
-
stop_upload_event.set()
|
|
2438
|
-
raise
|
|
2439
|
-
if async_inference_request_uuid is not None and len(results) > 0:
|
|
2440
|
-
inference_request["result"] = results
|
|
2441
|
-
stop_upload_event.set()
|
|
2442
|
-
return results
|
|
2443
|
-
|
|
2444
|
-
def _on_inference_start(self, inference_request_uuid):
|
|
2445
|
-
inference_request = {
|
|
2446
|
-
"progress": Progress("Inferring model...", total_cnt=1),
|
|
2447
|
-
"is_inferring": True,
|
|
2448
|
-
"cancel_inference": False,
|
|
2449
|
-
"result": None,
|
|
2450
|
-
"pending_results": [],
|
|
2451
|
-
"preparing_progress": {"current": 0, "total": 1},
|
|
2452
|
-
"exception": None,
|
|
2453
|
-
}
|
|
2454
|
-
self._inference_requests[inference_request_uuid] = inference_request
|
|
2455
|
-
|
|
2456
|
-
def _on_inference_end(self, future, inference_request_uuid):
|
|
2457
|
-
logger.debug("callback: on_inference_end()")
|
|
2458
|
-
inference_request = self._inference_requests.get(inference_request_uuid)
|
|
2459
|
-
if inference_request is not None:
|
|
2460
|
-
inference_request["is_inferring"] = False
|
|
2137
|
+
inference_request.done()
|
|
2461
2138
|
|
|
2462
2139
|
def _check_serve_before_call(self, func):
|
|
2463
2140
|
@wraps(func)
|
|
@@ -2481,6 +2158,24 @@ class Inference:
|
|
|
2481
2158
|
def is_model_deployed(self):
|
|
2482
2159
|
return self._model_served
|
|
2483
2160
|
|
|
2161
|
+
def _on_inference_start(self, inference_request_uuid):
|
|
2162
|
+
inference_request = {
|
|
2163
|
+
"progress": Progress("Inferring model...", total_cnt=1),
|
|
2164
|
+
"is_inferring": True,
|
|
2165
|
+
"cancel_inference": False,
|
|
2166
|
+
"result": None,
|
|
2167
|
+
"pending_results": [],
|
|
2168
|
+
"preparing_progress": {"current": 0, "total": 1},
|
|
2169
|
+
"exception": None,
|
|
2170
|
+
}
|
|
2171
|
+
self._inference_requests[inference_request_uuid] = inference_request
|
|
2172
|
+
|
|
2173
|
+
def _on_inference_end(self, future, inference_request_uuid):
|
|
2174
|
+
logger.debug("callback: on_inference_end()")
|
|
2175
|
+
inference_request = self._inference_requests.get(inference_request_uuid)
|
|
2176
|
+
if inference_request is not None:
|
|
2177
|
+
inference_request["is_inferring"] = False
|
|
2178
|
+
|
|
2484
2179
|
def schedule_task(self, func, *args, **kwargs):
|
|
2485
2180
|
inference_request_uuid = kwargs.get("inference_request_uuid", None)
|
|
2486
2181
|
if inference_request_uuid is None:
|
|
@@ -2523,6 +2218,228 @@ class Inference:
|
|
|
2523
2218
|
self.gui._success_label.hide()
|
|
2524
2219
|
raise e
|
|
2525
2220
|
|
|
2221
|
+
def validate_inference_state(self, state: Union[Dict, str], log_error=True):
|
|
2222
|
+
try:
|
|
2223
|
+
if isinstance(state, str):
|
|
2224
|
+
try:
|
|
2225
|
+
state = json.loads(state)
|
|
2226
|
+
except (json.decoder.JSONDecodeError, TypeError) as e:
|
|
2227
|
+
raise HTTPException(
|
|
2228
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
2229
|
+
detail=f"Cannot decode settings: {e}",
|
|
2230
|
+
)
|
|
2231
|
+
if not isinstance(state, dict):
|
|
2232
|
+
raise HTTPException(
|
|
2233
|
+
status_code=status.HTTP_400_BAD_REQUEST, detail="Settings is not json object"
|
|
2234
|
+
)
|
|
2235
|
+
batch_size = state.get("batch_size", None)
|
|
2236
|
+
if batch_size is None:
|
|
2237
|
+
batch_size = self.get_batch_size()
|
|
2238
|
+
if self.max_batch_size is not None and batch_size > self.max_batch_size:
|
|
2239
|
+
raise HTTPException(
|
|
2240
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
2241
|
+
detail=f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
|
|
2242
|
+
)
|
|
2243
|
+
except Exception as e:
|
|
2244
|
+
if log_error:
|
|
2245
|
+
logger.error(f"Error validating request state: {e}", exc_info=True)
|
|
2246
|
+
raise
|
|
2247
|
+
|
|
2248
|
+
def upload_predictions(
|
|
2249
|
+
self,
|
|
2250
|
+
predictions: List[Prediction],
|
|
2251
|
+
api: Api,
|
|
2252
|
+
upload_mode: str,
|
|
2253
|
+
context: Dict = None,
|
|
2254
|
+
dst_dataset_id: int = None,
|
|
2255
|
+
dst_project_id: int = None,
|
|
2256
|
+
progress_cb=None,
|
|
2257
|
+
iou_merge_threshold: float = None,
|
|
2258
|
+
inference_request: InferenceRequest = None,
|
|
2259
|
+
):
|
|
2260
|
+
ds_predictions: Dict[int, List[Prediction]] = defaultdict(list)
|
|
2261
|
+
for prediction in predictions:
|
|
2262
|
+
ds_predictions[prediction.dataset_id].append(prediction)
|
|
2263
|
+
|
|
2264
|
+
def _new_name(image_info: ImageInfo):
|
|
2265
|
+
name = Path(image_info.name)
|
|
2266
|
+
stem = name.stem
|
|
2267
|
+
parent = name.parent
|
|
2268
|
+
suffix = name.suffix
|
|
2269
|
+
return str(parent / f"{stem}(dataset_id:{image_info.dataset_id}){suffix}")
|
|
2270
|
+
|
|
2271
|
+
def _get_or_create_dataset(src_dataset_id, dst_project_id):
|
|
2272
|
+
if src_dataset_id is None:
|
|
2273
|
+
return None
|
|
2274
|
+
created_dataset_id = context.setdefault("created_dataset", {}).get(src_dataset_id, None)
|
|
2275
|
+
if created_dataset_id is not None:
|
|
2276
|
+
return created_dataset_id
|
|
2277
|
+
src_dataset_info: DatasetInfo = context.setdefault("dataset_info", {}).get(
|
|
2278
|
+
src_dataset_id
|
|
2279
|
+
)
|
|
2280
|
+
if src_dataset_info is None:
|
|
2281
|
+
src_dataset_info = api.dataset.get_info_by_id(src_dataset_id)
|
|
2282
|
+
context["dataset_info"][src_dataset_id] = src_dataset_info
|
|
2283
|
+
src_parent_id = src_dataset_info.parent_id
|
|
2284
|
+
dst_parent_id = _get_or_create_dataset(src_parent_id, dst_project_id)
|
|
2285
|
+
created_dataset = api.dataset.create(
|
|
2286
|
+
dst_project_id,
|
|
2287
|
+
src_dataset_info.name,
|
|
2288
|
+
description=f"Auto created dataset from inference request {inference_request.uuid if inference_request is not None else ''}",
|
|
2289
|
+
change_name_if_conflict=True,
|
|
2290
|
+
parent_id=dst_parent_id,
|
|
2291
|
+
)
|
|
2292
|
+
context["dataset_info"][created_dataset.id] = created_dataset
|
|
2293
|
+
context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
|
|
2294
|
+
return created_dataset.id
|
|
2295
|
+
|
|
2296
|
+
created_names = []
|
|
2297
|
+
if context is None:
|
|
2298
|
+
context = {}
|
|
2299
|
+
for dataset_id, preds in ds_predictions.items():
|
|
2300
|
+
if dst_project_id is not None:
|
|
2301
|
+
# upload to the destination project
|
|
2302
|
+
dst_dataset_id = _get_or_create_dataset(
|
|
2303
|
+
src_dataset_id=dataset_id, dst_project_id=dst_project_id
|
|
2304
|
+
)
|
|
2305
|
+
if dst_dataset_id is not None:
|
|
2306
|
+
# upload to the destination dataset
|
|
2307
|
+
dataset_info = context.setdefault("dataset_info", {}).get(dst_dataset_id, None)
|
|
2308
|
+
if dataset_info is None:
|
|
2309
|
+
dataset_info = api.dataset.get_info_by_id(dst_dataset_id)
|
|
2310
|
+
context["dataset_info"][dst_dataset_id] = dataset_info
|
|
2311
|
+
project_id = dataset_info.project_id
|
|
2312
|
+
project_meta = context.setdefault("project_meta", {}).get(project_id, None)
|
|
2313
|
+
if project_meta is None:
|
|
2314
|
+
project_meta = ProjectMeta.from_json(api.project.get_meta(project_id))
|
|
2315
|
+
context["project_meta"][project_id] = project_meta
|
|
2316
|
+
|
|
2317
|
+
meta_changed = False
|
|
2318
|
+
for pred in preds:
|
|
2319
|
+
ann = pred.annotation
|
|
2320
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
|
|
2321
|
+
meta_changed = meta_changed or meta_changed_
|
|
2322
|
+
pred.annotation = ann
|
|
2323
|
+
prediction.model_meta = project_meta
|
|
2324
|
+
|
|
2325
|
+
if meta_changed:
|
|
2326
|
+
project_meta = api.project.update_meta(project_id, project_meta)
|
|
2327
|
+
context["project_meta"][project_id] = project_meta
|
|
2328
|
+
|
|
2329
|
+
anns = _exclude_duplicated_predictions(
|
|
2330
|
+
api,
|
|
2331
|
+
[pred.annotation for pred in preds],
|
|
2332
|
+
dataset_id,
|
|
2333
|
+
[pred.image_id for pred in preds],
|
|
2334
|
+
iou=iou_merge_threshold,
|
|
2335
|
+
meta=project_meta,
|
|
2336
|
+
)
|
|
2337
|
+
for pred, ann in zip(preds, anns):
|
|
2338
|
+
pred.annotation = ann
|
|
2339
|
+
|
|
2340
|
+
context.setdefault("image_info", {})
|
|
2341
|
+
missing = [
|
|
2342
|
+
pred.image_id for pred in preds if pred.image_id not in context["image_info"]
|
|
2343
|
+
]
|
|
2344
|
+
if missing:
|
|
2345
|
+
context["image_info"].update(
|
|
2346
|
+
{
|
|
2347
|
+
image_info.id: image_info
|
|
2348
|
+
for image_info in api.image.get_info_by_id_batch(missing)
|
|
2349
|
+
}
|
|
2350
|
+
)
|
|
2351
|
+
image_infos: List[ImageInfo] = [
|
|
2352
|
+
context["image_info"][pred.image_id] for pred in preds
|
|
2353
|
+
]
|
|
2354
|
+
dst_names = [
|
|
2355
|
+
_new_name(image_info) if image_info.name in created_names else image_info.name
|
|
2356
|
+
for image_info in image_infos
|
|
2357
|
+
]
|
|
2358
|
+
dst_image_infos = api.image.copy_batch_optimized(
|
|
2359
|
+
dataset_id,
|
|
2360
|
+
image_infos,
|
|
2361
|
+
dst_dataset_id,
|
|
2362
|
+
dst_names=dst_names,
|
|
2363
|
+
with_annotations=False,
|
|
2364
|
+
save_source_date=False,
|
|
2365
|
+
)
|
|
2366
|
+
created_names.extend([image_info.name for image_info in dst_image_infos])
|
|
2367
|
+
api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
|
|
2368
|
+
else:
|
|
2369
|
+
# upload to the source dataset
|
|
2370
|
+
ds_info = context.setdefault("dataset_info", {}).get(dataset_id, None)
|
|
2371
|
+
if ds_info is None:
|
|
2372
|
+
ds_info = api.dataset.get_info_by_id(dataset_id)
|
|
2373
|
+
context["dataset_info"][dataset_id] = ds_info
|
|
2374
|
+
project_id = ds_info.project_id
|
|
2375
|
+
|
|
2376
|
+
project_meta = context.setdefault("project_meta", {}).get(project_id, None)
|
|
2377
|
+
if project_meta is None:
|
|
2378
|
+
project_meta = ProjectMeta.from_json(api.project.get_meta(project_id))
|
|
2379
|
+
context["project_meta"][project_id] = project_meta
|
|
2380
|
+
|
|
2381
|
+
meta_changed = False
|
|
2382
|
+
for pred in preds:
|
|
2383
|
+
ann = pred.annotation
|
|
2384
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
|
|
2385
|
+
meta_changed = meta_changed or meta_changed_
|
|
2386
|
+
pred.annotation = ann
|
|
2387
|
+
prediction.model_meta = project_meta
|
|
2388
|
+
|
|
2389
|
+
if meta_changed:
|
|
2390
|
+
project_meta = api.project.update_meta(project_id, project_meta)
|
|
2391
|
+
context["project_meta"][project_id] = project_meta
|
|
2392
|
+
|
|
2393
|
+
anns = _exclude_duplicated_predictions(
|
|
2394
|
+
api,
|
|
2395
|
+
[pred.annotation for pred in preds],
|
|
2396
|
+
dataset_id,
|
|
2397
|
+
[pred.image_id for pred in preds],
|
|
2398
|
+
iou=iou_merge_threshold,
|
|
2399
|
+
meta=project_meta,
|
|
2400
|
+
)
|
|
2401
|
+
for pred, ann in zip(preds, anns):
|
|
2402
|
+
pred.annotation = ann
|
|
2403
|
+
|
|
2404
|
+
if upload_mode in ["iou_merge", "append"]:
|
|
2405
|
+
context.setdefault("annotation", {})
|
|
2406
|
+
missing = []
|
|
2407
|
+
for pred in preds:
|
|
2408
|
+
if pred.image_id not in context["annotation"]:
|
|
2409
|
+
missing.append(pred.image_id)
|
|
2410
|
+
for image_id, ann_info in zip(
|
|
2411
|
+
missing, api.annotation.download_batch(dataset_id, missing)
|
|
2412
|
+
):
|
|
2413
|
+
context["annotation"][image_id] = Annotation.from_json(
|
|
2414
|
+
ann_info.annotation, project_meta
|
|
2415
|
+
)
|
|
2416
|
+
for pred in preds:
|
|
2417
|
+
pred.annotation = context["annotation"][pred.image_id].merge(
|
|
2418
|
+
pred.annotation
|
|
2419
|
+
)
|
|
2420
|
+
|
|
2421
|
+
api.annotation.upload_anns(
|
|
2422
|
+
[pred.image_id for pred in preds],
|
|
2423
|
+
[pred.annotation for pred in preds],
|
|
2424
|
+
)
|
|
2425
|
+
|
|
2426
|
+
if progress_cb is not None:
|
|
2427
|
+
progress_cb(len(preds))
|
|
2428
|
+
|
|
2429
|
+
if inference_request is not None:
|
|
2430
|
+
results = self._format_output(predictions)
|
|
2431
|
+
for result in results:
|
|
2432
|
+
result["annotation"] = None
|
|
2433
|
+
result["data"] = None
|
|
2434
|
+
inference_request.add_results(results)
|
|
2435
|
+
|
|
2436
|
+
def add_results_to_request(
|
|
2437
|
+
self, predictions: List[Prediction], inference_request: InferenceRequest
|
|
2438
|
+
):
|
|
2439
|
+
results = self._format_output(predictions)
|
|
2440
|
+
inference_request.add_results(results)
|
|
2441
|
+
inference_request.done(len(results))
|
|
2442
|
+
|
|
2526
2443
|
def serve(self):
|
|
2527
2444
|
if not self._use_gui and not self._is_local_deploy:
|
|
2528
2445
|
Progress("Deploying model ...", 1)
|
|
@@ -2583,28 +2500,46 @@ class Inference:
|
|
|
2583
2500
|
server = self._app.get_server()
|
|
2584
2501
|
self._app.set_ready_check_function(self.is_model_deployed)
|
|
2585
2502
|
|
|
2586
|
-
|
|
2587
|
-
|
|
2588
|
-
|
|
2589
|
-
|
|
2590
|
-
|
|
2591
|
-
|
|
2592
|
-
|
|
2593
|
-
|
|
2594
|
-
|
|
2595
|
-
|
|
2596
|
-
self.
|
|
2597
|
-
|
|
2598
|
-
|
|
2599
|
-
|
|
2600
|
-
|
|
2601
|
-
|
|
2503
|
+
if self.api is not None:
|
|
2504
|
+
|
|
2505
|
+
@call_on_autostart()
|
|
2506
|
+
def autostart_func():
|
|
2507
|
+
gpu_count = get_gpu_count()
|
|
2508
|
+
if gpu_count > 1:
|
|
2509
|
+
# run autostart after 5 min
|
|
2510
|
+
def delayed_autostart():
|
|
2511
|
+
logger.debug("Found more than one GPU, autostart will be delayed.")
|
|
2512
|
+
time.sleep(self._autostart_delay_time)
|
|
2513
|
+
if not self._model_served:
|
|
2514
|
+
logger.debug("Deploying the model via autostart...")
|
|
2515
|
+
self.gui.deploy_with_current_params()
|
|
2516
|
+
|
|
2517
|
+
self._executor.submit(delayed_autostart)
|
|
2518
|
+
else:
|
|
2519
|
+
# run autostart immediately
|
|
2520
|
+
self.gui.deploy_with_current_params()
|
|
2602
2521
|
|
|
2603
2522
|
if not self._use_gui:
|
|
2604
2523
|
Progress("Model deployed", 1).iter_done_report()
|
|
2605
2524
|
else:
|
|
2606
2525
|
autostart_func()
|
|
2607
2526
|
|
|
2527
|
+
@server.exception_handler(HTTPException)
|
|
2528
|
+
def http_exception_handler(request: Request, exc: HTTPException):
|
|
2529
|
+
response_content = {
|
|
2530
|
+
"detail": exc.detail,
|
|
2531
|
+
"success": False,
|
|
2532
|
+
}
|
|
2533
|
+
if isinstance(exc.detail, dict):
|
|
2534
|
+
if "message" in exc.detail:
|
|
2535
|
+
response_content["message"] = exc.detail["message"]
|
|
2536
|
+
if "success" in exc.detail:
|
|
2537
|
+
response_content["success"] = exc.detail["success"]
|
|
2538
|
+
elif isinstance(exc.detail, str):
|
|
2539
|
+
response_content["message"] = exc.detail
|
|
2540
|
+
|
|
2541
|
+
return JSONResponse(status_code=exc.status_code, content=response_content)
|
|
2542
|
+
|
|
2608
2543
|
self.cache.add_cache_endpoint(server)
|
|
2609
2544
|
self.cache.add_cache_files_endpoint(server)
|
|
2610
2545
|
|
|
@@ -2617,311 +2552,353 @@ class Inference:
|
|
|
2617
2552
|
def get_custom_inference_settings():
|
|
2618
2553
|
return {"settings": self.custom_inference_settings}
|
|
2619
2554
|
|
|
2555
|
+
@server.post("/get_model_meta")
|
|
2620
2556
|
@server.post("/get_output_classes_and_tags")
|
|
2621
2557
|
def get_output_classes_and_tags():
|
|
2622
2558
|
return self.model_meta.to_json()
|
|
2623
2559
|
|
|
2624
2560
|
@server.post("/inference_image_id")
|
|
2625
2561
|
def inference_image_id(request: Request):
|
|
2626
|
-
|
|
2627
|
-
|
|
2562
|
+
state = request.state.state
|
|
2563
|
+
logger.debug("Received a request to '/inference_image_id'", extra={"state": state})
|
|
2564
|
+
self.validate_inference_state(state)
|
|
2565
|
+
api = self.api_from_request(request)
|
|
2566
|
+
return self.inference_requests_manager.run(self._inference_image_ids, api, state)[0]
|
|
2567
|
+
|
|
2568
|
+
@server.post("/inference_image_id_async")
|
|
2569
|
+
def inference_image_id_async(request: Request):
|
|
2570
|
+
state = request.state.state
|
|
2571
|
+
logger.debug(
|
|
2572
|
+
"Received a request to 'inference_image_id_async'",
|
|
2573
|
+
extra={"state": state},
|
|
2574
|
+
)
|
|
2575
|
+
self.validate_inference_state(state)
|
|
2576
|
+
api = self.api_from_request(request)
|
|
2577
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(
|
|
2578
|
+
self._inference_image_ids,
|
|
2579
|
+
api,
|
|
2580
|
+
state,
|
|
2581
|
+
)
|
|
2582
|
+
return {
|
|
2583
|
+
"message": "Scheduled inference task.",
|
|
2584
|
+
"inference_request_uuid": inference_request.uuid,
|
|
2585
|
+
}
|
|
2586
|
+
|
|
2587
|
+
@server.post("/inference_image")
|
|
2588
|
+
def inference_image(
|
|
2589
|
+
files: List[UploadFile], settings: str = Form("{}"), state: str = Form("{}")
|
|
2590
|
+
):
|
|
2591
|
+
if state == "{}" or not state:
|
|
2592
|
+
state = settings
|
|
2593
|
+
state = str(state)
|
|
2594
|
+
logger.debug("Received a request to 'inference_image'", extra={"state": state})
|
|
2595
|
+
self.validate_inference_state(state)
|
|
2596
|
+
state = json.loads(state)
|
|
2597
|
+
if len(files) != 1:
|
|
2598
|
+
raise HTTPException(
|
|
2599
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
2600
|
+
detail=f"Only one file expected but got {len(files)}",
|
|
2601
|
+
)
|
|
2602
|
+
try:
|
|
2603
|
+
file = files[0]
|
|
2604
|
+
inference_request = self.inference_requests_manager.create()
|
|
2605
|
+
inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, file.size)
|
|
2606
|
+
|
|
2607
|
+
img_bytes = b""
|
|
2608
|
+
while buf := file.read(64 * 1024 * 1024):
|
|
2609
|
+
img_bytes += buf
|
|
2610
|
+
inference_request.done(len(buf))
|
|
2611
|
+
|
|
2612
|
+
image = sly_image.read_bytes(img_bytes)
|
|
2613
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
2614
|
+
self._inference_images, [image], state, inference_request=inference_request
|
|
2615
|
+
)
|
|
2616
|
+
future.result()
|
|
2617
|
+
return inference_request.pop_pending_results()[0]
|
|
2618
|
+
except sly_image.UnsupportedImageFormat:
|
|
2619
|
+
raise HTTPException(
|
|
2620
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
2621
|
+
detail=f"File has unsupported format. Supported formats: {sly_image.SUPPORTED_IMG_EXTS}",
|
|
2622
|
+
)
|
|
2628
2623
|
|
|
2629
2624
|
@server.post("/inference_image_url")
|
|
2630
2625
|
def inference_image_url(request: Request):
|
|
2631
|
-
|
|
2632
|
-
|
|
2626
|
+
state = request.state.state
|
|
2627
|
+
logger.debug("Received a request to 'inference_image_url'", extra={"state": state})
|
|
2628
|
+
self.validate_inference_state(state)
|
|
2629
|
+
image_url = state["image_url"]
|
|
2630
|
+
ext = sly_fs.get_file_ext(image_url)
|
|
2631
|
+
if ext == "":
|
|
2632
|
+
ext = ".jpg"
|
|
2633
|
+
with requests.get(image_url, stream=True) as response:
|
|
2634
|
+
response.raise_for_status()
|
|
2635
|
+
response.raw.decode_content = True
|
|
2636
|
+
image = self.cache.add_image_to_cache(image_url, response.raw, ext=ext)
|
|
2637
|
+
return self.inference_requests_manager.run(self._inference_images, [image], state)[0]
|
|
2633
2638
|
|
|
2634
2639
|
@server.post("/inference_batch_ids")
|
|
2635
|
-
def inference_batch_ids(
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2639
|
-
|
|
2640
|
-
|
|
2641
|
-
"message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
|
|
2642
|
-
"success": False,
|
|
2643
|
-
}
|
|
2644
|
-
logger.debug(f"'inference_batch_ids' request in json format:{request.state.state}")
|
|
2645
|
-
return self._inference_batch_ids(request.state.api, request.state.state)
|
|
2640
|
+
def inference_batch_ids(request: Request):
|
|
2641
|
+
state = request.state.state
|
|
2642
|
+
logger.debug("Received a request to 'inference_batch_ids'", extra={"state": state})
|
|
2643
|
+
self.validate_inference_state(state)
|
|
2644
|
+
api = self.api_from_request(request)
|
|
2645
|
+
return self.inference_requests_manager.run(self._inference_image_ids, api, state)
|
|
2646
2646
|
|
|
2647
2647
|
@server.post("/inference_batch_ids_async")
|
|
2648
|
-
def inference_batch_ids_async(
|
|
2648
|
+
def inference_batch_ids_async(request: Request):
|
|
2649
|
+
state = request.state.state
|
|
2649
2650
|
logger.debug(
|
|
2650
|
-
f"
|
|
2651
|
+
f"Received a request to 'inference_batch_ids_async'", extra={"state": state}
|
|
2651
2652
|
)
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
batch_size = self.get_batch_size()
|
|
2657
|
-
if self.max_batch_size is not None and batch_size > self.max_batch_size:
|
|
2658
|
-
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2659
|
-
return {
|
|
2660
|
-
"message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
|
|
2661
|
-
"success": False,
|
|
2662
|
-
}
|
|
2663
|
-
inference_request_uuid = uuid.uuid5(
|
|
2664
|
-
namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
|
|
2665
|
-
).hex
|
|
2666
|
-
self._on_inference_start(inference_request_uuid)
|
|
2667
|
-
future = self._executor.submit(
|
|
2668
|
-
self._handle_error_in_async,
|
|
2669
|
-
inference_request_uuid,
|
|
2670
|
-
self._inference_images_ids,
|
|
2671
|
-
request.state.api,
|
|
2672
|
-
request.state.state,
|
|
2673
|
-
images_ids,
|
|
2674
|
-
inference_request_uuid,
|
|
2675
|
-
)
|
|
2676
|
-
end_callback = partial(
|
|
2677
|
-
self._on_inference_end, inference_request_uuid=inference_request_uuid
|
|
2678
|
-
)
|
|
2679
|
-
future.add_done_callback(end_callback)
|
|
2680
|
-
logger.debug(
|
|
2681
|
-
"Inference has scheduled from 'inference_batch_ids_async' endpoint",
|
|
2682
|
-
extra={"inference_request_uuid": inference_request_uuid},
|
|
2653
|
+
self.validate_inference_state(state)
|
|
2654
|
+
api = self.api_from_request(request)
|
|
2655
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(
|
|
2656
|
+
self._inference_image_ids, api, state
|
|
2683
2657
|
)
|
|
2684
2658
|
return {
|
|
2685
|
-
"message": "
|
|
2686
|
-
"inference_request_uuid":
|
|
2659
|
+
"message": "Scheduled inference task.",
|
|
2660
|
+
"inference_request_uuid": inference_request.uuid,
|
|
2687
2661
|
}
|
|
2688
2662
|
|
|
2689
|
-
@server.post("/
|
|
2690
|
-
def
|
|
2691
|
-
|
|
2692
|
-
|
|
2693
|
-
|
|
2694
|
-
|
|
2695
|
-
batch_size = self.get_batch_size()
|
|
2696
|
-
if self.max_batch_size is not None and batch_size > self.max_batch_size:
|
|
2697
|
-
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2698
|
-
return {
|
|
2699
|
-
"message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
|
|
2700
|
-
"success": False,
|
|
2701
|
-
}
|
|
2702
|
-
return self._inference_video_id(request.state.api, request.state.state)
|
|
2703
|
-
|
|
2704
|
-
@server.post("/inference_image")
|
|
2705
|
-
def inference_image(
|
|
2706
|
-
response: Response, files: List[UploadFile], settings: str = Form("{}")
|
|
2663
|
+
@server.post("/inference_batch")
|
|
2664
|
+
def inference_batch(
|
|
2665
|
+
response: Response,
|
|
2666
|
+
files: List[UploadFile],
|
|
2667
|
+
settings: str = Form("{}"),
|
|
2668
|
+
state: str = Form("{}"),
|
|
2707
2669
|
):
|
|
2708
|
-
if
|
|
2709
|
-
|
|
2710
|
-
|
|
2670
|
+
if state == "{}" or not state:
|
|
2671
|
+
state = settings
|
|
2672
|
+
state = str(state)
|
|
2673
|
+
logger.debug("Received a request to 'inference_batch'", extra={"state": state})
|
|
2674
|
+
self.validate_inference_state(state)
|
|
2675
|
+
state = json.loads(state)
|
|
2676
|
+
if len(files) == 0:
|
|
2677
|
+
raise HTTPException(
|
|
2678
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
2679
|
+
detail=f"At least one file is expected but got {len(files)}",
|
|
2680
|
+
)
|
|
2711
2681
|
try:
|
|
2712
|
-
|
|
2713
|
-
|
|
2714
|
-
|
|
2715
|
-
|
|
2716
|
-
|
|
2717
|
-
|
|
2718
|
-
|
|
2719
|
-
|
|
2682
|
+
inference_request = self.inference_requests_manager.create()
|
|
2683
|
+
inference_request.set_stage(
|
|
2684
|
+
InferenceRequest.Stage.PREPARING, 0, sum([file.size for file in files])
|
|
2685
|
+
)
|
|
2686
|
+
|
|
2687
|
+
names = []
|
|
2688
|
+
for file in files:
|
|
2689
|
+
name = file.filename
|
|
2690
|
+
if name is None or name == "":
|
|
2691
|
+
name = rand_str(10)
|
|
2692
|
+
ext = Path(name).suffix
|
|
2693
|
+
img_bytes = b""
|
|
2694
|
+
while buf := file.file.read(64 * 1024 * 1024):
|
|
2695
|
+
img_bytes += buf
|
|
2696
|
+
inference_request.done(len(buf))
|
|
2697
|
+
self.cache.add_image_to_cache(name, img_bytes, ext=ext)
|
|
2698
|
+
names.append(name)
|
|
2699
|
+
|
|
2700
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
2701
|
+
self._inference_images, names, state, inference_request=inference_request
|
|
2702
|
+
)
|
|
2703
|
+
future.result()
|
|
2704
|
+
return inference_request.pop_pending_results()
|
|
2720
2705
|
except sly_image.UnsupportedImageFormat:
|
|
2721
2706
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2722
2707
|
return f"File has unsupported format. Supported formats: {sly_image.SUPPORTED_IMG_EXTS}"
|
|
2723
2708
|
|
|
2724
|
-
@server.post("/
|
|
2725
|
-
def
|
|
2726
|
-
response: Response,
|
|
2709
|
+
@server.post("/inference_batch_async")
|
|
2710
|
+
def inference_batch_async(
|
|
2711
|
+
response: Response,
|
|
2712
|
+
files: List[UploadFile],
|
|
2713
|
+
settings: str = Form("{}"),
|
|
2714
|
+
state: str = Form("{}"),
|
|
2727
2715
|
):
|
|
2716
|
+
if state == "{}" or not state:
|
|
2717
|
+
state = settings
|
|
2718
|
+
state = str(state)
|
|
2719
|
+
logger.debug("Received a request to 'inference_batch'", extra={"state": state})
|
|
2720
|
+
self.validate_inference_state(state)
|
|
2721
|
+
state = json.loads(state)
|
|
2722
|
+
if len(files) == 0:
|
|
2723
|
+
raise HTTPException(
|
|
2724
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
2725
|
+
detail=f"At least one file is expected but got {len(files)}",
|
|
2726
|
+
)
|
|
2728
2727
|
try:
|
|
2729
|
-
|
|
2730
|
-
|
|
2731
|
-
|
|
2732
|
-
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
|
|
2736
|
-
|
|
2737
|
-
|
|
2738
|
-
|
|
2739
|
-
|
|
2740
|
-
|
|
2741
|
-
|
|
2742
|
-
|
|
2743
|
-
|
|
2744
|
-
|
|
2728
|
+
inference_request = self.inference_requests_manager.create()
|
|
2729
|
+
inference_request.set_stage(
|
|
2730
|
+
InferenceRequest.Stage.PREPARING, 0, sum([file.size for file in files])
|
|
2731
|
+
)
|
|
2732
|
+
|
|
2733
|
+
names = []
|
|
2734
|
+
for file in files:
|
|
2735
|
+
name = file.filename
|
|
2736
|
+
if name is None or name == "":
|
|
2737
|
+
name = rand_str(10)
|
|
2738
|
+
ext = Path(name).suffix
|
|
2739
|
+
img_bytes = b""
|
|
2740
|
+
while buf := file.file.read(64 * 1024 * 1024):
|
|
2741
|
+
img_bytes += buf
|
|
2742
|
+
inference_request.done(len(buf))
|
|
2743
|
+
self.cache.add_image_to_cache(name, img_bytes, ext=ext)
|
|
2744
|
+
names.append(name)
|
|
2745
|
+
|
|
2746
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(
|
|
2747
|
+
self._inference_images, names, state, inference_request=inference_request
|
|
2748
|
+
)
|
|
2749
|
+
return {
|
|
2750
|
+
"message": "Scheduled inference task.",
|
|
2751
|
+
"inference_request_uuid": inference_request.uuid,
|
|
2752
|
+
}
|
|
2745
2753
|
except sly_image.UnsupportedImageFormat:
|
|
2746
2754
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2747
2755
|
return f"File has unsupported format. Supported formats: {sly_image.SUPPORTED_IMG_EXTS}"
|
|
2748
2756
|
|
|
2749
|
-
@server.post("/
|
|
2750
|
-
def
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
)
|
|
2755
|
-
self.
|
|
2756
|
-
|
|
2757
|
-
self._handle_error_in_async,
|
|
2758
|
-
inference_request_uuid,
|
|
2759
|
-
self._inference_image_id,
|
|
2760
|
-
request.state.api,
|
|
2761
|
-
request.state.state,
|
|
2762
|
-
inference_request_uuid,
|
|
2757
|
+
@server.post("/inference_video_id")
|
|
2758
|
+
def inference_video_id(request: Request):
|
|
2759
|
+
state = request.state.state
|
|
2760
|
+
logger.debug(f"Received a request to 'inference_video_id'", extra={"state": state})
|
|
2761
|
+
self.validate_inference_state(state)
|
|
2762
|
+
api = self.api_from_request(request)
|
|
2763
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
2764
|
+
self._inference_video_id, api, state
|
|
2763
2765
|
)
|
|
2764
|
-
|
|
2765
|
-
|
|
2766
|
+
future.result()
|
|
2767
|
+
results = {"ann": inference_request.pop_pending_results()}
|
|
2768
|
+
final_result = inference_request.final_result
|
|
2769
|
+
if final_result is not None:
|
|
2770
|
+
results.update(final_result)
|
|
2771
|
+
return results
|
|
2772
|
+
|
|
2773
|
+
@server.post("/inference_video_async")
|
|
2774
|
+
def inference_video_async(
|
|
2775
|
+
files: List[UploadFile],
|
|
2776
|
+
settings: str = Form("{}"),
|
|
2777
|
+
state: str = Form("{}"),
|
|
2778
|
+
):
|
|
2779
|
+
if state == "{}" or not state:
|
|
2780
|
+
state = settings
|
|
2781
|
+
state = str(state)
|
|
2782
|
+
logger.debug("Received a request to 'inference_video_async'", extra={"state": state})
|
|
2783
|
+
self.validate_inference_state(state)
|
|
2784
|
+
state = json.loads(state)
|
|
2785
|
+
|
|
2786
|
+
file = files[0]
|
|
2787
|
+
video_name = files[0].filename
|
|
2788
|
+
video_source = files[0].file
|
|
2789
|
+
file_size = file.size
|
|
2790
|
+
|
|
2791
|
+
inference_request = self.inference_requests_manager.create()
|
|
2792
|
+
inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, file_size)
|
|
2793
|
+
|
|
2794
|
+
video_source.read = progress_wrapper(
|
|
2795
|
+
video_source.read, inference_request.progress.iters_done_report
|
|
2766
2796
|
)
|
|
2767
|
-
|
|
2768
|
-
|
|
2769
|
-
|
|
2770
|
-
|
|
2797
|
+
|
|
2798
|
+
if self.cache.is_persistent:
|
|
2799
|
+
self.cache.add_video_to_cache(video_name, video_source)
|
|
2800
|
+
video_path = self.cache.get_video_path(video_name)
|
|
2801
|
+
else:
|
|
2802
|
+
video_path = os.path.join(tempfile.gettempdir(), video_name)
|
|
2803
|
+
with open(video_path, "wb") as video_file:
|
|
2804
|
+
shutil.copyfileobj(
|
|
2805
|
+
video_source, open(video_path, "wb"), length=(64 * 1024 * 1024)
|
|
2806
|
+
)
|
|
2807
|
+
|
|
2808
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(
|
|
2809
|
+
self._inference_video,
|
|
2810
|
+
path=video_path,
|
|
2811
|
+
state=state,
|
|
2812
|
+
inference_request=inference_request,
|
|
2771
2813
|
)
|
|
2814
|
+
|
|
2772
2815
|
return {
|
|
2773
|
-
"message": "
|
|
2774
|
-
"inference_request_uuid":
|
|
2816
|
+
"message": "Scheduled inference task.",
|
|
2817
|
+
"inference_request_uuid": inference_request.uuid,
|
|
2775
2818
|
}
|
|
2776
2819
|
|
|
2777
2820
|
@server.post("/inference_video_id_async")
|
|
2778
2821
|
def inference_video_id_async(response: Response, request: Request):
|
|
2779
|
-
|
|
2780
|
-
|
|
2781
|
-
|
|
2782
|
-
|
|
2783
|
-
|
|
2784
|
-
|
|
2785
|
-
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2786
|
-
return {
|
|
2787
|
-
"message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
|
|
2788
|
-
"success": False,
|
|
2789
|
-
}
|
|
2790
|
-
inference_request_uuid = uuid.uuid5(
|
|
2791
|
-
namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
|
|
2792
|
-
).hex
|
|
2793
|
-
self._on_inference_start(inference_request_uuid)
|
|
2794
|
-
future = self._executor.submit(
|
|
2795
|
-
self._handle_error_in_async,
|
|
2796
|
-
inference_request_uuid,
|
|
2797
|
-
self._inference_video_id,
|
|
2798
|
-
request.state.api,
|
|
2799
|
-
request.state.state,
|
|
2800
|
-
inference_request_uuid,
|
|
2801
|
-
)
|
|
2802
|
-
end_callback = partial(
|
|
2803
|
-
self._on_inference_end, inference_request_uuid=inference_request_uuid
|
|
2804
|
-
)
|
|
2805
|
-
future.add_done_callback(end_callback)
|
|
2806
|
-
logger.debug(
|
|
2807
|
-
"Inference has scheduled from 'inference_video_id_async' endpoint",
|
|
2808
|
-
extra={"inference_request_uuid": inference_request_uuid},
|
|
2822
|
+
state = request.state.state
|
|
2823
|
+
logger.debug("Received a request to 'inference_video_id_async'", extra={"state": state})
|
|
2824
|
+
self.validate_inference_state(state)
|
|
2825
|
+
api = self.api_from_request(request)
|
|
2826
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(
|
|
2827
|
+
self._inference_video_id, api, state
|
|
2809
2828
|
)
|
|
2810
2829
|
return {
|
|
2811
2830
|
"message": "Inference has started.",
|
|
2812
|
-
"inference_request_uuid":
|
|
2831
|
+
"inference_request_uuid": inference_request.uuid,
|
|
2813
2832
|
}
|
|
2814
2833
|
|
|
2815
2834
|
@server.post("/inference_project_id_async")
|
|
2816
2835
|
def inference_project_id_async(response: Response, request: Request):
|
|
2836
|
+
state = request.state.state
|
|
2817
2837
|
logger.debug(
|
|
2818
|
-
|
|
2838
|
+
"Received a request to 'inference_project_id_async'", extra={"state": state}
|
|
2819
2839
|
)
|
|
2820
|
-
|
|
2821
|
-
|
|
2822
|
-
|
|
2823
|
-
|
|
2824
|
-
# check batch size
|
|
2825
|
-
batch_size = request.state.state.get("batch_size", None)
|
|
2826
|
-
if batch_size is None:
|
|
2827
|
-
batch_size = self.get_batch_size()
|
|
2828
|
-
if self.max_batch_size is not None and batch_size > self.max_batch_size:
|
|
2829
|
-
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2830
|
-
return {
|
|
2831
|
-
"message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
|
|
2832
|
-
"success": False,
|
|
2833
|
-
}
|
|
2834
|
-
inference_request_uuid = uuid.uuid5(
|
|
2835
|
-
namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
|
|
2836
|
-
).hex
|
|
2837
|
-
self._on_inference_start(inference_request_uuid)
|
|
2838
|
-
future = self._executor.submit(
|
|
2839
|
-
self._handle_error_in_async,
|
|
2840
|
-
inference_request_uuid,
|
|
2841
|
-
self._inference_project_id,
|
|
2842
|
-
request.state.api,
|
|
2843
|
-
request.state.state,
|
|
2844
|
-
project_info,
|
|
2845
|
-
inference_request_uuid,
|
|
2846
|
-
)
|
|
2847
|
-
logger.debug(
|
|
2848
|
-
"Inference has scheduled from 'inference_project_id_async' endpoint",
|
|
2849
|
-
extra={"inference_request_uuid": inference_request_uuid},
|
|
2840
|
+
self.validate_inference_state(state)
|
|
2841
|
+
api = self.api_from_request(request)
|
|
2842
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(
|
|
2843
|
+
self._inference_project_id, api, state
|
|
2850
2844
|
)
|
|
2851
2845
|
return {
|
|
2852
2846
|
"message": "Inference has started.",
|
|
2853
|
-
"inference_request_uuid":
|
|
2847
|
+
"inference_request_uuid": inference_request.uuid,
|
|
2854
2848
|
}
|
|
2855
2849
|
|
|
2856
2850
|
@server.post("/run_speedtest")
|
|
2857
2851
|
def run_speedtest(response: Response, request: Request):
|
|
2858
|
-
|
|
2859
|
-
|
|
2860
|
-
|
|
2861
|
-
|
|
2862
|
-
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2863
|
-
response.body = {"message": "Only images projects are supported."}
|
|
2864
|
-
raise ValueError("Only images projects are supported.")
|
|
2865
|
-
batch_size = request.state.state["batch_size"]
|
|
2852
|
+
state = request.state.state
|
|
2853
|
+
logger.debug(f"'run_speedtest' request in json format:{state}")
|
|
2854
|
+
|
|
2855
|
+
batch_size = state["batch_size"]
|
|
2866
2856
|
if batch_size > 1 and not self.is_batch_inference_supported():
|
|
2867
2857
|
response.status_code = status.HTTP_501_NOT_IMPLEMENTED
|
|
2868
2858
|
return {
|
|
2869
2859
|
"message": "Batch inference is not implemented for this model.",
|
|
2870
2860
|
"success": False,
|
|
2871
2861
|
}
|
|
2872
|
-
|
|
2873
|
-
|
|
2862
|
+
|
|
2863
|
+
self.validate_inference_state(state)
|
|
2864
|
+
api = self.api_from_request(request)
|
|
2865
|
+
|
|
2866
|
+
project_id = state["projectId"]
|
|
2867
|
+
project_info = api.project.get_info_by_id(project_id)
|
|
2868
|
+
if project_info.type != str(ProjectType.IMAGES):
|
|
2874
2869
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2875
|
-
|
|
2876
|
-
|
|
2877
|
-
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
|
|
2881
|
-
).hex
|
|
2882
|
-
self._on_inference_start(inference_request_uuid)
|
|
2883
|
-
future = self._executor.submit(
|
|
2884
|
-
self._handle_error_in_async,
|
|
2885
|
-
inference_request_uuid,
|
|
2886
|
-
self._run_speedtest,
|
|
2887
|
-
request.state.api,
|
|
2888
|
-
request.state.state,
|
|
2889
|
-
inference_request_uuid,
|
|
2890
|
-
)
|
|
2891
|
-
logger.debug(
|
|
2892
|
-
"Speedtest has scheduled from 'run_speedtest' endpoint",
|
|
2893
|
-
extra={"inference_request_uuid": inference_request_uuid},
|
|
2870
|
+
response.body = {"message": "Only images projects are supported."}
|
|
2871
|
+
raise ValueError("Only images projects are supported.")
|
|
2872
|
+
|
|
2873
|
+
inference_request, _ = self.inference_requests_manager.schedule_task(
|
|
2874
|
+
self._run_speedtest, api, state
|
|
2894
2875
|
)
|
|
2895
2876
|
return {
|
|
2896
2877
|
"message": "Inference has started.",
|
|
2897
|
-
"inference_request_uuid":
|
|
2878
|
+
"inference_request_uuid": inference_request.uuid,
|
|
2898
2879
|
}
|
|
2899
2880
|
|
|
2900
2881
|
@server.post(f"/get_inference_progress")
|
|
2901
2882
|
def get_inference_progress(response: Response, request: Request):
|
|
2902
|
-
|
|
2883
|
+
state = request.state.state
|
|
2884
|
+
logger.debug("Received a request to '/get_inference_progress'", extra={"state": state})
|
|
2885
|
+
inference_request_uuid = state.get("inference_request_uuid")
|
|
2903
2886
|
if inference_request_uuid is None:
|
|
2904
2887
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2905
2888
|
return {"message": "Error: 'inference_request_uuid' is required."}
|
|
2906
2889
|
|
|
2907
|
-
inference_request = self.
|
|
2908
|
-
inference_request["progress"] = _convert_sly_progress_to_dict(
|
|
2909
|
-
inference_request["progress"]
|
|
2910
|
-
)
|
|
2911
|
-
|
|
2912
|
-
# Logging
|
|
2890
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
2913
2891
|
log_extra = _get_log_extra_for_inference_request(
|
|
2914
|
-
|
|
2892
|
+
inference_request.uuid, inference_request
|
|
2915
2893
|
)
|
|
2894
|
+
data = {**inference_request.to_json(), **log_extra}
|
|
2895
|
+
if inference_request.stage != InferenceRequest.Stage.INFERENCE:
|
|
2896
|
+
data["progress"] = {"current": 0, "total": 1}
|
|
2916
2897
|
logger.debug(
|
|
2917
2898
|
f"Sending inference progress with uuid:",
|
|
2918
|
-
extra=
|
|
2899
|
+
extra=data,
|
|
2919
2900
|
)
|
|
2920
|
-
|
|
2921
|
-
# Ger rid of `pending_results` to less response size
|
|
2922
|
-
inference_request["pending_results"] = []
|
|
2923
|
-
inference_request.pop("lock", None)
|
|
2924
|
-
return inference_request
|
|
2901
|
+
return data
|
|
2925
2902
|
|
|
2926
2903
|
@server.post(f"/pop_inference_results")
|
|
2927
2904
|
def pop_inference_results(response: Response, request: Request):
|
|
@@ -2930,23 +2907,34 @@ class Inference:
|
|
|
2930
2907
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2931
2908
|
return {"message": "Error: 'inference_request_uuid' is required."}
|
|
2932
2909
|
|
|
2933
|
-
|
|
2934
|
-
|
|
2935
|
-
|
|
2910
|
+
if inference_request_uuid in self._inference_requests:
|
|
2911
|
+
inference_request = self._inference_requests[inference_request_uuid].copy()
|
|
2912
|
+
inference_request["pending_results"] = inference_request["pending_results"].copy()
|
|
2936
2913
|
|
|
2937
|
-
|
|
2938
|
-
|
|
2914
|
+
# Clear the queue `pending_results`
|
|
2915
|
+
self._inference_requests[inference_request_uuid]["pending_results"].clear()
|
|
2939
2916
|
|
|
2940
|
-
|
|
2941
|
-
|
|
2942
|
-
|
|
2917
|
+
inference_request["progress"] = _convert_sly_progress_to_dict(
|
|
2918
|
+
inference_request["progress"]
|
|
2919
|
+
)
|
|
2920
|
+
log_extra = _get_log_extra_for_inference_request(
|
|
2921
|
+
inference_request_uuid, inference_request
|
|
2922
|
+
)
|
|
2923
|
+
logger.debug(f"Sending inference delta results with uuid:", extra=log_extra)
|
|
2924
|
+
return inference_request
|
|
2943
2925
|
|
|
2944
|
-
|
|
2926
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
2945
2927
|
log_extra = _get_log_extra_for_inference_request(
|
|
2946
|
-
|
|
2928
|
+
inference_request.uuid, inference_request
|
|
2947
2929
|
)
|
|
2930
|
+
data = {
|
|
2931
|
+
**inference_request.to_json(),
|
|
2932
|
+
**log_extra,
|
|
2933
|
+
"pending_results": inference_request.pop_pending_results(),
|
|
2934
|
+
}
|
|
2935
|
+
|
|
2948
2936
|
logger.debug(f"Sending inference delta results with uuid:", extra=log_extra)
|
|
2949
|
-
return
|
|
2937
|
+
return data
|
|
2950
2938
|
|
|
2951
2939
|
@server.post(f"/get_inference_result")
|
|
2952
2940
|
def get_inference_result(response: Response, request: Request):
|
|
@@ -2955,22 +2943,34 @@ class Inference:
|
|
|
2955
2943
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
2956
2944
|
return {"message": "Error: 'inference_request_uuid' is required."}
|
|
2957
2945
|
|
|
2958
|
-
|
|
2946
|
+
if inference_request_uuid in self._inference_requests:
|
|
2947
|
+
inference_request = self._inference_requests[inference_request_uuid].copy()
|
|
2959
2948
|
|
|
2960
|
-
|
|
2961
|
-
|
|
2962
|
-
|
|
2949
|
+
inference_request["progress"] = _convert_sly_progress_to_dict(
|
|
2950
|
+
inference_request["progress"]
|
|
2951
|
+
)
|
|
2963
2952
|
|
|
2964
|
-
|
|
2953
|
+
# Logging
|
|
2954
|
+
log_extra = _get_log_extra_for_inference_request(
|
|
2955
|
+
inference_request_uuid, inference_request
|
|
2956
|
+
)
|
|
2957
|
+
|
|
2958
|
+
logger.debug(
|
|
2959
|
+
f"Sending inference result with uuid:",
|
|
2960
|
+
extra=log_extra,
|
|
2961
|
+
)
|
|
2962
|
+
return inference_request["result"]
|
|
2963
|
+
|
|
2964
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
2965
2965
|
log_extra = _get_log_extra_for_inference_request(
|
|
2966
|
-
|
|
2966
|
+
inference_request.uuid, inference_request
|
|
2967
2967
|
)
|
|
2968
2968
|
logger.debug(
|
|
2969
2969
|
f"Sending inference result with uuid:",
|
|
2970
2970
|
extra=log_extra,
|
|
2971
2971
|
)
|
|
2972
2972
|
|
|
2973
|
-
return inference_request
|
|
2973
|
+
return inference_request.final_result
|
|
2974
2974
|
|
|
2975
2975
|
@server.post(f"/stop_inference")
|
|
2976
2976
|
def stop_inference(response: Response, request: Request):
|
|
@@ -2981,8 +2981,12 @@ class Inference:
|
|
|
2981
2981
|
"message": "Error: 'inference_request_uuid' is required.",
|
|
2982
2982
|
"success": False,
|
|
2983
2983
|
}
|
|
2984
|
-
|
|
2985
|
-
|
|
2984
|
+
if inference_request_uuid in self._inference_requests:
|
|
2985
|
+
inference_request = self._inference_requests[inference_request_uuid]
|
|
2986
|
+
inference_request["cancel_inference"] = True
|
|
2987
|
+
else:
|
|
2988
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
2989
|
+
inference_request.stop()
|
|
2986
2990
|
return {"message": "Inference will be stopped.", "success": True}
|
|
2987
2991
|
|
|
2988
2992
|
@server.post(f"/clear_inference_request")
|
|
@@ -2994,7 +2998,10 @@ class Inference:
|
|
|
2994
2998
|
"message": "Error: 'inference_request_uuid' is required.",
|
|
2995
2999
|
"success": False,
|
|
2996
3000
|
}
|
|
2997
|
-
|
|
3001
|
+
if inference_request_uuid in self._inference_requests:
|
|
3002
|
+
del self._inference_requests[inference_request_uuid]
|
|
3003
|
+
else:
|
|
3004
|
+
self.inference_requests_manager.remove_after(inference_request_uuid, 60)
|
|
2998
3005
|
logger.debug("Removed an inference request:", extra={"uuid": inference_request_uuid})
|
|
2999
3006
|
return {"success": True}
|
|
3000
3007
|
|
|
@@ -3005,8 +3012,13 @@ class Inference:
|
|
|
3005
3012
|
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
3006
3013
|
return {"message": "Error: 'inference_request_uuid' is required."}
|
|
3007
3014
|
|
|
3008
|
-
|
|
3009
|
-
|
|
3015
|
+
if inference_request_uuid in self._inference_requests:
|
|
3016
|
+
inference_request = self._inference_requests[inference_request_uuid].copy()
|
|
3017
|
+
return inference_request["preparing_progress"]
|
|
3018
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
3019
|
+
return _get_log_extra_for_inference_request(inference_request.uuid, inference_request)[
|
|
3020
|
+
"preparing_progress"
|
|
3021
|
+
]
|
|
3010
3022
|
|
|
3011
3023
|
@server.post("/get_deploy_settings")
|
|
3012
3024
|
def _get_deploy_settings(response: Response, request: Request):
|
|
@@ -3052,22 +3064,84 @@ class Inference:
|
|
|
3052
3064
|
self.shutdown_model()
|
|
3053
3065
|
state = request.state.state
|
|
3054
3066
|
deploy_params = state["deploy_params"]
|
|
3067
|
+
model_name = state.get("model_name", None)
|
|
3055
3068
|
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
3069
|
+
if deploy_params["model_source"] == ModelSource.PRETRAINED and model_name:
|
|
3070
|
+
deploy_params = self._build_deploy_params_from_api(
|
|
3071
|
+
model_name, deploy_params
|
|
3072
|
+
)
|
|
3056
3073
|
model_files = self._download_model_files(deploy_params)
|
|
3057
3074
|
deploy_params["model_files"] = model_files
|
|
3075
|
+
deploy_params = self._set_common_deploy_params(deploy_params)
|
|
3058
3076
|
self._load_model_headless(**deploy_params)
|
|
3059
3077
|
elif isinstance(self.gui, GUI.ServingGUI):
|
|
3078
|
+
if deploy_params["model_source"] == ModelSource.PRETRAINED and model_name:
|
|
3079
|
+
deploy_params = self._build_legacy_deploy_params_from_api(model_name)
|
|
3080
|
+
deploy_params = self._set_common_deploy_params(deploy_params)
|
|
3060
3081
|
self._load_model(deploy_params)
|
|
3082
|
+
elif self.gui is None and self.api is None:
|
|
3083
|
+
if deploy_params["model_source"] == ModelSource.PRETRAINED and model_name:
|
|
3084
|
+
deploy_params = self._build_deploy_params_from_api(
|
|
3085
|
+
model_name, deploy_params
|
|
3086
|
+
)
|
|
3087
|
+
model_files = self._download_model_files(deploy_params)
|
|
3088
|
+
deploy_params["model_files"] = model_files
|
|
3089
|
+
|
|
3090
|
+
deploy_params = self._set_common_deploy_params(deploy_params)
|
|
3091
|
+
self._load_model_headless(**deploy_params)
|
|
3092
|
+
logger.info(
|
|
3093
|
+
f"Model has been successfully loaded on {deploy_params['device']} device"
|
|
3094
|
+
)
|
|
3095
|
+
return {"result": "model was successfully deployed"}
|
|
3061
3096
|
|
|
3062
|
-
|
|
3063
|
-
|
|
3064
|
-
|
|
3065
|
-
|
|
3097
|
+
else:
|
|
3098
|
+
raise ValueError("Unknown GUI type")
|
|
3099
|
+
if self.gui is not None:
|
|
3100
|
+
self.set_params_to_gui(deploy_params)
|
|
3101
|
+
# update to set correct device
|
|
3102
|
+
device = deploy_params.get("device", "cpu")
|
|
3103
|
+
self.gui.set_deployed(device)
|
|
3066
3104
|
return {"result": "model was successfully deployed"}
|
|
3067
3105
|
except Exception as e:
|
|
3068
|
-
self.gui
|
|
3106
|
+
if self.gui is not None:
|
|
3107
|
+
self.gui._success_label.hide()
|
|
3069
3108
|
raise e
|
|
3070
3109
|
|
|
3110
|
+
@server.post("/list_pretrained_models")
|
|
3111
|
+
def _list_pretrained_models():
|
|
3112
|
+
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
3113
|
+
return [
|
|
3114
|
+
_get_model_name(model) for model in self._gui.pretrained_models_table._models
|
|
3115
|
+
]
|
|
3116
|
+
elif hasattr(self, "pretrained_models"):
|
|
3117
|
+
return [_get_model_name(model) for model in self.pretrained_models]
|
|
3118
|
+
else:
|
|
3119
|
+
if hasattr(self, "pretrained_models_table"):
|
|
3120
|
+
return [
|
|
3121
|
+
_get_model_name(model)
|
|
3122
|
+
for model in self.pretrained_models_table._models # pylint: disable=no-member
|
|
3123
|
+
]
|
|
3124
|
+
else:
|
|
3125
|
+
raise HTTPException(
|
|
3126
|
+
status_code=400,
|
|
3127
|
+
detail="Pretrained models table is not available in this app.",
|
|
3128
|
+
)
|
|
3129
|
+
|
|
3130
|
+
@server.post("/list_pretrained_model_infos")
|
|
3131
|
+
def _list_pretrained_model_infos():
|
|
3132
|
+
if isinstance(self.gui, GUI.ServingGUITemplate):
|
|
3133
|
+
return self._gui.pretrained_models_table._models
|
|
3134
|
+
elif hasattr(self, "pretrained_models"):
|
|
3135
|
+
return self.pretrained_models
|
|
3136
|
+
else:
|
|
3137
|
+
if hasattr(self, "pretrained_models_table"):
|
|
3138
|
+
return self.pretrained_models_table._models
|
|
3139
|
+
else:
|
|
3140
|
+
raise HTTPException(
|
|
3141
|
+
status_code=400,
|
|
3142
|
+
detail="Pretrained models table is not available in this app.",
|
|
3143
|
+
)
|
|
3144
|
+
|
|
3071
3145
|
@server.post("/is_deployed")
|
|
3072
3146
|
def _is_deployed(response: Response, request: Request):
|
|
3073
3147
|
return {
|
|
@@ -3080,6 +3154,37 @@ class Inference:
|
|
|
3080
3154
|
def _get_deploy_info():
|
|
3081
3155
|
return asdict(self._get_deploy_info())
|
|
3082
3156
|
|
|
3157
|
+
@server.post("/get_inference_status")
|
|
3158
|
+
def _get_inference_status(request: Request, response: Response):
|
|
3159
|
+
state = request.state.state
|
|
3160
|
+
inference_request_uuid = state.get("inference_request_uuid")
|
|
3161
|
+
if inference_request_uuid is None:
|
|
3162
|
+
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
3163
|
+
return {"message": "Error: 'inference_request_uuid' is required."}
|
|
3164
|
+
inference_request = self.inference_requests_manager.get(inference_request_uuid)
|
|
3165
|
+
if inference_request is None:
|
|
3166
|
+
response.status_code = status.HTTP_404_NOT_FOUND
|
|
3167
|
+
return {"message": "Error: 'inference_request_uuid' is not found."}
|
|
3168
|
+
return inference_request.status()
|
|
3169
|
+
|
|
3170
|
+
@server.post("/get_status")
|
|
3171
|
+
def _get_status(request: Request):
|
|
3172
|
+
progress = self.inference_requests_manager.global_progress.to_json()
|
|
3173
|
+
ram_allocated, ram_total = get_ram_usage()
|
|
3174
|
+
gpu_allocated, gpu_total = get_gpu_usage()
|
|
3175
|
+
return {
|
|
3176
|
+
"is_deployed": self.is_model_deployed(),
|
|
3177
|
+
"progress": progress,
|
|
3178
|
+
"gpu_memory": {
|
|
3179
|
+
"allocated": gpu_allocated,
|
|
3180
|
+
"total": gpu_total,
|
|
3181
|
+
},
|
|
3182
|
+
"ram_memory": {
|
|
3183
|
+
"allocated": ram_allocated,
|
|
3184
|
+
"total": ram_total,
|
|
3185
|
+
},
|
|
3186
|
+
}
|
|
3187
|
+
|
|
3083
3188
|
# Local deploy without predict args
|
|
3084
3189
|
if self._is_local_deploy:
|
|
3085
3190
|
self._run_server()
|
|
@@ -3433,7 +3538,7 @@ class Inference:
|
|
|
3433
3538
|
change_name_if_conflict=True,
|
|
3434
3539
|
)
|
|
3435
3540
|
state["output_project_id"] = output_project.id
|
|
3436
|
-
results = self.
|
|
3541
|
+
results = self.inference_requests_manager.run(self._inference_project_id, api, state)
|
|
3437
3542
|
|
|
3438
3543
|
dataset_infos = api.dataset.get_list(project_id)
|
|
3439
3544
|
datasets_map = {dataset_info.id: dataset_info.name for dataset_info in dataset_infos}
|
|
@@ -3617,136 +3722,157 @@ class Inference:
|
|
|
3617
3722
|
f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input"
|
|
3618
3723
|
)
|
|
3619
3724
|
|
|
3620
|
-
def _exclude_duplicated_predictions(
|
|
3621
|
-
self,
|
|
3622
|
-
api: Api,
|
|
3623
|
-
pred_anns: List[Annotation],
|
|
3624
|
-
settings: dict,
|
|
3625
|
-
dataset_id: int,
|
|
3626
|
-
gt_image_ids: List[int],
|
|
3627
|
-
meta: Optional[ProjectMeta] = None,
|
|
3628
|
-
):
|
|
3629
|
-
"""
|
|
3630
|
-
Filter out predictions that significantly overlap with ground truth (GT) objects.
|
|
3631
|
-
|
|
3632
|
-
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
3633
|
-
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
3634
|
-
- Gets ProjectMeta object if not provided
|
|
3635
|
-
- Downloads GT annotations for the specified image IDs
|
|
3636
|
-
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
3637
|
-
|
|
3638
|
-
:param api: Supervisely API object
|
|
3639
|
-
:type api: Api
|
|
3640
|
-
:param pred_anns: List of Annotation objects containing predictions
|
|
3641
|
-
:type pred_anns: List[Annotation]
|
|
3642
|
-
:param settings: Inference settings
|
|
3643
|
-
:type settings: dict
|
|
3644
|
-
:param dataset_id: ID of the dataset containing the images
|
|
3645
|
-
:type dataset_id: int
|
|
3646
|
-
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
3647
|
-
:type gt_image_ids: List[int]
|
|
3648
|
-
:param meta: ProjectMeta object
|
|
3649
|
-
:type meta: Optional[ProjectMeta]
|
|
3650
|
-
:return: List of Annotation objects containing filtered predictions
|
|
3651
|
-
:rtype: List[Annotation]
|
|
3652
|
-
|
|
3653
|
-
Notes:
|
|
3654
|
-
------
|
|
3655
|
-
- Requires PyTorch and torchvision for IoU calculations
|
|
3656
|
-
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
3657
|
-
"""
|
|
3658
|
-
iou = settings.get("existing_objects_iou_thresh")
|
|
3659
|
-
if isinstance(iou, float) and 0 < iou <= 1:
|
|
3660
|
-
if meta is None:
|
|
3661
|
-
ds = api.dataset.get_info_by_id(dataset_id)
|
|
3662
|
-
meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
|
|
3663
|
-
gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
|
|
3664
|
-
gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
|
|
3665
|
-
for i in range(0, len(pred_anns)):
|
|
3666
|
-
before = len(pred_anns[i].labels)
|
|
3667
|
-
with Timer() as timer:
|
|
3668
|
-
pred_anns[i] = self._filter_duplicated_predictions_from_ann(
|
|
3669
|
-
gt_anns[i], pred_anns[i], iou
|
|
3670
|
-
)
|
|
3671
|
-
after = len(pred_anns[i].labels)
|
|
3672
|
-
logger.debug(
|
|
3673
|
-
f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
|
|
3674
|
-
)
|
|
3675
|
-
return pred_anns
|
|
3676
|
-
|
|
3677
|
-
def _filter_duplicated_predictions_from_ann(
|
|
3678
|
-
self, gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
|
3679
|
-
) -> Annotation:
|
|
3680
|
-
"""
|
|
3681
|
-
Filter out predictions that significantly overlap with ground truth annotations.
|
|
3682
|
-
|
|
3683
|
-
This function compares each prediction with ground truth annotations of the same class
|
|
3684
|
-
and removes predictions that have an IoU (Intersection over Union) greater than or equal
|
|
3685
|
-
to the specified threshold with any ground truth annotation. This is useful for identifying
|
|
3686
|
-
new objects that aren't already annotated in the ground truth.
|
|
3687
|
-
|
|
3688
|
-
:param gt_ann: Annotation object containing ground truth labels
|
|
3689
|
-
:type gt_ann: Annotation
|
|
3690
|
-
:param pred_ann: Annotation object containing prediction labels to be filtered
|
|
3691
|
-
:type pred_ann: Annotation
|
|
3692
|
-
:param iou_threshold: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
3693
|
-
ground truth box of the same class will be removed
|
|
3694
|
-
:type iou_threshold: float
|
|
3695
|
-
:return: A new annotation object containing only predictions that don't significantly
|
|
3696
|
-
overlap with ground truth annotations
|
|
3697
|
-
:rtype: Annotation
|
|
3698
|
-
|
|
3699
|
-
|
|
3700
|
-
Notes:
|
|
3701
|
-
------
|
|
3702
|
-
- Predictions with classes not present in ground truth will be kept
|
|
3703
|
-
- Requires PyTorch and torchvision for IoU calculations
|
|
3704
|
-
"""
|
|
3705
3725
|
|
|
3706
|
-
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
|
|
3710
|
-
|
|
3711
|
-
|
|
3726
|
+
def _exclude_duplicated_predictions(
|
|
3727
|
+
api: Api,
|
|
3728
|
+
pred_anns: List[Annotation],
|
|
3729
|
+
dataset_id: int,
|
|
3730
|
+
gt_image_ids: List[int],
|
|
3731
|
+
iou: float = None,
|
|
3732
|
+
meta: Optional[ProjectMeta] = None,
|
|
3733
|
+
):
|
|
3734
|
+
"""
|
|
3735
|
+
Filter out predictions that significantly overlap with ground truth (GT) objects.
|
|
3736
|
+
|
|
3737
|
+
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
3738
|
+
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
3739
|
+
- Gets ProjectMeta object if not provided
|
|
3740
|
+
- Downloads GT annotations for the specified image IDs
|
|
3741
|
+
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
3742
|
+
|
|
3743
|
+
:param api: Supervisely API object
|
|
3744
|
+
:type api: Api
|
|
3745
|
+
:param pred_anns: List of Annotation objects containing predictions
|
|
3746
|
+
:type pred_anns: List[Annotation]
|
|
3747
|
+
:param dataset_id: ID of the dataset containing the images
|
|
3748
|
+
:type dataset_id: int
|
|
3749
|
+
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
3750
|
+
:type gt_image_ids: List[int]
|
|
3751
|
+
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
3752
|
+
ground truth box of the same class will be removed. None if no filtering is needed
|
|
3753
|
+
:type iou: Optional[float]
|
|
3754
|
+
:param meta: ProjectMeta object
|
|
3755
|
+
:type meta: Optional[ProjectMeta]
|
|
3756
|
+
:return: List of Annotation objects containing filtered predictions
|
|
3757
|
+
:rtype: List[Annotation]
|
|
3758
|
+
|
|
3759
|
+
Notes:
|
|
3760
|
+
------
|
|
3761
|
+
- Requires PyTorch and torchvision for IoU calculations
|
|
3762
|
+
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
3763
|
+
"""
|
|
3764
|
+
if isinstance(iou, float) and 0 < iou <= 1:
|
|
3765
|
+
if meta is None:
|
|
3766
|
+
ds = api.dataset.get_info_by_id(dataset_id)
|
|
3767
|
+
meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
|
|
3768
|
+
gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
|
|
3769
|
+
gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
|
|
3770
|
+
for i in range(0, len(pred_anns)):
|
|
3771
|
+
before = len(pred_anns[i].labels)
|
|
3772
|
+
with Timer() as timer:
|
|
3773
|
+
pred_anns[i] = _filter_duplicated_predictions_from_ann(
|
|
3774
|
+
gt_anns[i], pred_anns[i], iou
|
|
3775
|
+
)
|
|
3776
|
+
after = len(pred_anns[i].labels)
|
|
3777
|
+
logger.debug(
|
|
3778
|
+
f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
|
|
3779
|
+
)
|
|
3780
|
+
return pred_anns
|
|
3712
3781
|
|
|
3713
|
-
def _to_tensor(geom):
|
|
3714
|
-
return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
|
|
3715
3782
|
|
|
3716
|
-
|
|
3717
|
-
|
|
3718
|
-
|
|
3719
|
-
|
|
3783
|
+
def _filter_duplicated_predictions_from_ann(
|
|
3784
|
+
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
|
3785
|
+
) -> Annotation:
|
|
3786
|
+
"""
|
|
3787
|
+
Filter out predictions that significantly overlap with ground truth annotations.
|
|
3788
|
+
|
|
3789
|
+
This function compares each prediction with ground truth annotations of the same class
|
|
3790
|
+
and removes predictions that have an IoU (Intersection over Union) greater than or equal
|
|
3791
|
+
to the specified threshold with any ground truth annotation. This is useful for identifying
|
|
3792
|
+
new objects that aren't already annotated in the ground truth.
|
|
3793
|
+
|
|
3794
|
+
:param gt_ann: Annotation object containing ground truth labels
|
|
3795
|
+
:type gt_ann: Annotation
|
|
3796
|
+
:param pred_ann: Annotation object containing prediction labels to be filtered
|
|
3797
|
+
:type pred_ann: Annotation
|
|
3798
|
+
:param iou_threshold: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
3799
|
+
ground truth box of the same class will be removed
|
|
3800
|
+
:type iou_threshold: float
|
|
3801
|
+
:return: A new annotation object containing only predictions that don't significantly
|
|
3802
|
+
overlap with ground truth annotations
|
|
3803
|
+
:rtype: Annotation
|
|
3804
|
+
|
|
3805
|
+
|
|
3806
|
+
Notes:
|
|
3807
|
+
------
|
|
3808
|
+
- Predictions with classes not present in ground truth will be kept
|
|
3809
|
+
- Requires PyTorch and torchvision for IoU calculations
|
|
3810
|
+
"""
|
|
3720
3811
|
|
|
3721
|
-
|
|
3722
|
-
|
|
3723
|
-
|
|
3724
|
-
continue
|
|
3725
|
-
gt_cls_bboxes[label.obj_class.name].append(label)
|
|
3812
|
+
try:
|
|
3813
|
+
import torch
|
|
3814
|
+
from torchvision.ops import box_iou
|
|
3726
3815
|
|
|
3727
|
-
|
|
3728
|
-
|
|
3729
|
-
if len(gt) == 0:
|
|
3730
|
-
new_labels.extend(pred)
|
|
3731
|
-
continue
|
|
3732
|
-
pred_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in pred]).float()
|
|
3733
|
-
gt_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in gt]).float()
|
|
3734
|
-
iou_matrix = box_iou(pred_bboxes, gt_bboxes)
|
|
3735
|
-
iou_matrix = iou_matrix.cpu().numpy()
|
|
3736
|
-
keep_indices = np.where(np.all(iou_matrix < iou_threshold, axis=1))[0]
|
|
3737
|
-
new_labels.extend([pred[i] for i in keep_indices])
|
|
3816
|
+
except ImportError:
|
|
3817
|
+
raise ImportError("Please install PyTorch and torchvision to use this feature.")
|
|
3738
3818
|
|
|
3739
|
-
|
|
3819
|
+
def _to_tensor(geom):
|
|
3820
|
+
return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
|
|
3740
3821
|
|
|
3822
|
+
new_labels = []
|
|
3823
|
+
pred_cls_bboxes = defaultdict(list)
|
|
3824
|
+
for label in pred_ann.labels:
|
|
3825
|
+
pred_cls_bboxes[label.obj_class.name].append(label)
|
|
3826
|
+
|
|
3827
|
+
gt_cls_bboxes = defaultdict(list)
|
|
3828
|
+
for label in gt_ann.labels:
|
|
3829
|
+
if label.obj_class.name not in pred_cls_bboxes:
|
|
3830
|
+
continue
|
|
3831
|
+
gt_cls_bboxes[label.obj_class.name].append(label)
|
|
3832
|
+
|
|
3833
|
+
for name, pred in pred_cls_bboxes.items():
|
|
3834
|
+
gt = gt_cls_bboxes[name]
|
|
3835
|
+
if len(gt) == 0:
|
|
3836
|
+
new_labels.extend(pred)
|
|
3837
|
+
continue
|
|
3838
|
+
pred_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in pred]).float()
|
|
3839
|
+
gt_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in gt]).float()
|
|
3840
|
+
iou_matrix = box_iou(pred_bboxes, gt_bboxes)
|
|
3841
|
+
iou_matrix = iou_matrix.cpu().numpy()
|
|
3842
|
+
keep_indices = np.where(np.all(iou_matrix < iou_threshold, axis=1))[0]
|
|
3843
|
+
new_labels.extend([pred[i] for i in keep_indices])
|
|
3844
|
+
|
|
3845
|
+
return pred_ann.clone(labels=new_labels)
|
|
3846
|
+
|
|
3847
|
+
|
|
3848
|
+
def _get_log_extra_for_inference_request(
|
|
3849
|
+
inference_request_uuid, inference_request: Union[InferenceRequest, dict]
|
|
3850
|
+
):
|
|
3851
|
+
if isinstance(inference_request, dict):
|
|
3852
|
+
log_extra = {
|
|
3853
|
+
"uuid": inference_request_uuid,
|
|
3854
|
+
"progress": inference_request["progress"],
|
|
3855
|
+
"is_inferring": inference_request["is_inferring"],
|
|
3856
|
+
"cancel_inference": inference_request["cancel_inference"],
|
|
3857
|
+
"has_result": inference_request["result"] is not None,
|
|
3858
|
+
"pending_results": len(inference_request["pending_results"]),
|
|
3859
|
+
}
|
|
3860
|
+
return log_extra
|
|
3741
3861
|
|
|
3742
|
-
|
|
3862
|
+
progress = inference_request.progress_json()
|
|
3863
|
+
del progress["message"]
|
|
3743
3864
|
log_extra = {
|
|
3744
|
-
"uuid":
|
|
3745
|
-
"progress":
|
|
3746
|
-
"is_inferring": inference_request
|
|
3747
|
-
"
|
|
3748
|
-
"
|
|
3749
|
-
"
|
|
3865
|
+
"uuid": inference_request.uuid,
|
|
3866
|
+
"progress": progress,
|
|
3867
|
+
"is_inferring": inference_request.is_inferring(),
|
|
3868
|
+
"stopped": inference_request.is_stopped(),
|
|
3869
|
+
"finished": inference_request.is_finished(),
|
|
3870
|
+
"cancel_inference": inference_request.is_stopped(),
|
|
3871
|
+
"has_result": inference_request.final_result is not None,
|
|
3872
|
+
"pending_results": inference_request.pending_num(),
|
|
3873
|
+
"exception": inference_request.exception_json(),
|
|
3874
|
+
"result": inference_request._final_result,
|
|
3875
|
+
"preparing_progress": progress,
|
|
3750
3876
|
}
|
|
3751
3877
|
return log_extra
|
|
3752
3878
|
|
|
@@ -4059,3 +4185,33 @@ def get_hardware_info(device: str) -> str:
|
|
|
4059
4185
|
except Exception as e:
|
|
4060
4186
|
logger.error("Error while getting hardware info", exc_info=True)
|
|
4061
4187
|
return "Unknown"
|
|
4188
|
+
|
|
4189
|
+
|
|
4190
|
+
def progress_wrapper(func, progress_cb):
|
|
4191
|
+
@wraps(func)
|
|
4192
|
+
def wrapped_func(*args, **kwargs):
|
|
4193
|
+
result = func(*args, **kwargs)
|
|
4194
|
+
progress_cb(len(result))
|
|
4195
|
+
return result
|
|
4196
|
+
|
|
4197
|
+
return wrapped_func
|
|
4198
|
+
|
|
4199
|
+
|
|
4200
|
+
def batched_iter(iterable, batch_size):
|
|
4201
|
+
batch = []
|
|
4202
|
+
for item in iterable:
|
|
4203
|
+
batch.append(item)
|
|
4204
|
+
if len(batch) == batch_size:
|
|
4205
|
+
yield batch
|
|
4206
|
+
batch = []
|
|
4207
|
+
if batch:
|
|
4208
|
+
yield batch
|
|
4209
|
+
|
|
4210
|
+
|
|
4211
|
+
def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
|
|
4212
|
+
for key in keys:
|
|
4213
|
+
if key in data:
|
|
4214
|
+
if ignore_none and data[key] is None:
|
|
4215
|
+
continue
|
|
4216
|
+
return data[key]
|
|
4217
|
+
return None
|