supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. supervisely/__init__.py +137 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/annotation.py +8 -2
  4. supervisely/annotation/json_geometries_map.py +14 -11
  5. supervisely/annotation/label.py +80 -3
  6. supervisely/api/annotation_api.py +14 -11
  7. supervisely/api/api.py +59 -38
  8. supervisely/api/app_api.py +11 -2
  9. supervisely/api/dataset_api.py +74 -12
  10. supervisely/api/entities_collection_api.py +10 -0
  11. supervisely/api/entity_annotation/figure_api.py +52 -4
  12. supervisely/api/entity_annotation/object_api.py +3 -3
  13. supervisely/api/entity_annotation/tag_api.py +63 -12
  14. supervisely/api/guides_api.py +210 -0
  15. supervisely/api/image_api.py +72 -1
  16. supervisely/api/labeling_job_api.py +83 -1
  17. supervisely/api/labeling_queue_api.py +33 -7
  18. supervisely/api/module_api.py +9 -0
  19. supervisely/api/project_api.py +71 -26
  20. supervisely/api/storage_api.py +3 -1
  21. supervisely/api/task_api.py +13 -2
  22. supervisely/api/team_api.py +4 -3
  23. supervisely/api/video/video_annotation_api.py +119 -3
  24. supervisely/api/video/video_api.py +65 -14
  25. supervisely/api/video/video_figure_api.py +24 -11
  26. supervisely/app/__init__.py +1 -1
  27. supervisely/app/content.py +23 -7
  28. supervisely/app/development/development.py +18 -2
  29. supervisely/app/fastapi/__init__.py +1 -0
  30. supervisely/app/fastapi/custom_static_files.py +1 -1
  31. supervisely/app/fastapi/multi_user.py +105 -0
  32. supervisely/app/fastapi/subapp.py +88 -42
  33. supervisely/app/fastapi/websocket.py +77 -9
  34. supervisely/app/singleton.py +21 -0
  35. supervisely/app/v1/app_service.py +18 -2
  36. supervisely/app/v1/constants.py +7 -1
  37. supervisely/app/widgets/__init__.py +6 -0
  38. supervisely/app/widgets/activity_feed/__init__.py +0 -0
  39. supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
  40. supervisely/app/widgets/activity_feed/style.css +78 -0
  41. supervisely/app/widgets/activity_feed/template.html +22 -0
  42. supervisely/app/widgets/card/card.py +20 -0
  43. supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
  44. supervisely/app/widgets/classes_list_selector/template.html +60 -93
  45. supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
  46. supervisely/app/widgets/classes_table/classes_table.py +1 -0
  47. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  48. supervisely/app/widgets/dialog/dialog.py +12 -0
  49. supervisely/app/widgets/dialog/template.html +2 -1
  50. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
  51. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  52. supervisely/app/widgets/fast_table/fast_table.py +184 -60
  53. supervisely/app/widgets/fast_table/template.html +1 -1
  54. supervisely/app/widgets/heatmap/__init__.py +0 -0
  55. supervisely/app/widgets/heatmap/heatmap.py +564 -0
  56. supervisely/app/widgets/heatmap/script.js +533 -0
  57. supervisely/app/widgets/heatmap/style.css +233 -0
  58. supervisely/app/widgets/heatmap/template.html +21 -0
  59. supervisely/app/widgets/modal/__init__.py +0 -0
  60. supervisely/app/widgets/modal/modal.py +198 -0
  61. supervisely/app/widgets/modal/template.html +10 -0
  62. supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
  63. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  64. supervisely/app/widgets/radio_tabs/template.html +1 -0
  65. supervisely/app/widgets/select/select.py +6 -3
  66. supervisely/app/widgets/select_class/__init__.py +0 -0
  67. supervisely/app/widgets/select_class/select_class.py +363 -0
  68. supervisely/app/widgets/select_class/template.html +50 -0
  69. supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
  70. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
  71. supervisely/app/widgets/select_tag/__init__.py +0 -0
  72. supervisely/app/widgets/select_tag/select_tag.py +352 -0
  73. supervisely/app/widgets/select_tag/template.html +64 -0
  74. supervisely/app/widgets/select_team/select_team.py +37 -4
  75. supervisely/app/widgets/select_team/template.html +4 -5
  76. supervisely/app/widgets/select_user/__init__.py +0 -0
  77. supervisely/app/widgets/select_user/select_user.py +270 -0
  78. supervisely/app/widgets/select_user/template.html +13 -0
  79. supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
  80. supervisely/app/widgets/select_workspace/template.html +9 -12
  81. supervisely/app/widgets/table/table.py +68 -13
  82. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  83. supervisely/aug/aug.py +6 -2
  84. supervisely/convert/base_converter.py +1 -0
  85. supervisely/convert/converter.py +2 -2
  86. supervisely/convert/image/csv/csv_converter.py +24 -15
  87. supervisely/convert/image/image_converter.py +3 -1
  88. supervisely/convert/image/image_helper.py +48 -4
  89. supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
  90. supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
  91. supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
  92. supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
  93. supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
  94. supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
  95. supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
  96. supervisely/convert/pointcloud/las/las_converter.py +13 -1
  97. supervisely/convert/pointcloud/las/las_helper.py +110 -11
  98. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
  99. supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
  100. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
  101. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
  102. supervisely/convert/video/__init__.py +1 -0
  103. supervisely/convert/video/multi_view/__init__.py +0 -0
  104. supervisely/convert/video/multi_view/multi_view.py +543 -0
  105. supervisely/convert/video/sly/sly_video_converter.py +359 -3
  106. supervisely/convert/video/video_converter.py +24 -4
  107. supervisely/convert/volume/dicom/dicom_converter.py +13 -5
  108. supervisely/convert/volume/dicom/dicom_helper.py +30 -18
  109. supervisely/geometry/constants.py +1 -0
  110. supervisely/geometry/geometry.py +4 -0
  111. supervisely/geometry/helpers.py +5 -1
  112. supervisely/geometry/oriented_bbox.py +676 -0
  113. supervisely/geometry/polyline_3d.py +110 -0
  114. supervisely/geometry/rectangle.py +2 -1
  115. supervisely/io/env.py +76 -1
  116. supervisely/io/fs.py +21 -0
  117. supervisely/nn/benchmark/base_evaluator.py +104 -11
  118. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
  119. supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
  120. supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
  121. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
  122. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
  123. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
  124. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
  125. supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
  126. supervisely/nn/inference/cache.py +43 -18
  127. supervisely/nn/inference/gui/serving_gui_template.py +5 -2
  128. supervisely/nn/inference/inference.py +916 -222
  129. supervisely/nn/inference/inference_request.py +55 -10
  130. supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
  131. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  132. supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
  133. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  134. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  135. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  136. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  137. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  138. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  139. supervisely/nn/inference/session.py +43 -35
  140. supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
  141. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  142. supervisely/nn/inference/tracking/tracker_interface.py +10 -1
  143. supervisely/nn/inference/uploader.py +139 -12
  144. supervisely/nn/live_training/__init__.py +7 -0
  145. supervisely/nn/live_training/api_server.py +111 -0
  146. supervisely/nn/live_training/artifacts_utils.py +243 -0
  147. supervisely/nn/live_training/checkpoint_utils.py +229 -0
  148. supervisely/nn/live_training/dynamic_sampler.py +44 -0
  149. supervisely/nn/live_training/helpers.py +14 -0
  150. supervisely/nn/live_training/incremental_dataset.py +146 -0
  151. supervisely/nn/live_training/live_training.py +497 -0
  152. supervisely/nn/live_training/loss_plateau_detector.py +111 -0
  153. supervisely/nn/live_training/request_queue.py +52 -0
  154. supervisely/nn/model/model_api.py +9 -0
  155. supervisely/nn/model/prediction.py +2 -1
  156. supervisely/nn/model/prediction_session.py +26 -14
  157. supervisely/nn/prediction_dto.py +19 -1
  158. supervisely/nn/tracker/base_tracker.py +11 -1
  159. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  160. supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
  161. supervisely/nn/tracker/botsort_tracker.py +94 -65
  162. supervisely/nn/tracker/utils.py +4 -5
  163. supervisely/nn/tracker/visualize.py +93 -93
  164. supervisely/nn/training/gui/classes_selector.py +16 -1
  165. supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
  166. supervisely/nn/training/train_app.py +46 -31
  167. supervisely/project/data_version.py +115 -51
  168. supervisely/project/download.py +1 -1
  169. supervisely/project/pointcloud_episode_project.py +37 -8
  170. supervisely/project/pointcloud_project.py +30 -2
  171. supervisely/project/project.py +14 -2
  172. supervisely/project/project_meta.py +27 -1
  173. supervisely/project/project_settings.py +32 -18
  174. supervisely/project/versioning/__init__.py +1 -0
  175. supervisely/project/versioning/common.py +20 -0
  176. supervisely/project/versioning/schema_fields.py +35 -0
  177. supervisely/project/versioning/video_schema.py +221 -0
  178. supervisely/project/versioning/volume_schema.py +87 -0
  179. supervisely/project/video_project.py +717 -15
  180. supervisely/project/volume_project.py +623 -5
  181. supervisely/template/experiment/experiment.html.jinja +4 -4
  182. supervisely/template/experiment/experiment_generator.py +14 -21
  183. supervisely/template/live_training/__init__.py +0 -0
  184. supervisely/template/live_training/header.html.jinja +96 -0
  185. supervisely/template/live_training/live_training.html.jinja +51 -0
  186. supervisely/template/live_training/live_training_generator.py +464 -0
  187. supervisely/template/live_training/sly-style.css +402 -0
  188. supervisely/template/live_training/template.html.jinja +18 -0
  189. supervisely/versions.json +28 -26
  190. supervisely/video/sampling.py +39 -20
  191. supervisely/video/video.py +41 -12
  192. supervisely/video_annotation/video_figure.py +38 -4
  193. supervisely/video_annotation/video_object.py +29 -4
  194. supervisely/volume/stl_converter.py +2 -0
  195. supervisely/worker_api/agent_rpc.py +24 -1
  196. supervisely/worker_api/rpc_servicer.py +31 -7
  197. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
  198. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
  199. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
  200. supervisely_lib/__init__.py +6 -1
  201. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
  202. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
  203. {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
@@ -5,12 +5,14 @@ import asyncio
5
5
  import inspect
6
6
  import json
7
7
  import os
8
+ import queue
8
9
  import re
9
10
  import shutil
10
11
  import subprocess
11
12
  import tempfile
12
13
  import threading
13
14
  import time
15
+ import uuid
14
16
  from collections import OrderedDict, defaultdict
15
17
  from concurrent.futures import ThreadPoolExecutor
16
18
  from dataclasses import asdict, dataclass
@@ -45,13 +47,14 @@ from supervisely._utils import (
45
47
  rand_str,
46
48
  )
47
49
  from supervisely.annotation.annotation import Annotation
48
- from supervisely.annotation.label import Label
50
+ from supervisely.annotation.label import Label, LabelingStatus
49
51
  from supervisely.annotation.obj_class import ObjClass
50
52
  from supervisely.annotation.tag_collection import TagCollection
51
53
  from supervisely.annotation.tag_meta import TagMeta, TagValueType
52
54
  from supervisely.api.api import Api, ApiField
53
55
  from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
54
56
  from supervisely.api.image_api import ImageInfo
57
+ from supervisely.api.video.video_api import VideoInfo
55
58
  from supervisely.app.content import get_data_dir
56
59
  from supervisely.app.fastapi.subapp import (
57
60
  Application,
@@ -67,6 +70,7 @@ from supervisely.decorators.inference import (
67
70
  process_images_batch_sliding_window,
68
71
  )
69
72
  from supervisely.geometry.any_geometry import AnyGeometry
73
+ from supervisely.geometry.geometry import Geometry
70
74
  from supervisely.imaging.color import get_predefined_colors
71
75
  from supervisely.io.fs import list_files
72
76
  from supervisely.nn.experiments import ExperimentInfo
@@ -75,7 +79,7 @@ from supervisely.nn.inference.inference_request import (
75
79
  InferenceRequest,
76
80
  InferenceRequestsManager,
77
81
  )
78
- from supervisely.nn.inference.uploader import Uploader
82
+ from supervisely.nn.inference.uploader import Downloader, Uploader
79
83
  from supervisely.nn.model.model_api import ModelAPI, Prediction
80
84
  from supervisely.nn.prediction_dto import Prediction as PredictionDTO
81
85
  from supervisely.nn.utils import (
@@ -94,6 +98,17 @@ from supervisely.project.project_meta import ProjectMeta
94
98
  from supervisely.sly_logger import logger
95
99
  from supervisely.task.progress import Progress
96
100
  from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
101
+ from supervisely.video_annotation.frame import Frame
102
+ from supervisely.video_annotation.frame_collection import FrameCollection
103
+ from supervisely.video_annotation.key_id_map import KeyIdMap
104
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
105
+ from supervisely.video_annotation.video_figure import VideoFigure
106
+ from supervisely.video_annotation.video_object import VideoObject
107
+ from supervisely.video_annotation.video_object_collection import (
108
+ VideoObject,
109
+ VideoObjectCollection,
110
+ )
111
+ from supervisely.video_annotation.video_tag_collection import VideoTagCollection
97
112
 
98
113
  try:
99
114
  from typing import Literal
@@ -140,6 +155,7 @@ class Inference:
140
155
  """Default batch size for inference"""
141
156
  INFERENCE_SETTINGS: str = None
142
157
  """Path to file with custom inference settings"""
158
+ DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
143
159
 
144
160
  def __init__(
145
161
  self,
@@ -193,7 +209,6 @@ class Inference:
193
209
  self._task_id = None
194
210
  self._sliding_window_mode = sliding_window_mode
195
211
  self._autostart_delay_time = 5 * 60 # 5 min
196
- self._tracker = None
197
212
  self._hardware: str = None
198
213
  if custom_inference_settings is None:
199
214
  if self.INFERENCE_SETTINGS is not None:
@@ -427,7 +442,7 @@ class Inference:
427
442
 
428
443
  device = "cuda" if torch.cuda.is_available() else "cpu"
429
444
  except Exception as e:
430
- logger.warn(
445
+ logger.warning(
431
446
  f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
432
447
  )
433
448
  device = "cpu"
@@ -734,15 +749,15 @@ class Inference:
734
749
  for model in self.pretrained_models:
735
750
  model_meta = model.get("meta")
736
751
  if model_meta is not None:
737
- model_name = model_meta.get("model_name")
738
- if model_name is not None:
739
- if model_name.lower() == model_name.lower():
752
+ this_model_name = model_meta.get("model_name")
753
+ if this_model_name is not None:
754
+ if this_model_name.lower() == model_name.lower():
740
755
  selected_model = model
741
756
  break
742
757
  else:
743
- model_name = model.get("model_name")
744
- if model_name is not None:
745
- if model_name.lower() == model_name.lower():
758
+ this_model_name = model.get("model_name")
759
+ if this_model_name is not None:
760
+ if this_model_name.lower() == model_name.lower():
746
761
  selected_model = model
747
762
  break
748
763
 
@@ -863,6 +878,50 @@ class Inference:
863
878
  self.gui.download_progress.hide()
864
879
  return local_model_files
865
880
 
881
+ def _fallback_download_custom_model_pt(self, deploy_params: dict):
882
+ """
883
+ Downloads the PyTorch checkpoint from Team Files if TensorRT is failed to load.
884
+ """
885
+ team_id = sly_env.team_id()
886
+
887
+ checkpoint_name = sly_fs.get_file_name(deploy_params["model_files"]["checkpoint"])
888
+ artifacts_dir = deploy_params["model_info"]["artifacts_dir"]
889
+ checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
890
+ checkpoint_ext = sly_fs.get_file_ext(deploy_params["model_info"]["checkpoints"][0])
891
+
892
+ pt_checkpoint_name = f"{checkpoint_name}{checkpoint_ext}"
893
+ remote_checkpoint_path = os.path.join(checkpoints_dir, pt_checkpoint_name)
894
+ local_checkpoint_path = os.path.join(self.model_dir, pt_checkpoint_name)
895
+
896
+ file_info = self.api.file.get_info_by_path(team_id, remote_checkpoint_path)
897
+ file_size = file_info.sizeb
898
+ if self.gui is not None:
899
+ with self.gui.download_progress(
900
+ message=f"Fallback. Downloading PyTorch checkpoint: '{pt_checkpoint_name}'",
901
+ total=file_size,
902
+ unit="bytes",
903
+ unit_scale=True,
904
+ ) as download_pbar:
905
+ self.gui.download_progress.show()
906
+ self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path, progress_cb=download_pbar.update)
907
+ self.gui.download_progress.hide()
908
+ else:
909
+ self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path)
910
+
911
+ return local_checkpoint_path
912
+
913
+ def _remove_exported_checkpoints(self, checkpoint_path: str):
914
+ """
915
+ Removes the exported checkpoints for provided PyTorch checkpoint path.
916
+ """
917
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
918
+ onnx_path = checkpoint_path.replace(checkpoint_ext, ".onnx")
919
+ engine_path = checkpoint_path.replace(checkpoint_ext, ".engine")
920
+ if os.path.exists(onnx_path):
921
+ sly_fs.silent_remove(onnx_path)
922
+ if os.path.exists(engine_path):
923
+ sly_fs.silent_remove(engine_path)
924
+
866
925
  def _download_custom_model(self, model_files: dict, log_progress: bool = True):
867
926
  """
868
927
  Downloads the custom model data.
@@ -1060,7 +1119,41 @@ class Inference:
1060
1119
  self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
1061
1120
  self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
1062
1121
  self._hardware = get_hardware_info(self.device)
1063
- self.load_model(**deploy_params)
1122
+
1123
+ model_files = deploy_params.get("model_files", None)
1124
+ if model_files is not None:
1125
+ checkpoint_path = deploy_params["model_files"]["checkpoint"]
1126
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
1127
+ if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
1128
+ try:
1129
+ self.load_model(**deploy_params)
1130
+ except Exception as e:
1131
+ logger.warning(
1132
+ f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}"
1133
+ )
1134
+ checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
1135
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1136
+ logger.info("Exporting PyTorch model to TensorRT...")
1137
+ self._remove_exported_checkpoints(checkpoint_path)
1138
+ checkpoint_path = self.export_tensorrt(deploy_params)
1139
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1140
+ self.load_model(**deploy_params)
1141
+ if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
1142
+ if self.runtime == RuntimeType.ONNXRUNTIME:
1143
+ logger.info("Exporting PyTorch model to ONNX...")
1144
+ self._remove_exported_checkpoints(checkpoint_path)
1145
+ checkpoint_path = self.export_onnx(deploy_params)
1146
+ elif self.runtime == RuntimeType.TENSORRT:
1147
+ logger.info("Exporting PyTorch model to TensorRT...")
1148
+ self._remove_exported_checkpoints(checkpoint_path)
1149
+ checkpoint_path = self.export_tensorrt(deploy_params)
1150
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1151
+ self.load_model(**deploy_params)
1152
+ else:
1153
+ self.load_model(**deploy_params)
1154
+ else:
1155
+ self.load_model(**deploy_params)
1156
+
1064
1157
  self._model_served = True
1065
1158
  self._deploy_params = deploy_params
1066
1159
  if self._task_id is not None and is_production():
@@ -1269,18 +1362,19 @@ class Inference:
1269
1362
 
1270
1363
  def get_classes(self) -> List[str]:
1271
1364
  return self.classes
1272
-
1365
+
1273
1366
  def _tracker_init(self, tracker: str, tracker_settings: dict):
1274
1367
  # Check if tracking is supported for this model
1275
1368
  info = self.get_info()
1276
1369
  tracking_support = info.get("tracking_on_videos_support", False)
1277
-
1370
+
1278
1371
  if not tracking_support:
1279
1372
  logger.debug("Tracking is not supported for this model")
1280
1373
  return None
1281
-
1374
+
1282
1375
  if tracker == "botsort":
1283
1376
  from supervisely.nn.tracker import BotSortTracker
1377
+
1284
1378
  device = tracker_settings.get("device", self.device)
1285
1379
  logger.debug(f"Initializing BotSort tracker with device: {device}")
1286
1380
  return BotSortTracker(settings=tracker_settings, device=device)
@@ -1289,7 +1383,6 @@ class Inference:
1289
1383
  logger.warning(f"Unknown tracking type: {tracker}. Tracking is disabled.")
1290
1384
  return None
1291
1385
 
1292
-
1293
1386
  def get_info(self) -> Dict[str, Any]:
1294
1387
  num_classes = None
1295
1388
  classes = None
@@ -1298,15 +1391,15 @@ class Inference:
1298
1391
  if classes is not None:
1299
1392
  num_classes = len(classes)
1300
1393
  except NotImplementedError:
1301
- logger.warn(f"get_classes() function not implemented for {type(self)} object.")
1394
+ logger.warning(f"get_classes() function not implemented for {type(self)} object.")
1302
1395
  except AttributeError:
1303
- logger.warn("Probably, get_classes() function not working without model deploy.")
1396
+ logger.warning("Probably, get_classes() function not working without model deploy.")
1304
1397
  except Exception as exc:
1305
- logger.warn("Unknown exception. Please, contact support")
1398
+ logger.warning("Unknown exception. Please, contact support")
1306
1399
  logger.exception(exc)
1307
1400
 
1308
1401
  if num_classes is None:
1309
- logger.warn(f"get_classes() function return {classes}; skip classes processing.")
1402
+ logger.warning(f"get_classes() function return {classes}; skip classes processing.")
1310
1403
 
1311
1404
  return {
1312
1405
  "app_name": get_name_from_env(default="Neural Network Serving"),
@@ -1324,6 +1417,42 @@ class Inference:
1324
1417
 
1325
1418
  # pylint: enable=method-hidden
1326
1419
 
1420
+ def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
1421
+ """
1422
+ Get default parameters for all available tracking algorithms.
1423
+
1424
+ Returns:
1425
+ {"botsort": {"track_high_thresh": 0.6, ...}}
1426
+ Empty dict if tracking not supported.
1427
+ """
1428
+ info = self.get_info()
1429
+ trackers_params = {}
1430
+
1431
+ tracking_support = info.get("tracking_on_videos_support")
1432
+ if not tracking_support:
1433
+ return trackers_params
1434
+
1435
+ tracking_algorithms = info.get("tracking_algorithms", [])
1436
+
1437
+ for tracker_name in tracking_algorithms:
1438
+ try:
1439
+ if tracker_name == "botsort":
1440
+ from supervisely.nn.tracker import BotSortTracker
1441
+
1442
+ trackers_params[tracker_name] = BotSortTracker.get_default_params()
1443
+ # Add other trackers here as elif blocks
1444
+ else:
1445
+ logger.debug(f"Tracker '{tracker_name}' not implemented")
1446
+ except Exception as e:
1447
+ logger.warning(f"Failed to get params for '{tracker_name}': {e}")
1448
+
1449
+ INTERNAL_FIELDS = {"device", "fps"}
1450
+ for tracker_name, params in trackers_params.items():
1451
+ trackers_params[tracker_name] = {
1452
+ k: v for k, v in params.items() if k not in INTERNAL_FIELDS
1453
+ }
1454
+ return trackers_params
1455
+
1327
1456
  def get_human_readable_info(self, replace_none_with: Optional[str] = None):
1328
1457
  hr_info = {}
1329
1458
  info = self.get_info()
@@ -1439,8 +1568,12 @@ class Inference:
1439
1568
  # for example empty mask
1440
1569
  continue
1441
1570
  if isinstance(label, list):
1571
+ for lb in label:
1572
+ lb.status = LabelingStatus.AUTO
1442
1573
  labels.extend(label)
1443
1574
  continue
1575
+
1576
+ label.status = LabelingStatus.AUTO
1444
1577
  labels.append(label)
1445
1578
 
1446
1579
  # create annotation with correct image resolution
@@ -1871,8 +2004,8 @@ class Inference:
1871
2004
  else:
1872
2005
  n_frames = frames_reader.frames_count()
1873
2006
 
1874
- self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
1875
-
2007
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2008
+
1876
2009
  progress_total = (n_frames + step - 1) // step
1877
2010
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
1878
2011
 
@@ -1896,32 +2029,30 @@ class Inference:
1896
2029
  source=frames,
1897
2030
  settings=inference_settings,
1898
2031
  )
1899
-
1900
- if self._tracker is not None:
1901
- anns = self._apply_tracker_to_anns(frames, anns)
1902
-
2032
+
2033
+ if inference_request.tracker is not None:
2034
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2035
+
1903
2036
  predictions = [
1904
2037
  Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
1905
2038
  for ann, frame_index in zip(anns, batch)
1906
2039
  ]
1907
-
2040
+
1908
2041
  for pred, this_slides_data in zip(predictions, slides_data):
1909
2042
  pred.extra_data["slides_data"] = this_slides_data
1910
2043
  batch_results = self._format_output(predictions)
1911
-
2044
+
1912
2045
  inference_request.add_results(batch_results)
1913
2046
  inference_request.done(len(batch_results))
1914
2047
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
1915
2048
  video_ann_json = None
1916
- if self._tracker is not None:
2049
+ if inference_request.tracker is not None:
1917
2050
  inference_request.set_stage("Postprocess...", 0, 1)
1918
-
1919
- video_ann_json = self._tracker.video_annotation.to_json()
2051
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
1920
2052
  inference_request.done()
1921
2053
  result = {"ann": results, "video_ann": video_ann_json}
1922
2054
  inference_request.final_result = result.copy()
1923
2055
  return video_ann_json
1924
-
1925
2056
 
1926
2057
  def _inference_image_ids(
1927
2058
  self,
@@ -1949,7 +2080,7 @@ class Inference:
1949
2080
  upload_mode = state.get("upload_mode", None)
1950
2081
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
1951
2082
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
1952
- iou_merge_threshold = 0.7
2083
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
1953
2084
 
1954
2085
  images_infos = api.image.get_info_by_id_batch(image_ids)
1955
2086
  images_infos_dict = {im_info.id: im_info for im_info in images_infos}
@@ -1991,14 +2122,9 @@ class Inference:
1991
2122
  output_dataset_id
1992
2123
  ] = output_dataset_info
1993
2124
 
1994
- # start download to cache in background
1995
- dataset_image_infos: Dict[int, List[ImageInfo]] = defaultdict(list)
1996
- for image_info in images_infos:
1997
- dataset_image_infos[image_info.dataset_id].append(image_info)
1998
- for dataset_id, ds_image_infos in dataset_image_infos.items():
1999
- self.cache.run_cache_task_manually(
2000
- api, [info.id for info in ds_image_infos], dataset_id=dataset_id
2001
- )
2125
+ def download_f(item: int):
2126
+ self.cache.download_image(api, item)
2127
+ return item
2002
2128
 
2003
2129
  _upload_predictions = partial(
2004
2130
  self.upload_predictions,
@@ -2014,7 +2140,9 @@ class Inference:
2014
2140
  )
2015
2141
 
2016
2142
  _add_results_to_request = partial(
2017
- self.add_results_to_request, inference_request=inference_request
2143
+ self.add_results_to_request,
2144
+ inference_request=inference_request,
2145
+ progress_cb=inference_request.done,
2018
2146
  )
2019
2147
 
2020
2148
  if upload_mode is None:
@@ -2023,40 +2151,60 @@ class Inference:
2023
2151
  upload_f = _upload_predictions
2024
2152
 
2025
2153
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(image_ids))
2154
+ download_workers = max(8, min(batch_size, 64))
2026
2155
  with Uploader(upload_f, logger=logger) as uploader:
2027
- for image_ids_batch in batched(image_ids, batch_size=batch_size):
2028
- if uploader.has_exception():
2029
- exception = uploader.exception
2030
- raise exception
2031
- if inference_request.is_stopped():
2032
- logger.debug(
2033
- f"Cancelling inference project...",
2034
- extra={"inference_request_uuid": inference_request.uuid},
2035
- )
2036
- break
2037
-
2038
- images_nps = [self.cache.download_image(api, img_id) for img_id in image_ids_batch]
2039
- anns, slides_data = self._inference_auto(
2040
- source=images_nps,
2041
- settings=inference_settings,
2042
- )
2156
+ with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
2157
+ for image_id in image_ids:
2158
+ downloader.put(image_id)
2159
+ downloader.next(100)
2160
+ for image_ids_batch in batched(image_ids, batch_size=batch_size):
2161
+ if uploader.has_exception():
2162
+ exception = uploader.exception
2163
+ raise exception
2164
+ if inference_request.is_stopped():
2165
+ logger.debug(
2166
+ f"Cancelling inference...",
2167
+ extra={"inference_request_uuid": inference_request.uuid},
2168
+ )
2169
+ break
2170
+ if inference_request.is_paused():
2171
+ logger.info("Inference request is paused. Waiting...")
2172
+ while inference_request.is_paused():
2173
+ if (
2174
+ inference_request.paused_for()
2175
+ > inference_request.PAUSE_SLEEP_MAX_WAIT
2176
+ ):
2177
+ logger.info(
2178
+ "Inference request has been paused for too long. Cancelling..."
2179
+ )
2180
+ raise RuntimeError("Inference request cancelled due to long pause.")
2181
+ time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
2043
2182
 
2044
- batch_predictions = []
2045
- for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
2046
- image_info: ImageInfo = images_infos_dict[image_id]
2047
- dataset_info = dataset_infos_dict[image_info.dataset_id]
2048
- prediction = Prediction(
2049
- ann,
2050
- model_meta=self.model_meta,
2051
- name=image_info.name,
2052
- image_id=image_info.id,
2053
- dataset_id=image_info.dataset_id,
2054
- project_id=dataset_info.project_id,
2183
+ images_nps = [
2184
+ self.cache.download_image(api, img_id) for img_id in image_ids_batch
2185
+ ]
2186
+ downloader.next(len(image_ids_batch))
2187
+ anns, slides_data = self._inference_auto(
2188
+ source=images_nps,
2189
+ settings=inference_settings,
2055
2190
  )
2056
- prediction.extra_data["slides_data"] = this_slides_data
2057
- batch_predictions.append(prediction)
2058
2191
 
2059
- uploader.put(batch_predictions)
2192
+ batch_predictions = []
2193
+ for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
2194
+ image_info: ImageInfo = images_infos_dict[image_id]
2195
+ dataset_info = dataset_infos_dict[image_info.dataset_id]
2196
+ prediction = Prediction(
2197
+ ann,
2198
+ model_meta=self.model_meta,
2199
+ name=image_info.name,
2200
+ image_id=image_info.id,
2201
+ dataset_id=image_info.dataset_id,
2202
+ project_id=dataset_info.project_id,
2203
+ )
2204
+ prediction.extra_data["slides_data"] = this_slides_data
2205
+ batch_predictions.append(prediction)
2206
+
2207
+ uploader.put(batch_predictions)
2060
2208
 
2061
2209
  def _inference_video_id(
2062
2210
  self,
@@ -2071,7 +2219,7 @@ class Inference:
2071
2219
  video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2072
2220
  if video_id is None:
2073
2221
  raise ValueError("Video id is not provided")
2074
- video_info = api.video.get_info_by_id(video_id)
2222
+ video_info = api.video.get_info_by_id(video_id, force_metadata_for_links=True)
2075
2223
  start_frame_index = get_value_for_keys(
2076
2224
  state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2077
2225
  )
@@ -2101,8 +2249,8 @@ class Inference:
2101
2249
  else:
2102
2250
  n_frames = video_info.frames_count
2103
2251
 
2104
- self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2105
-
2252
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2253
+
2106
2254
  logger.debug(
2107
2255
  f"Video info:",
2108
2256
  extra=dict(
@@ -2137,10 +2285,10 @@ class Inference:
2137
2285
  source=frames,
2138
2286
  settings=inference_settings,
2139
2287
  )
2140
-
2141
- if self._tracker is not None:
2142
- anns = self._apply_tracker_to_anns(frames, anns)
2143
-
2288
+
2289
+ if inference_request.tracker is not None:
2290
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2291
+
2144
2292
  predictions = [
2145
2293
  Prediction(
2146
2294
  ann,
@@ -2148,21 +2296,181 @@ class Inference:
2148
2296
  frame_index=frame_index,
2149
2297
  video_id=video_info.id,
2150
2298
  dataset_id=video_info.dataset_id,
2151
- project_id=video_info.project_id,
2152
- )
2299
+ project_id=video_info.project_id,
2300
+ )
2153
2301
  for ann, frame_index in zip(anns, batch)
2154
2302
  ]
2155
2303
  for pred, this_slides_data in zip(predictions, slides_data):
2156
2304
  pred.extra_data["slides_data"] = this_slides_data
2157
2305
  batch_results = self._format_output(predictions)
2158
-
2306
+
2159
2307
  inference_request.add_results(batch_results)
2160
2308
  inference_request.done(len(batch_results))
2161
2309
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
2162
2310
  video_ann_json = None
2163
- if self._tracker is not None:
2311
+ if inference_request.tracker is not None:
2312
+ inference_request.set_stage("Postprocess...", 0, progress_total)
2313
+
2314
+ video_ann_json = inference_request.tracker.create_video_annotation(
2315
+ video_info.frames_count,
2316
+ start_frame_index,
2317
+ step=step,
2318
+ progress_cb=inference_request.done,
2319
+ ).to_json()
2320
+ inference_request.final_result = {"video_ann": video_ann_json}
2321
+ return video_ann_json
2322
+
2323
+ def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
2324
+ logger.debug("Inferring video_id...", extra={"state": state})
2325
+ inference_settings = self._get_inference_settings(state)
2326
+ logger.debug(f"Inference settings:", extra=inference_settings)
2327
+ batch_size = self._get_batch_size_from_state(state)
2328
+ video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2329
+ if video_id is None:
2330
+ raise ValueError("Video id is not provided")
2331
+ video_info = api.video.get_info_by_id(video_id)
2332
+ start_frame_index = get_value_for_keys(
2333
+ state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2334
+ )
2335
+ if start_frame_index is None:
2336
+ start_frame_index = 0
2337
+ step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
2338
+ if step is None:
2339
+ step = 1
2340
+ end_frame_index = get_value_for_keys(
2341
+ state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
2342
+ )
2343
+ duration = state.get("duration", None)
2344
+ frames_count = get_value_for_keys(
2345
+ state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
2346
+ )
2347
+ tracking = state.get("tracker", None)
2348
+ direction = state.get("direction", "forward")
2349
+ direction = 1 if direction == "forward" else -1
2350
+ track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
2351
+
2352
+ if frames_count is not None:
2353
+ n_frames = frames_count
2354
+ elif end_frame_index is not None:
2355
+ n_frames = end_frame_index - start_frame_index
2356
+ elif duration is not None:
2357
+ fps = video_info.frames_count / video_info.duration
2358
+ n_frames = int(duration * fps)
2359
+ else:
2360
+ n_frames = video_info.frames_count
2361
+
2362
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2363
+
2364
+ logger.debug(
2365
+ f"Video info:",
2366
+ extra=dict(
2367
+ w=video_info.frame_width,
2368
+ h=video_info.frame_height,
2369
+ start_frame_index=start_frame_index,
2370
+ n_frames=n_frames,
2371
+ ),
2372
+ )
2373
+
2374
+ # start downloading video in background
2375
+ self.cache.run_cache_task_manually(api, None, video_id=video_id)
2376
+
2377
+ progress_total = (n_frames + step - 1) // step
2378
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
2379
+
2380
+ _upload_f = partial(
2381
+ self.upload_predictions_to_video,
2382
+ api=api,
2383
+ video_info=video_info,
2384
+ track_id=track_id,
2385
+ context=inference_request.context,
2386
+ progress_cb=inference_request.done,
2387
+ inference_request=inference_request,
2388
+ )
2389
+
2390
+ _range = (start_frame_index, start_frame_index + direction * n_frames)
2391
+ if _range[0] > _range[1]:
2392
+ _range = (_range[1], _range[0])
2393
+
2394
+ def _notify_f(predictions: List[Prediction]):
2395
+ logger.debug(
2396
+ "Notifying tracking progress...",
2397
+ extra={
2398
+ "track_id": track_id,
2399
+ "range": _range,
2400
+ "current": inference_request.progress.current,
2401
+ "total": inference_request.progress.total,
2402
+ },
2403
+ )
2404
+ stopped = self.api.video.notify_progress(
2405
+ track_id=track_id,
2406
+ video_id=video_info.id,
2407
+ frame_start=_range[0],
2408
+ frame_end=_range[1],
2409
+ current=inference_request.progress.current,
2410
+ total=inference_request.progress.total,
2411
+ )
2412
+ if stopped:
2413
+ inference_request.stop()
2414
+ logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
2415
+
2416
+ def _exception_handler(e: Exception):
2417
+ self.api.video.notify_tracking_error(
2418
+ track_id=track_id,
2419
+ error=str(type(e)),
2420
+ message=str(e),
2421
+ )
2422
+ raise e
2423
+
2424
+ with Uploader(
2425
+ upload_f=_upload_f,
2426
+ notify_f=_notify_f,
2427
+ exception_handler=_exception_handler,
2428
+ logger=logger,
2429
+ ) as uploader:
2430
+ for batch in batched(
2431
+ range(
2432
+ start_frame_index, start_frame_index + direction * n_frames, direction * step
2433
+ ),
2434
+ batch_size,
2435
+ ):
2436
+ if inference_request.is_stopped():
2437
+ logger.debug(
2438
+ f"Cancelling inference video...",
2439
+ extra={"inference_request_uuid": inference_request.uuid},
2440
+ )
2441
+ break
2442
+ logger.debug(
2443
+ f"Inferring frames {batch[0]}-{batch[-1]}:",
2444
+ )
2445
+ frames = self.cache.download_frames(
2446
+ api, video_info.id, batch, redownload_video=True
2447
+ )
2448
+ anns, slides_data = self._inference_auto(
2449
+ source=frames,
2450
+ settings=inference_settings,
2451
+ )
2452
+
2453
+ if inference_request.tracker is not None:
2454
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2455
+
2456
+ predictions = [
2457
+ Prediction(
2458
+ ann,
2459
+ model_meta=self.model_meta,
2460
+ frame_index=frame_index,
2461
+ video_id=video_info.id,
2462
+ dataset_id=video_info.dataset_id,
2463
+ project_id=video_info.project_id,
2464
+ )
2465
+ for ann, frame_index in zip(anns, batch)
2466
+ ]
2467
+ for pred, this_slides_data in zip(predictions, slides_data):
2468
+ pred.extra_data["slides_data"] = this_slides_data
2469
+ uploader.put(predictions)
2470
+ video_ann_json = None
2471
+ if inference_request.tracker is not None:
2164
2472
  inference_request.set_stage("Postprocess...", 0, 1)
2165
- video_ann_json = self._tracker.video_annotation.to_json()
2473
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2166
2474
  inference_request.done()
2167
2475
  inference_request.final_result = {"video_ann": video_ann_json}
2168
2476
  return video_ann_json
@@ -2188,10 +2496,9 @@ class Inference:
2188
2496
  upload_mode = state.get("upload_mode", None)
2189
2497
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
2190
2498
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
2191
- iou_merge_threshold = 0.7
2499
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
2192
2500
  cache_project_on_model = state.get("cache_project_on_model", False)
2193
2501
 
2194
- project_info = api.project.get_info_by_id(project_id)
2195
2502
  inference_request.context.setdefault("project_info", {})[project_id] = project_info
2196
2503
  dataset_ids = state.get("dataset_ids", None)
2197
2504
  if dataset_ids is None:
@@ -2226,7 +2533,11 @@ class Inference:
2226
2533
 
2227
2534
  if cache_project_on_model:
2228
2535
  download_to_cache(
2229
- api, project_info.id, datasets_infos, progress_cb=inference_request.done
2536
+ api,
2537
+ project_info.id,
2538
+ datasets_infos,
2539
+ progress_cb=inference_request.done,
2540
+ skip_create_readme=True,
2230
2541
  )
2231
2542
 
2232
2543
  images_infos_dict = {}
@@ -2235,20 +2546,9 @@ class Inference:
2235
2546
  if not cache_project_on_model:
2236
2547
  inference_request.done(dataset_info.items_count)
2237
2548
 
2238
- def _download_images(datasets_infos: List[DatasetInfo]):
2239
- for dataset_info in datasets_infos:
2240
- image_ids = [image_info.id for image_info in images_infos_dict[dataset_info.id]]
2241
- with ThreadPoolExecutor(max(8, min(batch_size, 64))) as executor:
2242
- for image_id in image_ids:
2243
- executor.submit(
2244
- self.cache.download_image,
2245
- api,
2246
- image_id,
2247
- )
2248
-
2249
- if not cache_project_on_model:
2250
- # start downloading in parallel
2251
- threading.Thread(target=_download_images, args=[datasets_infos], daemon=True).start()
2549
+ def download_f(item: int):
2550
+ self.cache.download_image(api, item)
2551
+ return item
2252
2552
 
2253
2553
  _upload_predictions = partial(
2254
2554
  self.upload_predictions,
@@ -2263,7 +2563,9 @@ class Inference:
2263
2563
  )
2264
2564
 
2265
2565
  _add_results_to_request = partial(
2266
- self.add_results_to_request, inference_request=inference_request
2566
+ self.add_results_to_request,
2567
+ inference_request=inference_request,
2568
+ progress_cb=inference_request.done,
2267
2569
  )
2268
2570
 
2269
2571
  if upload_mode is None:
@@ -2271,57 +2573,78 @@ class Inference:
2271
2573
  else:
2272
2574
  upload_f = _upload_predictions
2273
2575
 
2576
+ download_workers = max(8, min(batch_size, 64))
2274
2577
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, inference_progress_total)
2275
2578
  with Uploader(upload_f, logger=logger) as uploader:
2276
- for dataset_info in datasets_infos:
2277
- for images_infos_batch in batched(
2278
- images_infos_dict[dataset_info.id], batch_size=batch_size
2279
- ):
2280
- if inference_request.is_stopped():
2281
- logger.debug(
2282
- f"Cancelling inference project...",
2283
- extra={"inference_request_uuid": inference_request.uuid},
2284
- )
2285
- return
2286
- if uploader.has_exception():
2287
- exception = uploader.exception
2288
- raise exception
2289
- if cache_project_on_model:
2290
- images_paths, _ = zip(
2291
- *read_from_cached_project(
2292
- project_info.id,
2293
- dataset_info.name,
2294
- [ii.name for ii in images_infos_batch],
2579
+ with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
2580
+ for images in images_infos_dict.values():
2581
+ for image in images:
2582
+ downloader.put(image.id)
2583
+ downloader.next(100)
2584
+ for dataset_info in datasets_infos:
2585
+ for images_infos_batch in batched(
2586
+ images_infos_dict[dataset_info.id], batch_size=batch_size
2587
+ ):
2588
+ if uploader.has_exception():
2589
+ exception = uploader.exception
2590
+ raise exception
2591
+ if inference_request.is_stopped():
2592
+ logger.debug(
2593
+ f"Cancelling inference project...",
2594
+ extra={"inference_request_uuid": inference_request.uuid},
2295
2595
  )
2596
+ return
2597
+ if inference_request.is_paused():
2598
+ logger.info("Inference request is paused. Waiting...")
2599
+ while inference_request.is_paused():
2600
+ if (
2601
+ inference_request.paused_for()
2602
+ > inference_request.PAUSE_SLEEP_MAX_WAIT
2603
+ ):
2604
+ logger.info(
2605
+ "Inference request has been paused for too long. Cancelling..."
2606
+ )
2607
+ raise RuntimeError(
2608
+ "Inference request cancelled due to long pause."
2609
+ )
2610
+ time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
2611
+ if cache_project_on_model:
2612
+ images_paths, _ = zip(
2613
+ *read_from_cached_project(
2614
+ project_info.id,
2615
+ dataset_info.name,
2616
+ [ii.name for ii in images_infos_batch],
2617
+ )
2618
+ )
2619
+ images_nps = [sly_image.read(img_path) for img_path in images_paths]
2620
+ else:
2621
+ images_nps = self.cache.download_images(
2622
+ api,
2623
+ dataset_info.id,
2624
+ [info.id for info in images_infos_batch],
2625
+ return_images=True,
2626
+ )
2627
+ downloader.next(len(images_infos_batch))
2628
+ anns, slides_data = self._inference_auto(
2629
+ source=images_nps,
2630
+ settings=inference_settings,
2296
2631
  )
2297
- images_nps = [sly_image.read(img_path) for img_path in images_paths]
2298
- else:
2299
- images_nps = self.cache.download_images(
2300
- api,
2301
- dataset_info.id,
2302
- [info.id for info in images_infos_batch],
2303
- return_images=True,
2304
- )
2305
- anns, slides_data = self._inference_auto(
2306
- source=images_nps,
2307
- settings=inference_settings,
2308
- )
2309
- predictions = [
2310
- Prediction(
2311
- ann,
2312
- model_meta=self.model_meta,
2313
- image_id=image_info.id,
2314
- name=image_info.name,
2315
- dataset_id=dataset_info.id,
2316
- project_id=dataset_info.project_id,
2317
- image_name=image_info.name,
2318
- )
2319
- for ann, image_info in zip(anns, images_infos_batch)
2320
- ]
2321
- for pred, this_slides_data in zip(predictions, slides_data):
2322
- pred.extra_data["slides_data"] = this_slides_data
2632
+ predictions = [
2633
+ Prediction(
2634
+ ann,
2635
+ model_meta=self.model_meta,
2636
+ image_id=image_info.id,
2637
+ name=image_info.name,
2638
+ dataset_id=dataset_info.id,
2639
+ project_id=dataset_info.project_id,
2640
+ image_name=image_info.name,
2641
+ )
2642
+ for ann, image_info in zip(anns, images_infos_batch)
2643
+ ]
2644
+ for pred, this_slides_data in zip(predictions, slides_data):
2645
+ pred.extra_data["slides_data"] = this_slides_data
2323
2646
 
2324
- uploader.put(predictions)
2647
+ uploader.put(predictions)
2325
2648
 
2326
2649
  def _run_speedtest(
2327
2650
  self,
@@ -2364,7 +2687,13 @@ class Inference:
2364
2687
  inference_request.done()
2365
2688
 
2366
2689
  if cache_project_on_model:
2367
- download_to_cache(api, project_id, datasets_infos, progress_cb=inference_request.done)
2690
+ download_to_cache(
2691
+ api,
2692
+ project_id,
2693
+ datasets_infos,
2694
+ progress_cb=inference_request.done,
2695
+ skip_create_readme=True,
2696
+ )
2368
2697
 
2369
2698
  inference_request.set_stage("warmup", 0, num_warmup)
2370
2699
 
@@ -2485,6 +2814,11 @@ class Inference:
2485
2814
  def _freeze_model(self):
2486
2815
  if self._model_frozen or not self._model_served:
2487
2816
  return
2817
+
2818
+ if not self._deploy_params:
2819
+ logger.warning("Deploy params are not set, cannot freeze the model.")
2820
+ return
2821
+
2488
2822
  logger.debug("Freezing model...")
2489
2823
  runtime = self._deploy_params.get("runtime")
2490
2824
  if runtime and runtime.lower() != RuntimeType.PYTORCH.lower():
@@ -2524,7 +2858,6 @@ class Inference:
2524
2858
  timer.daemon = True
2525
2859
  timer.start()
2526
2860
  self._freeze_timer = timer
2527
- logger.debug("Model will be frozen in %s seconds due to inactivity.", self._inactivity_timeout)
2528
2861
 
2529
2862
  def _set_served_callback(self):
2530
2863
  self._model_served = True
@@ -2637,6 +2970,10 @@ class Inference:
2637
2970
  for prediction in predictions:
2638
2971
  ds_predictions[prediction.dataset_id].append(prediction)
2639
2972
 
2973
+ def update_labeling_status(ann: Annotation) -> Annotation:
2974
+ for label in ann.labels:
2975
+ label.status = LabelingStatus.AUTO
2976
+
2640
2977
  def _new_name(image_info: ImageInfo):
2641
2978
  name = Path(image_info.name)
2642
2979
  stem = name.stem
@@ -2669,10 +3006,10 @@ class Inference:
2669
3006
  context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
2670
3007
  return created_dataset.id
2671
3008
 
2672
- created_names = []
2673
3009
  if context is None:
2674
3010
  context = {}
2675
3011
  for dataset_id, preds in ds_predictions.items():
3012
+ created_names = set()
2676
3013
  if dst_project_id is not None:
2677
3014
  # upload to the destination project
2678
3015
  dst_dataset_id = _get_or_create_dataset(
@@ -2712,8 +3049,15 @@ class Inference:
2712
3049
  iou=iou_merge_threshold,
2713
3050
  meta=project_meta,
2714
3051
  )
3052
+
3053
+ # Update labeling status of new predictions before upload
3054
+ anns_with_nn_flags = []
2715
3055
  for pred, ann in zip(preds, anns):
3056
+ update_labeling_status(ann)
2716
3057
  pred.annotation = ann
3058
+ anns_with_nn_flags.append(ann)
3059
+
3060
+ anns = anns_with_nn_flags
2717
3061
 
2718
3062
  context.setdefault("image_info", {})
2719
3063
  missing = [
@@ -2741,7 +3085,7 @@ class Inference:
2741
3085
  with_annotations=False,
2742
3086
  save_source_date=False,
2743
3087
  )
2744
- created_names.extend([image_info.name for image_info in dst_image_infos])
3088
+ created_names.update([image_info.name for image_info in dst_image_infos])
2745
3089
  api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
2746
3090
  else:
2747
3091
  # upload to the source dataset
@@ -2778,7 +3122,10 @@ class Inference:
2778
3122
  iou=iou_merge_threshold,
2779
3123
  meta=project_meta,
2780
3124
  )
3125
+
3126
+ # Update labeling status of predicted labels before optional merge
2781
3127
  for pred, ann in zip(preds, anns):
3128
+ update_labeling_status(ann)
2782
3129
  pred.annotation = ann
2783
3130
 
2784
3131
  if upload_mode in ["iou_merge", "append"]:
@@ -2814,11 +3161,89 @@ class Inference:
2814
3161
  inference_request.add_results(results)
2815
3162
 
2816
3163
  def add_results_to_request(
2817
- self, predictions: List[Prediction], inference_request: InferenceRequest
3164
+ self, predictions: List[Prediction], inference_request: InferenceRequest, progress_cb=None
2818
3165
  ):
2819
3166
  results = self._format_output(predictions)
2820
3167
  inference_request.add_results(results)
2821
- inference_request.done(len(results))
3168
+ if progress_cb:
3169
+ progress_cb(len(results))
3170
+
3171
+ def upload_predictions_to_video(
3172
+ self,
3173
+ predictions: List[Prediction],
3174
+ api: Api,
3175
+ video_info: VideoInfo,
3176
+ track_id: str,
3177
+ context: Dict,
3178
+ progress_cb=None,
3179
+ inference_request: InferenceRequest = None,
3180
+ ):
3181
+ key_id_map = KeyIdMap()
3182
+ project_meta = context.get("project_meta", None)
3183
+ if project_meta is None:
3184
+ project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
3185
+ context["project_meta"] = project_meta
3186
+ meta_changed = False
3187
+ for prediction in predictions:
3188
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
3189
+ project_meta, prediction.annotation, None
3190
+ )
3191
+ prediction.annotation = ann
3192
+ meta_changed = meta_changed or meta_changed_
3193
+ if meta_changed:
3194
+ project_meta = api.project.update_meta(video_info.project_id, project_meta)
3195
+ context["project_meta"] = project_meta
3196
+
3197
+ figure_data_by_object_id = defaultdict(list)
3198
+
3199
+ tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
3200
+ new_tracks: Dict[int, VideoObject] = {}
3201
+ for prediction in predictions:
3202
+ annotation = prediction.annotation
3203
+ tracks = annotation.custom_data
3204
+ for track, label in zip(tracks, annotation.labels):
3205
+ if track not in tracks_to_object_ids and track not in new_tracks:
3206
+ video_object = VideoObject(obj_class=label.obj_class)
3207
+ new_tracks[track] = video_object
3208
+ if new_tracks:
3209
+ tracks, video_objects = zip(*new_tracks.items())
3210
+ added_object_ids = api.video.object.append_bulk(
3211
+ video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
3212
+ )
3213
+ for track, object_id in zip(tracks, added_object_ids):
3214
+ tracks_to_object_ids[track] = object_id
3215
+ for prediction in predictions:
3216
+ annotation = prediction.annotation
3217
+ tracks = annotation.custom_data
3218
+ for track, label in zip(tracks, annotation.labels):
3219
+ object_id = tracks_to_object_ids[track]
3220
+ figure_data_by_object_id[object_id].append(
3221
+ {
3222
+ ApiField.OBJECT_ID: object_id,
3223
+ ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
3224
+ ApiField.GEOMETRY: label.geometry.to_json(),
3225
+ ApiField.META: {ApiField.FRAME: prediction.frame_index},
3226
+ ApiField.TRACK_ID: track_id,
3227
+ }
3228
+ )
3229
+
3230
+ for object_id, figures_data in figure_data_by_object_id.items():
3231
+ figures_keys = [uuid.uuid4() for _ in figures_data]
3232
+ api.video.figure._append_bulk(
3233
+ entity_id=video_info.id,
3234
+ figures_json=figures_data,
3235
+ figures_keys=figures_keys,
3236
+ key_id_map=key_id_map,
3237
+ )
3238
+ logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
3239
+ if progress_cb:
3240
+ progress_cb(len(predictions))
3241
+ if inference_request is not None:
3242
+ results = self._format_output(predictions)
3243
+ for result in results:
3244
+ result["annotation"] = None
3245
+ result["data"] = None
3246
+ inference_request.add_results(results)
2822
3247
 
2823
3248
  def serve(self):
2824
3249
  if not self._use_gui and not self._is_cli_deploy:
@@ -2902,7 +3327,7 @@ class Inference:
2902
3327
 
2903
3328
  if not self._use_gui:
2904
3329
  Progress("Model deployed", 1).iter_done_report()
2905
- else:
3330
+ elif self.api is not None:
2906
3331
  autostart_func()
2907
3332
 
2908
3333
  @server.exception_handler(HTTPException)
@@ -2929,6 +3354,11 @@ class Inference:
2929
3354
  def get_session_info(response: Response):
2930
3355
  return self.get_info()
2931
3356
 
3357
+ @server.post("/get_tracking_settings")
3358
+ @self._check_serve_before_call
3359
+ def get_tracking_settings(response: Response):
3360
+ return self.get_tracking_settings()
3361
+
2932
3362
  @server.post("/get_custom_inference_settings")
2933
3363
  def get_custom_inference_settings():
2934
3364
  return {"settings": self.custom_inference_settings}
@@ -3212,6 +3642,22 @@ class Inference:
3212
3642
  "inference_request_uuid": inference_request.uuid,
3213
3643
  }
3214
3644
 
3645
+ @server.post("/tracking_by_detection")
3646
+ def tracking_by_detection(response: Response, request: Request):
3647
+ state = request.state.state
3648
+ context = request.state.context
3649
+ state.update(context)
3650
+ if state.get("tracker") is None:
3651
+ state["tracker"] = "botsort"
3652
+
3653
+ logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
3654
+ self.validate_inference_state(state)
3655
+ api = self.api_from_request(request)
3656
+ inference_request, future = self.inference_requests_manager.schedule_task(
3657
+ self._tracking_by_detection, api, state
3658
+ )
3659
+ return {"message": "Track task started."}
3660
+
3215
3661
  @server.post("/inference_project_id_async")
3216
3662
  def inference_project_id_async(response: Response, request: Request):
3217
3663
  state = request.state.state
@@ -3275,10 +3721,7 @@ class Inference:
3275
3721
  data = {**inference_request.to_json(), **log_extra}
3276
3722
  if inference_request.stage != InferenceRequest.Stage.INFERENCE:
3277
3723
  data["progress"] = {"current": 0, "total": 1}
3278
- logger.debug(
3279
- f"Sending inference progress with uuid:",
3280
- extra=data,
3281
- )
3724
+ logger.debug(f"Sending inference progress with uuid:", extra=data)
3282
3725
  return data
3283
3726
 
3284
3727
  @server.post(f"/pop_inference_results")
@@ -4135,20 +4578,20 @@ class Inference:
4135
4578
  self._args.draw,
4136
4579
  )
4137
4580
 
4138
- def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
4581
+ def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
4139
4582
  updated_anns = []
4140
4583
  for frame, ann in zip(frames, anns):
4141
- matches = self._tracker.update(frame, ann)
4584
+ matches = tracker.update(frame, ann)
4142
4585
  track_ids = [match["track_id"] for match in matches]
4143
4586
  tracked_labels = [match["label"] for match in matches]
4144
-
4587
+
4145
4588
  filtered_annotation = ann.clone(
4146
4589
  labels=tracked_labels,
4147
4590
  custom_data=track_ids
4148
4591
  )
4149
4592
  updated_anns.append(filtered_annotation)
4150
4593
  return updated_anns
4151
-
4594
+
4152
4595
  def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
4153
4596
  if model_source == ModelSource.PRETRAINED:
4154
4597
  checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
@@ -4198,62 +4641,78 @@ class Inference:
4198
4641
  return
4199
4642
  self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
4200
4643
 
4644
+ def export_onnx(self, deploy_params: dict):
4645
+ raise NotImplementedError("Have to be implemented in child class after inheritance")
4201
4646
 
4202
- def _exclude_duplicated_predictions(
4203
- api: Api,
4204
- pred_anns: List[Annotation],
4205
- dataset_id: int,
4206
- gt_image_ids: List[int],
4207
- iou: float = None,
4208
- meta: Optional[ProjectMeta] = None,
4647
+ def export_tensorrt(self, deploy_params: dict):
4648
+ raise NotImplementedError("Have to be implemented in child class after inheritance")
4649
+
4650
+
4651
+ def _filter_duplicated_predictions_from_ann_cpu(
4652
+ gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
4209
4653
  ):
4210
4654
  """
4211
- Filter out predictions that significantly overlap with ground truth (GT) objects.
4655
+ Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
4656
+ Uses Shapely for geometric operations.
4212
4657
 
4213
- This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4214
- - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4215
- - Gets ProjectMeta object if not provided
4216
- - Downloads GT annotations for the specified image IDs
4217
- - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4658
+ Args:
4659
+ pred_ann: Predicted annotation object
4660
+ gt_ann: Ground truth annotation object
4661
+ iou_threshold: IoU threshold for filtering
4218
4662
 
4219
- :param api: Supervisely API object
4220
- :type api: Api
4221
- :param pred_anns: List of Annotation objects containing predictions
4222
- :type pred_anns: List[Annotation]
4223
- :param dataset_id: ID of the dataset containing the images
4224
- :type dataset_id: int
4225
- :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4226
- :type gt_image_ids: List[int]
4227
- :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4228
- ground truth box of the same class will be removed. None if no filtering is needed
4229
- :type iou: Optional[float]
4230
- :param meta: ProjectMeta object
4231
- :type meta: Optional[ProjectMeta]
4232
- :return: List of Annotation objects containing filtered predictions
4233
- :rtype: List[Annotation]
4234
-
4235
- Notes:
4236
- ------
4237
- - Requires PyTorch and torchvision for IoU calculations
4238
- - This method is useful for identifying new objects that aren't already annotated in the ground truth
4663
+ Returns:
4664
+ New annotation with filtered labels
4239
4665
  """
4240
- if isinstance(iou, float) and 0 < iou <= 1:
4241
- if meta is None:
4242
- ds = api.dataset.get_info_by_id(dataset_id)
4243
- meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4244
- gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4245
- gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4246
- for i in range(0, len(pred_anns)):
4247
- before = len(pred_anns[i].labels)
4248
- with Timer() as timer:
4249
- pred_anns[i] = _filter_duplicated_predictions_from_ann(
4250
- gt_anns[i], pred_anns[i], iou
4251
- )
4252
- after = len(pred_anns[i].labels)
4253
- logger.debug(
4254
- f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4255
- )
4256
- return pred_anns
4666
+ if not iou_threshold:
4667
+ return pred_ann
4668
+
4669
+ from shapely.geometry import box
4670
+
4671
+ def calculate_iou(geom1: Geometry, geom2: Geometry):
4672
+ """Calculate IoU between two geometries using Shapely."""
4673
+ bbox1 = geom1.to_bbox()
4674
+ bbox2 = geom2.to_bbox()
4675
+
4676
+ box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
4677
+ box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
4678
+
4679
+ intersection = box1.intersection(box2).area
4680
+ union = box1.union(box2).area
4681
+
4682
+ return intersection / union if union > 0 else 0.0
4683
+
4684
+ new_labels = []
4685
+ pred_cls_bboxes = defaultdict(list)
4686
+ for label in pred_ann.labels:
4687
+ name_shape = (label.obj_class.name, label.geometry.name())
4688
+ pred_cls_bboxes[name_shape].append(label)
4689
+
4690
+ gt_cls_bboxes = defaultdict(list)
4691
+ for label in gt_ann.labels:
4692
+ name_shape = (label.obj_class.name, label.geometry.name())
4693
+ if name_shape not in pred_cls_bboxes:
4694
+ continue
4695
+ gt_cls_bboxes[name_shape].append(label)
4696
+
4697
+ for name_shape, pred in pred_cls_bboxes.items():
4698
+ gt = gt_cls_bboxes[name_shape]
4699
+ if len(gt) == 0:
4700
+ new_labels.extend(pred)
4701
+ continue
4702
+
4703
+ for pred_label in pred:
4704
+ # Check if this prediction has IoU < threshold with ALL GT boxes
4705
+ keep = True
4706
+ for gt_label in gt:
4707
+ iou = calculate_iou(pred_label.geometry, gt_label.geometry)
4708
+ if iou >= iou_threshold:
4709
+ keep = False
4710
+ break
4711
+
4712
+ if keep:
4713
+ new_labels.append(pred_label)
4714
+
4715
+ return pred_ann.clone(labels=new_labels)
4257
4716
 
4258
4717
 
4259
4718
  def _filter_duplicated_predictions_from_ann(
@@ -4284,13 +4743,15 @@ def _filter_duplicated_predictions_from_ann(
4284
4743
  - Predictions with classes not present in ground truth will be kept
4285
4744
  - Requires PyTorch and torchvision for IoU calculations
4286
4745
  """
4746
+ if not iou_threshold:
4747
+ return pred_ann
4287
4748
 
4288
4749
  try:
4289
4750
  import torch
4290
4751
  from torchvision.ops import box_iou
4291
4752
 
4292
4753
  except ImportError:
4293
- raise ImportError("Please install PyTorch and torchvision to use this feature.")
4754
+ return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
4294
4755
 
4295
4756
  def _to_tensor(geom):
4296
4757
  return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
@@ -4298,16 +4759,18 @@ def _filter_duplicated_predictions_from_ann(
4298
4759
  new_labels = []
4299
4760
  pred_cls_bboxes = defaultdict(list)
4300
4761
  for label in pred_ann.labels:
4301
- pred_cls_bboxes[label.obj_class.name].append(label)
4762
+ name_shape = (label.obj_class.name, label.geometry.name())
4763
+ pred_cls_bboxes[name_shape].append(label)
4302
4764
 
4303
4765
  gt_cls_bboxes = defaultdict(list)
4304
4766
  for label in gt_ann.labels:
4305
- if label.obj_class.name not in pred_cls_bboxes:
4767
+ name_shape = (label.obj_class.name, label.geometry.name())
4768
+ if name_shape not in pred_cls_bboxes:
4306
4769
  continue
4307
- gt_cls_bboxes[label.obj_class.name].append(label)
4770
+ gt_cls_bboxes[name_shape].append(label)
4308
4771
 
4309
- for name, pred in pred_cls_bboxes.items():
4310
- gt = gt_cls_bboxes[name]
4772
+ for name_shape, pred in pred_cls_bboxes.items():
4773
+ gt = gt_cls_bboxes[name_shape]
4311
4774
  if len(gt) == 0:
4312
4775
  new_labels.extend(pred)
4313
4776
  continue
@@ -4321,6 +4784,63 @@ def _filter_duplicated_predictions_from_ann(
4321
4784
  return pred_ann.clone(labels=new_labels)
4322
4785
 
4323
4786
 
4787
+ def _exclude_duplicated_predictions(
4788
+ api: Api,
4789
+ pred_anns: List[Annotation],
4790
+ dataset_id: int,
4791
+ gt_image_ids: List[int],
4792
+ iou: float = None,
4793
+ meta: Optional[ProjectMeta] = None,
4794
+ ):
4795
+ """
4796
+ Filter out predictions that significantly overlap with ground truth (GT) objects.
4797
+
4798
+ This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4799
+ - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4800
+ - Gets ProjectMeta object if not provided
4801
+ - Downloads GT annotations for the specified image IDs
4802
+ - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4803
+
4804
+ :param api: Supervisely API object
4805
+ :type api: Api
4806
+ :param pred_anns: List of Annotation objects containing predictions
4807
+ :type pred_anns: List[Annotation]
4808
+ :param dataset_id: ID of the dataset containing the images
4809
+ :type dataset_id: int
4810
+ :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4811
+ :type gt_image_ids: List[int]
4812
+ :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4813
+ ground truth box of the same class will be removed. None if no filtering is needed
4814
+ :type iou: Optional[float]
4815
+ :param meta: ProjectMeta object
4816
+ :type meta: Optional[ProjectMeta]
4817
+ :return: List of Annotation objects containing filtered predictions
4818
+ :rtype: List[Annotation]
4819
+
4820
+ Notes:
4821
+ ------
4822
+ - Requires PyTorch and torchvision for IoU calculations
4823
+ - This method is useful for identifying new objects that aren't already annotated in the ground truth
4824
+ """
4825
+ if isinstance(iou, float) and 0 < iou <= 1:
4826
+ if meta is None:
4827
+ ds = api.dataset.get_info_by_id(dataset_id)
4828
+ meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4829
+ gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4830
+ gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4831
+ for i in range(0, len(pred_anns)):
4832
+ before = len(pred_anns[i].labels)
4833
+ with Timer() as timer:
4834
+ pred_anns[i] = _filter_duplicated_predictions_from_ann(
4835
+ gt_anns[i], pred_anns[i], iou
4836
+ )
4837
+ after = len(pred_anns[i].labels)
4838
+ logger.debug(
4839
+ f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4840
+ )
4841
+ return pred_anns
4842
+
4843
+
4324
4844
  def _get_log_extra_for_inference_request(
4325
4845
  inference_request_uuid, inference_request: Union[InferenceRequest, dict]
4326
4846
  ):
@@ -4347,8 +4867,8 @@ def _get_log_extra_for_inference_request(
4347
4867
  "has_result": inference_request.final_result is not None,
4348
4868
  "pending_results": inference_request.pending_num(),
4349
4869
  "exception": inference_request.exception_json(),
4350
- "result": inference_request._final_result,
4351
4870
  "preparing_progress": progress,
4871
+ "result": inference_request.final_result is not None, # for backward compatibility
4352
4872
  }
4353
4873
  return log_extra
4354
4874
 
@@ -4428,7 +4948,7 @@ def get_gpu_count():
4428
4948
  gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
4429
4949
  return gpu_count
4430
4950
  except (subprocess.CalledProcessError, FileNotFoundError) as exc:
4431
- logger.warn("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4951
+ logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4432
4952
  return 0
4433
4953
 
4434
4954
 
@@ -4608,7 +5128,180 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suf
4608
5128
  img_tags = None
4609
5129
  if not any_label_updated:
4610
5130
  labels = None
4611
- ann = ann.clone(img_tags=TagCollection(img_tags))
5131
+ ann = ann.clone(img_tags=img_tags)
5132
+ return meta, ann, meta_changed
5133
+
5134
+
5135
+ def update_meta_and_ann_for_video_annotation(
5136
+ meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
5137
+ ):
5138
+ """Update project meta and annotation to match each other
5139
+ If obj class or tag meta from annotation conflicts with project meta
5140
+ add suffix to obj class or tag meta.
5141
+ Return tuple of updated project meta, annotation and boolean flag if meta was changed.
5142
+ """
5143
+ obj_classes_suffixes = ["_nn"]
5144
+ tag_meta_suffixes = ["_nn"]
5145
+ if model_prediction_suffix is not None:
5146
+ obj_classes_suffixes = [model_prediction_suffix]
5147
+ tag_meta_suffixes = [model_prediction_suffix]
5148
+ logger.debug(
5149
+ f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
5150
+ )
5151
+ logger.debug("source meta", extra={"meta": meta.to_json()})
5152
+ meta_changed = False
5153
+
5154
+ # meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
5155
+ # if replaced_classes_in_meta:
5156
+ # meta_changed = True
5157
+ # logger.warning(
5158
+ # "Some classes names were fixed in project meta",
5159
+ # extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
5160
+ # )
5161
+
5162
+ new_objects: List[VideoObject] = []
5163
+ new_figures: List[VideoFigure] = []
5164
+ any_object_updated = False
5165
+ for video_object in ann.objects:
5166
+ this_object_figures = [
5167
+ figure for figure in ann.figures if figure.video_object.key() == video_object.key()
5168
+ ]
5169
+ this_object_changed = False
5170
+ original_obj_class_name = video_object.obj_class.name
5171
+ suffix_found = False
5172
+ for suffix in ["", *obj_classes_suffixes]:
5173
+ obj_class = video_object.obj_class
5174
+ obj_class_name = obj_class.name + suffix
5175
+ if suffix:
5176
+ obj_class = obj_class.clone(name=obj_class_name)
5177
+ video_object = video_object.clone(obj_class=obj_class)
5178
+ any_object_updated = True
5179
+ this_object_changed = True
5180
+ meta_obj_class = meta.get_obj_class(obj_class_name)
5181
+ if meta_obj_class is None:
5182
+ # obj class is not in meta, add it with suffix
5183
+ meta = meta.add_obj_class(obj_class)
5184
+ new_objects.append(video_object)
5185
+ meta_changed = True
5186
+ suffix_found = True
5187
+ break
5188
+ elif (
5189
+ meta_obj_class.geometry_type.geometry_name()
5190
+ == video_object.obj_class.geometry_type.geometry_name()
5191
+ ):
5192
+ # if object geometry is the same as in meta, use meta obj class
5193
+ video_object = video_object.clone(obj_class=meta_obj_class)
5194
+ new_objects.append(video_object)
5195
+ suffix_found = True
5196
+ any_object_updated = True
5197
+ this_object_changed = True
5198
+ break
5199
+ elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
5200
+ # if meta obj class is AnyGeometry, use it in object
5201
+ video_object = video_object.clone(obj_class=meta_obj_class)
5202
+ new_objects.append(video_object)
5203
+ suffix_found = True
5204
+ any_object_updated = True
5205
+ this_object_changed = True
5206
+ break
5207
+ if not suffix_found:
5208
+ # if no suffix found, raise error
5209
+ raise ValueError(
5210
+ f"Can't add obj class {original_obj_class_name} to project meta. "
5211
+ "Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
5212
+ "Please check if model geometry type is compatible with existing obj classes."
5213
+ )
5214
+ elif this_object_changed:
5215
+ this_object_figures = [
5216
+ figure.clone(video_object=video_object) for figure in this_object_figures
5217
+ ]
5218
+ new_figures.extend(this_object_figures)
5219
+ if any_object_updated:
5220
+ frames_figures = {}
5221
+ for figure in new_figures:
5222
+ frames_figures.setdefault(figure.frame_index, []).append(figure)
5223
+ new_frames = FrameCollection(
5224
+ [
5225
+ Frame(index=frame_index, figures=figures)
5226
+ for frame_index, figures in frames_figures.items()
5227
+ ]
5228
+ )
5229
+ ann = ann.clone(objects=new_objects, frames=new_frames)
5230
+
5231
+ # check if tag metas are in project meta
5232
+ # if not, add them with suffix
5233
+ ann_tag_metas: Dict[str, TagMeta] = {}
5234
+ for video_object in ann.objects:
5235
+ for tag in video_object.tags:
5236
+ tag_name = tag.meta.name
5237
+ if tag_name not in ann_tag_metas:
5238
+ ann_tag_metas[tag_name] = tag.meta
5239
+ for tag in ann.tags:
5240
+ tag_name = tag.meta.name
5241
+ if tag_name not in ann_tag_metas:
5242
+ ann_tag_metas[tag_name] = tag.meta
5243
+
5244
+ changed_tag_metas = {}
5245
+ for ann_tag_meta in ann_tag_metas.values():
5246
+ meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
5247
+ if meta_tag_meta is None:
5248
+ meta = meta.add_tag_meta(ann_tag_meta)
5249
+ meta_changed = True
5250
+ elif not meta_tag_meta.is_compatible(ann_tag_meta):
5251
+ suffix_found = False
5252
+ for suffix in tag_meta_suffixes:
5253
+ new_tag_meta_name = ann_tag_meta.name + suffix
5254
+ meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
5255
+ if meta_tag_meta is None:
5256
+ new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
5257
+ meta = meta.add_tag_meta(new_tag_meta)
5258
+ changed_tag_metas[ann_tag_meta.name] = new_tag_meta
5259
+ meta_changed = True
5260
+ suffix_found = True
5261
+ break
5262
+ if meta_tag_meta.is_compatible(ann_tag_meta):
5263
+ changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
5264
+ suffix_found = True
5265
+ break
5266
+ if not suffix_found:
5267
+ raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
5268
+
5269
+ if changed_tag_metas:
5270
+ objects = []
5271
+ any_object_updated = False
5272
+ for video_object in ann.objects:
5273
+ any_tag_updated = False
5274
+ object_tags = []
5275
+ for tag in video_object.tags:
5276
+ if tag.meta.name in changed_tag_metas:
5277
+ object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5278
+ any_tag_updated = True
5279
+ else:
5280
+ object_tags.append(tag)
5281
+ if any_tag_updated:
5282
+ video_object = video_object.clone(tags=TagCollection(object_tags))
5283
+ any_object_updated = True
5284
+ objects.append(video_object)
5285
+
5286
+ video_tags = []
5287
+ any_tag_updated = False
5288
+ for tag in ann.tags:
5289
+ if tag.meta.name in changed_tag_metas:
5290
+ video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5291
+ any_tag_updated = True
5292
+ else:
5293
+ video_tags.append(tag)
5294
+ if any_tag_updated or any_object_updated:
5295
+ if any_tag_updated:
5296
+ video_tags = VideoTagCollection(video_tags)
5297
+ else:
5298
+ video_tags = None
5299
+ if any_object_updated:
5300
+ objects = VideoObjectCollection(objects)
5301
+ else:
5302
+ objects = None
5303
+ ann = ann.clone(tags=video_tags, objects=objects)
5304
+
4612
5305
  return meta, ann, meta_changed
4613
5306
 
4614
5307
 
@@ -4722,7 +5415,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
4722
5415
  return data[key]
4723
5416
  return None
4724
5417
 
4725
- def torch_load_safe(checkpoint_path: str, device:str = "cpu"):
5418
+
5419
+ def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
4726
5420
  import torch # pylint: disable=import-error
4727
5421
 
4728
5422
  # TODO: handle torch.load(weights_only=True) - change in torch 2.6.0