supervisely 6.73.357__py3-none-any.whl → 6.73.358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/_utils.py +12 -0
- supervisely/api/annotation_api.py +3 -0
- supervisely/api/api.py +2 -2
- supervisely/api/app_api.py +27 -2
- supervisely/api/entity_annotation/tag_api.py +0 -1
- supervisely/api/nn/__init__.py +0 -0
- supervisely/api/nn/deploy_api.py +821 -0
- supervisely/api/nn/neural_network_api.py +248 -0
- supervisely/api/task_api.py +26 -467
- supervisely/app/fastapi/subapp.py +1 -0
- supervisely/nn/__init__.py +2 -1
- supervisely/nn/artifacts/artifacts.py +5 -5
- supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
- supervisely/nn/experiments.py +28 -5
- supervisely/nn/inference/cache.py +178 -114
- supervisely/nn/inference/gui/gui.py +18 -35
- supervisely/nn/inference/gui/serving_gui.py +3 -1
- supervisely/nn/inference/inference.py +1421 -1265
- supervisely/nn/inference/inference_request.py +412 -0
- supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
- supervisely/nn/inference/session.py +2 -2
- supervisely/nn/inference/tracking/base_tracking.py +45 -79
- supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
- supervisely/nn/inference/tracking/mask_tracking.py +274 -250
- supervisely/nn/inference/tracking/tracker_interface.py +23 -0
- supervisely/nn/inference/uploader.py +164 -0
- supervisely/nn/model/__init__.py +0 -0
- supervisely/nn/model/model_api.py +259 -0
- supervisely/nn/model/prediction.py +311 -0
- supervisely/nn/model/prediction_session.py +632 -0
- supervisely/nn/tracking/__init__.py +1 -0
- supervisely/nn/tracking/boxmot.py +114 -0
- supervisely/nn/tracking/tracking.py +24 -0
- supervisely/nn/training/train_app.py +61 -19
- supervisely/nn/utils.py +43 -3
- supervisely/task/progress.py +12 -2
- supervisely/video/video.py +107 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/METADATA +2 -1
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/RECORD +43 -32
- supervisely/api/neural_network_api.py +0 -202
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/LICENSE +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/WHEEL +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from os import PathLike
|
|
4
|
+
from typing import Any, Dict, Iterator, List, Literal, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import requests
|
|
8
|
+
from requests import Timeout
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
import supervisely.io.env as env
|
|
12
|
+
from supervisely._utils import get_valid_kwargs, logger
|
|
13
|
+
from supervisely.api.api import Api
|
|
14
|
+
from supervisely.imaging._video import ALLOWED_VIDEO_EXTENSIONS
|
|
15
|
+
from supervisely.imaging.image import SUPPORTED_IMG_EXTS, write_bytes
|
|
16
|
+
from supervisely.io.fs import (
|
|
17
|
+
dir_exists,
|
|
18
|
+
file_exists,
|
|
19
|
+
get_file_ext,
|
|
20
|
+
list_files,
|
|
21
|
+
list_files_recursively,
|
|
22
|
+
)
|
|
23
|
+
from supervisely.io.network_exceptions import process_requests_exception
|
|
24
|
+
from supervisely.nn.model.prediction import Prediction
|
|
25
|
+
from supervisely.project.project import Dataset, OpenMode, Project
|
|
26
|
+
from supervisely.project.project_meta import ProjectMeta
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def value_generator(value):
|
|
30
|
+
while True:
|
|
31
|
+
yield value
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PredictionSession:
|
|
35
|
+
|
|
36
|
+
class Iterator:
|
|
37
|
+
def __init__(self, total, session: "PredictionSession", tqdm: tqdm = None):
|
|
38
|
+
self.total = total
|
|
39
|
+
self.session = session
|
|
40
|
+
self.results_queue = []
|
|
41
|
+
self.tqdm = tqdm
|
|
42
|
+
|
|
43
|
+
def __len__(self) -> int:
|
|
44
|
+
return self.total
|
|
45
|
+
|
|
46
|
+
def __iter__(self) -> Iterator:
|
|
47
|
+
return self
|
|
48
|
+
|
|
49
|
+
def __next__(self) -> Dict[str, Any]:
|
|
50
|
+
if not self.results_queue:
|
|
51
|
+
pending_results = self.session._wait_for_pending_results(tqdm=self.tqdm)
|
|
52
|
+
self.results_queue += pending_results
|
|
53
|
+
if not self.results_queue:
|
|
54
|
+
raise StopIteration
|
|
55
|
+
pred = self.results_queue.pop(0)
|
|
56
|
+
return pred
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
url: str,
|
|
61
|
+
input: Union[np.ndarray, str, PathLike, list] = None,
|
|
62
|
+
image_id: Union[List[int], int] = None,
|
|
63
|
+
video_id: Union[List[int], int] = None,
|
|
64
|
+
dataset_id: Union[List[int], int] = None,
|
|
65
|
+
project_id: Union[List[int], int] = None,
|
|
66
|
+
api: "Api" = None,
|
|
67
|
+
**kwargs: dict,
|
|
68
|
+
):
|
|
69
|
+
extra_input_args = ["image_ids", "video_ids", "dataset_ids", "project_ids"]
|
|
70
|
+
assert (
|
|
71
|
+
sum(
|
|
72
|
+
[
|
|
73
|
+
x is not None
|
|
74
|
+
for x in [
|
|
75
|
+
input,
|
|
76
|
+
image_id,
|
|
77
|
+
video_id,
|
|
78
|
+
dataset_id,
|
|
79
|
+
project_id,
|
|
80
|
+
*[kwargs.get(extra_input, None) for extra_input in extra_input_args],
|
|
81
|
+
]
|
|
82
|
+
]
|
|
83
|
+
)
|
|
84
|
+
== 1
|
|
85
|
+
), "Exactly one of input, image_ids, video_id, dataset_id, project_id or image_id must be provided."
|
|
86
|
+
|
|
87
|
+
self._iterator = None
|
|
88
|
+
self._base_url = url
|
|
89
|
+
self.inference_request_uuid = None
|
|
90
|
+
self.input = input
|
|
91
|
+
self.api = api
|
|
92
|
+
|
|
93
|
+
self.api_token = self._get_api_token()
|
|
94
|
+
self._model_meta = None
|
|
95
|
+
self.final_result = None
|
|
96
|
+
|
|
97
|
+
if "stride" in kwargs:
|
|
98
|
+
kwargs["step"] = kwargs["stride"]
|
|
99
|
+
if "start_frame" in kwargs:
|
|
100
|
+
kwargs["start_frame_index"] = kwargs["start_frame"]
|
|
101
|
+
if "num_frames" in kwargs:
|
|
102
|
+
kwargs["frames_count"] = kwargs["num_frames"]
|
|
103
|
+
self.kwargs = kwargs
|
|
104
|
+
if kwargs.get("show_progress", False) and "tqdm" not in kwargs:
|
|
105
|
+
kwargs["tqdm"] = tqdm()
|
|
106
|
+
self.tqdm = kwargs.pop("tqdm", None)
|
|
107
|
+
|
|
108
|
+
self.inference_settings = {
|
|
109
|
+
k: v for k, v in kwargs.items() if isinstance(v, (str, int, float))
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# extra input args
|
|
113
|
+
image_ids = self._set_var_from_kwargs("image_ids", kwargs, image_id)
|
|
114
|
+
video_ids = self._set_var_from_kwargs("video_ids", kwargs, video_id)
|
|
115
|
+
dataset_ids = self._set_var_from_kwargs("dataset_ids", kwargs, dataset_id)
|
|
116
|
+
project_ids = self._set_var_from_kwargs("project_ids", kwargs, project_id)
|
|
117
|
+
source = next(
|
|
118
|
+
x
|
|
119
|
+
for x in [
|
|
120
|
+
input,
|
|
121
|
+
image_id,
|
|
122
|
+
video_id,
|
|
123
|
+
dataset_id,
|
|
124
|
+
project_id,
|
|
125
|
+
image_ids,
|
|
126
|
+
video_ids,
|
|
127
|
+
dataset_ids,
|
|
128
|
+
project_ids,
|
|
129
|
+
]
|
|
130
|
+
if x is not None
|
|
131
|
+
)
|
|
132
|
+
self.kwargs["source"] = source
|
|
133
|
+
self.prediction_kwargs_iterator = value_generator({})
|
|
134
|
+
|
|
135
|
+
if not isinstance(input, list):
|
|
136
|
+
input = [input]
|
|
137
|
+
if isinstance(input[0], np.ndarray):
|
|
138
|
+
# input is numpy array
|
|
139
|
+
kwargs = get_valid_kwargs(kwargs, self._predict_images, exclude=["images"])
|
|
140
|
+
self._predict_images(input, **kwargs)
|
|
141
|
+
elif isinstance(input[0], (str, PathLike)):
|
|
142
|
+
if len(input) > 1:
|
|
143
|
+
# if the input is a list of paths, assume they are images
|
|
144
|
+
for x in input:
|
|
145
|
+
if not isinstance(x, (str, PathLike)):
|
|
146
|
+
raise ValueError("Input must be a list of strings or PathLike objects.")
|
|
147
|
+
self._iterator = self._predict_images_bytes(input, **kwargs)
|
|
148
|
+
else:
|
|
149
|
+
if dir_exists(input[0]):
|
|
150
|
+
try:
|
|
151
|
+
project = Project(str(input[0]), mode=OpenMode.READ)
|
|
152
|
+
except Exception:
|
|
153
|
+
project = None
|
|
154
|
+
image_paths = []
|
|
155
|
+
if project is not None:
|
|
156
|
+
for dataset in project.datasets:
|
|
157
|
+
dataset: Dataset
|
|
158
|
+
for _, image_path, _ in dataset.items():
|
|
159
|
+
image_paths.append(image_path)
|
|
160
|
+
else:
|
|
161
|
+
# if the input is a directory, assume it contains images
|
|
162
|
+
recursive = kwargs.get("recursive", False)
|
|
163
|
+
if recursive:
|
|
164
|
+
image_paths = list_files_recursively(
|
|
165
|
+
input[0], valid_extensions=SUPPORTED_IMG_EXTS
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
image_paths = list_files(input[0], valid_extensions=SUPPORTED_IMG_EXTS)
|
|
169
|
+
if len(image_paths) == 0:
|
|
170
|
+
raise ValueError("Directory is empty.")
|
|
171
|
+
self._iterator = self._predict_images(image_paths, **kwargs)
|
|
172
|
+
elif file_exists(input[0]):
|
|
173
|
+
ext = get_file_ext(input[0])
|
|
174
|
+
if ext == "":
|
|
175
|
+
raise ValueError("File has no extension.")
|
|
176
|
+
if ext in SUPPORTED_IMG_EXTS:
|
|
177
|
+
self._iterator = self._predict_images(input, **kwargs)
|
|
178
|
+
elif ext in ALLOWED_VIDEO_EXTENSIONS:
|
|
179
|
+
kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
|
|
180
|
+
self._iterator = self._predict_videos(input, **kwargs)
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Unsupported file extension: {ext}. Supported extensions are: {SUPPORTED_IMG_EXTS + ALLOWED_VIDEO_EXTENSIONS}"
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
raise ValueError(f"File or directory does not exist: {input[0]}")
|
|
187
|
+
elif image_ids is not None:
|
|
188
|
+
self._iterator = self._predict_images(image_ids, **kwargs)
|
|
189
|
+
elif video_ids is not None:
|
|
190
|
+
if len(video_ids) > 1:
|
|
191
|
+
raise ValueError("Only one video id can be provided.")
|
|
192
|
+
kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
|
|
193
|
+
self._iterator = self._predict_videos(video_ids, **kwargs)
|
|
194
|
+
elif dataset_ids is not None:
|
|
195
|
+
kwargs = get_valid_kwargs(
|
|
196
|
+
kwargs,
|
|
197
|
+
self._predict_datasets,
|
|
198
|
+
exclude=["dataset_ids"],
|
|
199
|
+
)
|
|
200
|
+
self._iterator = self._predict_datasets(dataset_ids, **kwargs)
|
|
201
|
+
elif project_ids is not None:
|
|
202
|
+
if len(project_ids) > 1:
|
|
203
|
+
raise ValueError("Only one project id can be provided.")
|
|
204
|
+
kwargs = get_valid_kwargs(
|
|
205
|
+
kwargs,
|
|
206
|
+
self._predict_projects,
|
|
207
|
+
exclude=["project_ids"],
|
|
208
|
+
)
|
|
209
|
+
self._iterator = self._predict_projects(project_ids, **kwargs)
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
"Unknown input type. Supported types are: numpy array, path to a file or directory, ImageInfo, VideoInfo, ProjectInfo, DatasetInfo."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def _set_var_from_kwargs(self, key, kwargs, default):
|
|
216
|
+
value = kwargs.get(key, default)
|
|
217
|
+
if value is None:
|
|
218
|
+
return None
|
|
219
|
+
if not isinstance(value, list):
|
|
220
|
+
value = [value]
|
|
221
|
+
return value
|
|
222
|
+
|
|
223
|
+
def __enter__(self):
|
|
224
|
+
return self
|
|
225
|
+
|
|
226
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
227
|
+
self.stop()
|
|
228
|
+
if exc_type is not None:
|
|
229
|
+
return False
|
|
230
|
+
|
|
231
|
+
def __next__(self):
|
|
232
|
+
try:
|
|
233
|
+
prediction_json = self._iterator.__next__()
|
|
234
|
+
this_kwargs = next(self.prediction_kwargs_iterator)
|
|
235
|
+
prediction = Prediction.from_json(
|
|
236
|
+
prediction_json, **self.kwargs, **this_kwargs, model_meta=self.model_meta
|
|
237
|
+
)
|
|
238
|
+
return prediction
|
|
239
|
+
except StopIteration:
|
|
240
|
+
self._on_infernce_end()
|
|
241
|
+
raise
|
|
242
|
+
except Exception:
|
|
243
|
+
self.stop()
|
|
244
|
+
raise
|
|
245
|
+
|
|
246
|
+
def next(self):
|
|
247
|
+
return self.__next__()
|
|
248
|
+
|
|
249
|
+
def __iter__(self):
|
|
250
|
+
return self
|
|
251
|
+
|
|
252
|
+
def __len__(self):
|
|
253
|
+
return len(self._iterator)
|
|
254
|
+
|
|
255
|
+
def _get_api_token(self):
|
|
256
|
+
if self.api is not None:
|
|
257
|
+
return self.api.token
|
|
258
|
+
return env.api_token(raise_not_found=False)
|
|
259
|
+
|
|
260
|
+
def _get_json_body(self):
|
|
261
|
+
body = {"state": {}, "context": {}}
|
|
262
|
+
if self.inference_request_uuid is not None:
|
|
263
|
+
body["state"]["inference_request_uuid"] = self.inference_request_uuid
|
|
264
|
+
if self.inference_settings:
|
|
265
|
+
body["state"]["settings"] = self.inference_settings
|
|
266
|
+
if self.api_token is not None:
|
|
267
|
+
body["api_token"] = self.api_token
|
|
268
|
+
return body
|
|
269
|
+
|
|
270
|
+
def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
|
|
271
|
+
if kwargs.get("headers") is None:
|
|
272
|
+
kwargs["headers"] = {}
|
|
273
|
+
if self.api is not None:
|
|
274
|
+
retries = min(self.api.retry_count, retries)
|
|
275
|
+
if "x-api-key" not in kwargs["headers"]:
|
|
276
|
+
kwargs["headers"]["x-api-key"] = self.api.token
|
|
277
|
+
url = self._base_url.rstrip("/") + "/" + method.lstrip("/")
|
|
278
|
+
if "timeout" not in kwargs:
|
|
279
|
+
kwargs["timeout"] = 60
|
|
280
|
+
for retry_idx in range(retries):
|
|
281
|
+
response = None
|
|
282
|
+
try:
|
|
283
|
+
logger.trace(f"POST {url}")
|
|
284
|
+
response = requests.post(url, *args, **kwargs)
|
|
285
|
+
if response.status_code != requests.codes.ok: # pylint: disable=no-member
|
|
286
|
+
Api._raise_for_status(response)
|
|
287
|
+
return response
|
|
288
|
+
except requests.RequestException as exc:
|
|
289
|
+
process_requests_exception(
|
|
290
|
+
logger,
|
|
291
|
+
exc,
|
|
292
|
+
method,
|
|
293
|
+
url,
|
|
294
|
+
verbose=True,
|
|
295
|
+
swallow_exc=True,
|
|
296
|
+
sleep_sec=5,
|
|
297
|
+
response=response,
|
|
298
|
+
retry_info={"retry_idx": retry_idx + 1, "retry_limit": retries},
|
|
299
|
+
)
|
|
300
|
+
if retry_idx + 1 == retries:
|
|
301
|
+
raise exc
|
|
302
|
+
|
|
303
|
+
def _get_inference_progress(self):
|
|
304
|
+
method = "get_inference_progress"
|
|
305
|
+
r = self._post(method, json=self._get_json_body())
|
|
306
|
+
return r.json()
|
|
307
|
+
|
|
308
|
+
def _get_inference_status(self):
|
|
309
|
+
method = "get_inference_status"
|
|
310
|
+
r = self._post(method, json=self._get_json_body())
|
|
311
|
+
return r.json()
|
|
312
|
+
|
|
313
|
+
def _stop_async_inference(self):
|
|
314
|
+
method = "stop_inference"
|
|
315
|
+
r = self._post(
|
|
316
|
+
method,
|
|
317
|
+
json=self._get_json_body(),
|
|
318
|
+
)
|
|
319
|
+
logger.info("Inference will be stopped on the server")
|
|
320
|
+
return r.json()
|
|
321
|
+
|
|
322
|
+
def _clear_inference_request(self):
|
|
323
|
+
method = "clear_inference_request"
|
|
324
|
+
r = self._post(
|
|
325
|
+
method,
|
|
326
|
+
json=self._get_json_body(),
|
|
327
|
+
)
|
|
328
|
+
logger.info("Inference request will be cleared on the server")
|
|
329
|
+
return r.json()
|
|
330
|
+
|
|
331
|
+
def _on_infernce_end(self):
|
|
332
|
+
if self.inference_request_uuid is None:
|
|
333
|
+
return
|
|
334
|
+
self._clear_inference_request()
|
|
335
|
+
|
|
336
|
+
@property
|
|
337
|
+
def model_meta(self) -> ProjectMeta:
|
|
338
|
+
if self._model_meta is None:
|
|
339
|
+
self._model_meta = ProjectMeta.from_json(
|
|
340
|
+
self._post("get_model_meta", json=self._get_json_body()).json()
|
|
341
|
+
)
|
|
342
|
+
return self._model_meta
|
|
343
|
+
|
|
344
|
+
def stop(self):
|
|
345
|
+
if self.inference_request_uuid is None:
|
|
346
|
+
logger.debug("No active inference request to stop.")
|
|
347
|
+
return
|
|
348
|
+
self._stop_async_inference()
|
|
349
|
+
self._on_infernce_end()
|
|
350
|
+
|
|
351
|
+
def is_done(self):
|
|
352
|
+
if self.inference_request_uuid is None:
|
|
353
|
+
raise RuntimeError(
|
|
354
|
+
"Inference is not started. Please start inference before checking the status."
|
|
355
|
+
)
|
|
356
|
+
return not self._get_inference_progress()["is_inferring"]
|
|
357
|
+
|
|
358
|
+
def progress(self):
|
|
359
|
+
if self.inference_request_uuid is None:
|
|
360
|
+
raise RuntimeError(
|
|
361
|
+
"Inference is not started. Please start inference before checking the status."
|
|
362
|
+
)
|
|
363
|
+
return self._get_inference_progress()["progress"]
|
|
364
|
+
|
|
365
|
+
def status(self):
|
|
366
|
+
if self.inference_request_uuid is None:
|
|
367
|
+
raise RuntimeError(
|
|
368
|
+
"Inference is not started. Please start inference before checking the status."
|
|
369
|
+
)
|
|
370
|
+
return self._get_inference_status()
|
|
371
|
+
|
|
372
|
+
def _pop_pending_results(self) -> Dict[str, Any]:
|
|
373
|
+
method = "pop_inference_results"
|
|
374
|
+
json_body = self._get_json_body()
|
|
375
|
+
return self._post(method, json=json_body).json()
|
|
376
|
+
|
|
377
|
+
def _update_progress(self, tqdm: tqdm, response: Dict[str, Any]):
|
|
378
|
+
if tqdm is None:
|
|
379
|
+
return
|
|
380
|
+
json_progress = response.get("progress", None)
|
|
381
|
+
if json_progress is None or json_progress.get("message") is None:
|
|
382
|
+
json_progress = response.get("preparing_progress", None)
|
|
383
|
+
if json_progress is None:
|
|
384
|
+
return
|
|
385
|
+
refresh = False
|
|
386
|
+
message = json_progress.get("message", json_progress.get("status", None))
|
|
387
|
+
if message is not None and tqdm.desc not in [message, f"{message}:"]:
|
|
388
|
+
tqdm.set_description(message, refresh=False)
|
|
389
|
+
refresh = True
|
|
390
|
+
current = json_progress.get("current", None)
|
|
391
|
+
if current is not None and tqdm.n != current:
|
|
392
|
+
tqdm.n = current
|
|
393
|
+
refresh = True
|
|
394
|
+
total = json_progress.get("total", None)
|
|
395
|
+
if total is not None and tqdm.total != total:
|
|
396
|
+
tqdm.total = total
|
|
397
|
+
refresh = True
|
|
398
|
+
is_size = json_progress.get("is_size", False)
|
|
399
|
+
if is_size and tqdm.unit == "it":
|
|
400
|
+
tqdm.unit = "iB"
|
|
401
|
+
tqdm.unit_scale = True
|
|
402
|
+
tqdm.unit_divisor = 1024
|
|
403
|
+
refresh = True
|
|
404
|
+
if not is_size and tqdm.unit == "iB":
|
|
405
|
+
tqdm.unit = "it"
|
|
406
|
+
tqdm.unit_scale = False
|
|
407
|
+
tqdm.unit_divisor = 1
|
|
408
|
+
refresh = True
|
|
409
|
+
if refresh:
|
|
410
|
+
tqdm.refresh()
|
|
411
|
+
|
|
412
|
+
def _wait_for_inference_start(
|
|
413
|
+
self, delay=1, timeout=None, tqdm: tqdm = None
|
|
414
|
+
) -> Tuple[dict, bool]:
|
|
415
|
+
has_started = False
|
|
416
|
+
timeout_exceeded = False
|
|
417
|
+
t0 = time.time()
|
|
418
|
+
last_stage = None
|
|
419
|
+
while not has_started and not timeout_exceeded:
|
|
420
|
+
resp = self._get_inference_progress()
|
|
421
|
+
stage = resp.get("stage")
|
|
422
|
+
if stage != last_stage:
|
|
423
|
+
logger.info(stage)
|
|
424
|
+
last_stage = stage
|
|
425
|
+
has_started = stage not in ["preparing", "preprocess", None]
|
|
426
|
+
has_started = has_started or bool(resp.get("result")) or resp["progress"]["total"] != 1
|
|
427
|
+
self._update_progress(tqdm, resp)
|
|
428
|
+
if not has_started:
|
|
429
|
+
time.sleep(delay)
|
|
430
|
+
timeout_exceeded = timeout and time.time() - t0 > timeout
|
|
431
|
+
if timeout_exceeded:
|
|
432
|
+
self.stop()
|
|
433
|
+
raise Timeout("Timeout exceeded. The server didn't start the inference")
|
|
434
|
+
return resp, has_started
|
|
435
|
+
|
|
436
|
+
def _wait_for_pending_results(self, delay=1, timeout=600, tqdm: tqdm = None) -> List[dict]:
|
|
437
|
+
logger.debug("waiting pending results...")
|
|
438
|
+
has_results = False
|
|
439
|
+
timeout_exceeded = False
|
|
440
|
+
t0 = time.monotonic()
|
|
441
|
+
while not has_results and not timeout_exceeded:
|
|
442
|
+
resp = self._pop_pending_results()
|
|
443
|
+
self._update_progress(tqdm, resp)
|
|
444
|
+
pending_results = resp["pending_results"]
|
|
445
|
+
exception_json = resp["exception"]
|
|
446
|
+
if exception_json:
|
|
447
|
+
exception_str = f"{exception_json['type']}: {exception_json['message']}"
|
|
448
|
+
raise RuntimeError(f"Inference Error: {exception_str}")
|
|
449
|
+
has_results = bool(pending_results)
|
|
450
|
+
if resp.get("finished", False):
|
|
451
|
+
break
|
|
452
|
+
if not has_results:
|
|
453
|
+
time.sleep(delay)
|
|
454
|
+
timeout_exceeded = timeout and time.monotonic() - t0 > timeout
|
|
455
|
+
if timeout_exceeded:
|
|
456
|
+
self.stop()
|
|
457
|
+
raise Timeout("Timeout exceeded. Pending results not received from the server.")
|
|
458
|
+
return pending_results
|
|
459
|
+
|
|
460
|
+
def _start_inference(self, method, **kwargs):
|
|
461
|
+
if self.inference_request_uuid:
|
|
462
|
+
raise RuntimeError(
|
|
463
|
+
"Inference is already running. Please stop it before starting a new one."
|
|
464
|
+
)
|
|
465
|
+
resp = self._post(method, **kwargs).json()
|
|
466
|
+
|
|
467
|
+
self.inference_request_uuid = resp["inference_request_uuid"]
|
|
468
|
+
|
|
469
|
+
logger.info(
|
|
470
|
+
"Inference has started:",
|
|
471
|
+
extra={"inference_request_uuid": resp.get("inference_request_uuid")},
|
|
472
|
+
)
|
|
473
|
+
try:
|
|
474
|
+
resp, has_started = self._wait_for_inference_start(tqdm=self.tqdm)
|
|
475
|
+
except:
|
|
476
|
+
self.stop()
|
|
477
|
+
raise
|
|
478
|
+
frame_iterator = self.Iterator(resp["progress"]["total"], self, tqdm=self.tqdm)
|
|
479
|
+
return frame_iterator
|
|
480
|
+
|
|
481
|
+
def _predict_images(self, images: List, **kwargs: dict):
|
|
482
|
+
if isinstance(images[0], bytes):
|
|
483
|
+
f = self._predict_images_bytes
|
|
484
|
+
elif isinstance(images[0], (str, PathLike)):
|
|
485
|
+
f = self._predict_images_paths
|
|
486
|
+
elif isinstance(images[0], np.ndarray):
|
|
487
|
+
f = self._predict_images_nps
|
|
488
|
+
elif isinstance(images[0], int):
|
|
489
|
+
f = self._predict_images_ids
|
|
490
|
+
else:
|
|
491
|
+
raise ValueError(f"Unsupported input type '{type(images[0])}'.")
|
|
492
|
+
kwargs = get_valid_kwargs(kwargs, f, exclude=["images"])
|
|
493
|
+
return f(images, **kwargs)
|
|
494
|
+
|
|
495
|
+
def _predict_images_bytes(self, images: List[bytes], batch_size: int = None):
|
|
496
|
+
files = [
|
|
497
|
+
("files", (f"image_{i}.png", image, "image/png")) for i, image in enumerate(images)
|
|
498
|
+
]
|
|
499
|
+
state = self._get_json_body()["state"]
|
|
500
|
+
if batch_size is not None:
|
|
501
|
+
state["batch_size"] = batch_size
|
|
502
|
+
method = "inference_batch_async"
|
|
503
|
+
uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
|
|
504
|
+
return self._start_inference(method, files=uploads)
|
|
505
|
+
|
|
506
|
+
def _predict_images_paths(self, images: List, batch_size: int = None):
|
|
507
|
+
files = []
|
|
508
|
+
try:
|
|
509
|
+
files = [("files", open(f, "rb")) for f in images]
|
|
510
|
+
state = self._get_json_body()["state"]
|
|
511
|
+
if batch_size is not None:
|
|
512
|
+
state["batch_size"] = batch_size
|
|
513
|
+
method = "inference_batch_async"
|
|
514
|
+
uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
|
|
515
|
+
return self._start_inference(method, files=uploads)
|
|
516
|
+
finally:
|
|
517
|
+
for _, f in files:
|
|
518
|
+
f.close()
|
|
519
|
+
|
|
520
|
+
def _predict_images_nps(self, images: List, batch_size: int = None):
|
|
521
|
+
images = [write_bytes(image, ".png") for image in images]
|
|
522
|
+
return self._predict_images_bytes(images, batch_size=batch_size)
|
|
523
|
+
|
|
524
|
+
def _predict_images_ids(
|
|
525
|
+
self, images: List[int], batch_size: int = None, upload_mode: str = None
|
|
526
|
+
):
|
|
527
|
+
method = "inference_batch_ids_async"
|
|
528
|
+
json_body = self._get_json_body()
|
|
529
|
+
state = json_body["state"]
|
|
530
|
+
state["images_ids"] = images
|
|
531
|
+
if batch_size is not None:
|
|
532
|
+
state["batch_size"] = batch_size
|
|
533
|
+
if upload_mode is not None:
|
|
534
|
+
state["upload_mode"] = upload_mode
|
|
535
|
+
return self._start_inference(method, json=json_body)
|
|
536
|
+
|
|
537
|
+
def _predict_videos(
|
|
538
|
+
self,
|
|
539
|
+
videos: Union[List[int], List[str], List[PathLike]],
|
|
540
|
+
start_frame: int = None,
|
|
541
|
+
num_frames: int = None,
|
|
542
|
+
stride=None,
|
|
543
|
+
end_frame=None,
|
|
544
|
+
duration=None,
|
|
545
|
+
direction: Literal["forward", "backward"] = None,
|
|
546
|
+
tracker: Literal["bot", "deepsort"] = None,
|
|
547
|
+
batch_size: int = None,
|
|
548
|
+
):
|
|
549
|
+
if len(videos) != 1:
|
|
550
|
+
raise ValueError("Only one video can be processed at a time.")
|
|
551
|
+
json_body = self._get_json_body()
|
|
552
|
+
state = json_body["state"]
|
|
553
|
+
for key, value in (
|
|
554
|
+
("start_frame", start_frame),
|
|
555
|
+
("num_frames", num_frames),
|
|
556
|
+
("stride", stride),
|
|
557
|
+
("end_frame", end_frame),
|
|
558
|
+
("duration", duration),
|
|
559
|
+
("direction", direction),
|
|
560
|
+
("tracker", tracker),
|
|
561
|
+
("batch_size", batch_size),
|
|
562
|
+
):
|
|
563
|
+
if value is not None:
|
|
564
|
+
state[key] = value
|
|
565
|
+
if isinstance(videos[0], int):
|
|
566
|
+
method = "inference_video_id_async"
|
|
567
|
+
state["video_id"] = videos[0]
|
|
568
|
+
return self._start_inference(method, json=json_body)
|
|
569
|
+
elif isinstance(videos[0], (str, PathLike)):
|
|
570
|
+
files = []
|
|
571
|
+
try:
|
|
572
|
+
method = "inference_video_async"
|
|
573
|
+
files = [("files", open(videos[0], "rb"))]
|
|
574
|
+
uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
|
|
575
|
+
return self._start_inference(method, files=uploads)
|
|
576
|
+
finally:
|
|
577
|
+
for _, f in files:
|
|
578
|
+
f.close()
|
|
579
|
+
else:
|
|
580
|
+
raise ValueError(
|
|
581
|
+
f"Unsupported input type '{type(videos[0])}'. Supported types are: int, str, PathLike."
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
def _predict_projects(
|
|
585
|
+
self,
|
|
586
|
+
project_ids: List[int],
|
|
587
|
+
dataset_ids: List[int] = None,
|
|
588
|
+
batch_size: int = None,
|
|
589
|
+
upload_mode: str = None,
|
|
590
|
+
iou_merge_threshold: float = None,
|
|
591
|
+
cache_project_on_model: bool = None,
|
|
592
|
+
):
|
|
593
|
+
if len(project_ids) != 1:
|
|
594
|
+
raise ValueError("Only one project can be processed at a time.")
|
|
595
|
+
method = "inference_project_id_async"
|
|
596
|
+
json_body = self._get_json_body()
|
|
597
|
+
state = json_body["state"]
|
|
598
|
+
state["project_id"] = project_ids[0]
|
|
599
|
+
if dataset_ids is not None:
|
|
600
|
+
state["dataset_ids"] = dataset_ids
|
|
601
|
+
if batch_size is not None:
|
|
602
|
+
state["batch_size"] = batch_size
|
|
603
|
+
if upload_mode is not None:
|
|
604
|
+
state["upload_mode"] = upload_mode
|
|
605
|
+
if iou_merge_threshold is not None:
|
|
606
|
+
state["iou_merge_threshold"] = iou_merge_threshold
|
|
607
|
+
if cache_project_on_model is not None:
|
|
608
|
+
state["cache_project_on_model"] = cache_project_on_model
|
|
609
|
+
|
|
610
|
+
return self._start_inference(method, json=json_body)
|
|
611
|
+
|
|
612
|
+
def _predict_datasets(
|
|
613
|
+
self,
|
|
614
|
+
dataset_ids: List[int],
|
|
615
|
+
batch_size: int = None,
|
|
616
|
+
upload_mode: str = None,
|
|
617
|
+
iou_merge_threshold: float = None,
|
|
618
|
+
cache_datasets_on_model: bool = None,
|
|
619
|
+
):
|
|
620
|
+
if self.api is None:
|
|
621
|
+
raise ValueError("Api is required to use this method.")
|
|
622
|
+
dataset_infos = [self.api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids]
|
|
623
|
+
if len(set([info.project_id for info in dataset_infos])) > 1:
|
|
624
|
+
raise ValueError("All datasets must belong to the same project.")
|
|
625
|
+
return self._predict_projects(
|
|
626
|
+
[dataset_infos[0].project_id],
|
|
627
|
+
dataset_ids=dataset_ids,
|
|
628
|
+
batch_size=batch_size,
|
|
629
|
+
upload_mode=upload_mode,
|
|
630
|
+
iou_merge_threshold=iou_merge_threshold,
|
|
631
|
+
cache_project_on_model=cache_datasets_on_model,
|
|
632
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from supervisely.nn.tracking.tracking import track
|