supervisely 6.73.356__py3-none-any.whl → 6.73.358__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. supervisely/_utils.py +12 -0
  2. supervisely/api/annotation_api.py +3 -0
  3. supervisely/api/api.py +2 -2
  4. supervisely/api/app_api.py +27 -2
  5. supervisely/api/entity_annotation/tag_api.py +0 -1
  6. supervisely/api/labeling_job_api.py +4 -1
  7. supervisely/api/nn/__init__.py +0 -0
  8. supervisely/api/nn/deploy_api.py +821 -0
  9. supervisely/api/nn/neural_network_api.py +248 -0
  10. supervisely/api/task_api.py +26 -467
  11. supervisely/app/fastapi/subapp.py +1 -0
  12. supervisely/nn/__init__.py +2 -1
  13. supervisely/nn/artifacts/artifacts.py +5 -5
  14. supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
  15. supervisely/nn/experiments.py +28 -5
  16. supervisely/nn/inference/cache.py +178 -114
  17. supervisely/nn/inference/gui/gui.py +18 -35
  18. supervisely/nn/inference/gui/serving_gui.py +3 -1
  19. supervisely/nn/inference/inference.py +1421 -1265
  20. supervisely/nn/inference/inference_request.py +412 -0
  21. supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
  22. supervisely/nn/inference/session.py +2 -2
  23. supervisely/nn/inference/tracking/base_tracking.py +45 -79
  24. supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
  25. supervisely/nn/inference/tracking/mask_tracking.py +274 -250
  26. supervisely/nn/inference/tracking/tracker_interface.py +23 -0
  27. supervisely/nn/inference/uploader.py +164 -0
  28. supervisely/nn/model/__init__.py +0 -0
  29. supervisely/nn/model/model_api.py +259 -0
  30. supervisely/nn/model/prediction.py +311 -0
  31. supervisely/nn/model/prediction_session.py +632 -0
  32. supervisely/nn/tracking/__init__.py +1 -0
  33. supervisely/nn/tracking/boxmot.py +114 -0
  34. supervisely/nn/tracking/tracking.py +24 -0
  35. supervisely/nn/training/train_app.py +61 -19
  36. supervisely/nn/utils.py +43 -3
  37. supervisely/task/progress.py +12 -2
  38. supervisely/video/video.py +107 -1
  39. supervisely/volume_annotation/volume_figure.py +8 -2
  40. {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/METADATA +2 -1
  41. {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/RECORD +45 -34
  42. supervisely/api/neural_network_api.py +0 -202
  43. {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/LICENSE +0 -0
  44. {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/WHEEL +0 -0
  45. {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/entry_points.txt +0 -0
  46. {supervisely-6.73.356.dist-info → supervisely-6.73.358.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,632 @@
1
+ import json
2
+ import time
3
+ from os import PathLike
4
+ from typing import Any, Dict, Iterator, List, Literal, Tuple, Union
5
+
6
+ import numpy as np
7
+ import requests
8
+ from requests import Timeout
9
+ from tqdm.auto import tqdm
10
+
11
+ import supervisely.io.env as env
12
+ from supervisely._utils import get_valid_kwargs, logger
13
+ from supervisely.api.api import Api
14
+ from supervisely.imaging._video import ALLOWED_VIDEO_EXTENSIONS
15
+ from supervisely.imaging.image import SUPPORTED_IMG_EXTS, write_bytes
16
+ from supervisely.io.fs import (
17
+ dir_exists,
18
+ file_exists,
19
+ get_file_ext,
20
+ list_files,
21
+ list_files_recursively,
22
+ )
23
+ from supervisely.io.network_exceptions import process_requests_exception
24
+ from supervisely.nn.model.prediction import Prediction
25
+ from supervisely.project.project import Dataset, OpenMode, Project
26
+ from supervisely.project.project_meta import ProjectMeta
27
+
28
+
29
+ def value_generator(value):
30
+ while True:
31
+ yield value
32
+
33
+
34
+ class PredictionSession:
35
+
36
+ class Iterator:
37
+ def __init__(self, total, session: "PredictionSession", tqdm: tqdm = None):
38
+ self.total = total
39
+ self.session = session
40
+ self.results_queue = []
41
+ self.tqdm = tqdm
42
+
43
+ def __len__(self) -> int:
44
+ return self.total
45
+
46
+ def __iter__(self) -> Iterator:
47
+ return self
48
+
49
+ def __next__(self) -> Dict[str, Any]:
50
+ if not self.results_queue:
51
+ pending_results = self.session._wait_for_pending_results(tqdm=self.tqdm)
52
+ self.results_queue += pending_results
53
+ if not self.results_queue:
54
+ raise StopIteration
55
+ pred = self.results_queue.pop(0)
56
+ return pred
57
+
58
+ def __init__(
59
+ self,
60
+ url: str,
61
+ input: Union[np.ndarray, str, PathLike, list] = None,
62
+ image_id: Union[List[int], int] = None,
63
+ video_id: Union[List[int], int] = None,
64
+ dataset_id: Union[List[int], int] = None,
65
+ project_id: Union[List[int], int] = None,
66
+ api: "Api" = None,
67
+ **kwargs: dict,
68
+ ):
69
+ extra_input_args = ["image_ids", "video_ids", "dataset_ids", "project_ids"]
70
+ assert (
71
+ sum(
72
+ [
73
+ x is not None
74
+ for x in [
75
+ input,
76
+ image_id,
77
+ video_id,
78
+ dataset_id,
79
+ project_id,
80
+ *[kwargs.get(extra_input, None) for extra_input in extra_input_args],
81
+ ]
82
+ ]
83
+ )
84
+ == 1
85
+ ), "Exactly one of input, image_ids, video_id, dataset_id, project_id or image_id must be provided."
86
+
87
+ self._iterator = None
88
+ self._base_url = url
89
+ self.inference_request_uuid = None
90
+ self.input = input
91
+ self.api = api
92
+
93
+ self.api_token = self._get_api_token()
94
+ self._model_meta = None
95
+ self.final_result = None
96
+
97
+ if "stride" in kwargs:
98
+ kwargs["step"] = kwargs["stride"]
99
+ if "start_frame" in kwargs:
100
+ kwargs["start_frame_index"] = kwargs["start_frame"]
101
+ if "num_frames" in kwargs:
102
+ kwargs["frames_count"] = kwargs["num_frames"]
103
+ self.kwargs = kwargs
104
+ if kwargs.get("show_progress", False) and "tqdm" not in kwargs:
105
+ kwargs["tqdm"] = tqdm()
106
+ self.tqdm = kwargs.pop("tqdm", None)
107
+
108
+ self.inference_settings = {
109
+ k: v for k, v in kwargs.items() if isinstance(v, (str, int, float))
110
+ }
111
+
112
+ # extra input args
113
+ image_ids = self._set_var_from_kwargs("image_ids", kwargs, image_id)
114
+ video_ids = self._set_var_from_kwargs("video_ids", kwargs, video_id)
115
+ dataset_ids = self._set_var_from_kwargs("dataset_ids", kwargs, dataset_id)
116
+ project_ids = self._set_var_from_kwargs("project_ids", kwargs, project_id)
117
+ source = next(
118
+ x
119
+ for x in [
120
+ input,
121
+ image_id,
122
+ video_id,
123
+ dataset_id,
124
+ project_id,
125
+ image_ids,
126
+ video_ids,
127
+ dataset_ids,
128
+ project_ids,
129
+ ]
130
+ if x is not None
131
+ )
132
+ self.kwargs["source"] = source
133
+ self.prediction_kwargs_iterator = value_generator({})
134
+
135
+ if not isinstance(input, list):
136
+ input = [input]
137
+ if isinstance(input[0], np.ndarray):
138
+ # input is numpy array
139
+ kwargs = get_valid_kwargs(kwargs, self._predict_images, exclude=["images"])
140
+ self._predict_images(input, **kwargs)
141
+ elif isinstance(input[0], (str, PathLike)):
142
+ if len(input) > 1:
143
+ # if the input is a list of paths, assume they are images
144
+ for x in input:
145
+ if not isinstance(x, (str, PathLike)):
146
+ raise ValueError("Input must be a list of strings or PathLike objects.")
147
+ self._iterator = self._predict_images_bytes(input, **kwargs)
148
+ else:
149
+ if dir_exists(input[0]):
150
+ try:
151
+ project = Project(str(input[0]), mode=OpenMode.READ)
152
+ except Exception:
153
+ project = None
154
+ image_paths = []
155
+ if project is not None:
156
+ for dataset in project.datasets:
157
+ dataset: Dataset
158
+ for _, image_path, _ in dataset.items():
159
+ image_paths.append(image_path)
160
+ else:
161
+ # if the input is a directory, assume it contains images
162
+ recursive = kwargs.get("recursive", False)
163
+ if recursive:
164
+ image_paths = list_files_recursively(
165
+ input[0], valid_extensions=SUPPORTED_IMG_EXTS
166
+ )
167
+ else:
168
+ image_paths = list_files(input[0], valid_extensions=SUPPORTED_IMG_EXTS)
169
+ if len(image_paths) == 0:
170
+ raise ValueError("Directory is empty.")
171
+ self._iterator = self._predict_images(image_paths, **kwargs)
172
+ elif file_exists(input[0]):
173
+ ext = get_file_ext(input[0])
174
+ if ext == "":
175
+ raise ValueError("File has no extension.")
176
+ if ext in SUPPORTED_IMG_EXTS:
177
+ self._iterator = self._predict_images(input, **kwargs)
178
+ elif ext in ALLOWED_VIDEO_EXTENSIONS:
179
+ kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
180
+ self._iterator = self._predict_videos(input, **kwargs)
181
+ else:
182
+ raise ValueError(
183
+ f"Unsupported file extension: {ext}. Supported extensions are: {SUPPORTED_IMG_EXTS + ALLOWED_VIDEO_EXTENSIONS}"
184
+ )
185
+ else:
186
+ raise ValueError(f"File or directory does not exist: {input[0]}")
187
+ elif image_ids is not None:
188
+ self._iterator = self._predict_images(image_ids, **kwargs)
189
+ elif video_ids is not None:
190
+ if len(video_ids) > 1:
191
+ raise ValueError("Only one video id can be provided.")
192
+ kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
193
+ self._iterator = self._predict_videos(video_ids, **kwargs)
194
+ elif dataset_ids is not None:
195
+ kwargs = get_valid_kwargs(
196
+ kwargs,
197
+ self._predict_datasets,
198
+ exclude=["dataset_ids"],
199
+ )
200
+ self._iterator = self._predict_datasets(dataset_ids, **kwargs)
201
+ elif project_ids is not None:
202
+ if len(project_ids) > 1:
203
+ raise ValueError("Only one project id can be provided.")
204
+ kwargs = get_valid_kwargs(
205
+ kwargs,
206
+ self._predict_projects,
207
+ exclude=["project_ids"],
208
+ )
209
+ self._iterator = self._predict_projects(project_ids, **kwargs)
210
+ else:
211
+ raise ValueError(
212
+ "Unknown input type. Supported types are: numpy array, path to a file or directory, ImageInfo, VideoInfo, ProjectInfo, DatasetInfo."
213
+ )
214
+
215
+ def _set_var_from_kwargs(self, key, kwargs, default):
216
+ value = kwargs.get(key, default)
217
+ if value is None:
218
+ return None
219
+ if not isinstance(value, list):
220
+ value = [value]
221
+ return value
222
+
223
+ def __enter__(self):
224
+ return self
225
+
226
+ def __exit__(self, exc_type, exc_val, exc_tb):
227
+ self.stop()
228
+ if exc_type is not None:
229
+ return False
230
+
231
+ def __next__(self):
232
+ try:
233
+ prediction_json = self._iterator.__next__()
234
+ this_kwargs = next(self.prediction_kwargs_iterator)
235
+ prediction = Prediction.from_json(
236
+ prediction_json, **self.kwargs, **this_kwargs, model_meta=self.model_meta
237
+ )
238
+ return prediction
239
+ except StopIteration:
240
+ self._on_infernce_end()
241
+ raise
242
+ except Exception:
243
+ self.stop()
244
+ raise
245
+
246
+ def next(self):
247
+ return self.__next__()
248
+
249
+ def __iter__(self):
250
+ return self
251
+
252
+ def __len__(self):
253
+ return len(self._iterator)
254
+
255
+ def _get_api_token(self):
256
+ if self.api is not None:
257
+ return self.api.token
258
+ return env.api_token(raise_not_found=False)
259
+
260
+ def _get_json_body(self):
261
+ body = {"state": {}, "context": {}}
262
+ if self.inference_request_uuid is not None:
263
+ body["state"]["inference_request_uuid"] = self.inference_request_uuid
264
+ if self.inference_settings:
265
+ body["state"]["settings"] = self.inference_settings
266
+ if self.api_token is not None:
267
+ body["api_token"] = self.api_token
268
+ return body
269
+
270
+ def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
271
+ if kwargs.get("headers") is None:
272
+ kwargs["headers"] = {}
273
+ if self.api is not None:
274
+ retries = min(self.api.retry_count, retries)
275
+ if "x-api-key" not in kwargs["headers"]:
276
+ kwargs["headers"]["x-api-key"] = self.api.token
277
+ url = self._base_url.rstrip("/") + "/" + method.lstrip("/")
278
+ if "timeout" not in kwargs:
279
+ kwargs["timeout"] = 60
280
+ for retry_idx in range(retries):
281
+ response = None
282
+ try:
283
+ logger.trace(f"POST {url}")
284
+ response = requests.post(url, *args, **kwargs)
285
+ if response.status_code != requests.codes.ok: # pylint: disable=no-member
286
+ Api._raise_for_status(response)
287
+ return response
288
+ except requests.RequestException as exc:
289
+ process_requests_exception(
290
+ logger,
291
+ exc,
292
+ method,
293
+ url,
294
+ verbose=True,
295
+ swallow_exc=True,
296
+ sleep_sec=5,
297
+ response=response,
298
+ retry_info={"retry_idx": retry_idx + 1, "retry_limit": retries},
299
+ )
300
+ if retry_idx + 1 == retries:
301
+ raise exc
302
+
303
+ def _get_inference_progress(self):
304
+ method = "get_inference_progress"
305
+ r = self._post(method, json=self._get_json_body())
306
+ return r.json()
307
+
308
+ def _get_inference_status(self):
309
+ method = "get_inference_status"
310
+ r = self._post(method, json=self._get_json_body())
311
+ return r.json()
312
+
313
+ def _stop_async_inference(self):
314
+ method = "stop_inference"
315
+ r = self._post(
316
+ method,
317
+ json=self._get_json_body(),
318
+ )
319
+ logger.info("Inference will be stopped on the server")
320
+ return r.json()
321
+
322
+ def _clear_inference_request(self):
323
+ method = "clear_inference_request"
324
+ r = self._post(
325
+ method,
326
+ json=self._get_json_body(),
327
+ )
328
+ logger.info("Inference request will be cleared on the server")
329
+ return r.json()
330
+
331
+ def _on_infernce_end(self):
332
+ if self.inference_request_uuid is None:
333
+ return
334
+ self._clear_inference_request()
335
+
336
+ @property
337
+ def model_meta(self) -> ProjectMeta:
338
+ if self._model_meta is None:
339
+ self._model_meta = ProjectMeta.from_json(
340
+ self._post("get_model_meta", json=self._get_json_body()).json()
341
+ )
342
+ return self._model_meta
343
+
344
+ def stop(self):
345
+ if self.inference_request_uuid is None:
346
+ logger.debug("No active inference request to stop.")
347
+ return
348
+ self._stop_async_inference()
349
+ self._on_infernce_end()
350
+
351
+ def is_done(self):
352
+ if self.inference_request_uuid is None:
353
+ raise RuntimeError(
354
+ "Inference is not started. Please start inference before checking the status."
355
+ )
356
+ return not self._get_inference_progress()["is_inferring"]
357
+
358
+ def progress(self):
359
+ if self.inference_request_uuid is None:
360
+ raise RuntimeError(
361
+ "Inference is not started. Please start inference before checking the status."
362
+ )
363
+ return self._get_inference_progress()["progress"]
364
+
365
+ def status(self):
366
+ if self.inference_request_uuid is None:
367
+ raise RuntimeError(
368
+ "Inference is not started. Please start inference before checking the status."
369
+ )
370
+ return self._get_inference_status()
371
+
372
+ def _pop_pending_results(self) -> Dict[str, Any]:
373
+ method = "pop_inference_results"
374
+ json_body = self._get_json_body()
375
+ return self._post(method, json=json_body).json()
376
+
377
+ def _update_progress(self, tqdm: tqdm, response: Dict[str, Any]):
378
+ if tqdm is None:
379
+ return
380
+ json_progress = response.get("progress", None)
381
+ if json_progress is None or json_progress.get("message") is None:
382
+ json_progress = response.get("preparing_progress", None)
383
+ if json_progress is None:
384
+ return
385
+ refresh = False
386
+ message = json_progress.get("message", json_progress.get("status", None))
387
+ if message is not None and tqdm.desc not in [message, f"{message}:"]:
388
+ tqdm.set_description(message, refresh=False)
389
+ refresh = True
390
+ current = json_progress.get("current", None)
391
+ if current is not None and tqdm.n != current:
392
+ tqdm.n = current
393
+ refresh = True
394
+ total = json_progress.get("total", None)
395
+ if total is not None and tqdm.total != total:
396
+ tqdm.total = total
397
+ refresh = True
398
+ is_size = json_progress.get("is_size", False)
399
+ if is_size and tqdm.unit == "it":
400
+ tqdm.unit = "iB"
401
+ tqdm.unit_scale = True
402
+ tqdm.unit_divisor = 1024
403
+ refresh = True
404
+ if not is_size and tqdm.unit == "iB":
405
+ tqdm.unit = "it"
406
+ tqdm.unit_scale = False
407
+ tqdm.unit_divisor = 1
408
+ refresh = True
409
+ if refresh:
410
+ tqdm.refresh()
411
+
412
+ def _wait_for_inference_start(
413
+ self, delay=1, timeout=None, tqdm: tqdm = None
414
+ ) -> Tuple[dict, bool]:
415
+ has_started = False
416
+ timeout_exceeded = False
417
+ t0 = time.time()
418
+ last_stage = None
419
+ while not has_started and not timeout_exceeded:
420
+ resp = self._get_inference_progress()
421
+ stage = resp.get("stage")
422
+ if stage != last_stage:
423
+ logger.info(stage)
424
+ last_stage = stage
425
+ has_started = stage not in ["preparing", "preprocess", None]
426
+ has_started = has_started or bool(resp.get("result")) or resp["progress"]["total"] != 1
427
+ self._update_progress(tqdm, resp)
428
+ if not has_started:
429
+ time.sleep(delay)
430
+ timeout_exceeded = timeout and time.time() - t0 > timeout
431
+ if timeout_exceeded:
432
+ self.stop()
433
+ raise Timeout("Timeout exceeded. The server didn't start the inference")
434
+ return resp, has_started
435
+
436
+ def _wait_for_pending_results(self, delay=1, timeout=600, tqdm: tqdm = None) -> List[dict]:
437
+ logger.debug("waiting pending results...")
438
+ has_results = False
439
+ timeout_exceeded = False
440
+ t0 = time.monotonic()
441
+ while not has_results and not timeout_exceeded:
442
+ resp = self._pop_pending_results()
443
+ self._update_progress(tqdm, resp)
444
+ pending_results = resp["pending_results"]
445
+ exception_json = resp["exception"]
446
+ if exception_json:
447
+ exception_str = f"{exception_json['type']}: {exception_json['message']}"
448
+ raise RuntimeError(f"Inference Error: {exception_str}")
449
+ has_results = bool(pending_results)
450
+ if resp.get("finished", False):
451
+ break
452
+ if not has_results:
453
+ time.sleep(delay)
454
+ timeout_exceeded = timeout and time.monotonic() - t0 > timeout
455
+ if timeout_exceeded:
456
+ self.stop()
457
+ raise Timeout("Timeout exceeded. Pending results not received from the server.")
458
+ return pending_results
459
+
460
+ def _start_inference(self, method, **kwargs):
461
+ if self.inference_request_uuid:
462
+ raise RuntimeError(
463
+ "Inference is already running. Please stop it before starting a new one."
464
+ )
465
+ resp = self._post(method, **kwargs).json()
466
+
467
+ self.inference_request_uuid = resp["inference_request_uuid"]
468
+
469
+ logger.info(
470
+ "Inference has started:",
471
+ extra={"inference_request_uuid": resp.get("inference_request_uuid")},
472
+ )
473
+ try:
474
+ resp, has_started = self._wait_for_inference_start(tqdm=self.tqdm)
475
+ except:
476
+ self.stop()
477
+ raise
478
+ frame_iterator = self.Iterator(resp["progress"]["total"], self, tqdm=self.tqdm)
479
+ return frame_iterator
480
+
481
+ def _predict_images(self, images: List, **kwargs: dict):
482
+ if isinstance(images[0], bytes):
483
+ f = self._predict_images_bytes
484
+ elif isinstance(images[0], (str, PathLike)):
485
+ f = self._predict_images_paths
486
+ elif isinstance(images[0], np.ndarray):
487
+ f = self._predict_images_nps
488
+ elif isinstance(images[0], int):
489
+ f = self._predict_images_ids
490
+ else:
491
+ raise ValueError(f"Unsupported input type '{type(images[0])}'.")
492
+ kwargs = get_valid_kwargs(kwargs, f, exclude=["images"])
493
+ return f(images, **kwargs)
494
+
495
+ def _predict_images_bytes(self, images: List[bytes], batch_size: int = None):
496
+ files = [
497
+ ("files", (f"image_{i}.png", image, "image/png")) for i, image in enumerate(images)
498
+ ]
499
+ state = self._get_json_body()["state"]
500
+ if batch_size is not None:
501
+ state["batch_size"] = batch_size
502
+ method = "inference_batch_async"
503
+ uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
504
+ return self._start_inference(method, files=uploads)
505
+
506
+ def _predict_images_paths(self, images: List, batch_size: int = None):
507
+ files = []
508
+ try:
509
+ files = [("files", open(f, "rb")) for f in images]
510
+ state = self._get_json_body()["state"]
511
+ if batch_size is not None:
512
+ state["batch_size"] = batch_size
513
+ method = "inference_batch_async"
514
+ uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
515
+ return self._start_inference(method, files=uploads)
516
+ finally:
517
+ for _, f in files:
518
+ f.close()
519
+
520
+ def _predict_images_nps(self, images: List, batch_size: int = None):
521
+ images = [write_bytes(image, ".png") for image in images]
522
+ return self._predict_images_bytes(images, batch_size=batch_size)
523
+
524
+ def _predict_images_ids(
525
+ self, images: List[int], batch_size: int = None, upload_mode: str = None
526
+ ):
527
+ method = "inference_batch_ids_async"
528
+ json_body = self._get_json_body()
529
+ state = json_body["state"]
530
+ state["images_ids"] = images
531
+ if batch_size is not None:
532
+ state["batch_size"] = batch_size
533
+ if upload_mode is not None:
534
+ state["upload_mode"] = upload_mode
535
+ return self._start_inference(method, json=json_body)
536
+
537
+ def _predict_videos(
538
+ self,
539
+ videos: Union[List[int], List[str], List[PathLike]],
540
+ start_frame: int = None,
541
+ num_frames: int = None,
542
+ stride=None,
543
+ end_frame=None,
544
+ duration=None,
545
+ direction: Literal["forward", "backward"] = None,
546
+ tracker: Literal["bot", "deepsort"] = None,
547
+ batch_size: int = None,
548
+ ):
549
+ if len(videos) != 1:
550
+ raise ValueError("Only one video can be processed at a time.")
551
+ json_body = self._get_json_body()
552
+ state = json_body["state"]
553
+ for key, value in (
554
+ ("start_frame", start_frame),
555
+ ("num_frames", num_frames),
556
+ ("stride", stride),
557
+ ("end_frame", end_frame),
558
+ ("duration", duration),
559
+ ("direction", direction),
560
+ ("tracker", tracker),
561
+ ("batch_size", batch_size),
562
+ ):
563
+ if value is not None:
564
+ state[key] = value
565
+ if isinstance(videos[0], int):
566
+ method = "inference_video_id_async"
567
+ state["video_id"] = videos[0]
568
+ return self._start_inference(method, json=json_body)
569
+ elif isinstance(videos[0], (str, PathLike)):
570
+ files = []
571
+ try:
572
+ method = "inference_video_async"
573
+ files = [("files", open(videos[0], "rb"))]
574
+ uploads = files + [("state", (None, json.dumps(state), "text/plain"))]
575
+ return self._start_inference(method, files=uploads)
576
+ finally:
577
+ for _, f in files:
578
+ f.close()
579
+ else:
580
+ raise ValueError(
581
+ f"Unsupported input type '{type(videos[0])}'. Supported types are: int, str, PathLike."
582
+ )
583
+
584
+ def _predict_projects(
585
+ self,
586
+ project_ids: List[int],
587
+ dataset_ids: List[int] = None,
588
+ batch_size: int = None,
589
+ upload_mode: str = None,
590
+ iou_merge_threshold: float = None,
591
+ cache_project_on_model: bool = None,
592
+ ):
593
+ if len(project_ids) != 1:
594
+ raise ValueError("Only one project can be processed at a time.")
595
+ method = "inference_project_id_async"
596
+ json_body = self._get_json_body()
597
+ state = json_body["state"]
598
+ state["project_id"] = project_ids[0]
599
+ if dataset_ids is not None:
600
+ state["dataset_ids"] = dataset_ids
601
+ if batch_size is not None:
602
+ state["batch_size"] = batch_size
603
+ if upload_mode is not None:
604
+ state["upload_mode"] = upload_mode
605
+ if iou_merge_threshold is not None:
606
+ state["iou_merge_threshold"] = iou_merge_threshold
607
+ if cache_project_on_model is not None:
608
+ state["cache_project_on_model"] = cache_project_on_model
609
+
610
+ return self._start_inference(method, json=json_body)
611
+
612
+ def _predict_datasets(
613
+ self,
614
+ dataset_ids: List[int],
615
+ batch_size: int = None,
616
+ upload_mode: str = None,
617
+ iou_merge_threshold: float = None,
618
+ cache_datasets_on_model: bool = None,
619
+ ):
620
+ if self.api is None:
621
+ raise ValueError("Api is required to use this method.")
622
+ dataset_infos = [self.api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids]
623
+ if len(set([info.project_id for info in dataset_infos])) > 1:
624
+ raise ValueError("All datasets must belong to the same project.")
625
+ return self._predict_projects(
626
+ [dataset_infos[0].project_id],
627
+ dataset_ids=dataset_ids,
628
+ batch_size=batch_size,
629
+ upload_mode=upload_mode,
630
+ iou_merge_threshold=iou_merge_threshold,
631
+ cache_project_on_model=cache_datasets_on_model,
632
+ )
@@ -0,0 +1 @@
1
+ from supervisely.nn.tracking.tracking import track