supervisely 6.73.283__py3-none-any.whl → 6.73.285__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/_utils.py +9 -0
- supervisely/api/entity_annotation/figure_api.py +3 -0
- supervisely/api/module_api.py +35 -1
- supervisely/api/video/video_api.py +1 -1
- supervisely/api/video_annotation_tool_api.py +58 -7
- supervisely/nn/benchmark/base_benchmark.py +13 -2
- supervisely/nn/benchmark/base_evaluator.py +2 -0
- supervisely/nn/benchmark/comparison/detection_visualization/text_templates.py +5 -0
- supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +25 -0
- supervisely/nn/benchmark/comparison/detection_visualization/visualizer.py +9 -3
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -0
- supervisely/nn/benchmark/instance_segmentation/text_templates.py +7 -0
- supervisely/nn/benchmark/object_detection/evaluator.py +15 -3
- supervisely/nn/benchmark/object_detection/metric_provider.py +21 -1
- supervisely/nn/benchmark/object_detection/text_templates.py +7 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/key_metrics.py +12 -0
- supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +41 -2
- supervisely/nn/benchmark/object_detection/visualizer.py +20 -0
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +1 -0
- supervisely/nn/benchmark/utils/detection/calculate_metrics.py +31 -33
- supervisely/nn/benchmark/visualization/renderer.py +2 -0
- supervisely/nn/inference/cache.py +19 -1
- supervisely/nn/inference/inference.py +22 -0
- supervisely/nn/inference/tracking/base_tracking.py +362 -0
- supervisely/nn/inference/tracking/bbox_tracking.py +179 -129
- supervisely/nn/inference/tracking/mask_tracking.py +420 -329
- supervisely/nn/inference/tracking/point_tracking.py +325 -288
- supervisely/nn/inference/tracking/tracker_interface.py +346 -13
- {supervisely-6.73.283.dist-info → supervisely-6.73.285.dist-info}/METADATA +1 -1
- {supervisely-6.73.283.dist-info → supervisely-6.73.285.dist-info}/RECORD +34 -33
- {supervisely-6.73.283.dist-info → supervisely-6.73.285.dist-info}/LICENSE +0 -0
- {supervisely-6.73.283.dist-info → supervisely-6.73.285.dist-info}/WHEEL +0 -0
- {supervisely-6.73.283.dist-info → supervisely-6.73.285.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.283.dist-info → supervisely-6.73.285.dist-info}/top_level.txt +0 -0
|
@@ -48,8 +48,11 @@ def calculate_metrics(
|
|
|
48
48
|
:return: Results of the evaluation
|
|
49
49
|
:rtype: dict
|
|
50
50
|
"""
|
|
51
|
+
from pycocotools.coco import COCO # pylint: disable=import-error
|
|
51
52
|
from pycocotools.cocoeval import COCOeval # pylint: disable=import-error
|
|
52
53
|
|
|
54
|
+
cocoGt: COCO = cocoGt
|
|
55
|
+
|
|
53
56
|
cocoEval = COCOeval(cocoGt, cocoDt, iouType=iouType)
|
|
54
57
|
cocoEval.evaluate()
|
|
55
58
|
progress_cb(1) if progress_cb is not None else None
|
|
@@ -66,23 +69,33 @@ def calculate_metrics(
|
|
|
66
69
|
progress_cb(1) if progress_cb is not None else None
|
|
67
70
|
cocoEval_cls.summarize()
|
|
68
71
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
72
|
+
iouThrs = cocoEval.params.iouThrs
|
|
73
|
+
evaluation_params = evaluation_params or {}
|
|
74
|
+
iou_threshold = evaluation_params.get("iou_threshold", 0.5)
|
|
75
|
+
iou_threshold_per_class = evaluation_params.get("iou_threshold_per_class")
|
|
76
|
+
if iou_threshold_per_class is not None:
|
|
77
|
+
iou_idx_per_class = {
|
|
78
|
+
cocoGt.getCatIds(catNms=[class_name])[0]: np.where(np.isclose(iouThrs, iou_thres))[0][0]
|
|
79
|
+
for class_name, iou_thres in iou_threshold_per_class.items()
|
|
80
|
+
}
|
|
81
|
+
else:
|
|
82
|
+
iou_idx = np.where(np.isclose(iouThrs, iou_threshold))[0][0]
|
|
83
|
+
iou_idx_per_class = {cat_id: iou_idx for cat_id in cocoGt.getCatIds()}
|
|
77
84
|
|
|
78
85
|
eval_img_dict = get_eval_img_dict(cocoEval)
|
|
79
86
|
eval_img_dict_cls = get_eval_img_dict(cocoEval_cls)
|
|
80
|
-
matches = get_matches(
|
|
87
|
+
matches = get_matches(
|
|
88
|
+
eval_img_dict,
|
|
89
|
+
eval_img_dict_cls,
|
|
90
|
+
cocoEval_cls,
|
|
91
|
+
iou_idx_per_class=iou_idx_per_class,
|
|
92
|
+
)
|
|
81
93
|
|
|
82
94
|
params = {
|
|
83
95
|
"iouThrs": cocoEval.params.iouThrs,
|
|
84
96
|
"recThrs": cocoEval.params.recThrs,
|
|
85
|
-
"evaluation_params": evaluation_params
|
|
97
|
+
"evaluation_params": evaluation_params,
|
|
98
|
+
"iou_idx_per_class": iou_idx_per_class,
|
|
86
99
|
}
|
|
87
100
|
coco_metrics = {"mAP": cocoEval.stats[0], "precision": cocoEval.eval["precision"]}
|
|
88
101
|
coco_metrics["AP50"] = cocoEval.stats[1]
|
|
@@ -204,27 +217,6 @@ def get_eval_img_dict(cocoEval):
|
|
|
204
217
|
return eval_img_dict
|
|
205
218
|
|
|
206
219
|
|
|
207
|
-
def get_eval_img_dict_cls(cocoEval_cls):
|
|
208
|
-
"""
|
|
209
|
-
type cocoEval_cls: COCOeval
|
|
210
|
-
"""
|
|
211
|
-
# For miss-classification
|
|
212
|
-
aRng = cocoEval_cls.params.areaRng[0]
|
|
213
|
-
eval_img_dict_cls = defaultdict(list) # img_id : dt/gt
|
|
214
|
-
for i, eval_img in enumerate(cocoEval_cls.evalImgs):
|
|
215
|
-
if eval_img is None or eval_img["aRng"] != aRng:
|
|
216
|
-
continue
|
|
217
|
-
img_id = eval_img["image_id"]
|
|
218
|
-
cat_id = eval_img["category_id"]
|
|
219
|
-
ious = cocoEval_cls.ious[(img_id, cat_id)]
|
|
220
|
-
# ! inplace operation
|
|
221
|
-
eval_img["ious"] = ious
|
|
222
|
-
eval_img_dict_cls[img_id].append(eval_img)
|
|
223
|
-
eval_img_dict_cls = dict(eval_img_dict_cls)
|
|
224
|
-
assert np.all([len(x) == 1 for x in eval_img_dict_cls.values()])
|
|
225
|
-
return eval_img_dict_cls
|
|
226
|
-
|
|
227
|
-
|
|
228
220
|
def _get_missclassified_match(eval_img_cls, dt_id, gtIds_orig, dtIds_orig, iou_t):
|
|
229
221
|
# Correction on miss-classification
|
|
230
222
|
gt_idx = np.nonzero(eval_img_cls["gtMatches"][iou_t] == dt_id)[0]
|
|
@@ -242,7 +234,12 @@ def _get_missclassified_match(eval_img_cls, dt_id, gtIds_orig, dtIds_orig, iou_t
|
|
|
242
234
|
return None, None
|
|
243
235
|
|
|
244
236
|
|
|
245
|
-
def get_matches(
|
|
237
|
+
def get_matches(
|
|
238
|
+
eval_img_dict: dict,
|
|
239
|
+
eval_img_dict_cls: dict,
|
|
240
|
+
cocoEval_cls,
|
|
241
|
+
iou_idx_per_class: dict = None,
|
|
242
|
+
):
|
|
246
243
|
"""
|
|
247
244
|
type cocoEval_cls: COCOeval
|
|
248
245
|
"""
|
|
@@ -255,7 +252,8 @@ def get_matches(eval_img_dict: dict, eval_img_dict_cls: dict, cocoEval_cls, iou_
|
|
|
255
252
|
gt_ids_orig_cls = [_["id"] for i in cat_ids for _ in cocoEval_cls._gts[img_id, i]]
|
|
256
253
|
|
|
257
254
|
for eval_img in eval_imgs:
|
|
258
|
-
|
|
255
|
+
cat_id = eval_img["category_id"]
|
|
256
|
+
iou_t = iou_idx_per_class[cat_id]
|
|
259
257
|
dtIds = np.array(eval_img["dtIds"])
|
|
260
258
|
gtIds = np.array(eval_img["gtIds"])
|
|
261
259
|
dtm = eval_img["dtMatches"][iou_t]
|
|
@@ -8,6 +8,7 @@ from supervisely.api.api import Api
|
|
|
8
8
|
from supervisely.io.fs import dir_empty, get_directory_size
|
|
9
9
|
from supervisely.nn.benchmark.visualization.widgets import BaseWidget
|
|
10
10
|
from supervisely.task.progress import tqdm_sly
|
|
11
|
+
from supervisely import logger
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class Renderer:
|
|
@@ -95,6 +96,7 @@ class Renderer:
|
|
|
95
96
|
pth = Path(self.base_dir).joinpath(self.report_name)
|
|
96
97
|
with open(pth, "w") as f:
|
|
97
98
|
f.write(report_link)
|
|
99
|
+
logger.debug(f"Report link: {self._get_report_link(api, team_id, remote_dir)}")
|
|
98
100
|
return str(pth)
|
|
99
101
|
|
|
100
102
|
def _get_report_link(self, api: Api, team_id: int, remote_dir: str):
|
|
@@ -303,6 +303,22 @@ class InferenceImageCache:
|
|
|
303
303
|
self.get_frame_from_cache(video_id, frame_index) for frame_index in frame_indexes
|
|
304
304
|
]
|
|
305
305
|
|
|
306
|
+
def frames_loader(
|
|
307
|
+
self, api: sly.Api, video_id: int, frame_indexes: List[int]
|
|
308
|
+
) -> Generator[np.ndarray, None, None]:
|
|
309
|
+
if not isinstance(self._cache, PersistentImageTTLCache):
|
|
310
|
+
for frame_index in frame_indexes:
|
|
311
|
+
yield self.download_frame(api, video_id, frame_index)
|
|
312
|
+
return
|
|
313
|
+
self.run_cache_task_manually(api, None, video_id=video_id)
|
|
314
|
+
for i, frame_index in enumerate(frame_indexes):
|
|
315
|
+
if video_id in self._cache:
|
|
316
|
+
break
|
|
317
|
+
yield self.download_frame(api, video_id, frame_index)
|
|
318
|
+
if i < len(frame_indexes):
|
|
319
|
+
for frame in self._read_frames_from_cached_video_iter(video_id, frame_indexes[i:]):
|
|
320
|
+
yield frame
|
|
321
|
+
|
|
306
322
|
def download_frame(self, api: sly.Api, video_id: int, frame_index: int) -> np.ndarray:
|
|
307
323
|
name = self._frame_name(video_id, frame_index)
|
|
308
324
|
self._wait_if_in_queue(name, api.logger)
|
|
@@ -401,7 +417,9 @@ class InferenceImageCache:
|
|
|
401
417
|
"""
|
|
402
418
|
return_images = kwargs.get("return_images", True)
|
|
403
419
|
progress_cb = kwargs.get("progress_cb", None)
|
|
404
|
-
video_info = kwargs.get("video_info",
|
|
420
|
+
video_info = kwargs.get("video_info", None)
|
|
421
|
+
if video_info is None:
|
|
422
|
+
video_info = api.video.get_info_by_id(video_id)
|
|
405
423
|
|
|
406
424
|
self._wait_if_in_queue(video_id, api.logger)
|
|
407
425
|
if not video_id in self._cache:
|
|
@@ -2252,6 +2252,25 @@ class Inference:
|
|
|
2252
2252
|
def is_model_deployed(self):
|
|
2253
2253
|
return self._model_served
|
|
2254
2254
|
|
|
2255
|
+
def schedule_task(self, func, *args, **kwargs):
|
|
2256
|
+
inference_request_uuid = kwargs.get("inference_request_uuid", None)
|
|
2257
|
+
if inference_request_uuid is None:
|
|
2258
|
+
self._executor.submit(func, *args, **kwargs)
|
|
2259
|
+
else:
|
|
2260
|
+
self._on_inference_start(inference_request_uuid)
|
|
2261
|
+
future = self._executor.submit(
|
|
2262
|
+
self._handle_error_in_async,
|
|
2263
|
+
inference_request_uuid,
|
|
2264
|
+
func,
|
|
2265
|
+
*args,
|
|
2266
|
+
**kwargs,
|
|
2267
|
+
)
|
|
2268
|
+
end_callback = partial(
|
|
2269
|
+
self._on_inference_end, inference_request_uuid=inference_request_uuid
|
|
2270
|
+
)
|
|
2271
|
+
future.add_done_callback(end_callback)
|
|
2272
|
+
logger.debug("Scheduled task.", extra={"inference_request_uuid": inference_request_uuid})
|
|
2273
|
+
|
|
2255
2274
|
def serve(self):
|
|
2256
2275
|
if not self._use_gui:
|
|
2257
2276
|
Progress("Deploying model ...", 1)
|
|
@@ -2320,6 +2339,9 @@ class Inference:
|
|
|
2320
2339
|
else:
|
|
2321
2340
|
autostart_func()
|
|
2322
2341
|
|
|
2342
|
+
self.cache.add_cache_endpoint(server)
|
|
2343
|
+
self.cache.add_cache_files_endpoint(server)
|
|
2344
|
+
|
|
2323
2345
|
@server.post(f"/get_session_info")
|
|
2324
2346
|
@self._check_serve_before_call
|
|
2325
2347
|
def get_session_info(response: Response):
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
3
|
+
import json
|
|
4
|
+
import traceback
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
from fastapi import Form, Request, Response, UploadFile, status
|
|
9
|
+
from pydantic import ValidationError
|
|
10
|
+
|
|
11
|
+
from supervisely._utils import find_value_by_keys
|
|
12
|
+
from supervisely.api.api import Api
|
|
13
|
+
from supervisely.api.module_api import ApiField
|
|
14
|
+
from supervisely.io import env
|
|
15
|
+
from supervisely.nn.inference.inference import (
|
|
16
|
+
Inference,
|
|
17
|
+
_convert_sly_progress_to_dict,
|
|
18
|
+
_get_log_extra_for_inference_request,
|
|
19
|
+
)
|
|
20
|
+
from supervisely.sly_logger import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def validate_key(data: Dict, key: str, type_: type):
|
|
24
|
+
if key not in data:
|
|
25
|
+
raise ValidationError(f"Key {key} not found in inference request.")
|
|
26
|
+
if not isinstance(data[key], type_):
|
|
27
|
+
raise ValidationError(f"Key {key} is not of type {type_}.")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def handle_validation(func):
|
|
31
|
+
def _find_response(args, kwargs):
|
|
32
|
+
for arg in args:
|
|
33
|
+
if isinstance(arg, Response):
|
|
34
|
+
return arg
|
|
35
|
+
for value in kwargs.values():
|
|
36
|
+
if isinstance(value, Response):
|
|
37
|
+
return value
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
def _handle_exception(e, response):
|
|
41
|
+
if response is not None:
|
|
42
|
+
logger.error(f"ValidationError: {e}", exc_info=True)
|
|
43
|
+
response.status_code = status.HTTP_400_BAD_REQUEST
|
|
44
|
+
return {"error": str(e), "success": False}
|
|
45
|
+
raise e
|
|
46
|
+
|
|
47
|
+
if inspect.iscoroutinefunction(func):
|
|
48
|
+
|
|
49
|
+
@functools.wraps(func)
|
|
50
|
+
async def async_wrapper(*args, **kwargs):
|
|
51
|
+
response = _find_response(args, kwargs)
|
|
52
|
+
try:
|
|
53
|
+
return await func(*args, **kwargs)
|
|
54
|
+
except ValidationError as e:
|
|
55
|
+
return _handle_exception(e, response)
|
|
56
|
+
|
|
57
|
+
return async_wrapper
|
|
58
|
+
|
|
59
|
+
@functools.wraps(func)
|
|
60
|
+
def wrapper(*args, **kwargs):
|
|
61
|
+
response = _find_response(args, kwargs)
|
|
62
|
+
try:
|
|
63
|
+
return func(*args, **kwargs)
|
|
64
|
+
except ValidationError as e:
|
|
65
|
+
return _handle_exception(e, response)
|
|
66
|
+
|
|
67
|
+
return wrapper
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class BaseTracking(Inference):
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
model_dir: Optional[str] = None,
|
|
74
|
+
custom_inference_settings: Optional[Union[Dict[str, Any], str]] = None,
|
|
75
|
+
):
|
|
76
|
+
Inference.__init__(
|
|
77
|
+
self,
|
|
78
|
+
model_dir,
|
|
79
|
+
custom_inference_settings,
|
|
80
|
+
sliding_window_mode=None,
|
|
81
|
+
use_gui=False,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
self.load_on_device(model_dir, "cuda")
|
|
86
|
+
except RuntimeError:
|
|
87
|
+
self.load_on_device(model_dir, "cpu")
|
|
88
|
+
logger.warning("Failed to load model on CUDA device.")
|
|
89
|
+
|
|
90
|
+
logger.debug(
|
|
91
|
+
"Smart cache params",
|
|
92
|
+
extra={"ttl": env.smart_cache_ttl(), "maxsize": env.smart_cache_size()},
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def get_info(self):
|
|
96
|
+
info = super().get_info()
|
|
97
|
+
info["task type"] = "tracking"
|
|
98
|
+
return info
|
|
99
|
+
|
|
100
|
+
def _on_inference_start(self, inference_request_uuid: str):
|
|
101
|
+
super()._on_inference_start(inference_request_uuid)
|
|
102
|
+
self._inference_requests[inference_request_uuid]["lock"] = Lock()
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def _notify_error_default(
|
|
106
|
+
api: Api, track_id: str, exception: Exception, with_traceback: bool = False
|
|
107
|
+
):
|
|
108
|
+
error_name = type(exception).__name__
|
|
109
|
+
message = str(exception)
|
|
110
|
+
if with_traceback:
|
|
111
|
+
message = f"{message}\n{traceback.format_exc()}"
|
|
112
|
+
api.video.notify_tracking_error(track_id, error_name, message)
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def _notify_error_direct(
|
|
116
|
+
api: Api,
|
|
117
|
+
session_id: str,
|
|
118
|
+
video_id,
|
|
119
|
+
track_id: str,
|
|
120
|
+
exception: Exception,
|
|
121
|
+
with_traceback: bool = False,
|
|
122
|
+
):
|
|
123
|
+
error_name = type(exception).__name__
|
|
124
|
+
message = str(exception)
|
|
125
|
+
if with_traceback:
|
|
126
|
+
message = f"{message}\n{traceback.format_exc()}"
|
|
127
|
+
api.vid_ann_tool.set_direct_tracking_error(
|
|
128
|
+
session_id=session_id,
|
|
129
|
+
video_id=video_id,
|
|
130
|
+
track_id=track_id,
|
|
131
|
+
message=f"{error_name}: {message}",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def _handle_error_in_async(self, uuid):
|
|
135
|
+
def decorator(func):
|
|
136
|
+
@functools.wraps(func)
|
|
137
|
+
def wrapper(*args, **kwargs):
|
|
138
|
+
try:
|
|
139
|
+
return func(*args, **kwargs)
|
|
140
|
+
except Exception as e:
|
|
141
|
+
inf_request = self._inference_requests.get(uuid, None)
|
|
142
|
+
if inf_request is not None:
|
|
143
|
+
inf_request["exception"] = str(e)
|
|
144
|
+
logger.error(f"Error in {func.__name__} function: {e}", exc_info=True)
|
|
145
|
+
raise e
|
|
146
|
+
|
|
147
|
+
return wrapper
|
|
148
|
+
|
|
149
|
+
return decorator
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def send_error_data(api, context):
|
|
153
|
+
def decorator(func):
|
|
154
|
+
@functools.wraps(func)
|
|
155
|
+
def wrapper(*args, **kwargs):
|
|
156
|
+
try:
|
|
157
|
+
return func(*args, **kwargs)
|
|
158
|
+
except Exception as exc:
|
|
159
|
+
try:
|
|
160
|
+
track_id = context["trackId"]
|
|
161
|
+
if ApiField.USE_DIRECT_PROGRESS_MESSAGES in context:
|
|
162
|
+
session_id = find_value_by_keys(context, ["sessionId", "session_id"])
|
|
163
|
+
video_id = find_value_by_keys(context, ["videoId", "video_id"])
|
|
164
|
+
BaseTracking._notify_error_direct(
|
|
165
|
+
api=api,
|
|
166
|
+
session_id=session_id,
|
|
167
|
+
video_id=video_id,
|
|
168
|
+
track_id=track_id,
|
|
169
|
+
exception=exc,
|
|
170
|
+
with_traceback=False,
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
BaseTracking._notify_error_default(
|
|
174
|
+
api=api, track_id=track_id, exception=exc, with_traceback=False
|
|
175
|
+
)
|
|
176
|
+
except Exception:
|
|
177
|
+
logger.error("An error occurred while sending error data", exc_info=True)
|
|
178
|
+
raise exc
|
|
179
|
+
|
|
180
|
+
return wrapper
|
|
181
|
+
|
|
182
|
+
return decorator
|
|
183
|
+
|
|
184
|
+
def schedule_task(self, func, *args, **kwargs):
|
|
185
|
+
inference_request_uuid = kwargs.get("inference_request_uuid", None)
|
|
186
|
+
if inference_request_uuid is None:
|
|
187
|
+
self._executor.submit(func, *args, **kwargs)
|
|
188
|
+
else:
|
|
189
|
+
self._on_inference_start(inference_request_uuid)
|
|
190
|
+
fn = self._handle_error_in_async(inference_request_uuid)(func)
|
|
191
|
+
future = self._executor.submit(
|
|
192
|
+
fn,
|
|
193
|
+
*args,
|
|
194
|
+
**kwargs,
|
|
195
|
+
)
|
|
196
|
+
end_callback = functools.partial(
|
|
197
|
+
self._on_inference_end, inference_request_uuid=inference_request_uuid
|
|
198
|
+
)
|
|
199
|
+
future.add_done_callback(end_callback)
|
|
200
|
+
logger.debug("Scheduled task.", extra={"inference_request_uuid": inference_request_uuid})
|
|
201
|
+
|
|
202
|
+
def _pop_tracking_results(self, inference_request_uuid: str, frame_range: Tuple = None):
|
|
203
|
+
inference_request = self._inference_requests[inference_request_uuid]
|
|
204
|
+
logger.debug(
|
|
205
|
+
"Pop tracking results",
|
|
206
|
+
extra={
|
|
207
|
+
"inference_request_uuid": inference_request_uuid,
|
|
208
|
+
"pending_results_len": len(inference_request["pending_results"]),
|
|
209
|
+
"frame_range": frame_range,
|
|
210
|
+
},
|
|
211
|
+
)
|
|
212
|
+
with inference_request["lock"]:
|
|
213
|
+
inference_request_copy = inference_request.copy()
|
|
214
|
+
|
|
215
|
+
if frame_range is not None:
|
|
216
|
+
|
|
217
|
+
def _in_range(figure):
|
|
218
|
+
return (
|
|
219
|
+
figure.frame_index >= frame_range[0]
|
|
220
|
+
and figure.frame_index <= frame_range[1]
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
inference_request_copy["pending_results"] = list(
|
|
224
|
+
filter(_in_range, inference_request_copy["pending_results"])
|
|
225
|
+
)
|
|
226
|
+
inference_request["pending_results"] = list(
|
|
227
|
+
filter(lambda x: not _in_range(x), inference_request["pending_results"])
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
inference_request["pending_results"] = []
|
|
231
|
+
inference_request_copy.pop("lock")
|
|
232
|
+
inference_request_copy["progress"] = _convert_sly_progress_to_dict(
|
|
233
|
+
inference_request_copy["progress"]
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
inference_request_copy["pending_results"] = [
|
|
237
|
+
figure.to_json() for figure in inference_request_copy["pending_results"]
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
return inference_request_copy
|
|
241
|
+
|
|
242
|
+
def _clear_tracking_results(self, inference_request_uuid):
|
|
243
|
+
del self._inference_requests[inference_request_uuid]
|
|
244
|
+
logger.debug("Removed an inference request:", extra={"uuid": inference_request_uuid})
|
|
245
|
+
|
|
246
|
+
def _stop_tracking(self, inference_request_uuid: str):
|
|
247
|
+
inference_request = self._inference_requests[inference_request_uuid]
|
|
248
|
+
inference_request["cancel_inference"] = True
|
|
249
|
+
logger.debug("Stopped tracking:", extra={"uuid": inference_request_uuid})
|
|
250
|
+
|
|
251
|
+
# Implement the following methods in the derived class
|
|
252
|
+
def track(self, api: Api, state: Dict, context: Dict):
|
|
253
|
+
raise NotImplementedError("Method `track` must be implemented.")
|
|
254
|
+
|
|
255
|
+
def track_api(self, api: Api, state: Dict, context: Dict):
|
|
256
|
+
raise NotImplementedError("Method `_track_api` must be implemented.")
|
|
257
|
+
|
|
258
|
+
def track_api_files(
|
|
259
|
+
self,
|
|
260
|
+
files: List[BinaryIO],
|
|
261
|
+
settings: Dict,
|
|
262
|
+
):
|
|
263
|
+
raise NotImplementedError("Method `track_api_files` must be implemented.")
|
|
264
|
+
|
|
265
|
+
def track_async(self, api: Api, state: Dict, context: Dict):
|
|
266
|
+
raise NotImplementedError("Method `track_async` must be implemented.")
|
|
267
|
+
|
|
268
|
+
def stop_tracking(self, state: Dict, context: Dict):
|
|
269
|
+
validate_key(context, "inference_request_uuid", str)
|
|
270
|
+
inference_request_uuid = context["inference_request_uuid"]
|
|
271
|
+
self._stop_tracking(inference_request_uuid)
|
|
272
|
+
return {"message": "Inference will be stopped.", "success": True}
|
|
273
|
+
|
|
274
|
+
def pop_tracking_results(self, state: Dict, context: Dict):
|
|
275
|
+
validate_key(context, "inference_request_uuid", str)
|
|
276
|
+
inference_request_uuid = context["inference_request_uuid"]
|
|
277
|
+
frame_range = find_value_by_keys(context, ["frameRange", "frame_range", "frames"])
|
|
278
|
+
tracking_results = self._pop_tracking_results(inference_request_uuid, frame_range)
|
|
279
|
+
log_extra = _get_log_extra_for_inference_request(inference_request_uuid, tracking_results)
|
|
280
|
+
logger.debug(f"Sending inference delta results with uuid:", extra=log_extra)
|
|
281
|
+
return tracking_results
|
|
282
|
+
|
|
283
|
+
def clear_tracking_results(self, state: Dict, context: Dict):
|
|
284
|
+
self._clear_tracking_results(context)
|
|
285
|
+
return {"message": "Inference results cleared.", "success": True}
|
|
286
|
+
|
|
287
|
+
def _register_endpoints(self):
|
|
288
|
+
server = self._app.get_server()
|
|
289
|
+
|
|
290
|
+
@server.post("/track")
|
|
291
|
+
@handle_validation
|
|
292
|
+
def track_handler(request: Request):
|
|
293
|
+
api = request.state.api
|
|
294
|
+
state = request.state.state
|
|
295
|
+
context = request.state.context
|
|
296
|
+
logger.info("Received track request.", extra={"context": context, "state": state})
|
|
297
|
+
return self.track(api, state, context)
|
|
298
|
+
|
|
299
|
+
@server.post("/track-api")
|
|
300
|
+
@handle_validation
|
|
301
|
+
async def track_api_handler(request: Request):
|
|
302
|
+
api = request.state.api
|
|
303
|
+
state = request.state.state
|
|
304
|
+
context = request.state.context
|
|
305
|
+
logger.info("Received track-api request.", extra={"context": context, "state": state})
|
|
306
|
+
result = self.track_api(api, state, context)
|
|
307
|
+
logger.info("Track-api request processed.")
|
|
308
|
+
return result
|
|
309
|
+
|
|
310
|
+
@server.post("/track-api-files")
|
|
311
|
+
@handle_validation
|
|
312
|
+
def track_api_files(
|
|
313
|
+
files: List[UploadFile],
|
|
314
|
+
settings: str = Form("{}"),
|
|
315
|
+
):
|
|
316
|
+
files = [file.file for file in files]
|
|
317
|
+
settings = json.loads(settings)
|
|
318
|
+
return self.track_api_files(files, settings)
|
|
319
|
+
|
|
320
|
+
@server.post("/track_async")
|
|
321
|
+
@handle_validation
|
|
322
|
+
def track_async_handler(request: Request):
|
|
323
|
+
api = request.state.api
|
|
324
|
+
state = request.state.state
|
|
325
|
+
context = request.state.context
|
|
326
|
+
logger.info("Received track_async request.", extra={"context": context, "state": state})
|
|
327
|
+
return self.track_async(api, state, context)
|
|
328
|
+
|
|
329
|
+
@server.post("/stop_tracking")
|
|
330
|
+
@handle_validation
|
|
331
|
+
def stop_tracking_handler(response: Response, request: Request):
|
|
332
|
+
state = request.state.state
|
|
333
|
+
context = request.state.context
|
|
334
|
+
logger.info(
|
|
335
|
+
"Received stop_tracking request.", extra={"context": context, "state": state}
|
|
336
|
+
)
|
|
337
|
+
return self.stop_tracking(state, context)
|
|
338
|
+
|
|
339
|
+
@server.post("/pop_tracking_results")
|
|
340
|
+
@handle_validation
|
|
341
|
+
def pop_tracking_results_handler(request: Request, response: Response):
|
|
342
|
+
state = request.state.state
|
|
343
|
+
context = request.state.context
|
|
344
|
+
logger.info(
|
|
345
|
+
"Received pop_tracking_results request.", extra={"context": context, "state": state}
|
|
346
|
+
)
|
|
347
|
+
return self.pop_tracking_results(state, context)
|
|
348
|
+
|
|
349
|
+
@server.post("/clear_tracking_results")
|
|
350
|
+
@handle_validation
|
|
351
|
+
def clear_tracking_results_handler(request: Request, response: Response):
|
|
352
|
+
context = request.state.context
|
|
353
|
+
state = request.state.state
|
|
354
|
+
logger.info(
|
|
355
|
+
"Received clear_tracking_results request.",
|
|
356
|
+
extra={"context": context, "state": state},
|
|
357
|
+
)
|
|
358
|
+
return self.clear_tracking_results(state, context)
|
|
359
|
+
|
|
360
|
+
def serve(self):
|
|
361
|
+
super().serve()
|
|
362
|
+
self._register_endpoints()
|