supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ import subprocess
11
11
  import tempfile
12
12
  import threading
13
13
  import time
14
+ import uuid
14
15
  from collections import OrderedDict, defaultdict
15
16
  from concurrent.futures import ThreadPoolExecutor
16
17
  from dataclasses import asdict, dataclass
@@ -19,6 +20,7 @@ from pathlib import Path
19
20
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
21
  from urllib.request import urlopen
21
22
 
23
+ import _pickle
22
24
  import numpy as np
23
25
  import requests
24
26
  import uvicorn
@@ -34,7 +36,6 @@ import supervisely.io.env as sly_env
34
36
  import supervisely.io.fs as sly_fs
35
37
  import supervisely.io.json as sly_json
36
38
  import supervisely.nn.inference.gui as GUI
37
- from supervisely.nn.experiments import ExperimentInfo
38
39
  from supervisely import DatasetInfo, batched
39
40
  from supervisely._utils import (
40
41
  add_callback,
@@ -45,13 +46,14 @@ from supervisely._utils import (
45
46
  rand_str,
46
47
  )
47
48
  from supervisely.annotation.annotation import Annotation
48
- from supervisely.annotation.label import Label
49
+ from supervisely.annotation.label import Label, LabelingStatus
49
50
  from supervisely.annotation.obj_class import ObjClass
50
51
  from supervisely.annotation.tag_collection import TagCollection
51
52
  from supervisely.annotation.tag_meta import TagMeta, TagValueType
52
53
  from supervisely.api.api import Api, ApiField
53
54
  from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
54
55
  from supervisely.api.image_api import ImageInfo
56
+ from supervisely.api.video.video_api import VideoInfo
55
57
  from supervisely.app.content import get_data_dir
56
58
  from supervisely.app.fastapi.subapp import (
57
59
  Application,
@@ -67,15 +69,17 @@ from supervisely.decorators.inference import (
67
69
  process_images_batch_sliding_window,
68
70
  )
69
71
  from supervisely.geometry.any_geometry import AnyGeometry
72
+ from supervisely.geometry.geometry import Geometry
70
73
  from supervisely.imaging.color import get_predefined_colors
71
74
  from supervisely.io.fs import list_files
75
+ from supervisely.nn.experiments import ExperimentInfo
72
76
  from supervisely.nn.inference.cache import InferenceImageCache
73
77
  from supervisely.nn.inference.inference_request import (
74
78
  InferenceRequest,
75
79
  InferenceRequestsManager,
76
80
  )
77
81
  from supervisely.nn.inference.uploader import Uploader
78
- from supervisely.nn.model.model_api import Prediction
82
+ from supervisely.nn.model.model_api import ModelAPI, Prediction
79
83
  from supervisely.nn.prediction_dto import Prediction as PredictionDTO
80
84
  from supervisely.nn.utils import (
81
85
  CheckpointInfo,
@@ -93,7 +97,18 @@ from supervisely.project.project_meta import ProjectMeta
93
97
  from supervisely.sly_logger import logger
94
98
  from supervisely.task.progress import Progress
95
99
  from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
96
- from supervisely.nn.model.model_api import ModelAPI
100
+ from supervisely.video_annotation.frame import Frame
101
+ from supervisely.video_annotation.frame_collection import FrameCollection
102
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
103
+ from supervisely.video_annotation.video_figure import VideoFigure
104
+ from supervisely.video_annotation.video_object import VideoObject
105
+ from supervisely.video_annotation.video_object_collection import VideoObjectCollection
106
+ from supervisely.video_annotation.video_tag_collection import VideoTagCollection
107
+ from supervisely.video_annotation.key_id_map import KeyIdMap
108
+ from supervisely.video_annotation.video_object_collection import (
109
+ VideoObject,
110
+ VideoObjectCollection,
111
+ )
97
112
 
98
113
  try:
99
114
  from typing import Literal
@@ -140,6 +155,7 @@ class Inference:
140
155
  """Default batch size for inference"""
141
156
  INFERENCE_SETTINGS: str = None
142
157
  """Path to file with custom inference settings"""
158
+ DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
143
159
 
144
160
  def __init__(
145
161
  self,
@@ -193,7 +209,6 @@ class Inference:
193
209
  self._task_id = None
194
210
  self._sliding_window_mode = sliding_window_mode
195
211
  self._autostart_delay_time = 5 * 60 # 5 min
196
- self._tracker = None
197
212
  self._hardware: str = None
198
213
  if custom_inference_settings is None:
199
214
  if self.INFERENCE_SETTINGS is not None:
@@ -383,7 +398,7 @@ class Inference:
383
398
  if m_name and m_name.lower() == model.lower():
384
399
  return m
385
400
  return None
386
-
401
+
387
402
  runtime = get_runtime(runtime)
388
403
  logger.debug(f"Runtime: {runtime}")
389
404
 
@@ -427,7 +442,7 @@ class Inference:
427
442
 
428
443
  device = "cuda" if torch.cuda.is_available() else "cpu"
429
444
  except Exception as e:
430
- logger.warn(
445
+ logger.warning(
431
446
  f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
432
447
  )
433
448
  device = "cpu"
@@ -863,13 +878,57 @@ class Inference:
863
878
  self.gui.download_progress.hide()
864
879
  return local_model_files
865
880
 
881
+ def _fallback_download_custom_model_pt(self, deploy_params: dict):
882
+ """
883
+ Downloads the PyTorch checkpoint from Team Files if TensorRT is failed to load.
884
+ """
885
+ team_id = sly_env.team_id()
886
+
887
+ checkpoint_name = sly_fs.get_file_name(deploy_params["model_files"]["checkpoint"])
888
+ artifacts_dir = deploy_params["model_info"]["artifacts_dir"]
889
+ checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
890
+ checkpoint_ext = sly_fs.get_file_ext(deploy_params["model_info"]["checkpoints"][0])
891
+
892
+ pt_checkpoint_name = f"{checkpoint_name}{checkpoint_ext}"
893
+ remote_checkpoint_path = os.path.join(checkpoints_dir, pt_checkpoint_name)
894
+ local_checkpoint_path = os.path.join(self.model_dir, pt_checkpoint_name)
895
+
896
+ file_info = self.api.file.get_info_by_path(team_id, remote_checkpoint_path)
897
+ file_size = file_info.sizeb
898
+ if self.gui is not None:
899
+ with self.gui.download_progress(
900
+ message=f"Fallback. Downloading PyTorch checkpoint: '{pt_checkpoint_name}'",
901
+ total=file_size,
902
+ unit="bytes",
903
+ unit_scale=True,
904
+ ) as download_pbar:
905
+ self.gui.download_progress.show()
906
+ self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path, progress_cb=download_pbar.update)
907
+ self.gui.download_progress.hide()
908
+ else:
909
+ self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path)
910
+
911
+ return local_checkpoint_path
912
+
913
+ def _remove_exported_checkpoints(self, checkpoint_path: str):
914
+ """
915
+ Removes the exported checkpoints for provided PyTorch checkpoint path.
916
+ """
917
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
918
+ onnx_path = checkpoint_path.replace(checkpoint_ext, ".onnx")
919
+ engine_path = checkpoint_path.replace(checkpoint_ext, ".engine")
920
+ if os.path.exists(onnx_path):
921
+ sly_fs.silent_remove(onnx_path)
922
+ if os.path.exists(engine_path):
923
+ sly_fs.silent_remove(engine_path)
924
+
866
925
  def _download_custom_model(self, model_files: dict, log_progress: bool = True):
867
926
  """
868
927
  Downloads the custom model data.
869
928
  """
870
929
  team_id = sly_env.team_id()
871
930
  local_model_files = {}
872
-
931
+
873
932
  # Sort files to download 'checkpoint' first
874
933
  files_order = sorted(model_files.keys(), key=lambda x: (0 if x == "checkpoint" else 1, x))
875
934
  for file in files_order:
@@ -905,17 +964,23 @@ class Inference:
905
964
  if extracted_files:
906
965
  local_model_files[file] = file_path
907
966
  return local_model_files
967
+ except _pickle.UnpicklingError as e:
968
+ # TODO: raise error - checkpoint is corrupted
969
+ logger.warning(f"Couldn't load '{file_name}'. Checkpoint might be corrupted. Error: {repr(e)}")
970
+ logger.warning("Model files will be downloaded from Team Files")
971
+ local_model_files[file] = file_path
972
+ continue
908
973
  except Exception as e:
909
- logger.debug(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
910
- logger.debug("Model files will be downloaded from Team Files")
974
+ logger.warning(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
975
+ logger.warning("Model files will be downloaded from Team Files")
911
976
  local_model_files[file] = file_path
912
977
  continue
913
-
978
+
914
979
  local_model_files[file] = file_path
915
980
  if log_progress:
916
981
  self.gui.download_progress.hide()
917
982
  return local_model_files
918
-
983
+
919
984
  def _get_deploy_parameters_from_custom_checkpoint(self, checkpoint_path: str, device: str, runtime: str) -> dict:
920
985
  def _read_experiment_info(artifacts_dir: str) -> Optional[dict]:
921
986
  exp_path = os.path.join(artifacts_dir, "experiment_info.json")
@@ -976,8 +1041,7 @@ class Inference:
976
1041
  # --- LOCAL ---
977
1042
  try:
978
1043
  logger.debug("Reading state dict...")
979
- import torch # pylint: disable=import-error
980
- ckpt = torch.load(checkpoint_path, map_location="cpu")
1044
+ ckpt = torch_load_safe(checkpoint_path)
981
1045
  model_info = ckpt.get("model_info", {})
982
1046
  model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
983
1047
  model_files["checkpoint"] = checkpoint_path
@@ -1017,10 +1081,8 @@ class Inference:
1017
1081
  if file_ext not in (".pth", ".pt"):
1018
1082
  return extracted_files
1019
1083
 
1020
- import torch # pylint: disable=import-error
1021
1084
  logger.debug(f"Reading checkpoint: {checkpoint_path}")
1022
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
1023
-
1085
+ checkpoint = torch_load_safe(checkpoint_path)
1024
1086
  # 1. Extract additional model files embedded into checkpoint (if any)
1025
1087
  ckpt_files = checkpoint.get("model_files", None)
1026
1088
  if ckpt_files and isinstance(ckpt_files, dict):
@@ -1057,7 +1119,41 @@ class Inference:
1057
1119
  self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
1058
1120
  self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
1059
1121
  self._hardware = get_hardware_info(self.device)
1060
- self.load_model(**deploy_params)
1122
+
1123
+ model_files = deploy_params.get("model_files", None)
1124
+ if model_files is not None:
1125
+ checkpoint_path = deploy_params["model_files"]["checkpoint"]
1126
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
1127
+ if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
1128
+ try:
1129
+ self.load_model(**deploy_params)
1130
+ except Exception as e:
1131
+ logger.warning(
1132
+ f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}"
1133
+ )
1134
+ checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
1135
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1136
+ logger.info("Exporting PyTorch model to TensorRT...")
1137
+ self._remove_exported_checkpoints(checkpoint_path)
1138
+ checkpoint_path = self.export_tensorrt(deploy_params)
1139
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1140
+ self.load_model(**deploy_params)
1141
+ if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
1142
+ if self.runtime == RuntimeType.ONNXRUNTIME:
1143
+ logger.info("Exporting PyTorch model to ONNX...")
1144
+ self._remove_exported_checkpoints(checkpoint_path)
1145
+ checkpoint_path = self.export_onnx(deploy_params)
1146
+ elif self.runtime == RuntimeType.TENSORRT:
1147
+ logger.info("Exporting PyTorch model to TensorRT...")
1148
+ self._remove_exported_checkpoints(checkpoint_path)
1149
+ checkpoint_path = self.export_tensorrt(deploy_params)
1150
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1151
+ self.load_model(**deploy_params)
1152
+ else:
1153
+ self.load_model(**deploy_params)
1154
+ else:
1155
+ self.load_model(**deploy_params)
1156
+
1061
1157
  self._model_served = True
1062
1158
  self._deploy_params = deploy_params
1063
1159
  if self._task_id is not None and is_production():
@@ -1159,6 +1255,8 @@ class Inference:
1159
1255
  if model_source == ModelSource.CUSTOM:
1160
1256
  self._set_model_meta_custom_model(model_info)
1161
1257
  self._set_checkpoint_info_custom_model(deploy_params)
1258
+ elif model_source == ModelSource.PRETRAINED:
1259
+ self._set_checkpoint_info_pretrained(deploy_params)
1162
1260
 
1163
1261
  try:
1164
1262
  if is_production():
@@ -1232,6 +1330,19 @@ class Inference:
1232
1330
  model_source=ModelSource.CUSTOM,
1233
1331
  )
1234
1332
 
1333
+ def _set_checkpoint_info_pretrained(self, deploy_params: dict):
1334
+ checkpoint_name = os.path.basename(deploy_params["model_files"]["checkpoint"])
1335
+ model_name = _get_model_name(deploy_params["model_info"])
1336
+ checkpoint_url = deploy_params["model_info"]["meta"]["model_files"]["checkpoint"]
1337
+ model_source = ModelSource.PRETRAINED
1338
+ self.checkpoint_info = CheckpointInfo(
1339
+ checkpoint_name=checkpoint_name,
1340
+ model_name=model_name,
1341
+ architecture=self.FRAMEWORK_NAME,
1342
+ checkpoint_url=checkpoint_url,
1343
+ model_source=model_source,
1344
+ )
1345
+
1235
1346
  def shutdown_model(self):
1236
1347
  self._model_served = False
1237
1348
  self._model_frozen = False
@@ -1252,6 +1363,26 @@ class Inference:
1252
1363
  def get_classes(self) -> List[str]:
1253
1364
  return self.classes
1254
1365
 
1366
+ def _tracker_init(self, tracker: str, tracker_settings: dict):
1367
+ # Check if tracking is supported for this model
1368
+ info = self.get_info()
1369
+ tracking_support = info.get("tracking_on_videos_support", False)
1370
+
1371
+ if not tracking_support:
1372
+ logger.debug("Tracking is not supported for this model")
1373
+ return None
1374
+
1375
+ if tracker == "botsort":
1376
+ from supervisely.nn.tracker import BotSortTracker
1377
+
1378
+ device = tracker_settings.get("device", self.device)
1379
+ logger.debug(f"Initializing BotSort tracker with device: {device}")
1380
+ return BotSortTracker(settings=tracker_settings, device=device)
1381
+ else:
1382
+ if tracker is not None:
1383
+ logger.warning(f"Unknown tracking type: {tracker}. Tracking is disabled.")
1384
+ return None
1385
+
1255
1386
  def get_info(self) -> Dict[str, Any]:
1256
1387
  num_classes = None
1257
1388
  classes = None
@@ -1260,15 +1391,15 @@ class Inference:
1260
1391
  if classes is not None:
1261
1392
  num_classes = len(classes)
1262
1393
  except NotImplementedError:
1263
- logger.warn(f"get_classes() function not implemented for {type(self)} object.")
1394
+ logger.warning(f"get_classes() function not implemented for {type(self)} object.")
1264
1395
  except AttributeError:
1265
- logger.warn("Probably, get_classes() function not working without model deploy.")
1396
+ logger.warning("Probably, get_classes() function not working without model deploy.")
1266
1397
  except Exception as exc:
1267
- logger.warn("Unknown exception. Please, contact support")
1398
+ logger.warning("Unknown exception. Please, contact support")
1268
1399
  logger.exception(exc)
1269
1400
 
1270
1401
  if num_classes is None:
1271
- logger.warn(f"get_classes() function return {classes}; skip classes processing.")
1402
+ logger.warning(f"get_classes() function return {classes}; skip classes processing.")
1272
1403
 
1273
1404
  return {
1274
1405
  "app_name": get_name_from_env(default="Neural Network Serving"),
@@ -1277,15 +1408,51 @@ class Inference:
1277
1408
  "sliding_window_support": self.sliding_window_mode,
1278
1409
  "videos_support": True,
1279
1410
  "async_video_inference_support": True,
1280
- "tracking_on_videos_support": True,
1411
+ "tracking_on_videos_support": False,
1281
1412
  "async_image_inference_support": True,
1282
- "tracking_algorithms": ["bot", "deepsort"],
1413
+ "tracking_algorithms": ["botsort"],
1283
1414
  "batch_inference_support": self.is_batch_inference_supported(),
1284
1415
  "max_batch_size": self.max_batch_size,
1285
1416
  }
1286
1417
 
1287
1418
  # pylint: enable=method-hidden
1288
1419
 
1420
+ def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
1421
+ """
1422
+ Get default parameters for all available tracking algorithms.
1423
+
1424
+ Returns:
1425
+ {"botsort": {"track_high_thresh": 0.6, ...}}
1426
+ Empty dict if tracking not supported.
1427
+ """
1428
+ info = self.get_info()
1429
+ trackers_params = {}
1430
+
1431
+ tracking_support = info.get("tracking_on_videos_support")
1432
+ if not tracking_support:
1433
+ return trackers_params
1434
+
1435
+ tracking_algorithms = info.get("tracking_algorithms", [])
1436
+
1437
+ for tracker_name in tracking_algorithms:
1438
+ try:
1439
+ if tracker_name == "botsort":
1440
+ from supervisely.nn.tracker import BotSortTracker
1441
+
1442
+ trackers_params[tracker_name] = BotSortTracker.get_default_params()
1443
+ # Add other trackers here as elif blocks
1444
+ else:
1445
+ logger.debug(f"Tracker '{tracker_name}' not implemented")
1446
+ except Exception as e:
1447
+ logger.warning(f"Failed to get params for '{tracker_name}': {e}")
1448
+
1449
+ INTERNAL_FIELDS = {"device", "fps"}
1450
+ for tracker_name, params in trackers_params.items():
1451
+ trackers_params[tracker_name] = {
1452
+ k: v for k, v in params.items() if k not in INTERNAL_FIELDS
1453
+ }
1454
+ return trackers_params
1455
+
1289
1456
  def get_human_readable_info(self, replace_none_with: Optional[str] = None):
1290
1457
  hr_info = {}
1291
1458
  info = self.get_info()
@@ -1401,8 +1568,12 @@ class Inference:
1401
1568
  # for example empty mask
1402
1569
  continue
1403
1570
  if isinstance(label, list):
1571
+ for lb in label:
1572
+ lb.status = LabelingStatus.AUTO
1404
1573
  labels.extend(label)
1405
1574
  continue
1575
+
1576
+ label.status = LabelingStatus.AUTO
1406
1577
  labels.append(label)
1407
1578
 
1408
1579
  # create annotation with correct image resolution
@@ -1447,7 +1618,7 @@ class Inference:
1447
1618
  if api is None:
1448
1619
  api = self.api
1449
1620
  return api
1450
-
1621
+
1451
1622
  def _inference_auto(
1452
1623
  self,
1453
1624
  source: List[Union[str, np.ndarray]],
@@ -1833,24 +2004,12 @@ class Inference:
1833
2004
  else:
1834
2005
  n_frames = frames_reader.frames_count()
1835
2006
 
1836
- if tracking == "bot":
1837
- from supervisely.nn.tracker import BoTTracker
1838
-
1839
- tracker = BoTTracker(state)
1840
- elif tracking == "deepsort":
1841
- from supervisely.nn.tracker import DeepSortTracker
1842
-
1843
- tracker = DeepSortTracker(state)
1844
- else:
1845
- if tracking is not None:
1846
- logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
1847
- tracker = None
2007
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
1848
2008
 
1849
2009
  progress_total = (n_frames + step - 1) // step
1850
2010
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
1851
2011
 
1852
2012
  results = []
1853
- tracks_data = {}
1854
2013
  for batch in batched(
1855
2014
  range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
1856
2015
  batch_size,
@@ -1870,28 +2029,30 @@ class Inference:
1870
2029
  source=frames,
1871
2030
  settings=inference_settings,
1872
2031
  )
2032
+
2033
+ if inference_request.tracker is not None:
2034
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2035
+
1873
2036
  predictions = [
1874
2037
  Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
1875
2038
  for ann, frame_index in zip(anns, batch)
1876
2039
  ]
2040
+
1877
2041
  for pred, this_slides_data in zip(predictions, slides_data):
1878
2042
  pred.extra_data["slides_data"] = this_slides_data
1879
2043
  batch_results = self._format_output(predictions)
1880
- if tracker is not None:
1881
- for frame_index, frame, ann in zip(batch, frames, anns):
1882
- tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
2044
+
1883
2045
  inference_request.add_results(batch_results)
1884
2046
  inference_request.done(len(batch_results))
1885
2047
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
1886
2048
  video_ann_json = None
1887
- if tracker is not None:
2049
+ if inference_request.tracker is not None:
1888
2050
  inference_request.set_stage("Postprocess...", 0, 1)
1889
- video_ann_json = tracker.get_annotation(
1890
- tracks_data, (video_height, video_witdth), n_frames
1891
- ).to_json()
2051
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
1892
2052
  inference_request.done()
1893
2053
  result = {"ann": results, "video_ann": video_ann_json}
1894
2054
  inference_request.final_result = result.copy()
2055
+ return video_ann_json
1895
2056
 
1896
2057
  def _inference_image_ids(
1897
2058
  self,
@@ -1915,10 +2076,11 @@ class Inference:
1915
2076
  raise ValueError("Image ids are not provided")
1916
2077
  if not isinstance(image_ids, list):
1917
2078
  image_ids = [image_ids]
2079
+ model_prediction_suffix = state.get("model_prediction_suffix", None)
1918
2080
  upload_mode = state.get("upload_mode", None)
1919
2081
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
1920
2082
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
1921
- iou_merge_threshold = 0.7
2083
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
1922
2084
 
1923
2085
  images_infos = api.image.get_info_by_id_batch(image_ids)
1924
2086
  images_infos_dict = {im_info.id: im_info for im_info in images_infos}
@@ -1979,6 +2141,7 @@ class Inference:
1979
2141
  progress_cb=inference_request.done,
1980
2142
  iou_merge_threshold=iou_merge_threshold,
1981
2143
  inference_request=inference_request,
2144
+ model_prediction_suffix=model_prediction_suffix,
1982
2145
  )
1983
2146
 
1984
2147
  _add_results_to_request = partial(
@@ -1994,8 +2157,8 @@ class Inference:
1994
2157
  with Uploader(upload_f, logger=logger) as uploader:
1995
2158
  for image_ids_batch in batched(image_ids, batch_size=batch_size):
1996
2159
  if uploader.has_exception():
1997
- exception = uploader.exception()
1998
- raise RuntimeError(f"Error in upload loop: {exception}") from exception
2160
+ exception = uploader.exception
2161
+ raise exception
1999
2162
  if inference_request.is_stopped():
2000
2163
  logger.debug(
2001
2164
  f"Cancelling inference project...",
@@ -2039,7 +2202,7 @@ class Inference:
2039
2202
  video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2040
2203
  if video_id is None:
2041
2204
  raise ValueError("Video id is not provided")
2042
- video_info = api.video.get_info_by_id(video_id)
2205
+ video_info = api.video.get_info_by_id(video_id, force_metadata_for_links=True)
2043
2206
  start_frame_index = get_value_for_keys(
2044
2207
  state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2045
2208
  )
@@ -2069,18 +2232,8 @@ class Inference:
2069
2232
  else:
2070
2233
  n_frames = video_info.frames_count
2071
2234
 
2072
- if tracking == "bot":
2073
- from supervisely.nn.tracker import BoTTracker
2235
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2074
2236
 
2075
- tracker = BoTTracker(state)
2076
- elif tracking == "deepsort":
2077
- from supervisely.nn.tracker import DeepSortTracker
2078
-
2079
- tracker = DeepSortTracker(state)
2080
- else:
2081
- if tracking is not None:
2082
- logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
2083
- tracker = None
2084
2237
  logger.debug(
2085
2238
  f"Video info:",
2086
2239
  extra=dict(
@@ -2097,7 +2250,6 @@ class Inference:
2097
2250
  progress_total = (n_frames + step - 1) // step
2098
2251
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
2099
2252
 
2100
- tracks_data = {}
2101
2253
  for batch in batched(
2102
2254
  range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
2103
2255
  batch_size,
@@ -2116,6 +2268,10 @@ class Inference:
2116
2268
  source=frames,
2117
2269
  settings=inference_settings,
2118
2270
  )
2271
+
2272
+ if inference_request.tracker is not None:
2273
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2274
+
2119
2275
  predictions = [
2120
2276
  Prediction(
2121
2277
  ann,
@@ -2130,20 +2286,173 @@ class Inference:
2130
2286
  for pred, this_slides_data in zip(predictions, slides_data):
2131
2287
  pred.extra_data["slides_data"] = this_slides_data
2132
2288
  batch_results = self._format_output(predictions)
2133
- if tracker is not None:
2134
- for frame_index, frame, ann in zip(batch, frames, anns):
2135
- tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
2289
+
2136
2290
  inference_request.add_results(batch_results)
2137
2291
  inference_request.done(len(batch_results))
2138
2292
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
2139
2293
  video_ann_json = None
2140
- if tracker is not None:
2294
+ if inference_request.tracker is not None:
2141
2295
  inference_request.set_stage("Postprocess...", 0, 1)
2142
- video_ann_json = tracker.get_annotation(
2143
- tracks_data, (video_info.frame_height, video_info.frame_width), n_frames
2144
- ).to_json()
2296
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2145
2297
  inference_request.done()
2146
2298
  inference_request.final_result = {"video_ann": video_ann_json}
2299
+ return video_ann_json
2300
+
2301
+ def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
2302
+ logger.debug("Inferring video_id...", extra={"state": state})
2303
+ inference_settings = self._get_inference_settings(state)
2304
+ logger.debug(f"Inference settings:", extra=inference_settings)
2305
+ batch_size = self._get_batch_size_from_state(state)
2306
+ video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2307
+ if video_id is None:
2308
+ raise ValueError("Video id is not provided")
2309
+ video_info = api.video.get_info_by_id(video_id)
2310
+ start_frame_index = get_value_for_keys(
2311
+ state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2312
+ )
2313
+ if start_frame_index is None:
2314
+ start_frame_index = 0
2315
+ step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
2316
+ if step is None:
2317
+ step = 1
2318
+ end_frame_index = get_value_for_keys(
2319
+ state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
2320
+ )
2321
+ duration = state.get("duration", None)
2322
+ frames_count = get_value_for_keys(
2323
+ state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
2324
+ )
2325
+ tracking = state.get("tracker", None)
2326
+ direction = state.get("direction", "forward")
2327
+ direction = 1 if direction == "forward" else -1
2328
+ track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
2329
+
2330
+ if frames_count is not None:
2331
+ n_frames = frames_count
2332
+ elif end_frame_index is not None:
2333
+ n_frames = end_frame_index - start_frame_index
2334
+ elif duration is not None:
2335
+ fps = video_info.frames_count / video_info.duration
2336
+ n_frames = int(duration * fps)
2337
+ else:
2338
+ n_frames = video_info.frames_count
2339
+
2340
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2341
+
2342
+ logger.debug(
2343
+ f"Video info:",
2344
+ extra=dict(
2345
+ w=video_info.frame_width,
2346
+ h=video_info.frame_height,
2347
+ start_frame_index=start_frame_index,
2348
+ n_frames=n_frames,
2349
+ ),
2350
+ )
2351
+
2352
+ # start downloading video in background
2353
+ self.cache.run_cache_task_manually(api, None, video_id=video_id)
2354
+
2355
+ progress_total = (n_frames + step - 1) // step
2356
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
2357
+
2358
+ _upload_f = partial(
2359
+ self.upload_predictions_to_video,
2360
+ api=api,
2361
+ video_info=video_info,
2362
+ track_id=track_id,
2363
+ context=inference_request.context,
2364
+ progress_cb=inference_request.done,
2365
+ inference_request=inference_request,
2366
+ )
2367
+
2368
+ _range = (start_frame_index, start_frame_index + direction * n_frames)
2369
+ if _range[0] > _range[1]:
2370
+ _range = (_range[1], _range[0])
2371
+
2372
+ def _notify_f(predictions: List[Prediction]):
2373
+ logger.debug(
2374
+ "Notifying tracking progress...",
2375
+ extra={
2376
+ "track_id": track_id,
2377
+ "range": _range,
2378
+ "current": inference_request.progress.current,
2379
+ "total": inference_request.progress.total,
2380
+ },
2381
+ )
2382
+ stopped = self.api.video.notify_progress(
2383
+ track_id=track_id,
2384
+ video_id=video_info.id,
2385
+ frame_start=_range[0],
2386
+ frame_end=_range[1],
2387
+ current=inference_request.progress.current,
2388
+ total=inference_request.progress.total,
2389
+ )
2390
+ if stopped:
2391
+ inference_request.stop()
2392
+ logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
2393
+
2394
+ def _exception_handler(e: Exception):
2395
+ self.api.video.notify_tracking_error(
2396
+ track_id=track_id,
2397
+ error=str(type(e)),
2398
+ message=str(e),
2399
+ )
2400
+ raise e
2401
+
2402
+ with Uploader(
2403
+ upload_f=_upload_f,
2404
+ notify_f=_notify_f,
2405
+ exception_handler=_exception_handler,
2406
+ logger=logger,
2407
+ ) as uploader:
2408
+ for batch in batched(
2409
+ range(
2410
+ start_frame_index, start_frame_index + direction * n_frames, direction * step
2411
+ ),
2412
+ batch_size,
2413
+ ):
2414
+ if inference_request.is_stopped():
2415
+ logger.debug(
2416
+ f"Cancelling inference video...",
2417
+ extra={"inference_request_uuid": inference_request.uuid},
2418
+ )
2419
+ break
2420
+ logger.debug(
2421
+ f"Inferring frames {batch[0]}-{batch[-1]}:",
2422
+ )
2423
+ frames = self.cache.download_frames(
2424
+ api, video_info.id, batch, redownload_video=True
2425
+ )
2426
+ anns, slides_data = self._inference_auto(
2427
+ source=frames,
2428
+ settings=inference_settings,
2429
+ )
2430
+
2431
+ if inference_request.tracker is not None:
2432
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2433
+
2434
+ predictions = [
2435
+ Prediction(
2436
+ ann,
2437
+ model_meta=self.model_meta,
2438
+ frame_index=frame_index,
2439
+ video_id=video_info.id,
2440
+ dataset_id=video_info.dataset_id,
2441
+ project_id=video_info.project_id,
2442
+ )
2443
+ for ann, frame_index in zip(anns, batch)
2444
+ ]
2445
+ for pred, this_slides_data in zip(predictions, slides_data):
2446
+ pred.extra_data["slides_data"] = this_slides_data
2447
+ uploader.put(predictions)
2448
+ video_ann_json = None
2449
+ if inference_request.tracker is not None:
2450
+ inference_request.set_stage("Postprocess...", 0, 1)
2451
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2452
+ inference_request.done()
2453
+ inference_request.final_result = {"video_ann": video_ann_json}
2454
+ return video_ann_json
2455
+
2147
2456
 
2148
2457
  def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
2149
2458
  """Inference project images.
@@ -2161,10 +2470,12 @@ class Inference:
2161
2470
  project_info = api.project.get_info_by_id(project_id)
2162
2471
  if project_info.type != str(ProjectType.IMAGES):
2163
2472
  raise ValueError("Only images projects are supported.")
2473
+
2474
+ model_prediction_suffix = state.get("model_prediction_suffix", None)
2164
2475
  upload_mode = state.get("upload_mode", None)
2165
2476
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
2166
2477
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
2167
- iou_merge_threshold = 0.7
2478
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
2168
2479
  cache_project_on_model = state.get("cache_project_on_model", False)
2169
2480
 
2170
2481
  project_info = api.project.get_info_by_id(project_id)
@@ -2235,6 +2546,7 @@ class Inference:
2235
2546
  progress_cb=inference_request.done,
2236
2547
  iou_merge_threshold=iou_merge_threshold,
2237
2548
  inference_request=inference_request,
2549
+ model_prediction_suffix=model_prediction_suffix,
2238
2550
  )
2239
2551
 
2240
2552
  _add_results_to_request = partial(
@@ -2260,7 +2572,7 @@ class Inference:
2260
2572
  return
2261
2573
  if uploader.has_exception():
2262
2574
  exception = uploader.exception
2263
- raise RuntimeError(f"Error in upload loop: {exception}") from exception
2575
+ raise exception
2264
2576
  if cache_project_on_model:
2265
2577
  images_paths, _ = zip(
2266
2578
  *read_from_cached_project(
@@ -2389,7 +2701,7 @@ class Inference:
2389
2701
  return
2390
2702
  if uploader.has_exception():
2391
2703
  exception = uploader.exception
2392
- raise RuntimeError(f"Error in upload loop: {exception}") from exception
2704
+ raise exception
2393
2705
  if i == num_warmup:
2394
2706
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, num_iterations)
2395
2707
 
@@ -2454,6 +2766,7 @@ class Inference:
2454
2766
  # raise DialogWindowError(title="Call undeployed model.", description=msg)
2455
2767
  raise RuntimeError(msg)
2456
2768
  return func(*args, **kwargs)
2769
+
2457
2770
  return wrapper
2458
2771
 
2459
2772
  def _freeze_model(self):
@@ -2498,7 +2811,6 @@ class Inference:
2498
2811
  timer.daemon = True
2499
2812
  timer.start()
2500
2813
  self._freeze_timer = timer
2501
- logger.debug("Model will be frozen in %s seconds due to inactivity.", self._inactivity_timeout)
2502
2814
 
2503
2815
  def _set_served_callback(self):
2504
2816
  self._model_served = True
@@ -2605,11 +2917,16 @@ class Inference:
2605
2917
  progress_cb=None,
2606
2918
  iou_merge_threshold: float = None,
2607
2919
  inference_request: InferenceRequest = None,
2920
+ model_prediction_suffix: str = None,
2608
2921
  ):
2609
2922
  ds_predictions: Dict[int, List[Prediction]] = defaultdict(list)
2610
2923
  for prediction in predictions:
2611
2924
  ds_predictions[prediction.dataset_id].append(prediction)
2612
2925
 
2926
+ def update_labeling_status(ann: Annotation) -> Annotation:
2927
+ for label in ann.labels:
2928
+ label.status = LabelingStatus.AUTO
2929
+
2613
2930
  def _new_name(image_info: ImageInfo):
2614
2931
  name = Path(image_info.name)
2615
2932
  stem = name.stem
@@ -2642,10 +2959,10 @@ class Inference:
2642
2959
  context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
2643
2960
  return created_dataset.id
2644
2961
 
2645
- created_names = []
2646
2962
  if context is None:
2647
2963
  context = {}
2648
2964
  for dataset_id, preds in ds_predictions.items():
2965
+ created_names = set()
2649
2966
  if dst_project_id is not None:
2650
2967
  # upload to the destination project
2651
2968
  dst_dataset_id = _get_or_create_dataset(
@@ -2666,7 +2983,9 @@ class Inference:
2666
2983
  meta_changed = False
2667
2984
  for pred in preds:
2668
2985
  ann = pred.annotation
2669
- project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
2986
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
2987
+ project_meta, ann, model_prediction_suffix
2988
+ )
2670
2989
  meta_changed = meta_changed or meta_changed_
2671
2990
  pred.annotation = ann
2672
2991
  prediction.model_meta = project_meta
@@ -2683,8 +3002,15 @@ class Inference:
2683
3002
  iou=iou_merge_threshold,
2684
3003
  meta=project_meta,
2685
3004
  )
3005
+
3006
+ # Update labeling status of new predictions before upload
3007
+ anns_with_nn_flags = []
2686
3008
  for pred, ann in zip(preds, anns):
3009
+ update_labeling_status(ann)
2687
3010
  pred.annotation = ann
3011
+ anns_with_nn_flags.append(ann)
3012
+
3013
+ anns = anns_with_nn_flags
2688
3014
 
2689
3015
  context.setdefault("image_info", {})
2690
3016
  missing = [
@@ -2712,7 +3038,7 @@ class Inference:
2712
3038
  with_annotations=False,
2713
3039
  save_source_date=False,
2714
3040
  )
2715
- created_names.extend([image_info.name for image_info in dst_image_infos])
3041
+ created_names.update([image_info.name for image_info in dst_image_infos])
2716
3042
  api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
2717
3043
  else:
2718
3044
  # upload to the source dataset
@@ -2730,7 +3056,9 @@ class Inference:
2730
3056
  meta_changed = False
2731
3057
  for pred in preds:
2732
3058
  ann = pred.annotation
2733
- project_meta, ann, meta_changed_ = update_meta_and_ann(project_meta, ann)
3059
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
3060
+ project_meta, ann, model_prediction_suffix
3061
+ )
2734
3062
  meta_changed = meta_changed or meta_changed_
2735
3063
  pred.annotation = ann
2736
3064
  prediction.model_meta = project_meta
@@ -2747,7 +3075,10 @@ class Inference:
2747
3075
  iou=iou_merge_threshold,
2748
3076
  meta=project_meta,
2749
3077
  )
3078
+
3079
+ # Update labeling status of predicted labels before optional merge
2750
3080
  for pred, ann in zip(preds, anns):
3081
+ update_labeling_status(ann)
2751
3082
  pred.annotation = ann
2752
3083
 
2753
3084
  if upload_mode in ["iou_merge", "append"]:
@@ -2789,6 +3120,83 @@ class Inference:
2789
3120
  inference_request.add_results(results)
2790
3121
  inference_request.done(len(results))
2791
3122
 
3123
+ def upload_predictions_to_video(
3124
+ self,
3125
+ predictions: List[Prediction],
3126
+ api: Api,
3127
+ video_info: VideoInfo,
3128
+ track_id: str,
3129
+ context: Dict,
3130
+ progress_cb=None,
3131
+ inference_request: InferenceRequest = None,
3132
+ ):
3133
+ key_id_map = KeyIdMap()
3134
+ project_meta = context.get("project_meta", None)
3135
+ if project_meta is None:
3136
+ project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
3137
+ context["project_meta"] = project_meta
3138
+ meta_changed = False
3139
+ for prediction in predictions:
3140
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
3141
+ project_meta, prediction.annotation, None
3142
+ )
3143
+ prediction.annotation = ann
3144
+ meta_changed = meta_changed or meta_changed_
3145
+ if meta_changed:
3146
+ project_meta = api.project.update_meta(video_info.project_id, project_meta)
3147
+ context["project_meta"] = project_meta
3148
+
3149
+ figure_data_by_object_id = defaultdict(list)
3150
+
3151
+ tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
3152
+ new_tracks: Dict[int, VideoObject] = {}
3153
+ for prediction in predictions:
3154
+ annotation = prediction.annotation
3155
+ tracks = annotation.custom_data
3156
+ for track, label in zip(tracks, annotation.labels):
3157
+ if track not in tracks_to_object_ids and track not in new_tracks:
3158
+ video_object = VideoObject(obj_class=label.obj_class)
3159
+ new_tracks[track] = video_object
3160
+ if new_tracks:
3161
+ tracks, video_objects = zip(*new_tracks.items())
3162
+ added_object_ids = api.video.object.append_bulk(
3163
+ video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
3164
+ )
3165
+ for track, object_id in zip(tracks, added_object_ids):
3166
+ tracks_to_object_ids[track] = object_id
3167
+ for prediction in predictions:
3168
+ annotation = prediction.annotation
3169
+ tracks = annotation.custom_data
3170
+ for track, label in zip(tracks, annotation.labels):
3171
+ object_id = tracks_to_object_ids[track]
3172
+ figure_data_by_object_id[object_id].append(
3173
+ {
3174
+ ApiField.OBJECT_ID: object_id,
3175
+ ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
3176
+ ApiField.GEOMETRY: label.geometry.to_json(),
3177
+ ApiField.META: {ApiField.FRAME: prediction.frame_index},
3178
+ ApiField.TRACK_ID: track_id,
3179
+ }
3180
+ )
3181
+
3182
+ for object_id, figures_data in figure_data_by_object_id.items():
3183
+ figures_keys = [uuid.uuid4() for _ in figures_data]
3184
+ api.video.figure._append_bulk(
3185
+ entity_id=video_info.id,
3186
+ figures_json=figures_data,
3187
+ figures_keys=figures_keys,
3188
+ key_id_map=key_id_map,
3189
+ )
3190
+ logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
3191
+ if progress_cb:
3192
+ progress_cb(len(predictions))
3193
+ if inference_request is not None:
3194
+ results = self._format_output(predictions)
3195
+ for result in results:
3196
+ result["annotation"] = None
3197
+ result["data"] = None
3198
+ inference_request.add_results(results)
3199
+
2792
3200
  def serve(self):
2793
3201
  if not self._use_gui and not self._is_cli_deploy:
2794
3202
  Progress("Deploying model ...", 1)
@@ -2812,12 +3220,12 @@ class Inference:
2812
3220
  # Predict and shutdown
2813
3221
  if self._args.mode == "predict":
2814
3222
  if any(
2815
- [
2816
- self._args.input,
2817
- self._args.project_id,
2818
- self._args.dataset_id,
2819
- self._args.image_id,
2820
- ]
3223
+ [
3224
+ self._args.input,
3225
+ self._args.project_id,
3226
+ self._args.dataset_id,
3227
+ self._args.image_id,
3228
+ ]
2821
3229
  ):
2822
3230
  self._parse_inference_settings_from_args()
2823
3231
  self._inference_by_cli_deploy_args()
@@ -2898,6 +3306,11 @@ class Inference:
2898
3306
  def get_session_info(response: Response):
2899
3307
  return self.get_info()
2900
3308
 
3309
+ @server.post("/get_tracking_settings")
3310
+ @self._check_serve_before_call
3311
+ def get_tracking_settings(response: Response):
3312
+ return self.get_tracking_settings()
3313
+
2901
3314
  @server.post("/get_custom_inference_settings")
2902
3315
  def get_custom_inference_settings():
2903
3316
  return {"settings": self.custom_inference_settings}
@@ -3181,6 +3594,22 @@ class Inference:
3181
3594
  "inference_request_uuid": inference_request.uuid,
3182
3595
  }
3183
3596
 
3597
+ @server.post("/tracking_by_detection")
3598
+ def tracking_by_detection(response: Response, request: Request):
3599
+ state = request.state.state
3600
+ context = request.state.context
3601
+ state.update(context)
3602
+ if state.get("tracker") is None:
3603
+ state["tracker"] = "botsort"
3604
+
3605
+ logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
3606
+ self.validate_inference_state(state)
3607
+ api = self.api_from_request(request)
3608
+ inference_request, future = self.inference_requests_manager.schedule_task(
3609
+ self._tracking_by_detection, api, state
3610
+ )
3611
+ return {"message": "Track task started."}
3612
+
3184
3613
  @server.post("/inference_project_id_async")
3185
3614
  def inference_project_id_async(response: Response, request: Request):
3186
3615
  state = request.state.state
@@ -3244,10 +3673,7 @@ class Inference:
3244
3673
  data = {**inference_request.to_json(), **log_extra}
3245
3674
  if inference_request.stage != InferenceRequest.Stage.INFERENCE:
3246
3675
  data["progress"] = {"current": 0, "total": 1}
3247
- logger.debug(
3248
- f"Sending inference progress with uuid:",
3249
- extra=data,
3250
- )
3676
+ logger.debug(f"Sending inference progress with uuid:", extra=data)
3251
3677
  return data
3252
3678
 
3253
3679
  @server.post(f"/pop_inference_results")
@@ -3671,6 +4097,7 @@ class Inference:
3671
4097
 
3672
4098
  def _parse_inference_settings_from_args(self):
3673
4099
  logger.debug("Parsing inference settings from args")
4100
+
3674
4101
  def parse_value(value: str):
3675
4102
  if value.lower() in ("true", "false"):
3676
4103
  return value.lower() == "true"
@@ -3797,8 +4224,7 @@ class Inference:
3797
4224
  try:
3798
4225
  # Read data from checkpoint
3799
4226
  logger.debug(f"Reading data from checkpoint: {checkpoint_path}")
3800
- import torch # pylint: disable=import-error
3801
- checkpoint = torch.load(checkpoint_path)
4227
+ checkpoint = torch_load_safe(checkpoint_path)
3802
4228
  model_info = checkpoint["model_info"]
3803
4229
  model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
3804
4230
  model_meta = os.path.join(self.model_dir, "model_meta.json")
@@ -4028,6 +4454,7 @@ class Inference:
4028
4454
  draw: bool = False,
4029
4455
  ):
4030
4456
  logger.info(f"Predicting Local Data: {input_path}")
4457
+
4031
4458
  def postprocess_image(image_path: str, ann: Annotation, pred_dir: str = None):
4032
4459
  image_name = sly_fs.get_file_name_with_ext(image_path)
4033
4460
  if pred_dir is not None:
@@ -4103,6 +4530,20 @@ class Inference:
4103
4530
  self._args.draw,
4104
4531
  )
4105
4532
 
4533
+ def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
4534
+ updated_anns = []
4535
+ for frame, ann in zip(frames, anns):
4536
+ matches = tracker.update(frame, ann)
4537
+ track_ids = [match["track_id"] for match in matches]
4538
+ tracked_labels = [match["label"] for match in matches]
4539
+
4540
+ filtered_annotation = ann.clone(
4541
+ labels=tracked_labels,
4542
+ custom_data=track_ids
4543
+ )
4544
+ updated_anns.append(filtered_annotation)
4545
+ return updated_anns
4546
+
4106
4547
  def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
4107
4548
  if model_source == ModelSource.PRETRAINED:
4108
4549
  checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
@@ -4136,13 +4577,14 @@ class Inference:
4136
4577
 
4137
4578
  task_id = experiment_info.task_id
4138
4579
  self.gui.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
4139
- self.gui.experiment_selector.set_by_task_id(task_id)
4580
+ self.gui.experiment_selector.set_selected_row_by_task_id(task_id)
4140
4581
 
4141
4582
  best_ckpt = experiment_info.best_checkpoint
4142
4583
  if best_ckpt:
4143
- row = self.gui.experiment_selector.get_by_task_id(task_id)
4584
+ row = self.gui.experiment_selector.get_selected_row_by_task_id(task_id)
4144
4585
  if row is not None:
4145
4586
  row.set_selected_checkpoint_by_name(best_ckpt)
4587
+
4146
4588
  except Exception as e:
4147
4589
  logger.warning(f"Failed to set checkpoint from experiment info: {repr(e)}")
4148
4590
 
@@ -4151,61 +4593,78 @@ class Inference:
4151
4593
  return
4152
4594
  self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
4153
4595
 
4154
- def _exclude_duplicated_predictions(
4155
- api: Api,
4156
- pred_anns: List[Annotation],
4157
- dataset_id: int,
4158
- gt_image_ids: List[int],
4159
- iou: float = None,
4160
- meta: Optional[ProjectMeta] = None,
4596
+ def export_onnx(self, deploy_params: dict):
4597
+ raise NotImplementedError("Have to be implemented in child class after inheritance")
4598
+
4599
+ def export_tensorrt(self, deploy_params: dict):
4600
+ raise NotImplementedError("Have to be implemented in child class after inheritance")
4601
+
4602
+
4603
+ def _filter_duplicated_predictions_from_ann_cpu(
4604
+ gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
4161
4605
  ):
4162
4606
  """
4163
- Filter out predictions that significantly overlap with ground truth (GT) objects.
4164
-
4165
- This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4166
- - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4167
- - Gets ProjectMeta object if not provided
4168
- - Downloads GT annotations for the specified image IDs
4169
- - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4607
+ Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
4608
+ Uses Shapely for geometric operations.
4170
4609
 
4171
- :param api: Supervisely API object
4172
- :type api: Api
4173
- :param pred_anns: List of Annotation objects containing predictions
4174
- :type pred_anns: List[Annotation]
4175
- :param dataset_id: ID of the dataset containing the images
4176
- :type dataset_id: int
4177
- :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4178
- :type gt_image_ids: List[int]
4179
- :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4180
- ground truth box of the same class will be removed. None if no filtering is needed
4181
- :type iou: Optional[float]
4182
- :param meta: ProjectMeta object
4183
- :type meta: Optional[ProjectMeta]
4184
- :return: List of Annotation objects containing filtered predictions
4185
- :rtype: List[Annotation]
4610
+ Args:
4611
+ pred_ann: Predicted annotation object
4612
+ gt_ann: Ground truth annotation object
4613
+ iou_threshold: IoU threshold for filtering
4186
4614
 
4187
- Notes:
4188
- ------
4189
- - Requires PyTorch and torchvision for IoU calculations
4190
- - This method is useful for identifying new objects that aren't already annotated in the ground truth
4615
+ Returns:
4616
+ New annotation with filtered labels
4191
4617
  """
4192
- if isinstance(iou, float) and 0 < iou <= 1:
4193
- if meta is None:
4194
- ds = api.dataset.get_info_by_id(dataset_id)
4195
- meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4196
- gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4197
- gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4198
- for i in range(0, len(pred_anns)):
4199
- before = len(pred_anns[i].labels)
4200
- with Timer() as timer:
4201
- pred_anns[i] = _filter_duplicated_predictions_from_ann(
4202
- gt_anns[i], pred_anns[i], iou
4203
- )
4204
- after = len(pred_anns[i].labels)
4205
- logger.debug(
4206
- f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4207
- )
4208
- return pred_anns
4618
+ if not iou_threshold:
4619
+ return pred_ann
4620
+
4621
+ from shapely.geometry import box
4622
+
4623
+ def calculate_iou(geom1: Geometry, geom2: Geometry):
4624
+ """Calculate IoU between two geometries using Shapely."""
4625
+ bbox1 = geom1.to_bbox()
4626
+ bbox2 = geom2.to_bbox()
4627
+
4628
+ box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
4629
+ box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
4630
+
4631
+ intersection = box1.intersection(box2).area
4632
+ union = box1.union(box2).area
4633
+
4634
+ return intersection / union if union > 0 else 0.0
4635
+
4636
+ new_labels = []
4637
+ pred_cls_bboxes = defaultdict(list)
4638
+ for label in pred_ann.labels:
4639
+ name_shape = (label.obj_class.name, label.geometry.name())
4640
+ pred_cls_bboxes[name_shape].append(label)
4641
+
4642
+ gt_cls_bboxes = defaultdict(list)
4643
+ for label in gt_ann.labels:
4644
+ name_shape = (label.obj_class.name, label.geometry.name())
4645
+ if name_shape not in pred_cls_bboxes:
4646
+ continue
4647
+ gt_cls_bboxes[name_shape].append(label)
4648
+
4649
+ for name_shape, pred in pred_cls_bboxes.items():
4650
+ gt = gt_cls_bboxes[name_shape]
4651
+ if len(gt) == 0:
4652
+ new_labels.extend(pred)
4653
+ continue
4654
+
4655
+ for pred_label in pred:
4656
+ # Check if this prediction has IoU < threshold with ALL GT boxes
4657
+ keep = True
4658
+ for gt_label in gt:
4659
+ iou = calculate_iou(pred_label.geometry, gt_label.geometry)
4660
+ if iou >= iou_threshold:
4661
+ keep = False
4662
+ break
4663
+
4664
+ if keep:
4665
+ new_labels.append(pred_label)
4666
+
4667
+ return pred_ann.clone(labels=new_labels)
4209
4668
 
4210
4669
 
4211
4670
  def _filter_duplicated_predictions_from_ann(
@@ -4236,13 +4695,15 @@ def _filter_duplicated_predictions_from_ann(
4236
4695
  - Predictions with classes not present in ground truth will be kept
4237
4696
  - Requires PyTorch and torchvision for IoU calculations
4238
4697
  """
4698
+ if not iou_threshold:
4699
+ return pred_ann
4239
4700
 
4240
4701
  try:
4241
4702
  import torch
4242
4703
  from torchvision.ops import box_iou
4243
4704
 
4244
4705
  except ImportError:
4245
- raise ImportError("Please install PyTorch and torchvision to use this feature.")
4706
+ return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
4246
4707
 
4247
4708
  def _to_tensor(geom):
4248
4709
  return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
@@ -4250,16 +4711,18 @@ def _filter_duplicated_predictions_from_ann(
4250
4711
  new_labels = []
4251
4712
  pred_cls_bboxes = defaultdict(list)
4252
4713
  for label in pred_ann.labels:
4253
- pred_cls_bboxes[label.obj_class.name].append(label)
4714
+ name_shape = (label.obj_class.name, label.geometry.name())
4715
+ pred_cls_bboxes[name_shape].append(label)
4254
4716
 
4255
4717
  gt_cls_bboxes = defaultdict(list)
4256
4718
  for label in gt_ann.labels:
4257
- if label.obj_class.name not in pred_cls_bboxes:
4719
+ name_shape = (label.obj_class.name, label.geometry.name())
4720
+ if name_shape not in pred_cls_bboxes:
4258
4721
  continue
4259
- gt_cls_bboxes[label.obj_class.name].append(label)
4722
+ gt_cls_bboxes[name_shape].append(label)
4260
4723
 
4261
- for name, pred in pred_cls_bboxes.items():
4262
- gt = gt_cls_bboxes[name]
4724
+ for name_shape, pred in pred_cls_bboxes.items():
4725
+ gt = gt_cls_bboxes[name_shape]
4263
4726
  if len(gt) == 0:
4264
4727
  new_labels.extend(pred)
4265
4728
  continue
@@ -4273,6 +4736,63 @@ def _filter_duplicated_predictions_from_ann(
4273
4736
  return pred_ann.clone(labels=new_labels)
4274
4737
 
4275
4738
 
4739
+ def _exclude_duplicated_predictions(
4740
+ api: Api,
4741
+ pred_anns: List[Annotation],
4742
+ dataset_id: int,
4743
+ gt_image_ids: List[int],
4744
+ iou: float = None,
4745
+ meta: Optional[ProjectMeta] = None,
4746
+ ):
4747
+ """
4748
+ Filter out predictions that significantly overlap with ground truth (GT) objects.
4749
+
4750
+ This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4751
+ - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4752
+ - Gets ProjectMeta object if not provided
4753
+ - Downloads GT annotations for the specified image IDs
4754
+ - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4755
+
4756
+ :param api: Supervisely API object
4757
+ :type api: Api
4758
+ :param pred_anns: List of Annotation objects containing predictions
4759
+ :type pred_anns: List[Annotation]
4760
+ :param dataset_id: ID of the dataset containing the images
4761
+ :type dataset_id: int
4762
+ :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4763
+ :type gt_image_ids: List[int]
4764
+ :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4765
+ ground truth box of the same class will be removed. None if no filtering is needed
4766
+ :type iou: Optional[float]
4767
+ :param meta: ProjectMeta object
4768
+ :type meta: Optional[ProjectMeta]
4769
+ :return: List of Annotation objects containing filtered predictions
4770
+ :rtype: List[Annotation]
4771
+
4772
+ Notes:
4773
+ ------
4774
+ - Requires PyTorch and torchvision for IoU calculations
4775
+ - This method is useful for identifying new objects that aren't already annotated in the ground truth
4776
+ """
4777
+ if isinstance(iou, float) and 0 < iou <= 1:
4778
+ if meta is None:
4779
+ ds = api.dataset.get_info_by_id(dataset_id)
4780
+ meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4781
+ gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4782
+ gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4783
+ for i in range(0, len(pred_anns)):
4784
+ before = len(pred_anns[i].labels)
4785
+ with Timer() as timer:
4786
+ pred_anns[i] = _filter_duplicated_predictions_from_ann(
4787
+ gt_anns[i], pred_anns[i], iou
4788
+ )
4789
+ after = len(pred_anns[i].labels)
4790
+ logger.debug(
4791
+ f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4792
+ )
4793
+ return pred_anns
4794
+
4795
+
4276
4796
  def _get_log_extra_for_inference_request(
4277
4797
  inference_request_uuid, inference_request: Union[InferenceRequest, dict]
4278
4798
  ):
@@ -4299,8 +4819,8 @@ def _get_log_extra_for_inference_request(
4299
4819
  "has_result": inference_request.final_result is not None,
4300
4820
  "pending_results": inference_request.pending_num(),
4301
4821
  "exception": inference_request.exception_json(),
4302
- "result": inference_request._final_result,
4303
4822
  "preparing_progress": progress,
4823
+ "result": inference_request.final_result is not None, # for backward compatibility
4304
4824
  }
4305
4825
  return log_extra
4306
4826
 
@@ -4380,7 +4900,7 @@ def get_gpu_count():
4380
4900
  gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
4381
4901
  return gpu_count
4382
4902
  except (subprocess.CalledProcessError, FileNotFoundError) as exc:
4383
- logger.warn("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4903
+ logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4384
4904
  return 0
4385
4905
 
4386
4906
 
@@ -4426,7 +4946,7 @@ def _fix_classes_names(meta: ProjectMeta, ann: Annotation):
4426
4946
  return meta, ann, replaced_classes_in_meta, list(replaced_classes_in_ann)
4427
4947
 
4428
4948
 
4429
- def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
4949
+ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suffix: str = None):
4430
4950
  """Update project meta and annotation to match each other
4431
4951
  If obj class or tag meta from annotation conflicts with project meta
4432
4952
  add suffix to obj class or tag meta.
@@ -4434,8 +4954,13 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
4434
4954
  """
4435
4955
  obj_classes_suffixes = ["_nn"]
4436
4956
  tag_meta_suffixes = ["_nn"]
4437
- ann_obj_classes = {}
4438
- ann_tag_metas = {}
4957
+ if model_prediction_suffix is not None:
4958
+ obj_classes_suffixes = [model_prediction_suffix]
4959
+ tag_meta_suffixes = [model_prediction_suffix]
4960
+ logger.debug(
4961
+ f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
4962
+ )
4963
+ logger.debug("source meta", extra={"meta": meta.to_json()})
4439
4964
  meta_changed = False
4440
4965
 
4441
4966
  meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
@@ -4446,91 +4971,289 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
4446
4971
  extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
4447
4972
  )
4448
4973
 
4449
- # get all obj classes and tag metas from annotation
4974
+ updated_labels = []
4975
+ any_label_updated = False
4976
+ for label in ann.labels:
4977
+ original_obj_class_name = label.obj_class.name
4978
+ suffix_found = False
4979
+ for suffix in ["", *obj_classes_suffixes]:
4980
+ label_obj_class = label.obj_class
4981
+ label_obj_class_name = label_obj_class.name + suffix
4982
+ if suffix:
4983
+ label_obj_class = label_obj_class.clone(name=label_obj_class_name)
4984
+ label = label.clone(obj_class=label_obj_class)
4985
+ any_label_updated = True
4986
+ meta_obj_class = meta.get_obj_class(label_obj_class_name)
4987
+ if meta_obj_class is None:
4988
+ # if obj class is not in meta, add it with suffix
4989
+ meta = meta.add_obj_class(label_obj_class)
4990
+ updated_labels.append(label)
4991
+ meta_changed = True
4992
+ suffix_found = True
4993
+ break
4994
+ elif meta_obj_class.geometry_type.geometry_name() == label.geometry.geometry_name():
4995
+ # if label geometry is the same as in meta, use meta obj class
4996
+ label = label.clone(obj_class=meta_obj_class)
4997
+ updated_labels.append(label)
4998
+ suffix_found = True
4999
+ any_label_updated = True
5000
+ break
5001
+ elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
5002
+ # if meta obj class is AnyGeometry, use it in label
5003
+ label = label.clone(obj_class=meta_obj_class)
5004
+ updated_labels.append(label)
5005
+ suffix_found = True
5006
+ any_label_updated = True
5007
+ break
5008
+ if not suffix_found:
5009
+ # if no suffix found, raise error
5010
+ raise ValueError(
5011
+ f"Can't add obj class {original_obj_class_name} to project meta. "
5012
+ "Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
5013
+ "Please check if model geometry type is compatible with existing obj classes."
5014
+ )
5015
+ if any_label_updated:
5016
+ ann = ann.clone(labels=updated_labels)
5017
+
5018
+ # check if tag metas are in project meta
5019
+ # if not, add them with suffix
5020
+ ann_tag_metas = {}
4450
5021
  for label in ann.labels:
4451
- ann_obj_classes[label.obj_class.name] = label.obj_class
4452
5022
  for tag in label.tags:
4453
5023
  ann_tag_metas[tag.meta.name] = tag.meta
4454
5024
  for tag in ann.img_tags:
4455
5025
  ann_tag_metas[tag.meta.name] = tag.meta
4456
5026
 
4457
- # check if obj classes are in project meta
4458
- # if not, add them.
4459
- # if shape is different, add them with suffix
4460
- changed_obj_classes = {}
4461
- for ann_obj_class in ann_obj_classes.values():
4462
- if meta.get_obj_class(ann_obj_class.name) is None:
4463
- meta = meta.add_obj_class(ann_obj_class)
5027
+ changed_tag_metas = {}
5028
+ for ann_tag_meta in ann_tag_metas.values():
5029
+ meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
5030
+ if meta_tag_meta is None:
5031
+ meta = meta.add_tag_meta(ann_tag_meta)
4464
5032
  meta_changed = True
4465
- elif (
4466
- meta.get_obj_class(ann_obj_class.name).geometry_type != ann_obj_class.geometry_type
4467
- and meta.get_obj_class(ann_obj_class.name).geometry_type != AnyGeometry
4468
- ):
4469
- found = False
4470
- for suffix in obj_classes_suffixes:
4471
- new_obj_class_name = ann_obj_class.name + suffix
4472
- meta_obj_class = meta.get_obj_class(new_obj_class_name)
4473
- if meta_obj_class is None:
4474
- new_obj_class = ann_obj_class.clone(name=new_obj_class_name)
4475
- meta = meta.add_obj_class(new_obj_class)
5033
+ elif not meta_tag_meta.is_compatible(ann_tag_meta):
5034
+ suffix_found = False
5035
+ for suffix in tag_meta_suffixes:
5036
+ new_tag_meta_name = ann_tag_meta.name + suffix
5037
+ meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
5038
+ if meta_tag_meta is None:
5039
+ new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
5040
+ meta = meta.add_tag_meta(new_tag_meta)
5041
+ changed_tag_metas[ann_tag_meta.name] = new_tag_meta
4476
5042
  meta_changed = True
4477
- changed_obj_classes[ann_obj_class.name] = new_obj_class
4478
- found = True
5043
+ suffix_found = True
4479
5044
  break
4480
- if meta_obj_class.geometry_type == ann_obj_class.geometry_type:
4481
- changed_obj_classes[ann_obj_class.name] = meta_obj_class
4482
- found = True
5045
+ if meta_tag_meta.is_compatible(ann_tag_meta):
5046
+ changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
5047
+ suffix_found = True
4483
5048
  break
4484
- if not found:
4485
- raise ValueError(f"Can't add obj class {ann_obj_class.name} to project meta")
5049
+ if not suffix_found:
5050
+ raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
5051
+
5052
+ if changed_tag_metas:
5053
+ labels = []
5054
+ any_label_updated = False
5055
+ for label in ann.labels:
5056
+ any_tag_updated = False
5057
+ label_tags = []
5058
+ for tag in label.tags:
5059
+ if tag.meta.name in changed_tag_metas:
5060
+ label_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5061
+ any_tag_updated = True
5062
+ else:
5063
+ label_tags.append(tag)
5064
+ if any_tag_updated:
5065
+ label = label.clone(tags=TagCollection(label_tags))
5066
+ any_label_updated = True
5067
+ labels.append(label)
5068
+ img_tags = []
5069
+ any_tag_updated = False
5070
+ for tag in ann.img_tags:
5071
+ if tag.meta.name in changed_tag_metas:
5072
+ img_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5073
+ any_tag_updated = True
5074
+ else:
5075
+ img_tags.append(tag)
5076
+ if any_tag_updated or any_label_updated:
5077
+ if any_tag_updated:
5078
+ img_tags = TagCollection(img_tags)
5079
+ else:
5080
+ img_tags = None
5081
+ if not any_label_updated:
5082
+ labels = None
5083
+ ann = ann.clone(img_tags=img_tags)
5084
+ return meta, ann, meta_changed
5085
+
5086
+
5087
+ def update_meta_and_ann_for_video_annotation(
5088
+ meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
5089
+ ):
5090
+ """Update project meta and annotation to match each other
5091
+ If obj class or tag meta from annotation conflicts with project meta
5092
+ add suffix to obj class or tag meta.
5093
+ Return tuple of updated project meta, annotation and boolean flag if meta was changed.
5094
+ """
5095
+ obj_classes_suffixes = ["_nn"]
5096
+ tag_meta_suffixes = ["_nn"]
5097
+ if model_prediction_suffix is not None:
5098
+ obj_classes_suffixes = [model_prediction_suffix]
5099
+ tag_meta_suffixes = [model_prediction_suffix]
5100
+ logger.debug(
5101
+ f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
5102
+ )
5103
+ logger.debug("source meta", extra={"meta": meta.to_json()})
5104
+ meta_changed = False
5105
+
5106
+ # meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
5107
+ # if replaced_classes_in_meta:
5108
+ # meta_changed = True
5109
+ # logger.warning(
5110
+ # "Some classes names were fixed in project meta",
5111
+ # extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
5112
+ # )
5113
+
5114
+ new_objects: List[VideoObject] = []
5115
+ new_figures: List[VideoFigure] = []
5116
+ any_object_updated = False
5117
+ for video_object in ann.objects:
5118
+ this_object_figures = [
5119
+ figure for figure in ann.figures if figure.video_object.key() == video_object.key()
5120
+ ]
5121
+ this_object_changed = False
5122
+ original_obj_class_name = video_object.obj_class.name
5123
+ suffix_found = False
5124
+ for suffix in ["", *obj_classes_suffixes]:
5125
+ obj_class = video_object.obj_class
5126
+ obj_class_name = obj_class.name + suffix
5127
+ if suffix:
5128
+ obj_class = obj_class.clone(name=obj_class_name)
5129
+ video_object = video_object.clone(obj_class=obj_class)
5130
+ any_object_updated = True
5131
+ this_object_changed = True
5132
+ meta_obj_class = meta.get_obj_class(obj_class_name)
5133
+ if meta_obj_class is None:
5134
+ # obj class is not in meta, add it with suffix
5135
+ meta = meta.add_obj_class(obj_class)
5136
+ new_objects.append(video_object)
5137
+ meta_changed = True
5138
+ suffix_found = True
5139
+ break
5140
+ elif (
5141
+ meta_obj_class.geometry_type.geometry_name()
5142
+ == video_object.obj_class.geometry_type.geometry_name()
5143
+ ):
5144
+ # if object geometry is the same as in meta, use meta obj class
5145
+ video_object = video_object.clone(obj_class=meta_obj_class)
5146
+ new_objects.append(video_object)
5147
+ suffix_found = True
5148
+ any_object_updated = True
5149
+ this_object_changed = True
5150
+ break
5151
+ elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
5152
+ # if meta obj class is AnyGeometry, use it in object
5153
+ video_object = video_object.clone(obj_class=meta_obj_class)
5154
+ new_objects.append(video_object)
5155
+ suffix_found = True
5156
+ any_object_updated = True
5157
+ this_object_changed = True
5158
+ break
5159
+ if not suffix_found:
5160
+ # if no suffix found, raise error
5161
+ raise ValueError(
5162
+ f"Can't add obj class {original_obj_class_name} to project meta. "
5163
+ "Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
5164
+ "Please check if model geometry type is compatible with existing obj classes."
5165
+ )
5166
+ elif this_object_changed:
5167
+ this_object_figures = [
5168
+ figure.clone(video_object=video_object) for figure in this_object_figures
5169
+ ]
5170
+ new_figures.extend(this_object_figures)
5171
+ if any_object_updated:
5172
+ frames_figures = {}
5173
+ for figure in new_figures:
5174
+ frames_figures.setdefault(figure.frame_index, []).append(figure)
5175
+ new_frames = FrameCollection(
5176
+ [
5177
+ Frame(index=frame_index, figures=figures)
5178
+ for frame_index, figures in frames_figures.items()
5179
+ ]
5180
+ )
5181
+ ann = ann.clone(objects=new_objects, frames=new_frames)
4486
5182
 
4487
5183
  # check if tag metas are in project meta
4488
5184
  # if not, add them with suffix
5185
+ ann_tag_metas: Dict[str, TagMeta] = {}
5186
+ for video_object in ann.objects:
5187
+ for tag in video_object.tags:
5188
+ tag_name = tag.meta.name
5189
+ if tag_name not in ann_tag_metas:
5190
+ ann_tag_metas[tag_name] = tag.meta
5191
+ for tag in ann.tags:
5192
+ tag_name = tag.meta.name
5193
+ if tag_name not in ann_tag_metas:
5194
+ ann_tag_metas[tag_name] = tag.meta
5195
+
4489
5196
  changed_tag_metas = {}
4490
- for tag_meta in ann_tag_metas.values():
4491
- if meta.get_tag_meta(tag_meta.name) is None:
4492
- meta = meta.add_tag_meta(tag_meta)
5197
+ for ann_tag_meta in ann_tag_metas.values():
5198
+ meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
5199
+ if meta_tag_meta is None:
5200
+ meta = meta.add_tag_meta(ann_tag_meta)
4493
5201
  meta_changed = True
4494
- elif not meta.get_tag_meta(tag_meta.name).is_compatible(tag_meta):
4495
- found = False
5202
+ elif not meta_tag_meta.is_compatible(ann_tag_meta):
5203
+ suffix_found = False
4496
5204
  for suffix in tag_meta_suffixes:
4497
- new_tag_meta_name = tag_meta.name + suffix
5205
+ new_tag_meta_name = ann_tag_meta.name + suffix
4498
5206
  meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
4499
5207
  if meta_tag_meta is None:
4500
- new_tag_meta = tag_meta.clone(name=new_tag_meta_name)
5208
+ new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
4501
5209
  meta = meta.add_tag_meta(new_tag_meta)
4502
- changed_tag_metas[tag_meta.name] = new_tag_meta
5210
+ changed_tag_metas[ann_tag_meta.name] = new_tag_meta
4503
5211
  meta_changed = True
4504
- found = True
5212
+ suffix_found = True
4505
5213
  break
4506
- if meta_tag_meta.is_compatible(tag_meta):
4507
- changed_tag_metas[tag_meta.name] = meta_tag_meta
4508
- found = True
5214
+ if meta_tag_meta.is_compatible(ann_tag_meta):
5215
+ changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
5216
+ suffix_found = True
4509
5217
  break
4510
- if not found:
4511
- raise ValueError(f"Can't add tag meta {tag_meta.name} to project meta")
4512
-
4513
- labels = []
4514
- for label in ann.labels:
4515
- if label.obj_class.name in changed_obj_classes:
4516
- label = label.clone(obj_class=changed_obj_classes[label.obj_class.name])
4517
-
4518
- label_tags = []
4519
- for tag in label.tags:
5218
+ if not suffix_found:
5219
+ raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
5220
+
5221
+ if changed_tag_metas:
5222
+ objects = []
5223
+ any_object_updated = False
5224
+ for video_object in ann.objects:
5225
+ any_tag_updated = False
5226
+ object_tags = []
5227
+ for tag in video_object.tags:
5228
+ if tag.meta.name in changed_tag_metas:
5229
+ object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5230
+ any_tag_updated = True
5231
+ else:
5232
+ object_tags.append(tag)
5233
+ if any_tag_updated:
5234
+ video_object = video_object.clone(tags=TagCollection(object_tags))
5235
+ any_object_updated = True
5236
+ objects.append(video_object)
5237
+
5238
+ video_tags = []
5239
+ any_tag_updated = False
5240
+ for tag in ann.tags:
4520
5241
  if tag.meta.name in changed_tag_metas:
4521
- label_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5242
+ video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5243
+ any_tag_updated = True
4522
5244
  else:
4523
- label_tags.append(tag)
4524
-
4525
- labels.append(label.clone(tags=TagCollection(label_tags)))
4526
- img_tags = []
4527
- for tag in ann.img_tags:
4528
- if tag.meta.name in changed_tag_metas:
4529
- img_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
4530
- else:
4531
- img_tags.append(tag)
5245
+ video_tags.append(tag)
5246
+ if any_tag_updated or any_object_updated:
5247
+ if any_tag_updated:
5248
+ video_tags = VideoTagCollection(video_tags)
5249
+ else:
5250
+ video_tags = None
5251
+ if any_object_updated:
5252
+ objects = VideoObjectCollection(objects)
5253
+ else:
5254
+ objects = None
5255
+ ann = ann.clone(tags=video_tags, objects=objects)
4532
5256
 
4533
- ann = ann.clone(labels=labels, img_tags=TagCollection(img_tags))
4534
5257
  return meta, ann, meta_changed
4535
5258
 
4536
5259
 
@@ -4643,3 +5366,22 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
4643
5366
  continue
4644
5367
  return data[key]
4645
5368
  return None
5369
+
5370
+
5371
+ def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
5372
+ import torch # pylint: disable=import-error
5373
+
5374
+ # TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
5375
+ try:
5376
+ logger.debug(f"Loading checkpoint from {checkpoint_path} on {device}")
5377
+ checkpoint = torch.load(checkpoint_path, map_location=device)
5378
+ logger.debug(f"Checkpoint loaded from {checkpoint_path} on {device}")
5379
+ except:
5380
+ logger.debug(
5381
+ f"Failed to load checkpoint from {checkpoint_path} on {device}. Trying again with weights_only=False"
5382
+ )
5383
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
5384
+ logger.debug(
5385
+ f"Checkpoint loaded from {checkpoint_path} on {device} with weights_only=False"
5386
+ )
5387
+ return checkpoint