supervisely 6.73.444__py3-none-any.whl → 6.73.468__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (68) hide show
  1. supervisely/__init__.py +24 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/api/dataset_api.py +74 -12
  5. supervisely/api/entity_annotation/figure_api.py +8 -5
  6. supervisely/api/image_api.py +4 -0
  7. supervisely/api/video/video_annotation_api.py +4 -2
  8. supervisely/api/video/video_api.py +41 -1
  9. supervisely/app/__init__.py +1 -1
  10. supervisely/app/content.py +14 -6
  11. supervisely/app/fastapi/__init__.py +1 -0
  12. supervisely/app/fastapi/custom_static_files.py +1 -1
  13. supervisely/app/fastapi/multi_user.py +88 -0
  14. supervisely/app/fastapi/subapp.py +88 -42
  15. supervisely/app/fastapi/websocket.py +77 -9
  16. supervisely/app/singleton.py +21 -0
  17. supervisely/app/v1/app_service.py +18 -2
  18. supervisely/app/v1/constants.py +7 -1
  19. supervisely/app/widgets/card/card.py +20 -0
  20. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  21. supervisely/app/widgets/dialog/dialog.py +12 -0
  22. supervisely/app/widgets/dialog/template.html +2 -1
  23. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  24. supervisely/app/widgets/fast_table/fast_table.py +121 -31
  25. supervisely/app/widgets/fast_table/template.html +1 -1
  26. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  27. supervisely/app/widgets/radio_tabs/template.html +1 -0
  28. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
  29. supervisely/app/widgets/table/table.py +68 -13
  30. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  31. supervisely/convert/image/csv/csv_converter.py +24 -15
  32. supervisely/convert/video/video_converter.py +2 -2
  33. supervisely/geometry/polyline_3d.py +110 -0
  34. supervisely/io/env.py +76 -1
  35. supervisely/nn/inference/cache.py +37 -17
  36. supervisely/nn/inference/inference.py +667 -114
  37. supervisely/nn/inference/inference_request.py +15 -8
  38. supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
  39. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  40. supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
  41. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  42. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  43. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  44. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  45. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  46. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  47. supervisely/nn/inference/session.py +43 -35
  48. supervisely/nn/model/model_api.py +9 -0
  49. supervisely/nn/model/prediction_session.py +8 -7
  50. supervisely/nn/prediction_dto.py +7 -0
  51. supervisely/nn/tracker/base_tracker.py +11 -1
  52. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  53. supervisely/nn/tracker/botsort_tracker.py +14 -7
  54. supervisely/nn/tracker/visualize.py +70 -72
  55. supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
  56. supervisely/nn/training/train_app.py +10 -5
  57. supervisely/project/project.py +9 -1
  58. supervisely/video/sampling.py +39 -20
  59. supervisely/video/video.py +41 -12
  60. supervisely/volume/stl_converter.py +2 -0
  61. supervisely/worker_api/agent_rpc.py +24 -1
  62. supervisely/worker_api/rpc_servicer.py +31 -7
  63. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/METADATA +14 -11
  64. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/RECORD +68 -66
  65. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/LICENSE +0 -0
  66. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/WHEEL +0 -0
  67. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/entry_points.txt +0 -0
  68. {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ import subprocess
11
11
  import tempfile
12
12
  import threading
13
13
  import time
14
+ import uuid
14
15
  from collections import OrderedDict, defaultdict
15
16
  from concurrent.futures import ThreadPoolExecutor
16
17
  from dataclasses import asdict, dataclass
@@ -52,6 +53,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
52
53
  from supervisely.api.api import Api, ApiField
53
54
  from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
54
55
  from supervisely.api.image_api import ImageInfo
56
+ from supervisely.api.video.video_api import VideoInfo
55
57
  from supervisely.app.content import get_data_dir
56
58
  from supervisely.app.fastapi.subapp import (
57
59
  Application,
@@ -67,6 +69,7 @@ from supervisely.decorators.inference import (
67
69
  process_images_batch_sliding_window,
68
70
  )
69
71
  from supervisely.geometry.any_geometry import AnyGeometry
72
+ from supervisely.geometry.geometry import Geometry
70
73
  from supervisely.imaging.color import get_predefined_colors
71
74
  from supervisely.io.fs import list_files
72
75
  from supervisely.nn.experiments import ExperimentInfo
@@ -94,6 +97,18 @@ from supervisely.project.project_meta import ProjectMeta
94
97
  from supervisely.sly_logger import logger
95
98
  from supervisely.task.progress import Progress
96
99
  from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
100
+ from supervisely.video_annotation.frame import Frame
101
+ from supervisely.video_annotation.frame_collection import FrameCollection
102
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
103
+ from supervisely.video_annotation.video_figure import VideoFigure
104
+ from supervisely.video_annotation.video_object import VideoObject
105
+ from supervisely.video_annotation.video_object_collection import VideoObjectCollection
106
+ from supervisely.video_annotation.video_tag_collection import VideoTagCollection
107
+ from supervisely.video_annotation.key_id_map import KeyIdMap
108
+ from supervisely.video_annotation.video_object_collection import (
109
+ VideoObject,
110
+ VideoObjectCollection,
111
+ )
97
112
 
98
113
  try:
99
114
  from typing import Literal
@@ -140,6 +155,7 @@ class Inference:
140
155
  """Default batch size for inference"""
141
156
  INFERENCE_SETTINGS: str = None
142
157
  """Path to file with custom inference settings"""
158
+ DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
143
159
 
144
160
  def __init__(
145
161
  self,
@@ -193,7 +209,6 @@ class Inference:
193
209
  self._task_id = None
194
210
  self._sliding_window_mode = sliding_window_mode
195
211
  self._autostart_delay_time = 5 * 60 # 5 min
196
- self._tracker = None
197
212
  self._hardware: str = None
198
213
  if custom_inference_settings is None:
199
214
  if self.INFERENCE_SETTINGS is not None:
@@ -427,7 +442,7 @@ class Inference:
427
442
 
428
443
  device = "cuda" if torch.cuda.is_available() else "cpu"
429
444
  except Exception as e:
430
- logger.warn(
445
+ logger.warning(
431
446
  f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
432
447
  )
433
448
  device = "cpu"
@@ -1105,31 +1120,37 @@ class Inference:
1105
1120
  self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
1106
1121
  self._hardware = get_hardware_info(self.device)
1107
1122
 
1108
- checkpoint_path = deploy_params["model_files"]["checkpoint"]
1109
- checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
1110
- if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
1111
- try:
1112
- self.load_model(**deploy_params)
1113
- except Exception as e:
1114
- logger.warning(f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}")
1115
- checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
1116
- deploy_params["model_files"]["checkpoint"] = checkpoint_path
1117
- logger.info("Exporting PyTorch model to TensorRT...")
1118
- self._remove_exported_checkpoints(checkpoint_path)
1119
- checkpoint_path = self.export_tensorrt(deploy_params)
1123
+ model_files = deploy_params.get("model_files", None)
1124
+ if model_files is not None:
1125
+ checkpoint_path = deploy_params["model_files"]["checkpoint"]
1126
+ checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
1127
+ if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
1128
+ try:
1129
+ self.load_model(**deploy_params)
1130
+ except Exception as e:
1131
+ logger.warning(
1132
+ f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}"
1133
+ )
1134
+ checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
1135
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1136
+ logger.info("Exporting PyTorch model to TensorRT...")
1137
+ self._remove_exported_checkpoints(checkpoint_path)
1138
+ checkpoint_path = self.export_tensorrt(deploy_params)
1139
+ deploy_params["model_files"]["checkpoint"] = checkpoint_path
1140
+ self.load_model(**deploy_params)
1141
+ if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
1142
+ if self.runtime == RuntimeType.ONNXRUNTIME:
1143
+ logger.info("Exporting PyTorch model to ONNX...")
1144
+ self._remove_exported_checkpoints(checkpoint_path)
1145
+ checkpoint_path = self.export_onnx(deploy_params)
1146
+ elif self.runtime == RuntimeType.TENSORRT:
1147
+ logger.info("Exporting PyTorch model to TensorRT...")
1148
+ self._remove_exported_checkpoints(checkpoint_path)
1149
+ checkpoint_path = self.export_tensorrt(deploy_params)
1120
1150
  deploy_params["model_files"]["checkpoint"] = checkpoint_path
1121
1151
  self.load_model(**deploy_params)
1122
- if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
1123
- if self.runtime == RuntimeType.ONNXRUNTIME:
1124
- logger.info("Exporting PyTorch model to ONNX...")
1125
- self._remove_exported_checkpoints(checkpoint_path)
1126
- checkpoint_path = self.export_onnx(deploy_params)
1127
- elif self.runtime == RuntimeType.TENSORRT:
1128
- logger.info("Exporting PyTorch model to TensorRT...")
1129
- self._remove_exported_checkpoints(checkpoint_path)
1130
- checkpoint_path = self.export_tensorrt(deploy_params)
1131
- deploy_params["model_files"]["checkpoint"] = checkpoint_path
1132
- self.load_model(**deploy_params)
1152
+ else:
1153
+ self.load_model(**deploy_params)
1133
1154
  else:
1134
1155
  self.load_model(**deploy_params)
1135
1156
 
@@ -1253,7 +1274,6 @@ class Inference:
1253
1274
  if self._model_meta is None:
1254
1275
  self._set_model_meta_from_classes()
1255
1276
 
1256
-
1257
1277
  def _set_model_meta_custom_model(self, model_info: dict):
1258
1278
  model_meta = model_info.get("model_meta")
1259
1279
  if model_meta is None:
@@ -1354,6 +1374,7 @@ class Inference:
1354
1374
 
1355
1375
  if tracker == "botsort":
1356
1376
  from supervisely.nn.tracker import BotSortTracker
1377
+
1357
1378
  device = tracker_settings.get("device", self.device)
1358
1379
  logger.debug(f"Initializing BotSort tracker with device: {device}")
1359
1380
  return BotSortTracker(settings=tracker_settings, device=device)
@@ -1370,15 +1391,15 @@ class Inference:
1370
1391
  if classes is not None:
1371
1392
  num_classes = len(classes)
1372
1393
  except NotImplementedError:
1373
- logger.warn(f"get_classes() function not implemented for {type(self)} object.")
1394
+ logger.warning(f"get_classes() function not implemented for {type(self)} object.")
1374
1395
  except AttributeError:
1375
- logger.warn("Probably, get_classes() function not working without model deploy.")
1396
+ logger.warning("Probably, get_classes() function not working without model deploy.")
1376
1397
  except Exception as exc:
1377
- logger.warn("Unknown exception. Please, contact support")
1398
+ logger.warning("Unknown exception. Please, contact support")
1378
1399
  logger.exception(exc)
1379
1400
 
1380
1401
  if num_classes is None:
1381
- logger.warn(f"get_classes() function return {classes}; skip classes processing.")
1402
+ logger.warning(f"get_classes() function return {classes}; skip classes processing.")
1382
1403
 
1383
1404
  return {
1384
1405
  "app_name": get_name_from_env(default="Neural Network Serving"),
@@ -1396,6 +1417,42 @@ class Inference:
1396
1417
 
1397
1418
  # pylint: enable=method-hidden
1398
1419
 
1420
+ def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
1421
+ """
1422
+ Get default parameters for all available tracking algorithms.
1423
+
1424
+ Returns:
1425
+ {"botsort": {"track_high_thresh": 0.6, ...}}
1426
+ Empty dict if tracking not supported.
1427
+ """
1428
+ info = self.get_info()
1429
+ trackers_params = {}
1430
+
1431
+ tracking_support = info.get("tracking_on_videos_support")
1432
+ if not tracking_support:
1433
+ return trackers_params
1434
+
1435
+ tracking_algorithms = info.get("tracking_algorithms", [])
1436
+
1437
+ for tracker_name in tracking_algorithms:
1438
+ try:
1439
+ if tracker_name == "botsort":
1440
+ from supervisely.nn.tracker import BotSortTracker
1441
+
1442
+ trackers_params[tracker_name] = BotSortTracker.get_default_params()
1443
+ # Add other trackers here as elif blocks
1444
+ else:
1445
+ logger.debug(f"Tracker '{tracker_name}' not implemented")
1446
+ except Exception as e:
1447
+ logger.warning(f"Failed to get params for '{tracker_name}': {e}")
1448
+
1449
+ INTERNAL_FIELDS = {"device", "fps"}
1450
+ for tracker_name, params in trackers_params.items():
1451
+ trackers_params[tracker_name] = {
1452
+ k: v for k, v in params.items() if k not in INTERNAL_FIELDS
1453
+ }
1454
+ return trackers_params
1455
+
1399
1456
  def get_human_readable_info(self, replace_none_with: Optional[str] = None):
1400
1457
  hr_info = {}
1401
1458
  info = self.get_info()
@@ -1947,7 +2004,7 @@ class Inference:
1947
2004
  else:
1948
2005
  n_frames = frames_reader.frames_count()
1949
2006
 
1950
- self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2007
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
1951
2008
 
1952
2009
  progress_total = (n_frames + step - 1) // step
1953
2010
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
@@ -1973,8 +2030,8 @@ class Inference:
1973
2030
  settings=inference_settings,
1974
2031
  )
1975
2032
 
1976
- if self._tracker is not None:
1977
- anns = self._apply_tracker_to_anns(frames, anns)
2033
+ if inference_request.tracker is not None:
2034
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
1978
2035
 
1979
2036
  predictions = [
1980
2037
  Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
@@ -1989,10 +2046,9 @@ class Inference:
1989
2046
  inference_request.done(len(batch_results))
1990
2047
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
1991
2048
  video_ann_json = None
1992
- if self._tracker is not None:
2049
+ if inference_request.tracker is not None:
1993
2050
  inference_request.set_stage("Postprocess...", 0, 1)
1994
-
1995
- video_ann_json = self._tracker.video_annotation.to_json()
2051
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
1996
2052
  inference_request.done()
1997
2053
  result = {"ann": results, "video_ann": video_ann_json}
1998
2054
  inference_request.final_result = result.copy()
@@ -2024,7 +2080,7 @@ class Inference:
2024
2080
  upload_mode = state.get("upload_mode", None)
2025
2081
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
2026
2082
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
2027
- iou_merge_threshold = 0.7
2083
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
2028
2084
 
2029
2085
  images_infos = api.image.get_info_by_id_batch(image_ids)
2030
2086
  images_infos_dict = {im_info.id: im_info for im_info in images_infos}
@@ -2146,7 +2202,7 @@ class Inference:
2146
2202
  video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2147
2203
  if video_id is None:
2148
2204
  raise ValueError("Video id is not provided")
2149
- video_info = api.video.get_info_by_id(video_id)
2205
+ video_info = api.video.get_info_by_id(video_id, force_metadata_for_links=True)
2150
2206
  start_frame_index = get_value_for_keys(
2151
2207
  state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2152
2208
  )
@@ -2176,7 +2232,7 @@ class Inference:
2176
2232
  else:
2177
2233
  n_frames = video_info.frames_count
2178
2234
 
2179
- self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2235
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2180
2236
 
2181
2237
  logger.debug(
2182
2238
  f"Video info:",
@@ -2213,8 +2269,8 @@ class Inference:
2213
2269
  settings=inference_settings,
2214
2270
  )
2215
2271
 
2216
- if self._tracker is not None:
2217
- anns = self._apply_tracker_to_anns(frames, anns)
2272
+ if inference_request.tracker is not None:
2273
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2218
2274
 
2219
2275
  predictions = [
2220
2276
  Prediction(
@@ -2223,8 +2279,8 @@ class Inference:
2223
2279
  frame_index=frame_index,
2224
2280
  video_id=video_info.id,
2225
2281
  dataset_id=video_info.dataset_id,
2226
- project_id=video_info.project_id,
2227
- )
2282
+ project_id=video_info.project_id,
2283
+ )
2228
2284
  for ann, frame_index in zip(anns, batch)
2229
2285
  ]
2230
2286
  for pred, this_slides_data in zip(predictions, slides_data):
@@ -2235,13 +2291,169 @@ class Inference:
2235
2291
  inference_request.done(len(batch_results))
2236
2292
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
2237
2293
  video_ann_json = None
2238
- if self._tracker is not None:
2294
+ if inference_request.tracker is not None:
2295
+ inference_request.set_stage("Postprocess...", 0, 1)
2296
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2297
+ inference_request.done()
2298
+ inference_request.final_result = {"video_ann": video_ann_json}
2299
+ return video_ann_json
2300
+
2301
+ def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
2302
+ logger.debug("Inferring video_id...", extra={"state": state})
2303
+ inference_settings = self._get_inference_settings(state)
2304
+ logger.debug(f"Inference settings:", extra=inference_settings)
2305
+ batch_size = self._get_batch_size_from_state(state)
2306
+ video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2307
+ if video_id is None:
2308
+ raise ValueError("Video id is not provided")
2309
+ video_info = api.video.get_info_by_id(video_id)
2310
+ start_frame_index = get_value_for_keys(
2311
+ state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2312
+ )
2313
+ if start_frame_index is None:
2314
+ start_frame_index = 0
2315
+ step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
2316
+ if step is None:
2317
+ step = 1
2318
+ end_frame_index = get_value_for_keys(
2319
+ state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
2320
+ )
2321
+ duration = state.get("duration", None)
2322
+ frames_count = get_value_for_keys(
2323
+ state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
2324
+ )
2325
+ tracking = state.get("tracker", None)
2326
+ direction = state.get("direction", "forward")
2327
+ direction = 1 if direction == "forward" else -1
2328
+ track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
2329
+
2330
+ if frames_count is not None:
2331
+ n_frames = frames_count
2332
+ elif end_frame_index is not None:
2333
+ n_frames = end_frame_index - start_frame_index
2334
+ elif duration is not None:
2335
+ fps = video_info.frames_count / video_info.duration
2336
+ n_frames = int(duration * fps)
2337
+ else:
2338
+ n_frames = video_info.frames_count
2339
+
2340
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2341
+
2342
+ logger.debug(
2343
+ f"Video info:",
2344
+ extra=dict(
2345
+ w=video_info.frame_width,
2346
+ h=video_info.frame_height,
2347
+ start_frame_index=start_frame_index,
2348
+ n_frames=n_frames,
2349
+ ),
2350
+ )
2351
+
2352
+ # start downloading video in background
2353
+ self.cache.run_cache_task_manually(api, None, video_id=video_id)
2354
+
2355
+ progress_total = (n_frames + step - 1) // step
2356
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
2357
+
2358
+ _upload_f = partial(
2359
+ self.upload_predictions_to_video,
2360
+ api=api,
2361
+ video_info=video_info,
2362
+ track_id=track_id,
2363
+ context=inference_request.context,
2364
+ progress_cb=inference_request.done,
2365
+ inference_request=inference_request,
2366
+ )
2367
+
2368
+ _range = (start_frame_index, start_frame_index + direction * n_frames)
2369
+ if _range[0] > _range[1]:
2370
+ _range = (_range[1], _range[0])
2371
+
2372
+ def _notify_f(predictions: List[Prediction]):
2373
+ logger.debug(
2374
+ "Notifying tracking progress...",
2375
+ extra={
2376
+ "track_id": track_id,
2377
+ "range": _range,
2378
+ "current": inference_request.progress.current,
2379
+ "total": inference_request.progress.total,
2380
+ },
2381
+ )
2382
+ stopped = self.api.video.notify_progress(
2383
+ track_id=track_id,
2384
+ video_id=video_info.id,
2385
+ frame_start=_range[0],
2386
+ frame_end=_range[1],
2387
+ current=inference_request.progress.current,
2388
+ total=inference_request.progress.total,
2389
+ )
2390
+ if stopped:
2391
+ inference_request.stop()
2392
+ logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
2393
+
2394
+ def _exception_handler(e: Exception):
2395
+ self.api.video.notify_tracking_error(
2396
+ track_id=track_id,
2397
+ error=str(type(e)),
2398
+ message=str(e),
2399
+ )
2400
+ raise e
2401
+
2402
+ with Uploader(
2403
+ upload_f=_upload_f,
2404
+ notify_f=_notify_f,
2405
+ exception_handler=_exception_handler,
2406
+ logger=logger,
2407
+ ) as uploader:
2408
+ for batch in batched(
2409
+ range(
2410
+ start_frame_index, start_frame_index + direction * n_frames, direction * step
2411
+ ),
2412
+ batch_size,
2413
+ ):
2414
+ if inference_request.is_stopped():
2415
+ logger.debug(
2416
+ f"Cancelling inference video...",
2417
+ extra={"inference_request_uuid": inference_request.uuid},
2418
+ )
2419
+ break
2420
+ logger.debug(
2421
+ f"Inferring frames {batch[0]}-{batch[-1]}:",
2422
+ )
2423
+ frames = self.cache.download_frames(
2424
+ api, video_info.id, batch, redownload_video=True
2425
+ )
2426
+ anns, slides_data = self._inference_auto(
2427
+ source=frames,
2428
+ settings=inference_settings,
2429
+ )
2430
+
2431
+ if inference_request.tracker is not None:
2432
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2433
+
2434
+ predictions = [
2435
+ Prediction(
2436
+ ann,
2437
+ model_meta=self.model_meta,
2438
+ frame_index=frame_index,
2439
+ video_id=video_info.id,
2440
+ dataset_id=video_info.dataset_id,
2441
+ project_id=video_info.project_id,
2442
+ )
2443
+ for ann, frame_index in zip(anns, batch)
2444
+ ]
2445
+ for pred, this_slides_data in zip(predictions, slides_data):
2446
+ pred.extra_data["slides_data"] = this_slides_data
2447
+ uploader.put(predictions)
2448
+ video_ann_json = None
2449
+ if inference_request.tracker is not None:
2239
2450
  inference_request.set_stage("Postprocess...", 0, 1)
2240
- video_ann_json = self._tracker.video_annotation.to_json()
2451
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2241
2452
  inference_request.done()
2242
2453
  inference_request.final_result = {"video_ann": video_ann_json}
2243
2454
  return video_ann_json
2244
2455
 
2456
+
2245
2457
  def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
2246
2458
  """Inference project images.
2247
2459
  If "output_project_id" in state, upload images and annotations to the output project.
@@ -2263,7 +2475,7 @@ class Inference:
2263
2475
  upload_mode = state.get("upload_mode", None)
2264
2476
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
2265
2477
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
2266
- iou_merge_threshold = 0.7
2478
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
2267
2479
  cache_project_on_model = state.get("cache_project_on_model", False)
2268
2480
 
2269
2481
  project_info = api.project.get_info_by_id(project_id)
@@ -2747,10 +2959,10 @@ class Inference:
2747
2959
  context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
2748
2960
  return created_dataset.id
2749
2961
 
2750
- created_names = []
2751
2962
  if context is None:
2752
2963
  context = {}
2753
2964
  for dataset_id, preds in ds_predictions.items():
2965
+ created_names = set()
2754
2966
  if dst_project_id is not None:
2755
2967
  # upload to the destination project
2756
2968
  dst_dataset_id = _get_or_create_dataset(
@@ -2826,7 +3038,7 @@ class Inference:
2826
3038
  with_annotations=False,
2827
3039
  save_source_date=False,
2828
3040
  )
2829
- created_names.extend([image_info.name for image_info in dst_image_infos])
3041
+ created_names.update([image_info.name for image_info in dst_image_infos])
2830
3042
  api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
2831
3043
  else:
2832
3044
  # upload to the source dataset
@@ -2908,6 +3120,83 @@ class Inference:
2908
3120
  inference_request.add_results(results)
2909
3121
  inference_request.done(len(results))
2910
3122
 
3123
+ def upload_predictions_to_video(
3124
+ self,
3125
+ predictions: List[Prediction],
3126
+ api: Api,
3127
+ video_info: VideoInfo,
3128
+ track_id: str,
3129
+ context: Dict,
3130
+ progress_cb=None,
3131
+ inference_request: InferenceRequest = None,
3132
+ ):
3133
+ key_id_map = KeyIdMap()
3134
+ project_meta = context.get("project_meta", None)
3135
+ if project_meta is None:
3136
+ project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
3137
+ context["project_meta"] = project_meta
3138
+ meta_changed = False
3139
+ for prediction in predictions:
3140
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
3141
+ project_meta, prediction.annotation, None
3142
+ )
3143
+ prediction.annotation = ann
3144
+ meta_changed = meta_changed or meta_changed_
3145
+ if meta_changed:
3146
+ project_meta = api.project.update_meta(video_info.project_id, project_meta)
3147
+ context["project_meta"] = project_meta
3148
+
3149
+ figure_data_by_object_id = defaultdict(list)
3150
+
3151
+ tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
3152
+ new_tracks: Dict[int, VideoObject] = {}
3153
+ for prediction in predictions:
3154
+ annotation = prediction.annotation
3155
+ tracks = annotation.custom_data
3156
+ for track, label in zip(tracks, annotation.labels):
3157
+ if track not in tracks_to_object_ids and track not in new_tracks:
3158
+ video_object = VideoObject(obj_class=label.obj_class)
3159
+ new_tracks[track] = video_object
3160
+ if new_tracks:
3161
+ tracks, video_objects = zip(*new_tracks.items())
3162
+ added_object_ids = api.video.object.append_bulk(
3163
+ video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
3164
+ )
3165
+ for track, object_id in zip(tracks, added_object_ids):
3166
+ tracks_to_object_ids[track] = object_id
3167
+ for prediction in predictions:
3168
+ annotation = prediction.annotation
3169
+ tracks = annotation.custom_data
3170
+ for track, label in zip(tracks, annotation.labels):
3171
+ object_id = tracks_to_object_ids[track]
3172
+ figure_data_by_object_id[object_id].append(
3173
+ {
3174
+ ApiField.OBJECT_ID: object_id,
3175
+ ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
3176
+ ApiField.GEOMETRY: label.geometry.to_json(),
3177
+ ApiField.META: {ApiField.FRAME: prediction.frame_index},
3178
+ ApiField.TRACK_ID: track_id,
3179
+ }
3180
+ )
3181
+
3182
+ for object_id, figures_data in figure_data_by_object_id.items():
3183
+ figures_keys = [uuid.uuid4() for _ in figures_data]
3184
+ api.video.figure._append_bulk(
3185
+ entity_id=video_info.id,
3186
+ figures_json=figures_data,
3187
+ figures_keys=figures_keys,
3188
+ key_id_map=key_id_map,
3189
+ )
3190
+ logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
3191
+ if progress_cb:
3192
+ progress_cb(len(predictions))
3193
+ if inference_request is not None:
3194
+ results = self._format_output(predictions)
3195
+ for result in results:
3196
+ result["annotation"] = None
3197
+ result["data"] = None
3198
+ inference_request.add_results(results)
3199
+
2911
3200
  def serve(self):
2912
3201
  if not self._use_gui and not self._is_cli_deploy:
2913
3202
  Progress("Deploying model ...", 1)
@@ -3017,6 +3306,11 @@ class Inference:
3017
3306
  def get_session_info(response: Response):
3018
3307
  return self.get_info()
3019
3308
 
3309
+ @server.post("/get_tracking_settings")
3310
+ @self._check_serve_before_call
3311
+ def get_tracking_settings(response: Response):
3312
+ return self.get_tracking_settings()
3313
+
3020
3314
  @server.post("/get_custom_inference_settings")
3021
3315
  def get_custom_inference_settings():
3022
3316
  return {"settings": self.custom_inference_settings}
@@ -3300,6 +3594,22 @@ class Inference:
3300
3594
  "inference_request_uuid": inference_request.uuid,
3301
3595
  }
3302
3596
 
3597
+ @server.post("/tracking_by_detection")
3598
+ def tracking_by_detection(response: Response, request: Request):
3599
+ state = request.state.state
3600
+ context = request.state.context
3601
+ state.update(context)
3602
+ if state.get("tracker") is None:
3603
+ state["tracker"] = "botsort"
3604
+
3605
+ logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
3606
+ self.validate_inference_state(state)
3607
+ api = self.api_from_request(request)
3608
+ inference_request, future = self.inference_requests_manager.schedule_task(
3609
+ self._tracking_by_detection, api, state
3610
+ )
3611
+ return {"message": "Track task started."}
3612
+
3303
3613
  @server.post("/inference_project_id_async")
3304
3614
  def inference_project_id_async(response: Response, request: Request):
3305
3615
  state = request.state.state
@@ -3363,10 +3673,7 @@ class Inference:
3363
3673
  data = {**inference_request.to_json(), **log_extra}
3364
3674
  if inference_request.stage != InferenceRequest.Stage.INFERENCE:
3365
3675
  data["progress"] = {"current": 0, "total": 1}
3366
- logger.debug(
3367
- f"Sending inference progress with uuid:",
3368
- extra=data,
3369
- )
3676
+ logger.debug(f"Sending inference progress with uuid:", extra=data)
3370
3677
  return data
3371
3678
 
3372
3679
  @server.post(f"/pop_inference_results")
@@ -4223,10 +4530,10 @@ class Inference:
4223
4530
  self._args.draw,
4224
4531
  )
4225
4532
 
4226
- def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
4533
+ def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
4227
4534
  updated_anns = []
4228
4535
  for frame, ann in zip(frames, anns):
4229
- matches = self._tracker.update(frame, ann)
4536
+ matches = tracker.update(frame, ann)
4230
4537
  track_ids = [match["track_id"] for match in matches]
4231
4538
  tracked_labels = [match["label"] for match in matches]
4232
4539
 
@@ -4292,61 +4599,72 @@ class Inference:
4292
4599
  def export_tensorrt(self, deploy_params: dict):
4293
4600
  raise NotImplementedError("Have to be implemented in child class after inheritance")
4294
4601
 
4295
- def _exclude_duplicated_predictions(
4296
- api: Api,
4297
- pred_anns: List[Annotation],
4298
- dataset_id: int,
4299
- gt_image_ids: List[int],
4300
- iou: float = None,
4301
- meta: Optional[ProjectMeta] = None,
4602
+
4603
+ def _filter_duplicated_predictions_from_ann_cpu(
4604
+ gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
4302
4605
  ):
4303
4606
  """
4304
- Filter out predictions that significantly overlap with ground truth (GT) objects.
4305
-
4306
- This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4307
- - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4308
- - Gets ProjectMeta object if not provided
4309
- - Downloads GT annotations for the specified image IDs
4310
- - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4607
+ Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
4608
+ Uses Shapely for geometric operations.
4311
4609
 
4312
- :param api: Supervisely API object
4313
- :type api: Api
4314
- :param pred_anns: List of Annotation objects containing predictions
4315
- :type pred_anns: List[Annotation]
4316
- :param dataset_id: ID of the dataset containing the images
4317
- :type dataset_id: int
4318
- :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4319
- :type gt_image_ids: List[int]
4320
- :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4321
- ground truth box of the same class will be removed. None if no filtering is needed
4322
- :type iou: Optional[float]
4323
- :param meta: ProjectMeta object
4324
- :type meta: Optional[ProjectMeta]
4325
- :return: List of Annotation objects containing filtered predictions
4326
- :rtype: List[Annotation]
4610
+ Args:
4611
+ pred_ann: Predicted annotation object
4612
+ gt_ann: Ground truth annotation object
4613
+ iou_threshold: IoU threshold for filtering
4327
4614
 
4328
- Notes:
4329
- ------
4330
- - Requires PyTorch and torchvision for IoU calculations
4331
- - This method is useful for identifying new objects that aren't already annotated in the ground truth
4615
+ Returns:
4616
+ New annotation with filtered labels
4332
4617
  """
4333
- if isinstance(iou, float) and 0 < iou <= 1:
4334
- if meta is None:
4335
- ds = api.dataset.get_info_by_id(dataset_id)
4336
- meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4337
- gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4338
- gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4339
- for i in range(0, len(pred_anns)):
4340
- before = len(pred_anns[i].labels)
4341
- with Timer() as timer:
4342
- pred_anns[i] = _filter_duplicated_predictions_from_ann(
4343
- gt_anns[i], pred_anns[i], iou
4344
- )
4345
- after = len(pred_anns[i].labels)
4346
- logger.debug(
4347
- f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4348
- )
4349
- return pred_anns
4618
+ if not iou_threshold:
4619
+ return pred_ann
4620
+
4621
+ from shapely.geometry import box
4622
+
4623
+ def calculate_iou(geom1: Geometry, geom2: Geometry):
4624
+ """Calculate IoU between two geometries using Shapely."""
4625
+ bbox1 = geom1.to_bbox()
4626
+ bbox2 = geom2.to_bbox()
4627
+
4628
+ box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
4629
+ box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
4630
+
4631
+ intersection = box1.intersection(box2).area
4632
+ union = box1.union(box2).area
4633
+
4634
+ return intersection / union if union > 0 else 0.0
4635
+
4636
+ new_labels = []
4637
+ pred_cls_bboxes = defaultdict(list)
4638
+ for label in pred_ann.labels:
4639
+ name_shape = (label.obj_class.name, label.geometry.name())
4640
+ pred_cls_bboxes[name_shape].append(label)
4641
+
4642
+ gt_cls_bboxes = defaultdict(list)
4643
+ for label in gt_ann.labels:
4644
+ name_shape = (label.obj_class.name, label.geometry.name())
4645
+ if name_shape not in pred_cls_bboxes:
4646
+ continue
4647
+ gt_cls_bboxes[name_shape].append(label)
4648
+
4649
+ for name_shape, pred in pred_cls_bboxes.items():
4650
+ gt = gt_cls_bboxes[name_shape]
4651
+ if len(gt) == 0:
4652
+ new_labels.extend(pred)
4653
+ continue
4654
+
4655
+ for pred_label in pred:
4656
+ # Check if this prediction has IoU < threshold with ALL GT boxes
4657
+ keep = True
4658
+ for gt_label in gt:
4659
+ iou = calculate_iou(pred_label.geometry, gt_label.geometry)
4660
+ if iou >= iou_threshold:
4661
+ keep = False
4662
+ break
4663
+
4664
+ if keep:
4665
+ new_labels.append(pred_label)
4666
+
4667
+ return pred_ann.clone(labels=new_labels)
4350
4668
 
4351
4669
 
4352
4670
  def _filter_duplicated_predictions_from_ann(
@@ -4377,13 +4695,15 @@ def _filter_duplicated_predictions_from_ann(
4377
4695
  - Predictions with classes not present in ground truth will be kept
4378
4696
  - Requires PyTorch and torchvision for IoU calculations
4379
4697
  """
4698
+ if not iou_threshold:
4699
+ return pred_ann
4380
4700
 
4381
4701
  try:
4382
4702
  import torch
4383
4703
  from torchvision.ops import box_iou
4384
4704
 
4385
4705
  except ImportError:
4386
- raise ImportError("Please install PyTorch and torchvision to use this feature.")
4706
+ return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
4387
4707
 
4388
4708
  def _to_tensor(geom):
4389
4709
  return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
@@ -4391,16 +4711,18 @@ def _filter_duplicated_predictions_from_ann(
4391
4711
  new_labels = []
4392
4712
  pred_cls_bboxes = defaultdict(list)
4393
4713
  for label in pred_ann.labels:
4394
- pred_cls_bboxes[label.obj_class.name].append(label)
4714
+ name_shape = (label.obj_class.name, label.geometry.name())
4715
+ pred_cls_bboxes[name_shape].append(label)
4395
4716
 
4396
4717
  gt_cls_bboxes = defaultdict(list)
4397
4718
  for label in gt_ann.labels:
4398
- if label.obj_class.name not in pred_cls_bboxes:
4719
+ name_shape = (label.obj_class.name, label.geometry.name())
4720
+ if name_shape not in pred_cls_bboxes:
4399
4721
  continue
4400
- gt_cls_bboxes[label.obj_class.name].append(label)
4722
+ gt_cls_bboxes[name_shape].append(label)
4401
4723
 
4402
- for name, pred in pred_cls_bboxes.items():
4403
- gt = gt_cls_bboxes[name]
4724
+ for name_shape, pred in pred_cls_bboxes.items():
4725
+ gt = gt_cls_bboxes[name_shape]
4404
4726
  if len(gt) == 0:
4405
4727
  new_labels.extend(pred)
4406
4728
  continue
@@ -4414,6 +4736,63 @@ def _filter_duplicated_predictions_from_ann(
4414
4736
  return pred_ann.clone(labels=new_labels)
4415
4737
 
4416
4738
 
4739
+ def _exclude_duplicated_predictions(
4740
+ api: Api,
4741
+ pred_anns: List[Annotation],
4742
+ dataset_id: int,
4743
+ gt_image_ids: List[int],
4744
+ iou: float = None,
4745
+ meta: Optional[ProjectMeta] = None,
4746
+ ):
4747
+ """
4748
+ Filter out predictions that significantly overlap with ground truth (GT) objects.
4749
+
4750
+ This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4751
+ - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4752
+ - Gets ProjectMeta object if not provided
4753
+ - Downloads GT annotations for the specified image IDs
4754
+ - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4755
+
4756
+ :param api: Supervisely API object
4757
+ :type api: Api
4758
+ :param pred_anns: List of Annotation objects containing predictions
4759
+ :type pred_anns: List[Annotation]
4760
+ :param dataset_id: ID of the dataset containing the images
4761
+ :type dataset_id: int
4762
+ :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4763
+ :type gt_image_ids: List[int]
4764
+ :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4765
+ ground truth box of the same class will be removed. None if no filtering is needed
4766
+ :type iou: Optional[float]
4767
+ :param meta: ProjectMeta object
4768
+ :type meta: Optional[ProjectMeta]
4769
+ :return: List of Annotation objects containing filtered predictions
4770
+ :rtype: List[Annotation]
4771
+
4772
+ Notes:
4773
+ ------
4774
+ - Requires PyTorch and torchvision for IoU calculations
4775
+ - This method is useful for identifying new objects that aren't already annotated in the ground truth
4776
+ """
4777
+ if isinstance(iou, float) and 0 < iou <= 1:
4778
+ if meta is None:
4779
+ ds = api.dataset.get_info_by_id(dataset_id)
4780
+ meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4781
+ gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4782
+ gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4783
+ for i in range(0, len(pred_anns)):
4784
+ before = len(pred_anns[i].labels)
4785
+ with Timer() as timer:
4786
+ pred_anns[i] = _filter_duplicated_predictions_from_ann(
4787
+ gt_anns[i], pred_anns[i], iou
4788
+ )
4789
+ after = len(pred_anns[i].labels)
4790
+ logger.debug(
4791
+ f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4792
+ )
4793
+ return pred_anns
4794
+
4795
+
4417
4796
  def _get_log_extra_for_inference_request(
4418
4797
  inference_request_uuid, inference_request: Union[InferenceRequest, dict]
4419
4798
  ):
@@ -4440,8 +4819,8 @@ def _get_log_extra_for_inference_request(
4440
4819
  "has_result": inference_request.final_result is not None,
4441
4820
  "pending_results": inference_request.pending_num(),
4442
4821
  "exception": inference_request.exception_json(),
4443
- "result": inference_request._final_result,
4444
4822
  "preparing_progress": progress,
4823
+ "result": inference_request.final_result is not None, # for backward compatibility
4445
4824
  }
4446
4825
  return log_extra
4447
4826
 
@@ -4521,7 +4900,7 @@ def get_gpu_count():
4521
4900
  gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
4522
4901
  return gpu_count
4523
4902
  except (subprocess.CalledProcessError, FileNotFoundError) as exc:
4524
- logger.warn("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4903
+ logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4525
4904
  return 0
4526
4905
 
4527
4906
 
@@ -4701,7 +5080,180 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suf
4701
5080
  img_tags = None
4702
5081
  if not any_label_updated:
4703
5082
  labels = None
4704
- ann = ann.clone(img_tags=TagCollection(img_tags))
5083
+ ann = ann.clone(img_tags=img_tags)
5084
+ return meta, ann, meta_changed
5085
+
5086
+
5087
+ def update_meta_and_ann_for_video_annotation(
5088
+ meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
5089
+ ):
5090
+ """Update project meta and annotation to match each other
5091
+ If obj class or tag meta from annotation conflicts with project meta
5092
+ add suffix to obj class or tag meta.
5093
+ Return tuple of updated project meta, annotation and boolean flag if meta was changed.
5094
+ """
5095
+ obj_classes_suffixes = ["_nn"]
5096
+ tag_meta_suffixes = ["_nn"]
5097
+ if model_prediction_suffix is not None:
5098
+ obj_classes_suffixes = [model_prediction_suffix]
5099
+ tag_meta_suffixes = [model_prediction_suffix]
5100
+ logger.debug(
5101
+ f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
5102
+ )
5103
+ logger.debug("source meta", extra={"meta": meta.to_json()})
5104
+ meta_changed = False
5105
+
5106
+ # meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
5107
+ # if replaced_classes_in_meta:
5108
+ # meta_changed = True
5109
+ # logger.warning(
5110
+ # "Some classes names were fixed in project meta",
5111
+ # extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
5112
+ # )
5113
+
5114
+ new_objects: List[VideoObject] = []
5115
+ new_figures: List[VideoFigure] = []
5116
+ any_object_updated = False
5117
+ for video_object in ann.objects:
5118
+ this_object_figures = [
5119
+ figure for figure in ann.figures if figure.video_object.key() == video_object.key()
5120
+ ]
5121
+ this_object_changed = False
5122
+ original_obj_class_name = video_object.obj_class.name
5123
+ suffix_found = False
5124
+ for suffix in ["", *obj_classes_suffixes]:
5125
+ obj_class = video_object.obj_class
5126
+ obj_class_name = obj_class.name + suffix
5127
+ if suffix:
5128
+ obj_class = obj_class.clone(name=obj_class_name)
5129
+ video_object = video_object.clone(obj_class=obj_class)
5130
+ any_object_updated = True
5131
+ this_object_changed = True
5132
+ meta_obj_class = meta.get_obj_class(obj_class_name)
5133
+ if meta_obj_class is None:
5134
+ # obj class is not in meta, add it with suffix
5135
+ meta = meta.add_obj_class(obj_class)
5136
+ new_objects.append(video_object)
5137
+ meta_changed = True
5138
+ suffix_found = True
5139
+ break
5140
+ elif (
5141
+ meta_obj_class.geometry_type.geometry_name()
5142
+ == video_object.obj_class.geometry_type.geometry_name()
5143
+ ):
5144
+ # if object geometry is the same as in meta, use meta obj class
5145
+ video_object = video_object.clone(obj_class=meta_obj_class)
5146
+ new_objects.append(video_object)
5147
+ suffix_found = True
5148
+ any_object_updated = True
5149
+ this_object_changed = True
5150
+ break
5151
+ elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
5152
+ # if meta obj class is AnyGeometry, use it in object
5153
+ video_object = video_object.clone(obj_class=meta_obj_class)
5154
+ new_objects.append(video_object)
5155
+ suffix_found = True
5156
+ any_object_updated = True
5157
+ this_object_changed = True
5158
+ break
5159
+ if not suffix_found:
5160
+ # if no suffix found, raise error
5161
+ raise ValueError(
5162
+ f"Can't add obj class {original_obj_class_name} to project meta. "
5163
+ "Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
5164
+ "Please check if model geometry type is compatible with existing obj classes."
5165
+ )
5166
+ elif this_object_changed:
5167
+ this_object_figures = [
5168
+ figure.clone(video_object=video_object) for figure in this_object_figures
5169
+ ]
5170
+ new_figures.extend(this_object_figures)
5171
+ if any_object_updated:
5172
+ frames_figures = {}
5173
+ for figure in new_figures:
5174
+ frames_figures.setdefault(figure.frame_index, []).append(figure)
5175
+ new_frames = FrameCollection(
5176
+ [
5177
+ Frame(index=frame_index, figures=figures)
5178
+ for frame_index, figures in frames_figures.items()
5179
+ ]
5180
+ )
5181
+ ann = ann.clone(objects=new_objects, frames=new_frames)
5182
+
5183
+ # check if tag metas are in project meta
5184
+ # if not, add them with suffix
5185
+ ann_tag_metas: Dict[str, TagMeta] = {}
5186
+ for video_object in ann.objects:
5187
+ for tag in video_object.tags:
5188
+ tag_name = tag.meta.name
5189
+ if tag_name not in ann_tag_metas:
5190
+ ann_tag_metas[tag_name] = tag.meta
5191
+ for tag in ann.tags:
5192
+ tag_name = tag.meta.name
5193
+ if tag_name not in ann_tag_metas:
5194
+ ann_tag_metas[tag_name] = tag.meta
5195
+
5196
+ changed_tag_metas = {}
5197
+ for ann_tag_meta in ann_tag_metas.values():
5198
+ meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
5199
+ if meta_tag_meta is None:
5200
+ meta = meta.add_tag_meta(ann_tag_meta)
5201
+ meta_changed = True
5202
+ elif not meta_tag_meta.is_compatible(ann_tag_meta):
5203
+ suffix_found = False
5204
+ for suffix in tag_meta_suffixes:
5205
+ new_tag_meta_name = ann_tag_meta.name + suffix
5206
+ meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
5207
+ if meta_tag_meta is None:
5208
+ new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
5209
+ meta = meta.add_tag_meta(new_tag_meta)
5210
+ changed_tag_metas[ann_tag_meta.name] = new_tag_meta
5211
+ meta_changed = True
5212
+ suffix_found = True
5213
+ break
5214
+ if meta_tag_meta.is_compatible(ann_tag_meta):
5215
+ changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
5216
+ suffix_found = True
5217
+ break
5218
+ if not suffix_found:
5219
+ raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
5220
+
5221
+ if changed_tag_metas:
5222
+ objects = []
5223
+ any_object_updated = False
5224
+ for video_object in ann.objects:
5225
+ any_tag_updated = False
5226
+ object_tags = []
5227
+ for tag in video_object.tags:
5228
+ if tag.meta.name in changed_tag_metas:
5229
+ object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5230
+ any_tag_updated = True
5231
+ else:
5232
+ object_tags.append(tag)
5233
+ if any_tag_updated:
5234
+ video_object = video_object.clone(tags=TagCollection(object_tags))
5235
+ any_object_updated = True
5236
+ objects.append(video_object)
5237
+
5238
+ video_tags = []
5239
+ any_tag_updated = False
5240
+ for tag in ann.tags:
5241
+ if tag.meta.name in changed_tag_metas:
5242
+ video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5243
+ any_tag_updated = True
5244
+ else:
5245
+ video_tags.append(tag)
5246
+ if any_tag_updated or any_object_updated:
5247
+ if any_tag_updated:
5248
+ video_tags = VideoTagCollection(video_tags)
5249
+ else:
5250
+ video_tags = None
5251
+ if any_object_updated:
5252
+ objects = VideoObjectCollection(objects)
5253
+ else:
5254
+ objects = None
5255
+ ann = ann.clone(tags=video_tags, objects=objects)
5256
+
4705
5257
  return meta, ann, meta_changed
4706
5258
 
4707
5259
 
@@ -4815,7 +5367,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
4815
5367
  return data[key]
4816
5368
  return None
4817
5369
 
4818
- def torch_load_safe(checkpoint_path: str, device:str = "cpu"):
5370
+
5371
+ def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
4819
5372
  import torch # pylint: disable=import-error
4820
5373
 
4821
5374
  # TODO: handle torch.load(weights_only=True) - change in torch 2.6.0