supervisely 6.73.357__py3-none-any.whl → 6.73.358__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. supervisely/_utils.py +12 -0
  2. supervisely/api/annotation_api.py +3 -0
  3. supervisely/api/api.py +2 -2
  4. supervisely/api/app_api.py +27 -2
  5. supervisely/api/entity_annotation/tag_api.py +0 -1
  6. supervisely/api/nn/__init__.py +0 -0
  7. supervisely/api/nn/deploy_api.py +821 -0
  8. supervisely/api/nn/neural_network_api.py +248 -0
  9. supervisely/api/task_api.py +26 -467
  10. supervisely/app/fastapi/subapp.py +1 -0
  11. supervisely/nn/__init__.py +2 -1
  12. supervisely/nn/artifacts/artifacts.py +5 -5
  13. supervisely/nn/benchmark/object_detection/metric_provider.py +3 -0
  14. supervisely/nn/experiments.py +28 -5
  15. supervisely/nn/inference/cache.py +178 -114
  16. supervisely/nn/inference/gui/gui.py +18 -35
  17. supervisely/nn/inference/gui/serving_gui.py +3 -1
  18. supervisely/nn/inference/inference.py +1421 -1265
  19. supervisely/nn/inference/inference_request.py +412 -0
  20. supervisely/nn/inference/object_detection_3d/object_detection_3d.py +31 -24
  21. supervisely/nn/inference/session.py +2 -2
  22. supervisely/nn/inference/tracking/base_tracking.py +45 -79
  23. supervisely/nn/inference/tracking/bbox_tracking.py +220 -155
  24. supervisely/nn/inference/tracking/mask_tracking.py +274 -250
  25. supervisely/nn/inference/tracking/tracker_interface.py +23 -0
  26. supervisely/nn/inference/uploader.py +164 -0
  27. supervisely/nn/model/__init__.py +0 -0
  28. supervisely/nn/model/model_api.py +259 -0
  29. supervisely/nn/model/prediction.py +311 -0
  30. supervisely/nn/model/prediction_session.py +632 -0
  31. supervisely/nn/tracking/__init__.py +1 -0
  32. supervisely/nn/tracking/boxmot.py +114 -0
  33. supervisely/nn/tracking/tracking.py +24 -0
  34. supervisely/nn/training/train_app.py +61 -19
  35. supervisely/nn/utils.py +43 -3
  36. supervisely/task/progress.py +12 -2
  37. supervisely/video/video.py +107 -1
  38. {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/METADATA +2 -1
  39. {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/RECORD +43 -32
  40. supervisely/api/neural_network_api.py +0 -202
  41. {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/LICENSE +0 -0
  42. {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/WHEEL +0 -0
  43. {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/entry_points.txt +0 -0
  44. {supervisely-6.73.357.dist-info → supervisely-6.73.358.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,22 @@
1
+ from __future__ import annotations
2
+
1
3
  import argparse
2
4
  import asyncio
3
5
  import inspect
4
6
  import json
5
7
  import os
6
8
  import re
9
+ import shutil
7
10
  import subprocess
8
- import sys
11
+ import tempfile
9
12
  import threading
10
13
  import time
11
- import uuid
12
14
  from collections import OrderedDict, defaultdict
13
15
  from concurrent.futures import ThreadPoolExecutor
14
16
  from dataclasses import asdict, dataclass
15
17
  from functools import partial, wraps
16
18
  from pathlib import Path
17
- from queue import Queue
18
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
19
20
  from urllib.request import urlopen
20
21
 
21
22
  import numpy as np
@@ -25,6 +26,7 @@ import yaml
25
26
  from fastapi import Form, HTTPException, Request, Response, UploadFile, status
26
27
  from fastapi.responses import JSONResponse
27
28
  from requests.structures import CaseInsensitiveDict
29
+ from tqdm import tqdm
28
30
 
29
31
  import supervisely.app.development as sly_app_development
30
32
  import supervisely.imaging.image as sly_image
@@ -32,7 +34,7 @@ import supervisely.io.env as sly_env
32
34
  import supervisely.io.fs as sly_fs
33
35
  import supervisely.io.json as sly_json
34
36
  import supervisely.nn.inference.gui as GUI
35
- from supervisely import DatasetInfo, ProjectInfo, VideoAnnotation, batched
37
+ from supervisely import DatasetInfo, batched
36
38
  from supervisely._utils import (
37
39
  add_callback,
38
40
  get_filename_from_headers,
@@ -49,8 +51,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
49
51
  from supervisely.api.api import Api, ApiField
50
52
  from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
51
53
  from supervisely.api.image_api import ImageInfo
52
- from supervisely.app.content import StateJson, get_data_dir
53
- from supervisely.app.exceptions import DialogWindowError
54
+ from supervisely.app.content import get_data_dir
54
55
  from supervisely.app.fastapi.subapp import (
55
56
  Application,
56
57
  call_on_autostart,
@@ -68,7 +69,13 @@ from supervisely.geometry.any_geometry import AnyGeometry
68
69
  from supervisely.imaging.color import get_predefined_colors
69
70
  from supervisely.io.fs import list_files
70
71
  from supervisely.nn.inference.cache import InferenceImageCache
71
- from supervisely.nn.prediction_dto import Prediction
72
+ from supervisely.nn.inference.inference_request import (
73
+ InferenceRequest,
74
+ InferenceRequestsManager,
75
+ )
76
+ from supervisely.nn.inference.uploader import Uploader
77
+ from supervisely.nn.model.model_api import Prediction
78
+ from supervisely.nn.prediction_dto import Prediction as PredictionDTO
72
79
  from supervisely.nn.utils import (
73
80
  CheckpointInfo,
74
81
  DeployInfo,
@@ -76,13 +83,15 @@ from supervisely.nn.utils import (
76
83
  ModelSource,
77
84
  RuntimeType,
78
85
  _get_model_name,
86
+ get_gpu_usage,
87
+ get_ram_usage,
79
88
  )
80
89
  from supervisely.project import ProjectType
81
90
  from supervisely.project.download import download_to_cache, read_from_cached_project
82
91
  from supervisely.project.project_meta import ProjectMeta
83
92
  from supervisely.sly_logger import logger
84
93
  from supervisely.task.progress import Progress
85
- from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS
94
+ from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
86
95
 
87
96
  try:
88
97
  from typing import Literal
@@ -283,6 +292,8 @@ class Inference:
283
292
  log_progress=True,
284
293
  )
285
294
 
295
+ self.inference_requests_manager = InferenceRequestsManager(executor=self._executor)
296
+
286
297
  def get_batch_size(self):
287
298
  if self.max_batch_size is not None:
288
299
  return min(self.DEFAULT_BATCH_SIZE, self.max_batch_size)
@@ -595,10 +606,55 @@ class Inference:
595
606
  def _checkpoints_cache_dir(self):
596
607
  return os.path.join(os.path.expanduser("~"), ".cache", "supervisely", "checkpoints")
597
608
 
609
+ def _build_deploy_params_from_api(self, model_name: str, deploy_params: dict = None) -> dict:
610
+ if deploy_params is None:
611
+ deploy_params = {}
612
+ selected_model = None
613
+ for model in self.pretrained_models:
614
+ if model["meta"]["model_name"].lower() == model_name.lower():
615
+ selected_model = model
616
+ break
617
+ if selected_model is None:
618
+ raise ValueError(f"Model {model_name} not found in models.json of serving app")
619
+ deploy_params["model_files"] = selected_model["meta"]["model_files"]
620
+ deploy_params["model_info"] = selected_model
621
+ return deploy_params
622
+
623
+ def _build_legacy_deploy_params_from_api(self, model_name: str) -> dict:
624
+ selected_model = None
625
+ if hasattr(self, "pretrained_models_table"):
626
+ selected_model = self.pretrained_models_table.get_by_model_name(model_name)
627
+ if selected_model is None:
628
+ # @TODO: Improve error message
629
+ raise ValueError("This app doesn't support new deploy api")
630
+
631
+ self.pretrained_models_table.set_by_model_name(model_name)
632
+ deploy_params = self.pretrained_models_table.get_selected_model_params()
633
+ return deploy_params
634
+
635
+ # @TODO: method name should be better?
636
+ def _set_common_deploy_params(self, deploy_params: dict) -> dict:
637
+ load_model_params = inspect.signature(self.load_model).parameters
638
+ has_runtime_param = "runtime" in load_model_params
639
+
640
+ if has_runtime_param:
641
+ if deploy_params.get("runtime", None) is None:
642
+ deploy_params["runtime"] = RuntimeType.PYTORCH
643
+ if deploy_params.get("device", None) is None:
644
+ deploy_params["device"] = "cuda:0" if get_gpu_count() > 0 else "cpu"
645
+ return deploy_params
646
+
598
647
  def _download_model_files(self, deploy_params: dict, log_progress: bool = True) -> dict:
599
- if deploy_params["runtime"] != RuntimeType.PYTORCH:
600
- export = deploy_params["model_info"].get("export", {})
601
- if export is not None:
648
+ if deploy_params["model_source"] == ModelSource.PRETRAINED:
649
+ headless = self.gui is None
650
+ return self._download_pretrained_model(
651
+ deploy_params["model_files"], log_progress, headless
652
+ )
653
+ elif deploy_params["model_source"] == ModelSource.CUSTOM:
654
+ if deploy_params["runtime"] != RuntimeType.PYTORCH:
655
+ export = deploy_params["model_info"].get("export", {})
656
+ if export is None:
657
+ export = {}
602
658
  export_model = export.get(deploy_params["runtime"], None)
603
659
  if export_model is not None:
604
660
  if sly_fs.get_file_name(export_model) == sly_fs.get_file_name(
@@ -608,13 +664,11 @@ class Inference:
608
664
  deploy_params["model_info"]["artifacts_dir"] + export_model
609
665
  )
610
666
  logger.info(f"Found model checkpoint for '{deploy_params['runtime']}'")
611
-
612
- if deploy_params["model_source"] == ModelSource.PRETRAINED:
613
- return self._download_pretrained_model(deploy_params["model_files"], log_progress)
614
- elif deploy_params["model_source"] == ModelSource.CUSTOM:
615
667
  return self._download_custom_model(deploy_params["model_files"], log_progress)
616
668
 
617
- def _download_pretrained_model(self, model_files: dict, log_progress: bool = True):
669
+ def _download_pretrained_model(
670
+ self, model_files: dict, log_progress: bool = True, headless: bool = False
671
+ ):
618
672
  """
619
673
  Downloads the pretrained model data.
620
674
  """
@@ -642,26 +696,39 @@ class Inference:
642
696
  continue
643
697
 
644
698
  if log_progress:
645
- with self.gui.download_progress(
646
- message=f"Downloading: '{file_name}'",
647
- total=file_size,
648
- unit="bytes",
649
- unit_scale=True,
650
- ) as download_pbar:
651
- self.gui.download_progress.show()
652
- sly_fs.download(
653
- url=file_url,
654
- save_path=file_path,
655
- progress=download_pbar.update,
656
- )
699
+ if not headless:
700
+ with self.gui.download_progress(
701
+ message=f"Downloading: '{file_name}'",
702
+ total=file_size,
703
+ unit="bytes",
704
+ unit_scale=True,
705
+ ) as download_pbar:
706
+ self.gui.download_progress.show()
707
+ sly_fs.download(
708
+ url=file_url,
709
+ save_path=file_path,
710
+ progress=download_pbar.update,
711
+ )
712
+ else:
713
+ with tqdm(
714
+ total=file_size,
715
+ unit="bytes",
716
+ unit_scale=True,
717
+ ) as download_pbar:
718
+ logger.info(f"Downloading: '{file_name}'")
719
+ sly_fs.download(
720
+ url=file_url, save_path=file_path, progress=download_pbar.update
721
+ )
657
722
  else:
723
+ logger.info(f"Downloading: '{file_name}'")
658
724
  sly_fs.download(url=file_url, save_path=file_path)
659
725
  local_model_files[file] = file_path
660
726
  else:
661
727
  local_model_files[file] = file_url
662
728
 
663
729
  if log_progress:
664
- self.gui.download_progress.hide()
730
+ if self.gui is not None:
731
+ self.gui.download_progress.hide()
665
732
  return local_model_files
666
733
 
667
734
  def _download_custom_model(self, model_files: dict, log_progress: bool = True):
@@ -732,7 +799,7 @@ class Inference:
732
799
  self.gui.show_deployed_model_info(self)
733
800
 
734
801
  def load_custom_checkpoint(
735
- self, model_files: dict, model_meta: dict, device: str = "cuda", **kwargs
802
+ self, model_files: dict, model_meta: dict, device: Optional[str] = None, **kwargs
736
803
  ):
737
804
  """
738
805
  Loads local custom model checkpoint.
@@ -886,7 +953,8 @@ class Inference:
886
953
  classes = None
887
954
  try:
888
955
  classes = self.get_classes()
889
- num_classes = len(classes)
956
+ if classes is not None:
957
+ num_classes = len(classes)
890
958
  except NotImplementedError:
891
959
  logger.warn(f"get_classes() function not implemented for {type(self)} object.")
892
960
  except AttributeError:
@@ -1002,13 +1070,13 @@ class Inference:
1002
1070
  self._model_meta = self._model_meta.add_tag_meta(tag_meta)
1003
1071
  return tag_meta
1004
1072
 
1005
- def _create_label(self, dto: Prediction) -> Label:
1073
+ def _create_label(self, dto: PredictionDTO) -> Label:
1006
1074
  raise NotImplementedError("Have to be implemented in child class")
1007
1075
 
1008
1076
  def _predictions_to_annotation(
1009
1077
  self,
1010
1078
  image_path: Union[str, np.ndarray],
1011
- predictions: List[Prediction],
1079
+ predictions: List[PredictionDTO],
1012
1080
  classes_whitelist: Optional[List[str]] = None,
1013
1081
  ) -> Annotation:
1014
1082
  labels = []
@@ -1067,6 +1135,15 @@ class Inference:
1067
1135
  logger.error(f"Error in {func.__name__} function: {e}", exc_info=True)
1068
1136
  raise e
1069
1137
 
1138
+ def api_from_request(self, request) -> Api:
1139
+ """
1140
+ Get API from request. If not found, use self.api.
1141
+ """
1142
+ api = request.state.api
1143
+ if api is None:
1144
+ api = self.api
1145
+ return api
1146
+
1070
1147
  def _inference_auto(
1071
1148
  self,
1072
1149
  source: List[Union[str, np.ndarray]],
@@ -1117,10 +1194,12 @@ class Inference:
1117
1194
  settings = self._get_inference_settings({})
1118
1195
 
1119
1196
  if isinstance(source[0], int):
1120
- ann_jsons = self._inference_batch_ids(
1121
- self.api, {"batch_ids": source, "settings": settings}
1197
+ results = self.inference_requests_manager.run(
1198
+ self._inference_image_ids, self.api, {"batch_ids": source, "settings": settings}
1122
1199
  )
1123
- anns = [Annotation.from_json(ann_json, self.model_meta) for ann_json in ann_jsons]
1200
+ anns = [
1201
+ Annotation.from_json(result["annotation"], self.model_meta) for result in results
1202
+ ]
1124
1203
  else:
1125
1204
  anns, _ = self._inference_auto(source, settings)
1126
1205
  if not input_is_list:
@@ -1240,17 +1319,17 @@ class Inference:
1240
1319
  return anns, benchmark
1241
1320
 
1242
1321
  # pylint: disable=method-hidden
1243
- def predict(self, image_path: str, settings: Dict[str, Any]) -> List[Prediction]:
1322
+ def predict(self, image_path: str, settings: Dict[str, Any]) -> List[PredictionDTO]:
1244
1323
  raise NotImplementedError("Have to be implemented in child class")
1245
1324
 
1246
- def predict_raw(self, image_path: str, settings: Dict[str, Any]) -> List[Prediction]:
1325
+ def predict_raw(self, image_path: str, settings: Dict[str, Any]) -> List[PredictionDTO]:
1247
1326
  raise NotImplementedError(
1248
1327
  "Have to be implemented in child class If sliding_window_mode is 'advanced'."
1249
1328
  )
1250
1329
 
1251
1330
  def predict_batch(
1252
1331
  self, images_np: List[np.ndarray], settings: Dict[str, Any]
1253
- ) -> List[List[Prediction]]:
1332
+ ) -> List[List[PredictionDTO]]:
1254
1333
  """Predict batch of images. `images_np` is a list of numpy arrays in RGB format
1255
1334
 
1256
1335
  If this method is not overridden in a subclass, the following fallback logic works:
@@ -1267,7 +1346,7 @@ class Inference:
1267
1346
 
1268
1347
  def predict_batch_raw(
1269
1348
  self, images_np: List[np.ndarray], settings: Dict[str, Any]
1270
- ) -> List[List[Prediction]]:
1349
+ ) -> List[List[PredictionDTO]]:
1271
1350
  """Predict batch of images. `source` is a list of numpy arrays in RGB format"""
1272
1351
  raise NotImplementedError(
1273
1352
  "Have to be implemented in child class If sliding_window_mode is 'advanced'."
@@ -1275,7 +1354,7 @@ class Inference:
1275
1354
 
1276
1355
  def predict_benchmark(
1277
1356
  self, images_np: List[np.ndarray], settings: dict
1278
- ) -> Tuple[List[List[Prediction]], dict]:
1357
+ ) -> Tuple[List[List[PredictionDTO]], dict]:
1279
1358
  """
1280
1359
  Inference a batch of images with speedtest benchmarking.
1281
1360
 
@@ -1318,15 +1397,24 @@ class Inference:
1318
1397
  )
1319
1398
  return is_predict_batch_overridden or is_predict_benchmark_overridden
1320
1399
 
1400
+ def set_conf_auto(self, conf: float, inference_settings: dict):
1401
+ conf_names = ["conf", "confidence", "confidence_threshold", "confidence_thresh"]
1402
+ for name in conf_names:
1403
+ if name in inference_settings:
1404
+ inference_settings[name] = conf
1405
+ return inference_settings
1406
+
1321
1407
  # pylint: enable=method-hidden
1322
1408
  def _get_inference_settings(self, state: dict):
1323
- settings = state.get("settings", {})
1409
+ settings = state.get("settings")
1324
1410
  if settings is None:
1325
1411
  settings = {}
1326
1412
  if "rectangle" in state.keys():
1327
1413
  settings["rectangle"] = state["rectangle"]
1414
+ conf = settings.get("conf", None)
1415
+ if conf is not None:
1416
+ settings = self.set_conf_auto(conf, settings)
1328
1417
  settings["sliding_window_mode"] = self.sliding_window_mode
1329
-
1330
1418
  for key, value in self.custom_inference_settings_dict.items():
1331
1419
  if key not in settings:
1332
1420
  logger.debug(
@@ -1335,13 +1423,19 @@ class Inference:
1335
1423
  settings[key] = value
1336
1424
  return settings
1337
1425
 
1426
+ def _get_batch_size_from_state(self, state: dict):
1427
+ batch_size = state.get("batch_size", None)
1428
+ if batch_size is None:
1429
+ batch_size = self.get_batch_size()
1430
+ return batch_size
1431
+
1338
1432
  @property
1339
1433
  def app(self) -> Application:
1340
1434
  return self._app
1341
1435
 
1342
1436
  def visualize(
1343
1437
  self,
1344
- predictions: List[Prediction],
1438
+ predictions: List[PredictionDTO],
1345
1439
  image_path: str,
1346
1440
  vis_path: str,
1347
1441
  thickness: Optional[int] = None,
@@ -1358,194 +1452,79 @@ class Inference:
1358
1452
 
1359
1453
  def _format_output(
1360
1454
  self,
1361
- anns: List[Annotation],
1362
- slides_data: List[dict] = None,
1455
+ predictions: List[Prediction],
1363
1456
  ) -> List[dict]:
1364
- if not slides_data:
1365
- slides_data = [{} for _ in range(len(anns))]
1366
- assert len(anns) == len(slides_data)
1367
- return [{"annotation": ann.to_json(), "data": data} for ann, data in zip(anns, slides_data)]
1368
-
1369
- def _inference_image(self, state: dict, file: UploadFile):
1370
- logger.debug("Inferring image...", extra={"state": state})
1371
- settings = self._get_inference_settings(state)
1372
- image_np = sly_image.read_bytes(file.file.read())
1373
- logger.debug("Inference settings:", extra=settings)
1374
- logger.debug("Image info:", extra={"w": image_np.shape[1], "h": image_np.shape[0]})
1375
- anns, slides_data = self._inference_auto(
1376
- [image_np],
1377
- settings=settings,
1378
- )
1379
- results = self._format_output(anns, slides_data)
1380
- return results[0]
1457
+ output = [
1458
+ {
1459
+ **pred.to_json(),
1460
+ "data": pred.extra_data.get("slides_data", {}),
1461
+ }
1462
+ for pred in predictions
1463
+ ]
1464
+ return output
1381
1465
 
1382
- def _inference_batch(self, state: dict, files: List[UploadFile]):
1466
+ def _inference_images(
1467
+ self,
1468
+ images: Iterable[Union[np.ndarray, str]],
1469
+ state: dict,
1470
+ inference_request: InferenceRequest,
1471
+ ):
1383
1472
  logger.debug("Inferring batch...", extra={"state": state})
1384
1473
  settings = self._get_inference_settings(state)
1385
- images = [sly_image.read_bytes(file.file.read()) for file in files]
1386
- anns, slides_data = self._inference_auto(
1387
- images,
1388
- settings=settings,
1389
- )
1390
- return self._format_output(anns, slides_data)
1391
-
1392
- def _inference_batch_ids(self, api: Api, state: dict):
1393
- logger.debug("Inferring batch_ids...", extra={"state": state})
1394
- settings = self._get_inference_settings(state)
1395
- ids = state["batch_ids"]
1396
- infos = api.image.get_info_by_id_batch(ids)
1397
- datasets = defaultdict(list)
1398
- for info in infos:
1399
- datasets[info.dataset_id].append(info.id)
1400
- results = []
1401
- for dataset_id, ids in datasets.items():
1402
- images_np = api.image.download_nps(dataset_id, ids)
1474
+ logger.debug("Inference settings:", extra={"inference_settings": settings})
1475
+ batch_size = self._get_batch_size_from_state(state)
1476
+
1477
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(images))
1478
+ for batch in batched_iter(images, batch_size=batch_size):
1479
+ batch = [
1480
+ self.cache.get_image_path(image) if isinstance(image, str) else image
1481
+ for image in batch
1482
+ ]
1403
1483
  anns, slides_data = self._inference_auto(
1404
- source=images_np,
1484
+ batch,
1405
1485
  settings=settings,
1406
1486
  )
1407
- anns = self._exclude_duplicated_predictions(api, anns, settings, dataset_id, ids)
1408
- results.extend(self._format_output(anns, slides_data))
1409
- return results
1410
-
1411
- def _inference_image_id(self, api: Api, state: dict, async_inference_request_uuid: str = None):
1412
- logger.debug("Inferring image_id...", extra={"state": state})
1413
- settings = self._get_inference_settings(state)
1414
- upload = state.get("upload", False)
1415
- image_id = state["image_id"]
1416
- image_info = api.image.get_info_by_id(image_id)
1417
- image_np = api.image.download_np(image_id)
1418
- logger.debug("Inference settings:", extra=settings)
1419
- logger.debug(
1420
- "Image info:",
1421
- extra={"id": image_id, "w": image_info.width, "h": image_info.height},
1422
- )
1423
-
1424
- inference_request = {}
1425
- if async_inference_request_uuid is not None:
1426
- try:
1427
- inference_request = self._inference_requests[async_inference_request_uuid]
1428
- except Exception as ex:
1429
- import traceback
1430
-
1431
- logger.error(traceback.format_exc())
1432
- raise RuntimeError(
1433
- f"async_inference_request_uuid {async_inference_request_uuid} was given, "
1434
- f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
1435
- )
1436
-
1437
- anns, slides_data = self._inference_auto(
1438
- [image_np],
1439
- settings=settings,
1440
- )
1441
- ann = anns[0]
1442
-
1443
- if upload:
1444
- ds_info = api.dataset.get_info_by_id(image_info.dataset_id, raise_error=True)
1445
- output_project_id = ds_info.project_id
1446
- output_project_meta = self.cache.get_project_meta(api, output_project_id)
1447
- logger.debug("Merging project meta...")
1448
-
1449
- output_project_meta, ann, meta_changed = update_meta_and_ann(output_project_meta, ann)
1450
- if meta_changed:
1451
- output_project_meta = api.project.update_meta(
1452
- output_project_id, output_project_meta
1453
- )
1454
- self.cache.set_project_meta(output_project_id, output_project_meta)
1455
-
1456
- ann = self._exclude_duplicated_predictions(
1457
- api, anns, settings, ds_info.id, [image_id], output_project_meta
1458
- )[0]
1459
-
1460
- logger.debug(
1461
- "Uploading annotation...",
1462
- extra={
1463
- "image_id": image_id,
1464
- "dataset_id": ds_info.id,
1465
- "project_id": output_project_id,
1466
- },
1467
- )
1468
- api.annotation.upload_ann(image_id, ann)
1469
- else:
1470
- ann = self._exclude_duplicated_predictions(
1471
- api, anns, settings, image_info.dataset_id, [image_id]
1472
- )[0]
1473
-
1474
- result = self._format_output(anns, slides_data)[0]
1475
- if async_inference_request_uuid is not None and ann is not None:
1476
- inference_request["result"] = result
1477
- return result
1478
-
1479
- def _inference_image_url(self, api: Api, state: dict):
1480
- logger.debug("Inferring image_url...", extra={"state": state})
1481
- settings = self._get_inference_settings(state)
1482
- image_url = state["image_url"]
1483
- ext = sly_fs.get_file_ext(image_url)
1484
- if ext == "":
1485
- ext = ".jpg"
1486
- image_path = os.path.join(get_data_dir(), rand_str(15) + ext)
1487
- sly_fs.download(image_url, image_path)
1488
- logger.debug("Inference settings:", extra=settings)
1489
- logger.debug(f"Downloaded path: {image_path}")
1490
- anns, slides_data = self._inference_auto(
1491
- [image_path],
1492
- settings=settings,
1493
- )
1494
- sly_fs.silent_remove(image_path)
1495
- return self._format_output(anns, slides_data)[0]
1496
-
1497
- def _inference_video_id(self, api: Api, state: dict, async_inference_request_uuid: str = None):
1498
- from supervisely.nn.inference.video_inference import InferenceVideoInterface
1499
-
1500
- logger.debug("Inferring video_id...", extra={"state": state})
1501
- video_info = api.video.get_info_by_id(state["videoId"])
1502
- n_frames = state.get("framesCount", video_info.frames_count)
1487
+ predictions = [Prediction(ann, model_meta=self.model_meta) for ann in anns]
1488
+ for pred, this_slides_data in zip(predictions, slides_data):
1489
+ pred.extra_data["slides_data"] = this_slides_data
1490
+ batch_results = self._format_output(predictions)
1491
+ inference_request.add_results(batch_results)
1492
+ inference_request.done(len(batch_results))
1493
+
1494
+ def _inference_video(
1495
+ self,
1496
+ path: str,
1497
+ state: Dict,
1498
+ inference_request: InferenceRequest,
1499
+ ):
1500
+ logger.debug("Inferring video...", extra={"path": path, "state": state})
1501
+ inference_settings = self._get_inference_settings(state)
1502
+ logger.debug(f"Inference settings:", extra=inference_settings)
1503
+ batch_size = self._get_batch_size_from_state(state)
1503
1504
  start_frame_index = state.get("startFrameIndex", 0)
1504
- direction = state.get("direction", "forward")
1505
- logger.debug(
1506
- f"Video info:",
1507
- extra=dict(
1508
- w=video_info.frame_width,
1509
- h=video_info.frame_height,
1510
- start_frame_index=start_frame_index,
1511
- n_frames=n_frames,
1512
- ),
1513
- )
1505
+ step = state.get("stride", None)
1506
+ if step is None:
1507
+ step = state.get("step", None)
1508
+ if step is None:
1509
+ step = 1
1510
+ end_frame_index = state.get("endFrameIndex", None)
1511
+ duration = state.get("duration", None)
1512
+ frames_count = state.get("framesCount", None)
1514
1513
  tracking = state.get("tracker", None)
1514
+ direction = state.get("direction", "forward")
1515
+ direction = 1 if direction == "forward" else -1
1515
1516
 
1516
- preparing_progress = {"current": 0, "total": 1}
1517
- if async_inference_request_uuid is not None:
1518
- try:
1519
- inference_request = self._inference_requests[async_inference_request_uuid]
1520
- except Exception as ex:
1521
- import traceback
1522
-
1523
- logger.error(traceback.format_exc())
1524
- raise RuntimeError(
1525
- f"async_inference_request_uuid {async_inference_request_uuid} was given, "
1526
- f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
1527
- )
1528
- sly_progress: Progress = inference_request["progress"]
1529
-
1530
- sly_progress.total = n_frames
1531
- inference_request["preparing_progress"]["total"] = n_frames
1532
- preparing_progress = inference_request["preparing_progress"]
1533
-
1534
- # progress
1535
- preparing_progress["status"] = "download_video"
1536
- preparing_progress["current"] = 0
1537
- preparing_progress["total"] = int(video_info.file_meta["size"])
1538
-
1539
- def _progress_cb(chunk_size):
1540
- preparing_progress["current"] += chunk_size
1541
-
1542
- self.cache.download_video(api, video_info.id, return_images=False, progress_cb=_progress_cb)
1543
- preparing_progress["status"] = "inference"
1544
-
1545
- settings = self._get_inference_settings(state)
1546
- logger.debug(f"Inference settings:", extra=settings)
1547
-
1548
- logger.debug(f"Total frames to infer: {n_frames}")
1517
+ frames_reader = VideoFrameReader(path)
1518
+ video_height, video_witdth = frames_reader.frame_size()
1519
+ if frames_count is not None:
1520
+ n_frames = frames_count
1521
+ elif end_frame_index is not None:
1522
+ n_frames = end_frame_index - start_frame_index
1523
+ elif duration is not None:
1524
+ fps = frames_reader.fps()
1525
+ n_frames = int(duration * fps)
1526
+ else:
1527
+ n_frames = frames_reader.frames_count()
1549
1528
 
1550
1529
  if tracking == "bot":
1551
1530
  from supervisely.nn.tracker import BoTTracker
@@ -1557,444 +1536,374 @@ class Inference:
1557
1536
  tracker = DeepSortTracker(state)
1558
1537
  else:
1559
1538
  if tracking is not None:
1560
- logger.warn(f"Unknown tracking type: {tracking}. Tracking is disabled.")
1539
+ logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
1561
1540
  tracker = None
1562
1541
 
1542
+ progress_total = (n_frames + step - 1) // step
1543
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
1544
+
1563
1545
  results = []
1564
- batch_size = state.get("batch_size", None)
1565
- if batch_size is None:
1566
- batch_size = self.get_batch_size()
1567
1546
  tracks_data = {}
1568
- direction = 1 if direction == "forward" else -1
1569
1547
  for batch in batched(
1570
- range(start_frame_index, start_frame_index + direction * n_frames, direction),
1548
+ range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
1571
1549
  batch_size,
1572
1550
  ):
1573
- if (
1574
- async_inference_request_uuid is not None
1575
- and inference_request["cancel_inference"] is True
1576
- ):
1551
+ if inference_request.is_stopped():
1577
1552
  logger.debug(
1578
- f"Cancelling inference video...",
1579
- extra={"inference_request_uuid": async_inference_request_uuid},
1553
+ f"Cancelling inference...",
1554
+ extra={"inference_request_uuid": inference_request.uuid},
1580
1555
  )
1581
1556
  results = []
1582
1557
  break
1583
1558
  logger.debug(
1584
1559
  f"Inferring frames {batch[0]}-{batch[-1]}:",
1585
1560
  )
1586
- frames = self.cache.download_frames(api, video_info.id, batch, redownload_video=True)
1561
+ frames = frames_reader.read_frames(batch)
1587
1562
  anns, slides_data = self._inference_auto(
1588
1563
  source=frames,
1589
- settings=settings,
1564
+ settings=inference_settings,
1590
1565
  )
1566
+ predictions = [
1567
+ Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
1568
+ for ann, frame_index in zip(anns, batch)
1569
+ ]
1570
+ for pred, this_slides_data in zip(predictions, slides_data):
1571
+ pred.extra_data["slides_data"] = this_slides_data
1572
+ batch_results = self._format_output(predictions)
1591
1573
  if tracker is not None:
1592
1574
  for frame_index, frame, ann in zip(batch, frames, anns):
1593
1575
  tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
1594
- batch_results = self._format_output(anns, slides_data)
1595
- results.extend(batch_results)
1596
- if async_inference_request_uuid is not None:
1597
- sly_progress.iters_done(len(batch))
1598
- inference_request["pending_results"].extend(batch_results)
1576
+ inference_request.add_results(batch_results)
1577
+ inference_request.done(len(batch_results))
1599
1578
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
1600
1579
  video_ann_json = None
1601
1580
  if tracker is not None:
1581
+ inference_request.set_stage("Postprocess...", 0, 1)
1602
1582
  video_ann_json = tracker.get_annotation(
1603
- tracks_data, (video_info.frame_height, video_info.frame_width), n_frames
1583
+ tracks_data, (video_height, video_witdth), n_frames
1604
1584
  ).to_json()
1585
+ inference_request.done()
1605
1586
  result = {"ann": results, "video_ann": video_ann_json}
1606
- if async_inference_request_uuid is not None and len(results) > 0:
1607
- inference_request["result"] = result.copy()
1608
- return result
1587
+ inference_request.final_result = result.copy()
1609
1588
 
1610
- def _inference_images_ids(
1589
+ def _inference_image_ids(
1611
1590
  self,
1612
1591
  api: Api,
1613
1592
  state: dict,
1614
- images_ids: List[int],
1615
- async_inference_request_uuid: str = None,
1593
+ inference_request: InferenceRequest,
1616
1594
  ):
1617
1595
  """Inference images by ids.
1618
1596
  If "output_project_id" in state, upload images and annotations to the output project.
1619
1597
  If "output_project_id" equal to source project id, upload annotations to the source project.
1620
1598
  If "output_project_id" is None, write annotations to inference request object.
1621
1599
  """
1622
- logger.debug("Inferring images...", extra={"state": state})
1623
- batch_size = state.get("batch_size", None)
1624
- if batch_size is None:
1625
- batch_size = self.get_batch_size()
1626
- output_project_id = state.get("output_project_id", None)
1627
- images_infos = api.image.get_info_by_id_batch(images_ids)
1600
+ logger.debug("Inferring batch_ids", extra={"state": state})
1601
+ inference_settings = self._get_inference_settings(state)
1602
+ logger.debug("Inference settings:", extra={"inference_settings": inference_settings})
1603
+ batch_size = self._get_batch_size_from_state(state)
1604
+ image_ids = get_value_for_keys(
1605
+ state, ["batch_ids", "image_ids", "images_ids", "imageIds", "image_id", "imageId"]
1606
+ )
1607
+ if image_ids is None:
1608
+ raise ValueError("Image ids are not provided")
1609
+ if not isinstance(image_ids, list):
1610
+ image_ids = [image_ids]
1611
+ upload_mode = state.get("upload_mode", None)
1612
+ iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
1613
+ if upload_mode == "iou_merge" and iou_merge_threshold is None:
1614
+ iou_merge_threshold = 0.7
1615
+
1616
+ images_infos = api.image.get_info_by_id_batch(image_ids)
1628
1617
  images_infos_dict = {im_info.id: im_info for im_info in images_infos}
1618
+ inference_request.context.setdefault("image_info", {}).update(images_infos_dict)
1619
+
1629
1620
  dataset_infos_dict = {
1630
1621
  ds_id: api.dataset.get_info_by_id(ds_id)
1631
1622
  for ds_id in set([im_info.dataset_id for im_info in images_infos])
1632
1623
  }
1624
+ inference_request.context.setdefault("dataset_info", {}).update(dataset_infos_dict)
1633
1625
 
1634
- if async_inference_request_uuid is not None:
1635
- try:
1636
- inference_request = self._inference_requests[async_inference_request_uuid]
1637
- except Exception as ex:
1638
- import traceback
1639
-
1640
- logger.error(traceback.format_exc())
1641
- raise RuntimeError(
1642
- f"async_inference_request_uuid {async_inference_request_uuid} was given, "
1643
- f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
1644
- )
1645
- sly_progress: Progress = inference_request["progress"]
1646
- sly_progress.total = len(images_ids)
1647
-
1648
- def _download_images(images_ids):
1649
- with ThreadPoolExecutor(max(8, min(batch_size, 64))) as executor:
1650
- for image_id in images_ids:
1651
- executor.submit(
1652
- self.cache.download_image,
1653
- api,
1654
- image_id,
1655
- )
1656
-
1657
- # start downloading in parallel
1658
- threading.Thread(target=_download_images, args=[images_ids], daemon=True).start()
1659
-
1660
- output_project_metas_dict = {}
1661
-
1662
- def _upload_results_to_source(results: List[Dict]):
1663
- nonlocal output_project_metas_dict
1664
- for result in results:
1665
- image_id = result["image_id"]
1666
- image_info: ImageInfo = images_infos_dict[image_id]
1667
- dataset_info: DatasetInfo = dataset_infos_dict[image_info.dataset_id]
1668
- project_id = dataset_info.project_id
1669
- ann = Annotation.from_json(result["annotation"], self.model_meta)
1670
- output_project_meta = output_project_metas_dict.get(project_id, None)
1671
- if output_project_meta is None:
1672
- output_project_meta = ProjectMeta.from_json(
1673
- api.project.get_meta(output_project_id)
1674
- )
1675
- output_project_meta, ann, meta_changed = update_meta_and_ann(
1676
- output_project_meta, ann
1677
- )
1678
- output_project_metas_dict[project_id] = output_project_meta
1679
- if meta_changed:
1680
- output_project_meta = api.project.update_meta(project_id, output_project_meta)
1681
- ann = update_classes(api, ann, output_project_meta, project_id)
1682
- api.annotation.append_labels(image_id, ann.labels)
1683
- if async_inference_request_uuid is not None:
1684
- sly_progress.iters_done(1)
1685
- inference_request["pending_results"].append(
1686
- {
1687
- "annotation": None, # to less response size
1688
- "data": None, # to less response size
1689
- "image_id": image_id,
1690
- "image_name": result["image_name"],
1691
- "dataset_id": result["dataset_id"],
1692
- }
1693
- )
1694
-
1695
- def _add_results_to_request(results: List[Dict]):
1696
- if async_inference_request_uuid is None:
1697
- return
1698
- inference_request["pending_results"].extend(results)
1699
- sly_progress.iters_done(len(results))
1700
-
1701
- new_dataset_id = {}
1702
-
1703
- def _get_or_create_new_dataset(output_project_id, src_dataset_id):
1704
- """Copy dataset in output project if not exists and return its id"""
1705
- if src_dataset_id in new_dataset_id:
1706
- return new_dataset_id[src_dataset_id]
1707
- dataset_info = api.dataset.get_info_by_id(src_dataset_id)
1708
-
1709
- def _create_parent_recursively(output_project_id, src_parent_id):
1710
- """Create parent datasets recursively and return the ID of the top-level parent"""
1711
- if src_parent_id in new_dataset_id:
1712
- return new_dataset_id[src_parent_id]
1713
- src_parent_info = dataset_infos_dict.get(src_parent_id)
1714
- if src_parent_info is None:
1715
- src_parent_info = api.dataset.get_info_by_id(src_parent_id)
1716
- if src_parent_info.parent_id is not None:
1717
- parent_id = _create_parent_recursively(
1718
- output_project_id, src_parent_info.parent_id
1719
- )
1720
- else:
1721
- parent_id = None
1722
- dst_parent = api.dataset.create(
1723
- output_project_id,
1724
- src_parent_info.name,
1725
- change_name_if_conflict=True,
1726
- parent_id=parent_id,
1727
- )
1728
- new_dataset_id[src_parent_info.id] = dst_parent.id
1729
- return dst_parent.id
1730
-
1731
- parent_id = None
1732
- if dataset_info.parent_id is not None:
1733
- parent_id = _create_parent_recursively(output_project_id, dataset_info.parent_id)
1734
-
1735
- output_dataset_id = api.dataset.create(
1736
- output_project_id,
1737
- dataset_info.name,
1626
+ output_project_id = state.get("output_project_id", None)
1627
+ output_dataset_id = None
1628
+ inference_request.context.setdefault("project_meta", {})
1629
+ if output_project_id is not None:
1630
+ if upload_mode is None:
1631
+ upload_mode = "append"
1632
+ if output_project_id is None and upload_mode == "create":
1633
+ image_info = images_infos[0]
1634
+ dataset_info = dataset_infos_dict[image_info.dataset_id]
1635
+ output_project_info = api.project.create(
1636
+ dataset_info.workspace_id,
1637
+ name=f"Predictions from task #{self.task_id}",
1638
+ description=f"Auto created project from inference request {inference_request.uuid}",
1738
1639
  change_name_if_conflict=True,
1739
- parent_id=parent_id,
1740
- ).id
1741
- new_dataset_id[src_dataset_id] = output_dataset_id
1742
- return output_dataset_id
1743
-
1744
- def _copy_images_to_dst(
1745
- src_dataset_id, dst_dataset_id, image_infos, dst_names
1746
- ) -> List[ImageInfo]:
1747
- return api.image.copy_batch_optimized(
1748
- src_dataset_id,
1749
- image_infos,
1750
- dst_dataset_id,
1751
- dst_names=dst_names,
1752
- with_annotations=False,
1753
- skip_validation=True,
1754
1640
  )
1755
-
1756
- def _upload_results_to_other(results: List[Dict]):
1757
- nonlocal output_project_metas_dict
1758
- if len(results) == 0:
1759
- return
1760
- src_dataset_id = results[0]["dataset_id"]
1761
- dataset_id = _get_or_create_new_dataset(output_project_id, src_dataset_id)
1762
- src_image_infos = [images_infos_dict[result["image_id"]] for result in results]
1763
- image_names = [result["image_name"] for result in results]
1764
- image_infos = _copy_images_to_dst(
1765
- src_dataset_id, dataset_id, src_image_infos, image_names
1641
+ output_project_id = output_project_info.id
1642
+ inference_request.context.setdefault("project_info", {})[
1643
+ output_project_id
1644
+ ] = output_project_info
1645
+ output_dataset_info = api.dataset.create(
1646
+ output_project_id,
1647
+ "Predictions",
1648
+ description=f"Auto created dataset from inference request {inference_request.uuid}",
1649
+ change_name_if_conflict=True,
1766
1650
  )
1767
- image_infos.sort(key=lambda x: image_names.index(x.name))
1768
- api.logger.debug(
1769
- "Uploading results to other project...",
1770
- extra={
1771
- "src_dataset_id": src_dataset_id,
1772
- "dst_project_id": output_project_id,
1773
- "dst_dataset_id": dataset_id,
1774
- "items_count": len(image_infos),
1775
- },
1651
+ output_dataset_id = output_dataset_info.id
1652
+ inference_request.context.setdefault("dataset_info", {})[
1653
+ output_dataset_id
1654
+ ] = output_dataset_info
1655
+
1656
+ # start download to cache in background
1657
+ dataset_image_infos: Dict[int, List[ImageInfo]] = defaultdict(list)
1658
+ for image_info in images_infos:
1659
+ dataset_image_infos[image_info.dataset_id].append(image_info)
1660
+ for dataset_id, ds_image_infos in dataset_image_infos.items():
1661
+ self.cache.run_cache_task_manually(
1662
+ api, [info.id for info in ds_image_infos], dataset_id=dataset_id
1776
1663
  )
1777
- meta_changed = False
1778
- anns = []
1779
- for result in results:
1780
- ann = Annotation.from_json(result["annotation"], self.model_meta)
1781
- output_project_meta = output_project_metas_dict.get(output_project_id, None)
1782
- if output_project_meta is None:
1783
- output_project_meta = ProjectMeta.from_json(
1784
- api.project.get_meta(output_project_id)
1785
- )
1786
- output_project_meta, ann, c = update_meta_and_ann(output_project_meta, ann)
1787
- output_project_metas_dict[output_project_id] = output_project_meta
1788
- meta_changed = meta_changed or c
1789
- anns.append(ann)
1790
- if meta_changed:
1791
- api.project.update_meta(output_project_id, output_project_meta)
1792
-
1793
- # upload in batches to update progress with each batch
1794
- # api.annotation.upload_anns() uploads in same batches anyways
1795
- for batch in batched(list(zip(anns, results, image_infos))):
1796
- batch_anns, batch_results, batch_image_infos = zip(*batch)
1797
- api.annotation.upload_anns(
1798
- img_ids=[info.id for info in batch_image_infos],
1799
- anns=batch_anns,
1800
- )
1801
- if async_inference_request_uuid is not None:
1802
- sly_progress.iters_done(len(batch_results))
1803
- inference_request["pending_results"].extend(
1804
- [{**result, "annotation": None, "data": None} for result in batch_results]
1805
- )
1806
1664
 
1807
- def upload_results_to_source_or_other(results: List[Dict]):
1808
- if len(results) == 0:
1809
- return
1810
- dataset_id = results[0]["dataset_id"]
1811
- dataset_info: DatasetInfo = dataset_infos_dict[dataset_id]
1812
- project_id = dataset_info.project_id
1813
- if project_id == output_project_id:
1814
- _upload_results_to_source(results)
1815
- else:
1816
- _upload_results_to_other(results)
1817
-
1818
- if output_project_id is None:
1819
- upload_f = _add_results_to_request
1820
- else:
1821
- upload_f = upload_results_to_source_or_other
1822
-
1823
- def _upload_loop(q: Queue, stop_event: threading.Event, api: Api, upload_f: Callable):
1824
- try:
1825
- while True:
1826
- items = []
1827
- while not q.empty():
1828
- items.append(q.get_nowait())
1829
- if len(items) > 0:
1830
- ds_batches = {}
1831
- for batch in items:
1832
- if len(batch) == 0:
1833
- continue
1834
- for each in batch:
1835
- ds_batches.setdefault(each["dataset_id"], []).append(each)
1836
- for _, joined_batch in ds_batches.items():
1837
- upload_f(joined_batch)
1838
- continue
1839
- if stop_event.is_set():
1840
- self._on_inference_end(None, async_inference_request_uuid)
1841
- return
1842
- time.sleep(1)
1843
- except Exception as e:
1844
- api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
1845
- raise
1846
-
1847
- upload_queue = Queue()
1848
- stop_upload_event = threading.Event()
1849
- upload_thread = threading.Thread(
1850
- target=_upload_loop,
1851
- args=[upload_queue, stop_upload_event, api, upload_f],
1852
- daemon=True,
1665
+ _upload_predictions = partial(
1666
+ self.upload_predictions,
1667
+ api=api,
1668
+ upload_mode=upload_mode,
1669
+ context=inference_request.context,
1670
+ dst_dataset_id=output_dataset_id,
1671
+ dst_project_id=output_project_id,
1672
+ progress_cb=inference_request.done,
1673
+ iou_merge_threshold=iou_merge_threshold,
1674
+ inference_request=inference_request,
1853
1675
  )
1854
- upload_thread.start()
1855
1676
 
1856
- settings = self._get_inference_settings(state)
1857
- logger.debug(f"Inference settings:", extra=settings)
1677
+ _add_results_to_request = partial(
1678
+ self.add_results_to_request, inference_request=inference_request
1679
+ )
1858
1680
 
1859
- results = []
1860
- stop = False
1861
- try:
1862
- for image_ids_batch in batched(images_ids, batch_size=batch_size):
1863
- if stop:
1864
- break
1865
- if (
1866
- async_inference_request_uuid is not None
1867
- and inference_request["cancel_inference"] is True
1868
- ):
1681
+ if upload_mode is None:
1682
+ upload_f = _add_results_to_request
1683
+ else:
1684
+ upload_f = _upload_predictions
1685
+
1686
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(image_ids))
1687
+ with Uploader(upload_f, logger=logger) as uploader:
1688
+ for image_ids_batch in batched(image_ids, batch_size=batch_size):
1689
+ if uploader.has_exception():
1690
+ exception = uploader.exception()
1691
+ raise RuntimeError(f"Error in upload loop: {exception}") from exception
1692
+ if inference_request.is_stopped():
1869
1693
  logger.debug(
1870
1694
  f"Cancelling inference project...",
1871
- extra={"inference_request_uuid": async_inference_request_uuid},
1695
+ extra={"inference_request_uuid": inference_request.uuid},
1872
1696
  )
1873
- results = []
1874
- stop = True
1875
1697
  break
1876
1698
 
1877
1699
  images_nps = [self.cache.download_image(api, img_id) for img_id in image_ids_batch]
1878
1700
  anns, slides_data = self._inference_auto(
1879
1701
  source=images_nps,
1880
- settings=settings,
1702
+ settings=inference_settings,
1881
1703
  )
1882
- batch_results = []
1883
- for i, ann in enumerate(anns):
1884
- image_info: ImageInfo = images_infos_dict[image_ids_batch[i]]
1885
- ds_info = dataset_infos_dict[image_info.dataset_id]
1886
- meta = output_project_metas_dict.get(ds_info.project_id, None)
1887
- iou = settings.get("existing_objects_iou_thresh")
1888
- if meta is None and isinstance(iou, float) and iou > 0:
1889
- meta = ProjectMeta.from_json(api.project.get_meta(ds_info.project_id))
1890
- output_project_metas_dict[ds_info.project_id] = meta
1891
- ann = self._exclude_duplicated_predictions(
1892
- api, [ann], settings, ds_info.id, [image_info.id], meta
1893
- )[0]
1894
- batch_results.append(
1895
- {
1896
- "annotation": ann.to_json(),
1897
- "data": slides_data[i],
1898
- "image_id": image_info.id,
1899
- "image_name": image_info.name,
1900
- "dataset_id": image_info.dataset_id,
1901
- }
1704
+
1705
+ batch_predictions = []
1706
+ for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
1707
+ image_info: ImageInfo = images_infos_dict[image_id]
1708
+ dataset_info = dataset_infos_dict[image_info.dataset_id]
1709
+ prediction = Prediction(
1710
+ ann,
1711
+ model_meta=self.model_meta,
1712
+ name=image_info.name,
1713
+ image_id=image_info.id,
1714
+ dataset_id=image_info.dataset_id,
1715
+ project_id=dataset_info.project_id,
1902
1716
  )
1903
- results.extend(batch_results)
1904
- upload_queue.put(batch_results)
1905
- except Exception:
1906
- stop_upload_event.set()
1907
- upload_thread.join()
1908
- raise
1909
- if async_inference_request_uuid is not None and len(results) > 0:
1910
- inference_request["result"] = {"ann": results}
1911
- stop_upload_event.set()
1912
- upload_thread.join()
1913
- return results
1717
+ prediction.extra_data["slides_data"] = this_slides_data
1718
+ batch_predictions.append(prediction)
1914
1719
 
1915
- def _inference_project_id(
1720
+ uploader.put(batch_predictions)
1721
+
1722
+ def _inference_video_id(
1916
1723
  self,
1917
1724
  api: Api,
1918
1725
  state: dict,
1919
- project_info: ProjectInfo = None,
1920
- async_inference_request_uuid: str = None,
1726
+ inference_request: InferenceRequest,
1921
1727
  ):
1728
+ logger.debug("Inferring video_id...", extra={"state": state})
1729
+ inference_settings = self._get_inference_settings(state)
1730
+ logger.debug(f"Inference settings:", extra=inference_settings)
1731
+ batch_size = self._get_batch_size_from_state(state)
1732
+ video_id = state["videoId"]
1733
+ video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
1734
+ if video_id is None:
1735
+ raise ValueError("Video id is not provided")
1736
+ video_info = api.video.get_info_by_id(video_id)
1737
+ start_frame_index = get_value_for_keys(
1738
+ state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
1739
+ )
1740
+ if start_frame_index is None:
1741
+ start_frame_index = 0
1742
+ step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
1743
+ if step is None:
1744
+ step = 1
1745
+ end_frame_index = get_value_for_keys(
1746
+ state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
1747
+ )
1748
+ duration = state.get("duration", None)
1749
+ frames_count = get_value_for_keys(
1750
+ state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
1751
+ )
1752
+ tracking = state.get("tracker", None)
1753
+ direction = state.get("direction", "forward")
1754
+ direction = 1 if direction == "forward" else -1
1755
+
1756
+ if frames_count is not None:
1757
+ n_frames = frames_count
1758
+ elif end_frame_index is not None:
1759
+ n_frames = end_frame_index - start_frame_index
1760
+ elif duration is not None:
1761
+ fps = video_info.frames_count / video_info.duration
1762
+ n_frames = int(duration * fps)
1763
+ else:
1764
+ n_frames = video_info.frames_count
1765
+
1766
+ if tracking == "bot":
1767
+ from supervisely.nn.tracker import BoTTracker
1768
+
1769
+ tracker = BoTTracker(state)
1770
+ elif tracking == "deepsort":
1771
+ from supervisely.nn.tracker import DeepSortTracker
1772
+
1773
+ tracker = DeepSortTracker(state)
1774
+ else:
1775
+ if tracking is not None:
1776
+ logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
1777
+ tracker = None
1778
+ logger.debug(
1779
+ f"Video info:",
1780
+ extra=dict(
1781
+ w=video_info.frame_width,
1782
+ h=video_info.frame_height,
1783
+ start_frame_index=start_frame_index,
1784
+ n_frames=n_frames,
1785
+ ),
1786
+ )
1787
+
1788
+ # start downloading video in background
1789
+ self.cache.run_cache_task_manually(api, None, video_id=video_id)
1790
+
1791
+ progress_total = (n_frames + step - 1) // step
1792
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
1793
+
1794
+ tracks_data = {}
1795
+ for batch in batched(
1796
+ range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
1797
+ batch_size,
1798
+ ):
1799
+ if inference_request.is_stopped():
1800
+ logger.debug(
1801
+ f"Cancelling inference video...",
1802
+ extra={"inference_request_uuid": inference_request.uuid},
1803
+ )
1804
+ break
1805
+ logger.debug(
1806
+ f"Inferring frames {batch[0]}-{batch[-1]}:",
1807
+ )
1808
+ frames = self.cache.download_frames(api, video_info.id, batch, redownload_video=True)
1809
+ anns, slides_data = self._inference_auto(
1810
+ source=frames,
1811
+ settings=inference_settings,
1812
+ )
1813
+ predictions = [
1814
+ Prediction(
1815
+ ann,
1816
+ model_meta=self.model_meta,
1817
+ frame_index=frame_index,
1818
+ video_id=video_info.id,
1819
+ dataset_id=video_info.dataset_id,
1820
+ project_id=video_info.project_id,
1821
+ )
1822
+ for ann, frame_index in zip(anns, batch)
1823
+ ]
1824
+ for pred, this_slides_data in zip(predictions, slides_data):
1825
+ pred.extra_data["slides_data"] = this_slides_data
1826
+ batch_results = self._format_output(predictions)
1827
+ if tracker is not None:
1828
+ for frame_index, frame, ann in zip(batch, frames, anns):
1829
+ tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
1830
+ inference_request.add_results(batch_results)
1831
+ inference_request.done(len(batch_results))
1832
+ logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
1833
+ video_ann_json = None
1834
+ if tracker is not None:
1835
+ inference_request.set_stage("Postprocess...", 0, 1)
1836
+ video_ann_json = tracker.get_annotation(
1837
+ tracks_data, (video_info.frame_height, video_info.frame_width), n_frames
1838
+ ).to_json()
1839
+ inference_request.done()
1840
+ inference_request.final_result = {"video_ann": video_ann_json}
1841
+
1842
+ def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
1922
1843
  """Inference project images.
1923
1844
  If "output_project_id" in state, upload images and annotations to the output project.
1924
1845
  If "output_project_id" equal to source project id, upload annotations to the source project.
1925
1846
  If "output_project_id" is None, write annotations to inference request object.
1926
1847
  """
1927
1848
  logger.debug("Inferring project...", extra={"state": state})
1928
- if project_info is None:
1929
- project_info = api.project.get_info_by_id(state["projectId"])
1930
- dataset_ids = state.get("dataset_ids", None)
1849
+ inference_settings = self._get_inference_settings(state)
1850
+ logger.debug("Inference settings:", extra={"inference_settings": inference_settings})
1851
+ batch_size = self._get_batch_size_from_state(state)
1852
+ project_id = get_value_for_keys(state, keys=["projectId", "project_id"])
1853
+ if project_id is None:
1854
+ raise ValueError("Project id is not provided")
1855
+ project_info = api.project.get_info_by_id(project_id)
1856
+ if project_info.type != str(ProjectType.IMAGES):
1857
+ raise ValueError("Only images projects are supported.")
1858
+ upload_mode = state.get("upload_mode", None)
1859
+ iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
1860
+ if upload_mode == "iou_merge" and iou_merge_threshold is None:
1861
+ iou_merge_threshold = 0.7
1931
1862
  cache_project_on_model = state.get("cache_project_on_model", False)
1932
- batch_size = state.get("batch_size", None)
1933
- if batch_size is None:
1934
- batch_size = self.get_batch_size()
1935
1863
 
1864
+ project_info = api.project.get_info_by_id(project_id)
1865
+ inference_request.context.setdefault("project_info", {})[project_id] = project_info
1866
+ dataset_ids = state.get("dataset_ids", None)
1867
+ if dataset_ids is None:
1868
+ dataset_ids = state.get("datasetIds", None)
1936
1869
  datasets_infos = api.dataset.get_list(project_info.id, recursive=True)
1870
+ inference_request.context.setdefault("dataset_info", {}).update(
1871
+ {ds_info.id: ds_info for ds_info in datasets_infos}
1872
+ )
1937
1873
  if dataset_ids is not None:
1938
1874
  datasets_infos = [ds_info for ds_info in datasets_infos if ds_info.id in dataset_ids]
1939
1875
 
1940
- # progress
1941
- preparing_progress = {"current": 0, "total": 1}
1942
- preparing_progress["status"] = "download_info"
1943
- preparing_progress["current"] = 0
1944
- preparing_progress["total"] = len(datasets_infos)
1945
- progress_cb = None
1946
- if async_inference_request_uuid is not None:
1947
- try:
1948
- inference_request = self._inference_requests[async_inference_request_uuid]
1949
- except Exception as ex:
1950
- import traceback
1951
-
1952
- logger.error(traceback.format_exc())
1953
- raise RuntimeError(
1954
- f"async_inference_request_uuid {async_inference_request_uuid} was given, "
1955
- f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
1956
- )
1957
- sly_progress: Progress = inference_request["progress"]
1958
- sly_progress.total = sum([ds_info.items_count for ds_info in datasets_infos])
1959
-
1960
- inference_request["preparing_progress"]["total"] = len(datasets_infos)
1961
- preparing_progress = inference_request["preparing_progress"]
1962
-
1963
- if cache_project_on_model:
1964
- progress_cb = sly_progress.iters_done
1965
- preparing_progress["total"] = sly_progress.total
1966
- preparing_progress["status"] = "download_project"
1876
+ preparing_progress_total = sum([ds_info.items_count for ds_info in datasets_infos])
1877
+ inference_progress_total = preparing_progress_total
1878
+ inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, preparing_progress_total)
1967
1879
 
1968
1880
  output_project_id = state.get("output_project_id", None)
1969
- output_project_meta = None
1881
+ inference_request.context.setdefault("project_meta", {})
1970
1882
  if output_project_id is not None:
1971
- logger.debug("Merging project meta...")
1972
- output_project_meta = ProjectMeta.from_json(api.project.get_meta(output_project_id))
1973
- changed = False
1974
- for obj_class in self.model_meta.obj_classes:
1975
- if output_project_meta.obj_classes.get(obj_class.name, None) is None:
1976
- output_project_meta = output_project_meta.add_obj_class(obj_class)
1977
- changed = True
1978
- for tag_meta in self.model_meta.tag_metas:
1979
- if output_project_meta.tag_metas.get(tag_meta.name, None) is None:
1980
- output_project_meta = output_project_meta.add_tag_meta(tag_meta)
1981
- changed = True
1982
- if changed:
1983
- output_project_meta = api.project.update_meta(
1984
- output_project_id, output_project_meta
1985
- )
1883
+ if upload_mode is None:
1884
+ upload_mode = "append"
1885
+ if output_project_id is None and upload_mode == "create":
1886
+ output_project_info = api.project.create(
1887
+ project_info.workspace_id,
1888
+ name=f"Predictions from task #{self.task_id}",
1889
+ description=f"Auto created project from inference request {inference_request.uuid}",
1890
+ change_name_if_conflict=True,
1891
+ )
1892
+ output_project_id = output_project_info.id
1893
+ inference_request.context.setdefault("project_info", {})[
1894
+ output_project_id
1895
+ ] = output_project_info
1986
1896
 
1987
1897
  if cache_project_on_model:
1988
- download_to_cache(api, project_info.id, datasets_infos, progress_cb=progress_cb)
1898
+ download_to_cache(
1899
+ api, project_info.id, datasets_infos, progress_cb=inference_request.done
1900
+ )
1989
1901
 
1990
1902
  images_infos_dict = {}
1991
1903
  for dataset_info in datasets_infos:
1992
1904
  images_infos_dict[dataset_info.id] = api.image.get_list(dataset_info.id)
1993
1905
  if not cache_project_on_model:
1994
- preparing_progress["current"] += 1
1995
-
1996
- preparing_progress["status"] = "inference"
1997
- preparing_progress["current"] = 0
1906
+ inference_request.done(dataset_info.items_count)
1998
1907
 
1999
1908
  def _download_images(datasets_infos: List[DatasetInfo]):
2000
1909
  for dataset_info in datasets_infos:
@@ -2011,166 +1920,41 @@ class Inference:
2011
1920
  # start downloading in parallel
2012
1921
  threading.Thread(target=_download_images, args=[datasets_infos], daemon=True).start()
2013
1922
 
2014
- def _upload_results_to_source(results: List[Dict]):
2015
- nonlocal output_project_meta
2016
- for result in results:
2017
- image_id = result["image_id"]
2018
- ann = Annotation.from_json(result["annotation"], self.model_meta)
2019
- output_project_meta, ann, meta_changed = update_meta_and_ann(
2020
- output_project_meta, ann
2021
- )
2022
- if meta_changed:
2023
- output_project_meta = api.project.update_meta(
2024
- project_info.id, output_project_meta
2025
- )
2026
- ann = update_classes(api, ann, output_project_meta, output_project_id)
2027
- api.annotation.append_labels(image_id, ann.labels)
2028
- if async_inference_request_uuid is not None:
2029
- sly_progress.iters_done(1)
2030
- inference_request["pending_results"].append(
2031
- {
2032
- "annotation": None, # to less response size
2033
- "data": None, # to less response size
2034
- "image_id": image_id,
2035
- "image_name": result["image_name"],
2036
- "dataset_id": result["dataset_id"],
2037
- }
2038
- )
2039
-
2040
- new_dataset_id = {}
2041
-
2042
- def _get_or_create_new_dataset(output_project_id, src_dataset_id):
2043
- """Copy dataset in output project if not exists and return its id"""
2044
- if src_dataset_id in new_dataset_id:
2045
- return new_dataset_id[src_dataset_id]
2046
- dataset_info = api.dataset.get_info_by_id(src_dataset_id)
2047
- if dataset_info.parent_id is None:
2048
- output_dataset_id = api.dataset.copy(
2049
- output_project_id,
2050
- src_dataset_id,
2051
- dataset_info.name,
2052
- change_name_if_conflict=True,
2053
- ).id
2054
- else:
2055
- parent_dataset_id = _get_or_create_new_dataset(
2056
- output_project_id, dataset_info.parent_id
2057
- )
2058
- output_dataset_info = api.dataset.create(
2059
- output_project_id, dataset_info.name, parent_id=parent_dataset_id
2060
- )
2061
- api.image.copy_batch_optimized(
2062
- dataset_info.id,
2063
- images_infos_dict[dataset_info.id],
2064
- output_dataset_info.id,
2065
- with_annotations=False,
2066
- )
2067
- output_dataset_id = output_dataset_info.id
2068
- new_dataset_id[src_dataset_id] = output_dataset_id
2069
- return output_dataset_id
2070
-
2071
- def _upload_results_to_other(results: List[Dict]):
2072
- nonlocal output_project_meta
2073
- if len(results) == 0:
2074
- return
2075
- src_dataset_id = results[0]["dataset_id"]
2076
- dataset_id = _get_or_create_new_dataset(output_project_id, src_dataset_id)
2077
- image_names = [result["image_name"] for result in results]
2078
- image_infos = api.image.get_list(
2079
- dataset_id,
2080
- filters=[{"field": "name", "operator": "in", "value": image_names}],
2081
- )
2082
- meta_changed = False
2083
- anns = []
2084
- for result in results:
2085
- ann = Annotation.from_json(result["annotation"], self.model_meta)
2086
- output_project_meta, ann, c = update_meta_and_ann(output_project_meta, ann)
2087
- meta_changed = meta_changed or c
2088
- anns.append(ann)
2089
- if meta_changed:
2090
- api.project.update_meta(output_project_id, output_project_meta)
2091
-
2092
- # upload in batches to update progress with each batch
2093
- # api.annotation.upload_anns() uploads in same batches anyways
2094
- for batch in batched(list(zip(anns, results, image_infos))):
2095
- batch_anns, batch_results, batch_image_infos = zip(*batch)
2096
- api.annotation.upload_anns(
2097
- img_ids=[info.id for info in batch_image_infos],
2098
- anns=batch_anns,
2099
- )
2100
- if async_inference_request_uuid is not None:
2101
- sly_progress.iters_done(len(batch_results))
2102
- inference_request["pending_results"].extend(
2103
- [{**result, "annotation": None, "data": None} for result in batch_results]
2104
- )
2105
-
2106
- def _add_results_to_request(results: List[Dict]):
2107
- if async_inference_request_uuid is None:
2108
- return
2109
- inference_request["pending_results"].extend(results)
2110
- sly_progress.iters_done(len(results))
1923
+ _upload_predictions = partial(
1924
+ self.upload_predictions,
1925
+ api=api,
1926
+ upload_mode=upload_mode,
1927
+ context=inference_request.context,
1928
+ dst_project_id=output_project_id,
1929
+ progress_cb=inference_request.done,
1930
+ iou_merge_threshold=iou_merge_threshold,
1931
+ inference_request=inference_request,
1932
+ )
2111
1933
 
2112
- def _upload_loop(q: Queue, stop_event: threading.Event, api: Api, upload_f: Callable):
2113
- try:
2114
- while True:
2115
- items = []
2116
- while not q.empty():
2117
- items.append(q.get_nowait())
2118
- if len(items) > 0:
2119
- ds_batches = {}
2120
- for batch in items:
2121
- if len(batch) == 0:
2122
- continue
2123
- ds_batches.setdefault(batch[0].get("dataset_id"), []).extend(batch)
2124
- for _, joined_batch in ds_batches.items():
2125
- upload_f(joined_batch)
2126
- continue
2127
- if stop_event.is_set():
2128
- self._on_inference_end(None, async_inference_request_uuid)
2129
- return
2130
- time.sleep(1)
2131
- except Exception as e:
2132
- api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
2133
- raise
1934
+ _add_results_to_request = partial(
1935
+ self.add_results_to_request, inference_request=inference_request
1936
+ )
2134
1937
 
2135
- if output_project_id is None:
1938
+ if upload_mode is None:
2136
1939
  upload_f = _add_results_to_request
2137
- elif output_project_id != project_info.id:
2138
- upload_f = _upload_results_to_other
2139
1940
  else:
2140
- upload_f = _upload_results_to_source
2141
-
2142
- upload_queue = Queue()
2143
- stop_upload_event = threading.Event()
2144
- upload_thread = threading.Thread(
2145
- target=_upload_loop,
2146
- args=[upload_queue, stop_upload_event, api, upload_f],
2147
- daemon=True,
2148
- )
2149
- upload_thread.start()
1941
+ upload_f = _upload_predictions
2150
1942
 
2151
- settings = self._get_inference_settings(state)
2152
- logger.debug(f"Inference settings:", extra=settings)
2153
- results = []
2154
- data_to_return = {}
2155
- stop = False
2156
- try:
1943
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, inference_progress_total)
1944
+ with Uploader(upload_f, logger=logger) as uploader:
2157
1945
  for dataset_info in datasets_infos:
2158
- if stop:
2159
- break
2160
1946
  for images_infos_batch in batched(
2161
1947
  images_infos_dict[dataset_info.id], batch_size=batch_size
2162
1948
  ):
2163
- if (
2164
- async_inference_request_uuid is not None
2165
- and inference_request["cancel_inference"] is True
2166
- ):
1949
+ if inference_request.is_stopped():
2167
1950
  logger.debug(
2168
1951
  f"Cancelling inference project...",
2169
- extra={"inference_request_uuid": async_inference_request_uuid},
1952
+ extra={"inference_request_uuid": inference_request.uuid},
2170
1953
  )
2171
- results = []
2172
- stop = True
2173
- break
1954
+ return
1955
+ if uploader.has_exception():
1956
+ exception = uploader.exception
1957
+ raise RuntimeError(f"Error in upload loop: {exception}") from exception
2174
1958
  if cache_project_on_model:
2175
1959
  images_paths, _ = zip(
2176
1960
  *read_from_cached_project(
@@ -2189,52 +1973,36 @@ class Inference:
2189
1973
  )
2190
1974
  anns, slides_data = self._inference_auto(
2191
1975
  source=images_nps,
2192
- settings=settings,
1976
+ settings=inference_settings,
2193
1977
  )
2194
- iou = settings.get("existing_objects_iou_thresh")
2195
- if output_project_meta is None and isinstance(iou, float) and iou > 0:
2196
- output_project_meta = ProjectMeta.from_json(
2197
- api.project.get_meta(project_info.id)
1978
+ predictions = [
1979
+ Prediction(
1980
+ ann,
1981
+ model_meta=self.model_meta,
1982
+ image_id=image_info.id,
1983
+ name=image_info.name,
1984
+ dataset_id=dataset_info.id,
1985
+ project_id=dataset_info.project_id,
1986
+ image_name=image_info.name,
2198
1987
  )
2199
- anns = self._exclude_duplicated_predictions(
2200
- api,
2201
- anns,
2202
- settings,
2203
- dataset_info.id,
2204
- [ii.id for ii in images_infos_batch],
2205
- output_project_meta,
2206
- )
2207
- batch_results = []
2208
- for i, ann in enumerate(anns):
2209
- batch_results.append(
2210
- {
2211
- "annotation": ann.to_json(),
2212
- "data": slides_data[i],
2213
- "image_id": images_infos_batch[i].id,
2214
- "image_name": images_infos_batch[i].name,
2215
- "dataset_id": dataset_info.id,
2216
- }
2217
- )
2218
- results.extend(batch_results)
2219
- upload_queue.put(batch_results)
2220
- except Exception:
2221
- stop_upload_event.set()
2222
- upload_thread.join()
2223
- raise
2224
- if async_inference_request_uuid is not None and len(results) > 0:
2225
- inference_request["result"] = {"ann": results}
2226
- stop_upload_event.set()
2227
- upload_thread.join()
2228
- return results
1988
+ for ann, image_info in zip(anns, images_infos_batch)
1989
+ ]
1990
+ for pred, this_slides_data in zip(predictions, slides_data):
1991
+ pred.extra_data["slides_data"] = this_slides_data
1992
+
1993
+ uploader.put(predictions)
2229
1994
 
2230
1995
  def _run_speedtest(
2231
1996
  self,
2232
1997
  api: Api,
2233
1998
  state: dict,
2234
- async_inference_request_uuid: str = None,
1999
+ inference_request: InferenceRequest,
2235
2000
  ):
2236
2001
  """Run speedtest on project images."""
2237
2002
  logger.debug("Running speedtest...", extra={"state": state})
2003
+ settings = self._get_inference_settings(state)
2004
+ logger.debug(f"Inference settings:", extra=settings)
2005
+
2238
2006
  project_id = state["projectId"]
2239
2007
  batch_size = state["batch_size"]
2240
2008
  num_iterations = state["num_iterations"]
@@ -2252,49 +2020,22 @@ class Inference:
2252
2020
  if dataset_id in datasets_infos_dict
2253
2021
  ]
2254
2022
 
2255
- # progress
2256
- preparing_progress = {"current": 0, "total": 1}
2257
- if async_inference_request_uuid is not None:
2258
- try:
2259
- inference_request = self._inference_requests[async_inference_request_uuid]
2260
- except Exception as ex:
2261
- import traceback
2262
-
2263
- logger.error(traceback.format_exc())
2264
- raise RuntimeError(
2265
- f"async_inference_request_uuid {async_inference_request_uuid} was given, "
2266
- f"but there is no such uuid in 'self._inference_requests' ({len(self._inference_requests)} items)"
2267
- )
2268
- sly_progress: Progress = inference_request["progress"]
2269
- sly_progress.total = num_iterations
2270
- sly_progress.current = 0
2271
-
2272
- preparing_progress = inference_request["preparing_progress"]
2023
+ preparing_progress_total = len(datasets_infos)
2024
+ if cache_project_on_model:
2025
+ preparing_progress_total += sum(
2026
+ dataset_info.items_count for dataset_info in datasets_infos
2027
+ )
2028
+ inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, preparing_progress_total)
2273
2029
 
2274
- preparing_progress["current"] = 0
2275
- preparing_progress["total"] = len(datasets_infos)
2276
- preparing_progress["status"] = "download_info"
2277
2030
  images_infos_dict = {}
2278
2031
  for dataset_info in datasets_infos:
2279
2032
  images_infos_dict[dataset_info.id] = api.image.get_list(dataset_info.id)
2280
- if not cache_project_on_model:
2281
- preparing_progress["current"] += 1
2033
+ inference_request.done()
2282
2034
 
2283
2035
  if cache_project_on_model:
2036
+ download_to_cache(api, project_id, datasets_infos, progress_cb=inference_request.done)
2284
2037
 
2285
- def _progress_cb(count: int = 1):
2286
- preparing_progress["current"] += count
2287
-
2288
- preparing_progress["current"] = 0
2289
- preparing_progress["total"] = sum(
2290
- dataset_info.items_count for dataset_info in datasets_infos
2291
- )
2292
- preparing_progress["status"] = "download_project"
2293
- download_to_cache(api, project_id, datasets_infos, progress_cb=_progress_cb)
2294
-
2295
- preparing_progress["status"] = "warmup"
2296
- preparing_progress["current"] = 0
2297
- preparing_progress["total"] = num_warmup
2038
+ inference_request.set_stage("warmup", 0, num_warmup)
2298
2039
 
2299
2040
  images_infos: List[ImageInfo] = [
2300
2041
  image_info for infos in images_infos_dict.values() for image_info in infos
@@ -2313,44 +2054,9 @@ class Inference:
2313
2054
  # start downloading in parallel
2314
2055
  threading.Thread(target=_download_images, daemon=True).start()
2315
2056
 
2316
- def _add_results_to_request(results: List[Dict]):
2317
- if async_inference_request_uuid is None:
2318
- return
2319
- inference_request["pending_results"].append(results)
2320
- sly_progress.iters_done(1)
2321
-
2322
- def _upload_loop(q: Queue, stop_event: threading.Event, api: Api, upload_f: Callable):
2323
- try:
2324
- while True:
2325
- items = []
2326
- while not q.empty():
2327
- items.append(q.get_nowait())
2328
- if len(items) > 0:
2329
- for batch in items:
2330
- upload_f(batch)
2331
- continue
2332
- if stop_event.is_set():
2333
- self._on_inference_end(None, async_inference_request_uuid)
2334
- return
2335
- time.sleep(1)
2336
- except Exception as e:
2337
- api.logger.error("Error in upload loop: %s", str(e), exc_info=True)
2338
- raise
2339
-
2340
- upload_f = _add_results_to_request
2341
-
2342
- upload_queue = Queue()
2343
- stop_upload_event = threading.Event()
2344
- threading.Thread(
2345
- target=_upload_loop,
2346
- args=[upload_queue, stop_upload_event, api, upload_f],
2347
- daemon=True,
2348
- ).start()
2349
-
2350
- settings = self._get_inference_settings(state)
2351
- logger.debug(f"Inference settings:", extra=settings)
2352
- results = []
2353
- stop = False
2057
+ def upload_f(benchmarks: List):
2058
+ inference_request.add_results(benchmarks)
2059
+ inference_request.done(len(benchmarks))
2354
2060
 
2355
2061
  def image_batch_generator(batch_size):
2356
2062
  logger.debug(
@@ -2366,23 +2072,20 @@ class Inference:
2366
2072
  batch = []
2367
2073
 
2368
2074
  batch_generator = image_batch_generator(batch_size)
2369
- try:
2075
+
2076
+ with Uploader(upload_f=upload_f, logger=logger) as uploader:
2370
2077
  for i in range(num_iterations + num_warmup):
2371
- if stop:
2372
- break
2373
- if (
2374
- async_inference_request_uuid is not None
2375
- and inference_request["cancel_inference"] is True
2376
- ):
2078
+ if inference_request.is_stopped():
2377
2079
  logger.debug(
2378
2080
  f"Cancelling inference project...",
2379
- extra={"inference_request_uuid": async_inference_request_uuid},
2081
+ extra={"inference_request_uuid": inference_request.uuid},
2380
2082
  )
2381
- results = []
2382
- stop = True
2383
- break
2083
+ return
2084
+ if uploader.has_exception():
2085
+ exception = uploader.exception
2086
+ raise RuntimeError(f"Error in upload loop: {exception}") from exception
2384
2087
  if i == num_warmup:
2385
- preparing_progress["status"] = "inference"
2088
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, num_iterations)
2386
2089
 
2387
2090
  images_infos_batch: List[ImageInfo] = next(batch_generator)
2388
2091
 
@@ -2429,35 +2132,9 @@ class Inference:
2429
2132
  )
2430
2133
  # Collect results if warmup is done
2431
2134
  if i >= num_warmup:
2432
- results.append(benchmark)
2433
- upload_queue.put(benchmark)
2135
+ uploader.put([benchmark])
2434
2136
  else:
2435
- preparing_progress["current"] += 1
2436
- except Exception:
2437
- stop_upload_event.set()
2438
- raise
2439
- if async_inference_request_uuid is not None and len(results) > 0:
2440
- inference_request["result"] = results
2441
- stop_upload_event.set()
2442
- return results
2443
-
2444
- def _on_inference_start(self, inference_request_uuid):
2445
- inference_request = {
2446
- "progress": Progress("Inferring model...", total_cnt=1),
2447
- "is_inferring": True,
2448
- "cancel_inference": False,
2449
- "result": None,
2450
- "pending_results": [],
2451
- "preparing_progress": {"current": 0, "total": 1},
2452
- "exception": None,
2453
- }
2454
- self._inference_requests[inference_request_uuid] = inference_request
2455
-
2456
- def _on_inference_end(self, future, inference_request_uuid):
2457
- logger.debug("callback: on_inference_end()")
2458
- inference_request = self._inference_requests.get(inference_request_uuid)
2459
- if inference_request is not None:
2460
- inference_request["is_inferring"] = False
2137
+ inference_request.done()
2461
2138
 
2462
2139
  def _check_serve_before_call(self, func):
2463
2140
  @wraps(func)
@@ -2481,6 +2158,24 @@ class Inference:
2481
2158
  def is_model_deployed(self):
2482
2159
  return self._model_served
2483
2160
 
2161
+ def _on_inference_start(self, inference_request_uuid):
2162
+ inference_request = {
2163
+ "progress": Progress("Inferring model...", total_cnt=1),
2164
+ "is_inferring": True,
2165
+ "cancel_inference": False,
2166
+ "result": None,
2167
+ "pending_results": [],
2168
+ "preparing_progress": {"current": 0, "total": 1},
2169
+ "exception": None,
2170
+ }
2171
+ self._inference_requests[inference_request_uuid] = inference_request
2172
+
2173
+ def _on_inference_end(self, future, inference_request_uuid):
2174
+ logger.debug("callback: on_inference_end()")
2175
+ inference_request = self._inference_requests.get(inference_request_uuid)
2176
+ if inference_request is not None:
2177
+ inference_request["is_inferring"] = False
2178
+
2484
2179
  def schedule_task(self, func, *args, **kwargs):
2485
2180
  inference_request_uuid = kwargs.get("inference_request_uuid", None)
2486
2181
  if inference_request_uuid is None:
@@ -2523,6 +2218,228 @@ class Inference:
2523
2218
  self.gui._success_label.hide()
2524
2219
  raise e
2525
2220
 
2221
+ def validate_inference_state(self, state: Union[Dict, str], log_error=True):
2222
+ try:
2223
+ if isinstance(state, str):
2224
+ try:
2225
+ state = json.loads(state)
2226
+ except (json.decoder.JSONDecodeError, TypeError) as e:
2227
+ raise HTTPException(
2228
+ status_code=status.HTTP_400_BAD_REQUEST,
2229
+ detail=f"Cannot decode settings: {e}",
2230
+ )
2231
+ if not isinstance(state, dict):
2232
+ raise HTTPException(
2233
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Settings is not json object"
2234
+ )
2235
+ batch_size = state.get("batch_size", None)
2236
+ if batch_size is None:
2237
+ batch_size = self.get_batch_size()
2238
+ if self.max_batch_size is not None and batch_size > self.max_batch_size:
2239
+ raise HTTPException(
2240
+ status_code=status.HTTP_400_BAD_REQUEST,
2241
+ detail=f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2242
+ )
2243
+ except Exception as e:
2244
+ if log_error:
2245
+ logger.error(f"Error validating request state: {e}", exc_info=True)
2246
+ raise
2247
+
2248
+ def upload_predictions(
2249
+ self,
2250
+ predictions: List[Prediction],
2251
+ api: Api,
2252
+ upload_mode: str,
2253
+ context: Dict = None,
2254
+ dst_dataset_id: int = None,
2255
+ dst_project_id: int = None,
2256
+ progress_cb=None,
2257
+ iou_merge_threshold: float = None,
2258
+ inference_request: InferenceRequest = None,
2259
+ ):
2260
+ ds_predictions: Dict[int, List[Prediction]] = defaultdict(list)
2261
+ for prediction in predictions:
2262
+ ds_predictions[prediction.dataset_id].append(prediction)
2263
+
2264
+ def _new_name(image_info: ImageInfo):
2265
+ name = Path(image_info.name)
2266
+ stem = name.stem
2267
+ parent = name.parent
2268
+ suffix = name.suffix
2269
+ return str(parent / f"{stem}(dataset_id:{image_info.dataset_id}){suffix}")
2270
+
2271
+ def _get_or_create_dataset(src_dataset_id, dst_project_id):
2272
+ if src_dataset_id is None:
2273
+ return None
2274
+ created_dataset_id = context.setdefault("created_dataset", {}).get(src_dataset_id, None)
2275
+ if created_dataset_id is not None:
2276
+ return created_dataset_id
2277
+ src_dataset_info: DatasetInfo = context.setdefault("dataset_info", {}).get(
2278
+ src_dataset_id
2279
+ )
2280
+ if src_dataset_info is None:
2281
+ src_dataset_info = api.dataset.get_info_by_id(src_dataset_id)
2282
+ context["dataset_info"][src_dataset_id] = src_dataset_info
2283
+ src_parent_id = src_dataset_info.parent_id
2284
+ dst_parent_id = _get_or_create_dataset(src_parent_id, dst_project_id)
2285
+ created_dataset = api.dataset.create(
2286
+ dst_project_id,
2287
+ src_dataset_info.name,
2288
+ description=f"Auto created dataset from inference request {inference_request.uuid if inference_request is not None else ''}",
2289
+ change_name_if_conflict=True,
2290
+ parent_id=dst_parent_id,
2291
+ )
2292
+ context["dataset_info"][created_dataset.id] = created_dataset
2293
+ context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
2294
+ return created_dataset.id
2295
+
2296
+ created_names = []
2297
+ if context is None:
2298
+ context = {}
2299
+ for dataset_id, preds in ds_predictions.items():
2300
+ if dst_project_id is not None:
2301
+ # upload to the destination project
2302
+ dst_dataset_id = _get_or_create_dataset(
2303
+ src_dataset_id=dataset_id, dst_project_id=dst_project_id
2304
+ )
2305
+ if dst_dataset_id is not None:
2306
+ # upload to the destination dataset
2307
+ dataset_info = context.setdefault("dataset_info", {}).get(dst_dataset_id, None)
2308
+ if dataset_info is None:
2309
+ dataset_info = api.dataset.get_info_by_id(dst_dataset_id)
2310
+ context["dataset_info"][dst_dataset_id] = dataset_info
2311
+ project_id = dataset_info.project_id
2312
+ project_meta = context.setdefault("project_meta", {}).get(project_id, None)
2313
+ if project_meta is None:
2314
+ project_meta = ProjectMeta.from_json(api.project.get_meta(project_id))
2315
+ context["project_meta"][project_id] = project_meta
2316
+
2317
+ meta_changed = False
2318
+ for pred in preds:
2319
+ ann = pred.annotation
2320
+ project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
2321
+ meta_changed = meta_changed or meta_changed_
2322
+ pred.annotation = ann
2323
+ prediction.model_meta = project_meta
2324
+
2325
+ if meta_changed:
2326
+ project_meta = api.project.update_meta(project_id, project_meta)
2327
+ context["project_meta"][project_id] = project_meta
2328
+
2329
+ anns = _exclude_duplicated_predictions(
2330
+ api,
2331
+ [pred.annotation for pred in preds],
2332
+ dataset_id,
2333
+ [pred.image_id for pred in preds],
2334
+ iou=iou_merge_threshold,
2335
+ meta=project_meta,
2336
+ )
2337
+ for pred, ann in zip(preds, anns):
2338
+ pred.annotation = ann
2339
+
2340
+ context.setdefault("image_info", {})
2341
+ missing = [
2342
+ pred.image_id for pred in preds if pred.image_id not in context["image_info"]
2343
+ ]
2344
+ if missing:
2345
+ context["image_info"].update(
2346
+ {
2347
+ image_info.id: image_info
2348
+ for image_info in api.image.get_info_by_id_batch(missing)
2349
+ }
2350
+ )
2351
+ image_infos: List[ImageInfo] = [
2352
+ context["image_info"][pred.image_id] for pred in preds
2353
+ ]
2354
+ dst_names = [
2355
+ _new_name(image_info) if image_info.name in created_names else image_info.name
2356
+ for image_info in image_infos
2357
+ ]
2358
+ dst_image_infos = api.image.copy_batch_optimized(
2359
+ dataset_id,
2360
+ image_infos,
2361
+ dst_dataset_id,
2362
+ dst_names=dst_names,
2363
+ with_annotations=False,
2364
+ save_source_date=False,
2365
+ )
2366
+ created_names.extend([image_info.name for image_info in dst_image_infos])
2367
+ api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
2368
+ else:
2369
+ # upload to the source dataset
2370
+ ds_info = context.setdefault("dataset_info", {}).get(dataset_id, None)
2371
+ if ds_info is None:
2372
+ ds_info = api.dataset.get_info_by_id(dataset_id)
2373
+ context["dataset_info"][dataset_id] = ds_info
2374
+ project_id = ds_info.project_id
2375
+
2376
+ project_meta = context.setdefault("project_meta", {}).get(project_id, None)
2377
+ if project_meta is None:
2378
+ project_meta = ProjectMeta.from_json(api.project.get_meta(project_id))
2379
+ context["project_meta"][project_id] = project_meta
2380
+
2381
+ meta_changed = False
2382
+ for pred in preds:
2383
+ ann = pred.annotation
2384
+ project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
2385
+ meta_changed = meta_changed or meta_changed_
2386
+ pred.annotation = ann
2387
+ prediction.model_meta = project_meta
2388
+
2389
+ if meta_changed:
2390
+ project_meta = api.project.update_meta(project_id, project_meta)
2391
+ context["project_meta"][project_id] = project_meta
2392
+
2393
+ anns = _exclude_duplicated_predictions(
2394
+ api,
2395
+ [pred.annotation for pred in preds],
2396
+ dataset_id,
2397
+ [pred.image_id for pred in preds],
2398
+ iou=iou_merge_threshold,
2399
+ meta=project_meta,
2400
+ )
2401
+ for pred, ann in zip(preds, anns):
2402
+ pred.annotation = ann
2403
+
2404
+ if upload_mode in ["iou_merge", "append"]:
2405
+ context.setdefault("annotation", {})
2406
+ missing = []
2407
+ for pred in preds:
2408
+ if pred.image_id not in context["annotation"]:
2409
+ missing.append(pred.image_id)
2410
+ for image_id, ann_info in zip(
2411
+ missing, api.annotation.download_batch(dataset_id, missing)
2412
+ ):
2413
+ context["annotation"][image_id] = Annotation.from_json(
2414
+ ann_info.annotation, project_meta
2415
+ )
2416
+ for pred in preds:
2417
+ pred.annotation = context["annotation"][pred.image_id].merge(
2418
+ pred.annotation
2419
+ )
2420
+
2421
+ api.annotation.upload_anns(
2422
+ [pred.image_id for pred in preds],
2423
+ [pred.annotation for pred in preds],
2424
+ )
2425
+
2426
+ if progress_cb is not None:
2427
+ progress_cb(len(preds))
2428
+
2429
+ if inference_request is not None:
2430
+ results = self._format_output(predictions)
2431
+ for result in results:
2432
+ result["annotation"] = None
2433
+ result["data"] = None
2434
+ inference_request.add_results(results)
2435
+
2436
+ def add_results_to_request(
2437
+ self, predictions: List[Prediction], inference_request: InferenceRequest
2438
+ ):
2439
+ results = self._format_output(predictions)
2440
+ inference_request.add_results(results)
2441
+ inference_request.done(len(results))
2442
+
2526
2443
  def serve(self):
2527
2444
  if not self._use_gui and not self._is_local_deploy:
2528
2445
  Progress("Deploying model ...", 1)
@@ -2583,28 +2500,46 @@ class Inference:
2583
2500
  server = self._app.get_server()
2584
2501
  self._app.set_ready_check_function(self.is_model_deployed)
2585
2502
 
2586
- @call_on_autostart()
2587
- def autostart_func():
2588
- gpu_count = get_gpu_count()
2589
- if gpu_count > 1:
2590
- # run autostart after 5 min
2591
- def delayed_autostart():
2592
- logger.debug("Found more than one GPU, autostart will be delayed.")
2593
- time.sleep(self._autostart_delay_time)
2594
- if not self._model_served:
2595
- logger.debug("Deploying the model via autostart...")
2596
- self.gui.deploy_with_current_params()
2597
-
2598
- self._executor.submit(delayed_autostart)
2599
- else:
2600
- # run autostart immediately
2601
- self.gui.deploy_with_current_params()
2503
+ if self.api is not None:
2504
+
2505
+ @call_on_autostart()
2506
+ def autostart_func():
2507
+ gpu_count = get_gpu_count()
2508
+ if gpu_count > 1:
2509
+ # run autostart after 5 min
2510
+ def delayed_autostart():
2511
+ logger.debug("Found more than one GPU, autostart will be delayed.")
2512
+ time.sleep(self._autostart_delay_time)
2513
+ if not self._model_served:
2514
+ logger.debug("Deploying the model via autostart...")
2515
+ self.gui.deploy_with_current_params()
2516
+
2517
+ self._executor.submit(delayed_autostart)
2518
+ else:
2519
+ # run autostart immediately
2520
+ self.gui.deploy_with_current_params()
2602
2521
 
2603
2522
  if not self._use_gui:
2604
2523
  Progress("Model deployed", 1).iter_done_report()
2605
2524
  else:
2606
2525
  autostart_func()
2607
2526
 
2527
+ @server.exception_handler(HTTPException)
2528
+ def http_exception_handler(request: Request, exc: HTTPException):
2529
+ response_content = {
2530
+ "detail": exc.detail,
2531
+ "success": False,
2532
+ }
2533
+ if isinstance(exc.detail, dict):
2534
+ if "message" in exc.detail:
2535
+ response_content["message"] = exc.detail["message"]
2536
+ if "success" in exc.detail:
2537
+ response_content["success"] = exc.detail["success"]
2538
+ elif isinstance(exc.detail, str):
2539
+ response_content["message"] = exc.detail
2540
+
2541
+ return JSONResponse(status_code=exc.status_code, content=response_content)
2542
+
2608
2543
  self.cache.add_cache_endpoint(server)
2609
2544
  self.cache.add_cache_files_endpoint(server)
2610
2545
 
@@ -2617,311 +2552,353 @@ class Inference:
2617
2552
  def get_custom_inference_settings():
2618
2553
  return {"settings": self.custom_inference_settings}
2619
2554
 
2555
+ @server.post("/get_model_meta")
2620
2556
  @server.post("/get_output_classes_and_tags")
2621
2557
  def get_output_classes_and_tags():
2622
2558
  return self.model_meta.to_json()
2623
2559
 
2624
2560
  @server.post("/inference_image_id")
2625
2561
  def inference_image_id(request: Request):
2626
- logger.debug(f"'inference_image_id' request in json format:{request.state.state}")
2627
- return self._inference_image_id(request.state.api, request.state.state)
2562
+ state = request.state.state
2563
+ logger.debug("Received a request to '/inference_image_id'", extra={"state": state})
2564
+ self.validate_inference_state(state)
2565
+ api = self.api_from_request(request)
2566
+ return self.inference_requests_manager.run(self._inference_image_ids, api, state)[0]
2567
+
2568
+ @server.post("/inference_image_id_async")
2569
+ def inference_image_id_async(request: Request):
2570
+ state = request.state.state
2571
+ logger.debug(
2572
+ "Received a request to 'inference_image_id_async'",
2573
+ extra={"state": state},
2574
+ )
2575
+ self.validate_inference_state(state)
2576
+ api = self.api_from_request(request)
2577
+ inference_request, _ = self.inference_requests_manager.schedule_task(
2578
+ self._inference_image_ids,
2579
+ api,
2580
+ state,
2581
+ )
2582
+ return {
2583
+ "message": "Scheduled inference task.",
2584
+ "inference_request_uuid": inference_request.uuid,
2585
+ }
2586
+
2587
+ @server.post("/inference_image")
2588
+ def inference_image(
2589
+ files: List[UploadFile], settings: str = Form("{}"), state: str = Form("{}")
2590
+ ):
2591
+ if state == "{}" or not state:
2592
+ state = settings
2593
+ state = str(state)
2594
+ logger.debug("Received a request to 'inference_image'", extra={"state": state})
2595
+ self.validate_inference_state(state)
2596
+ state = json.loads(state)
2597
+ if len(files) != 1:
2598
+ raise HTTPException(
2599
+ status_code=status.HTTP_400_BAD_REQUEST,
2600
+ detail=f"Only one file expected but got {len(files)}",
2601
+ )
2602
+ try:
2603
+ file = files[0]
2604
+ inference_request = self.inference_requests_manager.create()
2605
+ inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, file.size)
2606
+
2607
+ img_bytes = b""
2608
+ while buf := file.read(64 * 1024 * 1024):
2609
+ img_bytes += buf
2610
+ inference_request.done(len(buf))
2611
+
2612
+ image = sly_image.read_bytes(img_bytes)
2613
+ inference_request, future = self.inference_requests_manager.schedule_task(
2614
+ self._inference_images, [image], state, inference_request=inference_request
2615
+ )
2616
+ future.result()
2617
+ return inference_request.pop_pending_results()[0]
2618
+ except sly_image.UnsupportedImageFormat:
2619
+ raise HTTPException(
2620
+ status_code=status.HTTP_400_BAD_REQUEST,
2621
+ detail=f"File has unsupported format. Supported formats: {sly_image.SUPPORTED_IMG_EXTS}",
2622
+ )
2628
2623
 
2629
2624
  @server.post("/inference_image_url")
2630
2625
  def inference_image_url(request: Request):
2631
- logger.debug(f"'inference_image_url' request in json format:{request.state.state}")
2632
- return self._inference_image_url(request.state.api, request.state.state)
2626
+ state = request.state.state
2627
+ logger.debug("Received a request to 'inference_image_url'", extra={"state": state})
2628
+ self.validate_inference_state(state)
2629
+ image_url = state["image_url"]
2630
+ ext = sly_fs.get_file_ext(image_url)
2631
+ if ext == "":
2632
+ ext = ".jpg"
2633
+ with requests.get(image_url, stream=True) as response:
2634
+ response.raise_for_status()
2635
+ response.raw.decode_content = True
2636
+ image = self.cache.add_image_to_cache(image_url, response.raw, ext=ext)
2637
+ return self.inference_requests_manager.run(self._inference_images, [image], state)[0]
2633
2638
 
2634
2639
  @server.post("/inference_batch_ids")
2635
- def inference_batch_ids(response: Response, request: Request):
2636
- # check batch size
2637
- batch_size = len(request.state.state["batch_ids"])
2638
- if self.max_batch_size is not None and batch_size > self.max_batch_size:
2639
- response.status_code = status.HTTP_400_BAD_REQUEST
2640
- return {
2641
- "message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2642
- "success": False,
2643
- }
2644
- logger.debug(f"'inference_batch_ids' request in json format:{request.state.state}")
2645
- return self._inference_batch_ids(request.state.api, request.state.state)
2640
+ def inference_batch_ids(request: Request):
2641
+ state = request.state.state
2642
+ logger.debug("Received a request to 'inference_batch_ids'", extra={"state": state})
2643
+ self.validate_inference_state(state)
2644
+ api = self.api_from_request(request)
2645
+ return self.inference_requests_manager.run(self._inference_image_ids, api, state)
2646
2646
 
2647
2647
  @server.post("/inference_batch_ids_async")
2648
- def inference_batch_ids_async(response: Response, request: Request):
2648
+ def inference_batch_ids_async(request: Request):
2649
+ state = request.state.state
2649
2650
  logger.debug(
2650
- f"'inference_batch_ids_async' request in json format:{request.state.state}"
2651
+ f"Received a request to 'inference_batch_ids_async'", extra={"state": state}
2651
2652
  )
2652
- images_ids = request.state.state["images_ids"]
2653
- # check batch size
2654
- batch_size = request.state.state.get("batch_size", None)
2655
- if batch_size is None:
2656
- batch_size = self.get_batch_size()
2657
- if self.max_batch_size is not None and batch_size > self.max_batch_size:
2658
- response.status_code = status.HTTP_400_BAD_REQUEST
2659
- return {
2660
- "message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2661
- "success": False,
2662
- }
2663
- inference_request_uuid = uuid.uuid5(
2664
- namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
2665
- ).hex
2666
- self._on_inference_start(inference_request_uuid)
2667
- future = self._executor.submit(
2668
- self._handle_error_in_async,
2669
- inference_request_uuid,
2670
- self._inference_images_ids,
2671
- request.state.api,
2672
- request.state.state,
2673
- images_ids,
2674
- inference_request_uuid,
2675
- )
2676
- end_callback = partial(
2677
- self._on_inference_end, inference_request_uuid=inference_request_uuid
2678
- )
2679
- future.add_done_callback(end_callback)
2680
- logger.debug(
2681
- "Inference has scheduled from 'inference_batch_ids_async' endpoint",
2682
- extra={"inference_request_uuid": inference_request_uuid},
2653
+ self.validate_inference_state(state)
2654
+ api = self.api_from_request(request)
2655
+ inference_request, _ = self.inference_requests_manager.schedule_task(
2656
+ self._inference_image_ids, api, state
2683
2657
  )
2684
2658
  return {
2685
- "message": "Inference has started.",
2686
- "inference_request_uuid": inference_request_uuid,
2659
+ "message": "Scheduled inference task.",
2660
+ "inference_request_uuid": inference_request.uuid,
2687
2661
  }
2688
2662
 
2689
- @server.post("/inference_video_id")
2690
- def inference_video_id(response: Response, request: Request):
2691
- logger.debug(f"'inference_video_id' request in json format:{request.state.state}")
2692
- # check batch size
2693
- batch_size = request.state.state.get("batch_size", None)
2694
- if batch_size is None:
2695
- batch_size = self.get_batch_size()
2696
- if self.max_batch_size is not None and batch_size > self.max_batch_size:
2697
- response.status_code = status.HTTP_400_BAD_REQUEST
2698
- return {
2699
- "message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2700
- "success": False,
2701
- }
2702
- return self._inference_video_id(request.state.api, request.state.state)
2703
-
2704
- @server.post("/inference_image")
2705
- def inference_image(
2706
- response: Response, files: List[UploadFile], settings: str = Form("{}")
2663
+ @server.post("/inference_batch")
2664
+ def inference_batch(
2665
+ response: Response,
2666
+ files: List[UploadFile],
2667
+ settings: str = Form("{}"),
2668
+ state: str = Form("{}"),
2707
2669
  ):
2708
- if len(files) != 1:
2709
- response.status_code = status.HTTP_400_BAD_REQUEST
2710
- return f"Only one file expected but got {len(files)}"
2670
+ if state == "{}" or not state:
2671
+ state = settings
2672
+ state = str(state)
2673
+ logger.debug("Received a request to 'inference_batch'", extra={"state": state})
2674
+ self.validate_inference_state(state)
2675
+ state = json.loads(state)
2676
+ if len(files) == 0:
2677
+ raise HTTPException(
2678
+ status_code=status.HTTP_400_BAD_REQUEST,
2679
+ detail=f"At least one file is expected but got {len(files)}",
2680
+ )
2711
2681
  try:
2712
- state = json.loads(settings)
2713
- if type(state) != dict:
2714
- response.status_code = status.HTTP_400_BAD_REQUEST
2715
- return "Settings is not json object"
2716
- return self._inference_image(state, files[0])
2717
- except (json.decoder.JSONDecodeError, TypeError) as e:
2718
- response.status_code = status.HTTP_400_BAD_REQUEST
2719
- return f"Cannot decode settings: {e}"
2682
+ inference_request = self.inference_requests_manager.create()
2683
+ inference_request.set_stage(
2684
+ InferenceRequest.Stage.PREPARING, 0, sum([file.size for file in files])
2685
+ )
2686
+
2687
+ names = []
2688
+ for file in files:
2689
+ name = file.filename
2690
+ if name is None or name == "":
2691
+ name = rand_str(10)
2692
+ ext = Path(name).suffix
2693
+ img_bytes = b""
2694
+ while buf := file.file.read(64 * 1024 * 1024):
2695
+ img_bytes += buf
2696
+ inference_request.done(len(buf))
2697
+ self.cache.add_image_to_cache(name, img_bytes, ext=ext)
2698
+ names.append(name)
2699
+
2700
+ inference_request, future = self.inference_requests_manager.schedule_task(
2701
+ self._inference_images, names, state, inference_request=inference_request
2702
+ )
2703
+ future.result()
2704
+ return inference_request.pop_pending_results()
2720
2705
  except sly_image.UnsupportedImageFormat:
2721
2706
  response.status_code = status.HTTP_400_BAD_REQUEST
2722
2707
  return f"File has unsupported format. Supported formats: {sly_image.SUPPORTED_IMG_EXTS}"
2723
2708
 
2724
- @server.post("/inference_batch")
2725
- def inference_batch(
2726
- response: Response, files: List[UploadFile], settings: str = Form("{}")
2709
+ @server.post("/inference_batch_async")
2710
+ def inference_batch_async(
2711
+ response: Response,
2712
+ files: List[UploadFile],
2713
+ settings: str = Form("{}"),
2714
+ state: str = Form("{}"),
2727
2715
  ):
2716
+ if state == "{}" or not state:
2717
+ state = settings
2718
+ state = str(state)
2719
+ logger.debug("Received a request to 'inference_batch'", extra={"state": state})
2720
+ self.validate_inference_state(state)
2721
+ state = json.loads(state)
2722
+ if len(files) == 0:
2723
+ raise HTTPException(
2724
+ status_code=status.HTTP_400_BAD_REQUEST,
2725
+ detail=f"At least one file is expected but got {len(files)}",
2726
+ )
2728
2727
  try:
2729
- state = json.loads(settings)
2730
- if type(state) != dict:
2731
- response.status_code = status.HTTP_400_BAD_REQUEST
2732
- return "Settings is not json object"
2733
- # check batch size
2734
- batch_size = len(files)
2735
- if self.max_batch_size is not None and batch_size > self.max_batch_size:
2736
- response.status_code = status.HTTP_400_BAD_REQUEST
2737
- return {
2738
- "message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2739
- "success": False,
2740
- }
2741
- return self._inference_batch(state, files)
2742
- except (json.decoder.JSONDecodeError, TypeError) as e:
2743
- response.status_code = status.HTTP_400_BAD_REQUEST
2744
- return f"Cannot decode settings: {e}"
2728
+ inference_request = self.inference_requests_manager.create()
2729
+ inference_request.set_stage(
2730
+ InferenceRequest.Stage.PREPARING, 0, sum([file.size for file in files])
2731
+ )
2732
+
2733
+ names = []
2734
+ for file in files:
2735
+ name = file.filename
2736
+ if name is None or name == "":
2737
+ name = rand_str(10)
2738
+ ext = Path(name).suffix
2739
+ img_bytes = b""
2740
+ while buf := file.file.read(64 * 1024 * 1024):
2741
+ img_bytes += buf
2742
+ inference_request.done(len(buf))
2743
+ self.cache.add_image_to_cache(name, img_bytes, ext=ext)
2744
+ names.append(name)
2745
+
2746
+ inference_request, _ = self.inference_requests_manager.schedule_task(
2747
+ self._inference_images, names, state, inference_request=inference_request
2748
+ )
2749
+ return {
2750
+ "message": "Scheduled inference task.",
2751
+ "inference_request_uuid": inference_request.uuid,
2752
+ }
2745
2753
  except sly_image.UnsupportedImageFormat:
2746
2754
  response.status_code = status.HTTP_400_BAD_REQUEST
2747
2755
  return f"File has unsupported format. Supported formats: {sly_image.SUPPORTED_IMG_EXTS}"
2748
2756
 
2749
- @server.post("/inference_image_id_async")
2750
- def inference_image_id_async(request: Request):
2751
- logger.debug(f"'inference_image_id_async' request in json format:{request.state.state}")
2752
- inference_request_uuid = uuid.uuid5(
2753
- namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
2754
- ).hex
2755
- self._on_inference_start(inference_request_uuid)
2756
- future = self._executor.submit(
2757
- self._handle_error_in_async,
2758
- inference_request_uuid,
2759
- self._inference_image_id,
2760
- request.state.api,
2761
- request.state.state,
2762
- inference_request_uuid,
2757
+ @server.post("/inference_video_id")
2758
+ def inference_video_id(request: Request):
2759
+ state = request.state.state
2760
+ logger.debug(f"Received a request to 'inference_video_id'", extra={"state": state})
2761
+ self.validate_inference_state(state)
2762
+ api = self.api_from_request(request)
2763
+ inference_request, future = self.inference_requests_manager.schedule_task(
2764
+ self._inference_video_id, api, state
2763
2765
  )
2764
- end_callback = partial(
2765
- self._on_inference_end, inference_request_uuid=inference_request_uuid
2766
+ future.result()
2767
+ results = {"ann": inference_request.pop_pending_results()}
2768
+ final_result = inference_request.final_result
2769
+ if final_result is not None:
2770
+ results.update(final_result)
2771
+ return results
2772
+
2773
+ @server.post("/inference_video_async")
2774
+ def inference_video_async(
2775
+ files: List[UploadFile],
2776
+ settings: str = Form("{}"),
2777
+ state: str = Form("{}"),
2778
+ ):
2779
+ if state == "{}" or not state:
2780
+ state = settings
2781
+ state = str(state)
2782
+ logger.debug("Received a request to 'inference_video_async'", extra={"state": state})
2783
+ self.validate_inference_state(state)
2784
+ state = json.loads(state)
2785
+
2786
+ file = files[0]
2787
+ video_name = files[0].filename
2788
+ video_source = files[0].file
2789
+ file_size = file.size
2790
+
2791
+ inference_request = self.inference_requests_manager.create()
2792
+ inference_request.set_stage(InferenceRequest.Stage.PREPARING, 0, file_size)
2793
+
2794
+ video_source.read = progress_wrapper(
2795
+ video_source.read, inference_request.progress.iters_done_report
2766
2796
  )
2767
- future.add_done_callback(end_callback)
2768
- logger.debug(
2769
- "Inference has scheduled from 'inference_image_id_async' endpoint",
2770
- extra={"inference_request_uuid": inference_request_uuid},
2797
+
2798
+ if self.cache.is_persistent:
2799
+ self.cache.add_video_to_cache(video_name, video_source)
2800
+ video_path = self.cache.get_video_path(video_name)
2801
+ else:
2802
+ video_path = os.path.join(tempfile.gettempdir(), video_name)
2803
+ with open(video_path, "wb") as video_file:
2804
+ shutil.copyfileobj(
2805
+ video_source, open(video_path, "wb"), length=(64 * 1024 * 1024)
2806
+ )
2807
+
2808
+ inference_request, _ = self.inference_requests_manager.schedule_task(
2809
+ self._inference_video,
2810
+ path=video_path,
2811
+ state=state,
2812
+ inference_request=inference_request,
2771
2813
  )
2814
+
2772
2815
  return {
2773
- "message": "Inference has started.",
2774
- "inference_request_uuid": inference_request_uuid,
2816
+ "message": "Scheduled inference task.",
2817
+ "inference_request_uuid": inference_request.uuid,
2775
2818
  }
2776
2819
 
2777
2820
  @server.post("/inference_video_id_async")
2778
2821
  def inference_video_id_async(response: Response, request: Request):
2779
- logger.debug(f"'inference_video_id_async' request in json format:{request.state.state}")
2780
- # check batch size
2781
- batch_size = request.state.state.get("batch_size", None)
2782
- if batch_size is None:
2783
- batch_size = self.get_batch_size()
2784
- if self.max_batch_size is not None and batch_size > self.max_batch_size:
2785
- response.status_code = status.HTTP_400_BAD_REQUEST
2786
- return {
2787
- "message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2788
- "success": False,
2789
- }
2790
- inference_request_uuid = uuid.uuid5(
2791
- namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
2792
- ).hex
2793
- self._on_inference_start(inference_request_uuid)
2794
- future = self._executor.submit(
2795
- self._handle_error_in_async,
2796
- inference_request_uuid,
2797
- self._inference_video_id,
2798
- request.state.api,
2799
- request.state.state,
2800
- inference_request_uuid,
2801
- )
2802
- end_callback = partial(
2803
- self._on_inference_end, inference_request_uuid=inference_request_uuid
2804
- )
2805
- future.add_done_callback(end_callback)
2806
- logger.debug(
2807
- "Inference has scheduled from 'inference_video_id_async' endpoint",
2808
- extra={"inference_request_uuid": inference_request_uuid},
2822
+ state = request.state.state
2823
+ logger.debug("Received a request to 'inference_video_id_async'", extra={"state": state})
2824
+ self.validate_inference_state(state)
2825
+ api = self.api_from_request(request)
2826
+ inference_request, _ = self.inference_requests_manager.schedule_task(
2827
+ self._inference_video_id, api, state
2809
2828
  )
2810
2829
  return {
2811
2830
  "message": "Inference has started.",
2812
- "inference_request_uuid": inference_request_uuid,
2831
+ "inference_request_uuid": inference_request.uuid,
2813
2832
  }
2814
2833
 
2815
2834
  @server.post("/inference_project_id_async")
2816
2835
  def inference_project_id_async(response: Response, request: Request):
2836
+ state = request.state.state
2817
2837
  logger.debug(
2818
- f"'inference_project_id_async' request in json format:{request.state.state}"
2838
+ "Received a request to 'inference_project_id_async'", extra={"state": state}
2819
2839
  )
2820
- project_id = request.state.state["projectId"]
2821
- project_info = request.state.api.project.get_info_by_id(project_id)
2822
- if project_info.type != str(ProjectType.IMAGES):
2823
- raise ValueError("Only images projects are supported.")
2824
- # check batch size
2825
- batch_size = request.state.state.get("batch_size", None)
2826
- if batch_size is None:
2827
- batch_size = self.get_batch_size()
2828
- if self.max_batch_size is not None and batch_size > self.max_batch_size:
2829
- response.status_code = status.HTTP_400_BAD_REQUEST
2830
- return {
2831
- "message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2832
- "success": False,
2833
- }
2834
- inference_request_uuid = uuid.uuid5(
2835
- namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
2836
- ).hex
2837
- self._on_inference_start(inference_request_uuid)
2838
- future = self._executor.submit(
2839
- self._handle_error_in_async,
2840
- inference_request_uuid,
2841
- self._inference_project_id,
2842
- request.state.api,
2843
- request.state.state,
2844
- project_info,
2845
- inference_request_uuid,
2846
- )
2847
- logger.debug(
2848
- "Inference has scheduled from 'inference_project_id_async' endpoint",
2849
- extra={"inference_request_uuid": inference_request_uuid},
2840
+ self.validate_inference_state(state)
2841
+ api = self.api_from_request(request)
2842
+ inference_request, _ = self.inference_requests_manager.schedule_task(
2843
+ self._inference_project_id, api, state
2850
2844
  )
2851
2845
  return {
2852
2846
  "message": "Inference has started.",
2853
- "inference_request_uuid": inference_request_uuid,
2847
+ "inference_request_uuid": inference_request.uuid,
2854
2848
  }
2855
2849
 
2856
2850
  @server.post("/run_speedtest")
2857
2851
  def run_speedtest(response: Response, request: Request):
2858
- logger.debug(f"'run_speedtest' request in json format:{request.state.state}")
2859
- project_id = request.state.state["projectId"]
2860
- project_info = request.state.api.project.get_info_by_id(project_id)
2861
- if project_info.type != str(ProjectType.IMAGES):
2862
- response.status_code = status.HTTP_400_BAD_REQUEST
2863
- response.body = {"message": "Only images projects are supported."}
2864
- raise ValueError("Only images projects are supported.")
2865
- batch_size = request.state.state["batch_size"]
2852
+ state = request.state.state
2853
+ logger.debug(f"'run_speedtest' request in json format:{state}")
2854
+
2855
+ batch_size = state["batch_size"]
2866
2856
  if batch_size > 1 and not self.is_batch_inference_supported():
2867
2857
  response.status_code = status.HTTP_501_NOT_IMPLEMENTED
2868
2858
  return {
2869
2859
  "message": "Batch inference is not implemented for this model.",
2870
2860
  "success": False,
2871
2861
  }
2872
- # check batch size
2873
- if self.max_batch_size is not None and batch_size > self.max_batch_size:
2862
+
2863
+ self.validate_inference_state(state)
2864
+ api = self.api_from_request(request)
2865
+
2866
+ project_id = state["projectId"]
2867
+ project_info = api.project.get_info_by_id(project_id)
2868
+ if project_info.type != str(ProjectType.IMAGES):
2874
2869
  response.status_code = status.HTTP_400_BAD_REQUEST
2875
- return {
2876
- "message": f"Batch size should be less than or equal to {self.max_batch_size} for this model.",
2877
- "success": False,
2878
- }
2879
- inference_request_uuid = uuid.uuid5(
2880
- namespace=uuid.NAMESPACE_URL, name=f"{time.time()}"
2881
- ).hex
2882
- self._on_inference_start(inference_request_uuid)
2883
- future = self._executor.submit(
2884
- self._handle_error_in_async,
2885
- inference_request_uuid,
2886
- self._run_speedtest,
2887
- request.state.api,
2888
- request.state.state,
2889
- inference_request_uuid,
2890
- )
2891
- logger.debug(
2892
- "Speedtest has scheduled from 'run_speedtest' endpoint",
2893
- extra={"inference_request_uuid": inference_request_uuid},
2870
+ response.body = {"message": "Only images projects are supported."}
2871
+ raise ValueError("Only images projects are supported.")
2872
+
2873
+ inference_request, _ = self.inference_requests_manager.schedule_task(
2874
+ self._run_speedtest, api, state
2894
2875
  )
2895
2876
  return {
2896
2877
  "message": "Inference has started.",
2897
- "inference_request_uuid": inference_request_uuid,
2878
+ "inference_request_uuid": inference_request.uuid,
2898
2879
  }
2899
2880
 
2900
2881
  @server.post(f"/get_inference_progress")
2901
2882
  def get_inference_progress(response: Response, request: Request):
2902
- inference_request_uuid = request.state.state.get("inference_request_uuid")
2883
+ state = request.state.state
2884
+ logger.debug("Received a request to '/get_inference_progress'", extra={"state": state})
2885
+ inference_request_uuid = state.get("inference_request_uuid")
2903
2886
  if inference_request_uuid is None:
2904
2887
  response.status_code = status.HTTP_400_BAD_REQUEST
2905
2888
  return {"message": "Error: 'inference_request_uuid' is required."}
2906
2889
 
2907
- inference_request = self._inference_requests[inference_request_uuid].copy()
2908
- inference_request["progress"] = _convert_sly_progress_to_dict(
2909
- inference_request["progress"]
2910
- )
2911
-
2912
- # Logging
2890
+ inference_request = self.inference_requests_manager.get(inference_request_uuid)
2913
2891
  log_extra = _get_log_extra_for_inference_request(
2914
- inference_request_uuid, inference_request
2892
+ inference_request.uuid, inference_request
2915
2893
  )
2894
+ data = {**inference_request.to_json(), **log_extra}
2895
+ if inference_request.stage != InferenceRequest.Stage.INFERENCE:
2896
+ data["progress"] = {"current": 0, "total": 1}
2916
2897
  logger.debug(
2917
2898
  f"Sending inference progress with uuid:",
2918
- extra=log_extra,
2899
+ extra=data,
2919
2900
  )
2920
-
2921
- # Ger rid of `pending_results` to less response size
2922
- inference_request["pending_results"] = []
2923
- inference_request.pop("lock", None)
2924
- return inference_request
2901
+ return data
2925
2902
 
2926
2903
  @server.post(f"/pop_inference_results")
2927
2904
  def pop_inference_results(response: Response, request: Request):
@@ -2930,23 +2907,34 @@ class Inference:
2930
2907
  response.status_code = status.HTTP_400_BAD_REQUEST
2931
2908
  return {"message": "Error: 'inference_request_uuid' is required."}
2932
2909
 
2933
- # Copy results
2934
- inference_request = self._inference_requests[inference_request_uuid].copy()
2935
- inference_request["pending_results"] = inference_request["pending_results"].copy()
2910
+ if inference_request_uuid in self._inference_requests:
2911
+ inference_request = self._inference_requests[inference_request_uuid].copy()
2912
+ inference_request["pending_results"] = inference_request["pending_results"].copy()
2936
2913
 
2937
- # Clear the queue `pending_results`
2938
- self._inference_requests[inference_request_uuid]["pending_results"].clear()
2914
+ # Clear the queue `pending_results`
2915
+ self._inference_requests[inference_request_uuid]["pending_results"].clear()
2939
2916
 
2940
- inference_request["progress"] = _convert_sly_progress_to_dict(
2941
- inference_request["progress"]
2942
- )
2917
+ inference_request["progress"] = _convert_sly_progress_to_dict(
2918
+ inference_request["progress"]
2919
+ )
2920
+ log_extra = _get_log_extra_for_inference_request(
2921
+ inference_request_uuid, inference_request
2922
+ )
2923
+ logger.debug(f"Sending inference delta results with uuid:", extra=log_extra)
2924
+ return inference_request
2943
2925
 
2944
- # Logging
2926
+ inference_request = self.inference_requests_manager.get(inference_request_uuid)
2945
2927
  log_extra = _get_log_extra_for_inference_request(
2946
- inference_request_uuid, inference_request
2928
+ inference_request.uuid, inference_request
2947
2929
  )
2930
+ data = {
2931
+ **inference_request.to_json(),
2932
+ **log_extra,
2933
+ "pending_results": inference_request.pop_pending_results(),
2934
+ }
2935
+
2948
2936
  logger.debug(f"Sending inference delta results with uuid:", extra=log_extra)
2949
- return inference_request
2937
+ return data
2950
2938
 
2951
2939
  @server.post(f"/get_inference_result")
2952
2940
  def get_inference_result(response: Response, request: Request):
@@ -2955,22 +2943,34 @@ class Inference:
2955
2943
  response.status_code = status.HTTP_400_BAD_REQUEST
2956
2944
  return {"message": "Error: 'inference_request_uuid' is required."}
2957
2945
 
2958
- inference_request = self._inference_requests[inference_request_uuid].copy()
2946
+ if inference_request_uuid in self._inference_requests:
2947
+ inference_request = self._inference_requests[inference_request_uuid].copy()
2959
2948
 
2960
- inference_request["progress"] = _convert_sly_progress_to_dict(
2961
- inference_request["progress"]
2962
- )
2949
+ inference_request["progress"] = _convert_sly_progress_to_dict(
2950
+ inference_request["progress"]
2951
+ )
2963
2952
 
2964
- # Logging
2953
+ # Logging
2954
+ log_extra = _get_log_extra_for_inference_request(
2955
+ inference_request_uuid, inference_request
2956
+ )
2957
+
2958
+ logger.debug(
2959
+ f"Sending inference result with uuid:",
2960
+ extra=log_extra,
2961
+ )
2962
+ return inference_request["result"]
2963
+
2964
+ inference_request = self.inference_requests_manager.get(inference_request_uuid)
2965
2965
  log_extra = _get_log_extra_for_inference_request(
2966
- inference_request_uuid, inference_request
2966
+ inference_request.uuid, inference_request
2967
2967
  )
2968
2968
  logger.debug(
2969
2969
  f"Sending inference result with uuid:",
2970
2970
  extra=log_extra,
2971
2971
  )
2972
2972
 
2973
- return inference_request["result"]
2973
+ return inference_request.final_result
2974
2974
 
2975
2975
  @server.post(f"/stop_inference")
2976
2976
  def stop_inference(response: Response, request: Request):
@@ -2981,8 +2981,12 @@ class Inference:
2981
2981
  "message": "Error: 'inference_request_uuid' is required.",
2982
2982
  "success": False,
2983
2983
  }
2984
- inference_request = self._inference_requests[inference_request_uuid]
2985
- inference_request["cancel_inference"] = True
2984
+ if inference_request_uuid in self._inference_requests:
2985
+ inference_request = self._inference_requests[inference_request_uuid]
2986
+ inference_request["cancel_inference"] = True
2987
+ else:
2988
+ inference_request = self.inference_requests_manager.get(inference_request_uuid)
2989
+ inference_request.stop()
2986
2990
  return {"message": "Inference will be stopped.", "success": True}
2987
2991
 
2988
2992
  @server.post(f"/clear_inference_request")
@@ -2994,7 +2998,10 @@ class Inference:
2994
2998
  "message": "Error: 'inference_request_uuid' is required.",
2995
2999
  "success": False,
2996
3000
  }
2997
- del self._inference_requests[inference_request_uuid]
3001
+ if inference_request_uuid in self._inference_requests:
3002
+ del self._inference_requests[inference_request_uuid]
3003
+ else:
3004
+ self.inference_requests_manager.remove_after(inference_request_uuid, 60)
2998
3005
  logger.debug("Removed an inference request:", extra={"uuid": inference_request_uuid})
2999
3006
  return {"success": True}
3000
3007
 
@@ -3005,8 +3012,13 @@ class Inference:
3005
3012
  response.status_code = status.HTTP_400_BAD_REQUEST
3006
3013
  return {"message": "Error: 'inference_request_uuid' is required."}
3007
3014
 
3008
- inference_request = self._inference_requests[inference_request_uuid].copy()
3009
- return inference_request["preparing_progress"]
3015
+ if inference_request_uuid in self._inference_requests:
3016
+ inference_request = self._inference_requests[inference_request_uuid].copy()
3017
+ return inference_request["preparing_progress"]
3018
+ inference_request = self.inference_requests_manager.get(inference_request_uuid)
3019
+ return _get_log_extra_for_inference_request(inference_request.uuid, inference_request)[
3020
+ "preparing_progress"
3021
+ ]
3010
3022
 
3011
3023
  @server.post("/get_deploy_settings")
3012
3024
  def _get_deploy_settings(response: Response, request: Request):
@@ -3052,22 +3064,84 @@ class Inference:
3052
3064
  self.shutdown_model()
3053
3065
  state = request.state.state
3054
3066
  deploy_params = state["deploy_params"]
3067
+ model_name = state.get("model_name", None)
3055
3068
  if isinstance(self.gui, GUI.ServingGUITemplate):
3069
+ if deploy_params["model_source"] == ModelSource.PRETRAINED and model_name:
3070
+ deploy_params = self._build_deploy_params_from_api(
3071
+ model_name, deploy_params
3072
+ )
3056
3073
  model_files = self._download_model_files(deploy_params)
3057
3074
  deploy_params["model_files"] = model_files
3075
+ deploy_params = self._set_common_deploy_params(deploy_params)
3058
3076
  self._load_model_headless(**deploy_params)
3059
3077
  elif isinstance(self.gui, GUI.ServingGUI):
3078
+ if deploy_params["model_source"] == ModelSource.PRETRAINED and model_name:
3079
+ deploy_params = self._build_legacy_deploy_params_from_api(model_name)
3080
+ deploy_params = self._set_common_deploy_params(deploy_params)
3060
3081
  self._load_model(deploy_params)
3082
+ elif self.gui is None and self.api is None:
3083
+ if deploy_params["model_source"] == ModelSource.PRETRAINED and model_name:
3084
+ deploy_params = self._build_deploy_params_from_api(
3085
+ model_name, deploy_params
3086
+ )
3087
+ model_files = self._download_model_files(deploy_params)
3088
+ deploy_params["model_files"] = model_files
3089
+
3090
+ deploy_params = self._set_common_deploy_params(deploy_params)
3091
+ self._load_model_headless(**deploy_params)
3092
+ logger.info(
3093
+ f"Model has been successfully loaded on {deploy_params['device']} device"
3094
+ )
3095
+ return {"result": "model was successfully deployed"}
3061
3096
 
3062
- self.set_params_to_gui(deploy_params)
3063
- # update to set correct device
3064
- device = deploy_params.get("device", "cpu")
3065
- self.gui.set_deployed(device)
3097
+ else:
3098
+ raise ValueError("Unknown GUI type")
3099
+ if self.gui is not None:
3100
+ self.set_params_to_gui(deploy_params)
3101
+ # update to set correct device
3102
+ device = deploy_params.get("device", "cpu")
3103
+ self.gui.set_deployed(device)
3066
3104
  return {"result": "model was successfully deployed"}
3067
3105
  except Exception as e:
3068
- self.gui._success_label.hide()
3106
+ if self.gui is not None:
3107
+ self.gui._success_label.hide()
3069
3108
  raise e
3070
3109
 
3110
+ @server.post("/list_pretrained_models")
3111
+ def _list_pretrained_models():
3112
+ if isinstance(self.gui, GUI.ServingGUITemplate):
3113
+ return [
3114
+ _get_model_name(model) for model in self._gui.pretrained_models_table._models
3115
+ ]
3116
+ elif hasattr(self, "pretrained_models"):
3117
+ return [_get_model_name(model) for model in self.pretrained_models]
3118
+ else:
3119
+ if hasattr(self, "pretrained_models_table"):
3120
+ return [
3121
+ _get_model_name(model)
3122
+ for model in self.pretrained_models_table._models # pylint: disable=no-member
3123
+ ]
3124
+ else:
3125
+ raise HTTPException(
3126
+ status_code=400,
3127
+ detail="Pretrained models table is not available in this app.",
3128
+ )
3129
+
3130
+ @server.post("/list_pretrained_model_infos")
3131
+ def _list_pretrained_model_infos():
3132
+ if isinstance(self.gui, GUI.ServingGUITemplate):
3133
+ return self._gui.pretrained_models_table._models
3134
+ elif hasattr(self, "pretrained_models"):
3135
+ return self.pretrained_models
3136
+ else:
3137
+ if hasattr(self, "pretrained_models_table"):
3138
+ return self.pretrained_models_table._models
3139
+ else:
3140
+ raise HTTPException(
3141
+ status_code=400,
3142
+ detail="Pretrained models table is not available in this app.",
3143
+ )
3144
+
3071
3145
  @server.post("/is_deployed")
3072
3146
  def _is_deployed(response: Response, request: Request):
3073
3147
  return {
@@ -3080,6 +3154,37 @@ class Inference:
3080
3154
  def _get_deploy_info():
3081
3155
  return asdict(self._get_deploy_info())
3082
3156
 
3157
+ @server.post("/get_inference_status")
3158
+ def _get_inference_status(request: Request, response: Response):
3159
+ state = request.state.state
3160
+ inference_request_uuid = state.get("inference_request_uuid")
3161
+ if inference_request_uuid is None:
3162
+ response.status_code = status.HTTP_400_BAD_REQUEST
3163
+ return {"message": "Error: 'inference_request_uuid' is required."}
3164
+ inference_request = self.inference_requests_manager.get(inference_request_uuid)
3165
+ if inference_request is None:
3166
+ response.status_code = status.HTTP_404_NOT_FOUND
3167
+ return {"message": "Error: 'inference_request_uuid' is not found."}
3168
+ return inference_request.status()
3169
+
3170
+ @server.post("/get_status")
3171
+ def _get_status(request: Request):
3172
+ progress = self.inference_requests_manager.global_progress.to_json()
3173
+ ram_allocated, ram_total = get_ram_usage()
3174
+ gpu_allocated, gpu_total = get_gpu_usage()
3175
+ return {
3176
+ "is_deployed": self.is_model_deployed(),
3177
+ "progress": progress,
3178
+ "gpu_memory": {
3179
+ "allocated": gpu_allocated,
3180
+ "total": gpu_total,
3181
+ },
3182
+ "ram_memory": {
3183
+ "allocated": ram_allocated,
3184
+ "total": ram_total,
3185
+ },
3186
+ }
3187
+
3083
3188
  # Local deploy without predict args
3084
3189
  if self._is_local_deploy:
3085
3190
  self._run_server()
@@ -3433,7 +3538,7 @@ class Inference:
3433
3538
  change_name_if_conflict=True,
3434
3539
  )
3435
3540
  state["output_project_id"] = output_project.id
3436
- results = self._inference_project_id(api=self.api, state=state)
3541
+ results = self.inference_requests_manager.run(self._inference_project_id, api, state)
3437
3542
 
3438
3543
  dataset_infos = api.dataset.get_list(project_id)
3439
3544
  datasets_map = {dataset_info.id: dataset_info.name for dataset_info in dataset_infos}
@@ -3617,136 +3722,157 @@ class Inference:
3617
3722
  f"Checkpoint {checkpoint_url} not found in Team Files. Cannot set workflow input"
3618
3723
  )
3619
3724
 
3620
- def _exclude_duplicated_predictions(
3621
- self,
3622
- api: Api,
3623
- pred_anns: List[Annotation],
3624
- settings: dict,
3625
- dataset_id: int,
3626
- gt_image_ids: List[int],
3627
- meta: Optional[ProjectMeta] = None,
3628
- ):
3629
- """
3630
- Filter out predictions that significantly overlap with ground truth (GT) objects.
3631
-
3632
- This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
3633
- - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
3634
- - Gets ProjectMeta object if not provided
3635
- - Downloads GT annotations for the specified image IDs
3636
- - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
3637
-
3638
- :param api: Supervisely API object
3639
- :type api: Api
3640
- :param pred_anns: List of Annotation objects containing predictions
3641
- :type pred_anns: List[Annotation]
3642
- :param settings: Inference settings
3643
- :type settings: dict
3644
- :param dataset_id: ID of the dataset containing the images
3645
- :type dataset_id: int
3646
- :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
3647
- :type gt_image_ids: List[int]
3648
- :param meta: ProjectMeta object
3649
- :type meta: Optional[ProjectMeta]
3650
- :return: List of Annotation objects containing filtered predictions
3651
- :rtype: List[Annotation]
3652
-
3653
- Notes:
3654
- ------
3655
- - Requires PyTorch and torchvision for IoU calculations
3656
- - This method is useful for identifying new objects that aren't already annotated in the ground truth
3657
- """
3658
- iou = settings.get("existing_objects_iou_thresh")
3659
- if isinstance(iou, float) and 0 < iou <= 1:
3660
- if meta is None:
3661
- ds = api.dataset.get_info_by_id(dataset_id)
3662
- meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
3663
- gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
3664
- gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
3665
- for i in range(0, len(pred_anns)):
3666
- before = len(pred_anns[i].labels)
3667
- with Timer() as timer:
3668
- pred_anns[i] = self._filter_duplicated_predictions_from_ann(
3669
- gt_anns[i], pred_anns[i], iou
3670
- )
3671
- after = len(pred_anns[i].labels)
3672
- logger.debug(
3673
- f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
3674
- )
3675
- return pred_anns
3676
-
3677
- def _filter_duplicated_predictions_from_ann(
3678
- self, gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
3679
- ) -> Annotation:
3680
- """
3681
- Filter out predictions that significantly overlap with ground truth annotations.
3682
-
3683
- This function compares each prediction with ground truth annotations of the same class
3684
- and removes predictions that have an IoU (Intersection over Union) greater than or equal
3685
- to the specified threshold with any ground truth annotation. This is useful for identifying
3686
- new objects that aren't already annotated in the ground truth.
3687
-
3688
- :param gt_ann: Annotation object containing ground truth labels
3689
- :type gt_ann: Annotation
3690
- :param pred_ann: Annotation object containing prediction labels to be filtered
3691
- :type pred_ann: Annotation
3692
- :param iou_threshold: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
3693
- ground truth box of the same class will be removed
3694
- :type iou_threshold: float
3695
- :return: A new annotation object containing only predictions that don't significantly
3696
- overlap with ground truth annotations
3697
- :rtype: Annotation
3698
-
3699
-
3700
- Notes:
3701
- ------
3702
- - Predictions with classes not present in ground truth will be kept
3703
- - Requires PyTorch and torchvision for IoU calculations
3704
- """
3705
3725
 
3706
- try:
3707
- import torch
3708
- from torchvision.ops import box_iou
3709
-
3710
- except ImportError:
3711
- raise ImportError("Please install PyTorch and torchvision to use this feature.")
3726
+ def _exclude_duplicated_predictions(
3727
+ api: Api,
3728
+ pred_anns: List[Annotation],
3729
+ dataset_id: int,
3730
+ gt_image_ids: List[int],
3731
+ iou: float = None,
3732
+ meta: Optional[ProjectMeta] = None,
3733
+ ):
3734
+ """
3735
+ Filter out predictions that significantly overlap with ground truth (GT) objects.
3736
+
3737
+ This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
3738
+ - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
3739
+ - Gets ProjectMeta object if not provided
3740
+ - Downloads GT annotations for the specified image IDs
3741
+ - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
3742
+
3743
+ :param api: Supervisely API object
3744
+ :type api: Api
3745
+ :param pred_anns: List of Annotation objects containing predictions
3746
+ :type pred_anns: List[Annotation]
3747
+ :param dataset_id: ID of the dataset containing the images
3748
+ :type dataset_id: int
3749
+ :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
3750
+ :type gt_image_ids: List[int]
3751
+ :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
3752
+ ground truth box of the same class will be removed. None if no filtering is needed
3753
+ :type iou: Optional[float]
3754
+ :param meta: ProjectMeta object
3755
+ :type meta: Optional[ProjectMeta]
3756
+ :return: List of Annotation objects containing filtered predictions
3757
+ :rtype: List[Annotation]
3758
+
3759
+ Notes:
3760
+ ------
3761
+ - Requires PyTorch and torchvision for IoU calculations
3762
+ - This method is useful for identifying new objects that aren't already annotated in the ground truth
3763
+ """
3764
+ if isinstance(iou, float) and 0 < iou <= 1:
3765
+ if meta is None:
3766
+ ds = api.dataset.get_info_by_id(dataset_id)
3767
+ meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
3768
+ gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
3769
+ gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
3770
+ for i in range(0, len(pred_anns)):
3771
+ before = len(pred_anns[i].labels)
3772
+ with Timer() as timer:
3773
+ pred_anns[i] = _filter_duplicated_predictions_from_ann(
3774
+ gt_anns[i], pred_anns[i], iou
3775
+ )
3776
+ after = len(pred_anns[i].labels)
3777
+ logger.debug(
3778
+ f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
3779
+ )
3780
+ return pred_anns
3712
3781
 
3713
- def _to_tensor(geom):
3714
- return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
3715
3782
 
3716
- new_labels = []
3717
- pred_cls_bboxes = defaultdict(list)
3718
- for label in pred_ann.labels:
3719
- pred_cls_bboxes[label.obj_class.name].append(label)
3783
+ def _filter_duplicated_predictions_from_ann(
3784
+ gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
3785
+ ) -> Annotation:
3786
+ """
3787
+ Filter out predictions that significantly overlap with ground truth annotations.
3788
+
3789
+ This function compares each prediction with ground truth annotations of the same class
3790
+ and removes predictions that have an IoU (Intersection over Union) greater than or equal
3791
+ to the specified threshold with any ground truth annotation. This is useful for identifying
3792
+ new objects that aren't already annotated in the ground truth.
3793
+
3794
+ :param gt_ann: Annotation object containing ground truth labels
3795
+ :type gt_ann: Annotation
3796
+ :param pred_ann: Annotation object containing prediction labels to be filtered
3797
+ :type pred_ann: Annotation
3798
+ :param iou_threshold: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
3799
+ ground truth box of the same class will be removed
3800
+ :type iou_threshold: float
3801
+ :return: A new annotation object containing only predictions that don't significantly
3802
+ overlap with ground truth annotations
3803
+ :rtype: Annotation
3804
+
3805
+
3806
+ Notes:
3807
+ ------
3808
+ - Predictions with classes not present in ground truth will be kept
3809
+ - Requires PyTorch and torchvision for IoU calculations
3810
+ """
3720
3811
 
3721
- gt_cls_bboxes = defaultdict(list)
3722
- for label in gt_ann.labels:
3723
- if label.obj_class.name not in pred_cls_bboxes:
3724
- continue
3725
- gt_cls_bboxes[label.obj_class.name].append(label)
3812
+ try:
3813
+ import torch
3814
+ from torchvision.ops import box_iou
3726
3815
 
3727
- for name, pred in pred_cls_bboxes.items():
3728
- gt = gt_cls_bboxes[name]
3729
- if len(gt) == 0:
3730
- new_labels.extend(pred)
3731
- continue
3732
- pred_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in pred]).float()
3733
- gt_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in gt]).float()
3734
- iou_matrix = box_iou(pred_bboxes, gt_bboxes)
3735
- iou_matrix = iou_matrix.cpu().numpy()
3736
- keep_indices = np.where(np.all(iou_matrix < iou_threshold, axis=1))[0]
3737
- new_labels.extend([pred[i] for i in keep_indices])
3816
+ except ImportError:
3817
+ raise ImportError("Please install PyTorch and torchvision to use this feature.")
3738
3818
 
3739
- return pred_ann.clone(labels=new_labels)
3819
+ def _to_tensor(geom):
3820
+ return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
3740
3821
 
3822
+ new_labels = []
3823
+ pred_cls_bboxes = defaultdict(list)
3824
+ for label in pred_ann.labels:
3825
+ pred_cls_bboxes[label.obj_class.name].append(label)
3826
+
3827
+ gt_cls_bboxes = defaultdict(list)
3828
+ for label in gt_ann.labels:
3829
+ if label.obj_class.name not in pred_cls_bboxes:
3830
+ continue
3831
+ gt_cls_bboxes[label.obj_class.name].append(label)
3832
+
3833
+ for name, pred in pred_cls_bboxes.items():
3834
+ gt = gt_cls_bboxes[name]
3835
+ if len(gt) == 0:
3836
+ new_labels.extend(pred)
3837
+ continue
3838
+ pred_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in pred]).float()
3839
+ gt_bboxes = torch.stack([_to_tensor(l.geometry.to_bbox()) for l in gt]).float()
3840
+ iou_matrix = box_iou(pred_bboxes, gt_bboxes)
3841
+ iou_matrix = iou_matrix.cpu().numpy()
3842
+ keep_indices = np.where(np.all(iou_matrix < iou_threshold, axis=1))[0]
3843
+ new_labels.extend([pred[i] for i in keep_indices])
3844
+
3845
+ return pred_ann.clone(labels=new_labels)
3846
+
3847
+
3848
+ def _get_log_extra_for_inference_request(
3849
+ inference_request_uuid, inference_request: Union[InferenceRequest, dict]
3850
+ ):
3851
+ if isinstance(inference_request, dict):
3852
+ log_extra = {
3853
+ "uuid": inference_request_uuid,
3854
+ "progress": inference_request["progress"],
3855
+ "is_inferring": inference_request["is_inferring"],
3856
+ "cancel_inference": inference_request["cancel_inference"],
3857
+ "has_result": inference_request["result"] is not None,
3858
+ "pending_results": len(inference_request["pending_results"]),
3859
+ }
3860
+ return log_extra
3741
3861
 
3742
- def _get_log_extra_for_inference_request(inference_request_uuid, inference_request: dict):
3862
+ progress = inference_request.progress_json()
3863
+ del progress["message"]
3743
3864
  log_extra = {
3744
- "uuid": inference_request_uuid,
3745
- "progress": inference_request["progress"],
3746
- "is_inferring": inference_request["is_inferring"],
3747
- "cancel_inference": inference_request["cancel_inference"],
3748
- "has_result": inference_request["result"] is not None,
3749
- "pending_results": len(inference_request["pending_results"]),
3865
+ "uuid": inference_request.uuid,
3866
+ "progress": progress,
3867
+ "is_inferring": inference_request.is_inferring(),
3868
+ "stopped": inference_request.is_stopped(),
3869
+ "finished": inference_request.is_finished(),
3870
+ "cancel_inference": inference_request.is_stopped(),
3871
+ "has_result": inference_request.final_result is not None,
3872
+ "pending_results": inference_request.pending_num(),
3873
+ "exception": inference_request.exception_json(),
3874
+ "result": inference_request._final_result,
3875
+ "preparing_progress": progress,
3750
3876
  }
3751
3877
  return log_extra
3752
3878
 
@@ -4059,3 +4185,33 @@ def get_hardware_info(device: str) -> str:
4059
4185
  except Exception as e:
4060
4186
  logger.error("Error while getting hardware info", exc_info=True)
4061
4187
  return "Unknown"
4188
+
4189
+
4190
+ def progress_wrapper(func, progress_cb):
4191
+ @wraps(func)
4192
+ def wrapped_func(*args, **kwargs):
4193
+ result = func(*args, **kwargs)
4194
+ progress_cb(len(result))
4195
+ return result
4196
+
4197
+ return wrapped_func
4198
+
4199
+
4200
+ def batched_iter(iterable, batch_size):
4201
+ batch = []
4202
+ for item in iterable:
4203
+ batch.append(item)
4204
+ if len(batch) == batch_size:
4205
+ yield batch
4206
+ batch = []
4207
+ if batch:
4208
+ yield batch
4209
+
4210
+
4211
+ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
4212
+ for key in keys:
4213
+ if key in data:
4214
+ if ignore_none and data[key] is None:
4215
+ continue
4216
+ return data[key]
4217
+ return None