supervisely 6.73.452__py3-none-any.whl → 6.73.513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. supervisely/__init__.py +25 -1
  2. supervisely/annotation/annotation.py +8 -2
  3. supervisely/annotation/json_geometries_map.py +13 -12
  4. supervisely/api/annotation_api.py +6 -3
  5. supervisely/api/api.py +2 -0
  6. supervisely/api/app_api.py +10 -1
  7. supervisely/api/dataset_api.py +74 -12
  8. supervisely/api/entities_collection_api.py +10 -0
  9. supervisely/api/entity_annotation/figure_api.py +28 -0
  10. supervisely/api/entity_annotation/object_api.py +3 -3
  11. supervisely/api/entity_annotation/tag_api.py +63 -12
  12. supervisely/api/guides_api.py +210 -0
  13. supervisely/api/image_api.py +4 -0
  14. supervisely/api/labeling_job_api.py +83 -1
  15. supervisely/api/labeling_queue_api.py +33 -7
  16. supervisely/api/module_api.py +5 -0
  17. supervisely/api/project_api.py +71 -26
  18. supervisely/api/storage_api.py +3 -1
  19. supervisely/api/task_api.py +13 -2
  20. supervisely/api/team_api.py +4 -3
  21. supervisely/api/video/video_annotation_api.py +119 -3
  22. supervisely/api/video/video_api.py +65 -14
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +23 -7
  25. supervisely/app/development/development.py +18 -2
  26. supervisely/app/fastapi/__init__.py +1 -0
  27. supervisely/app/fastapi/custom_static_files.py +1 -1
  28. supervisely/app/fastapi/multi_user.py +105 -0
  29. supervisely/app/fastapi/subapp.py +88 -42
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +6 -0
  35. supervisely/app/widgets/activity_feed/__init__.py +0 -0
  36. supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
  37. supervisely/app/widgets/activity_feed/style.css +78 -0
  38. supervisely/app/widgets/activity_feed/template.html +22 -0
  39. supervisely/app/widgets/card/card.py +20 -0
  40. supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
  41. supervisely/app/widgets/classes_list_selector/template.html +60 -93
  42. supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
  43. supervisely/app/widgets/classes_table/classes_table.py +1 -0
  44. supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
  45. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
  46. supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
  47. supervisely/app/widgets/fast_table/fast_table.py +184 -60
  48. supervisely/app/widgets/fast_table/template.html +1 -1
  49. supervisely/app/widgets/heatmap/__init__.py +0 -0
  50. supervisely/app/widgets/heatmap/heatmap.py +564 -0
  51. supervisely/app/widgets/heatmap/script.js +533 -0
  52. supervisely/app/widgets/heatmap/style.css +233 -0
  53. supervisely/app/widgets/heatmap/template.html +21 -0
  54. supervisely/app/widgets/modal/__init__.py +0 -0
  55. supervisely/app/widgets/modal/modal.py +198 -0
  56. supervisely/app/widgets/modal/template.html +10 -0
  57. supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
  58. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  59. supervisely/app/widgets/radio_tabs/template.html +1 -0
  60. supervisely/app/widgets/select/select.py +6 -3
  61. supervisely/app/widgets/select_class/__init__.py +0 -0
  62. supervisely/app/widgets/select_class/select_class.py +363 -0
  63. supervisely/app/widgets/select_class/template.html +50 -0
  64. supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
  65. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
  66. supervisely/app/widgets/select_tag/__init__.py +0 -0
  67. supervisely/app/widgets/select_tag/select_tag.py +352 -0
  68. supervisely/app/widgets/select_tag/template.html +64 -0
  69. supervisely/app/widgets/select_team/select_team.py +37 -4
  70. supervisely/app/widgets/select_team/template.html +4 -5
  71. supervisely/app/widgets/select_user/__init__.py +0 -0
  72. supervisely/app/widgets/select_user/select_user.py +270 -0
  73. supervisely/app/widgets/select_user/template.html +13 -0
  74. supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
  75. supervisely/app/widgets/select_workspace/template.html +9 -12
  76. supervisely/app/widgets/table/table.py +68 -13
  77. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  78. supervisely/aug/aug.py +6 -2
  79. supervisely/convert/base_converter.py +1 -0
  80. supervisely/convert/converter.py +2 -2
  81. supervisely/convert/image/image_converter.py +3 -1
  82. supervisely/convert/image/image_helper.py +48 -4
  83. supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
  84. supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
  85. supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
  86. supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
  87. supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
  88. supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
  89. supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
  90. supervisely/convert/pointcloud/las/las_converter.py +13 -1
  91. supervisely/convert/pointcloud/las/las_helper.py +110 -11
  92. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
  93. supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
  94. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
  95. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
  96. supervisely/convert/video/__init__.py +1 -0
  97. supervisely/convert/video/multi_view/__init__.py +0 -0
  98. supervisely/convert/video/multi_view/multi_view.py +543 -0
  99. supervisely/convert/video/sly/sly_video_converter.py +359 -3
  100. supervisely/convert/video/video_converter.py +22 -2
  101. supervisely/convert/volume/dicom/dicom_converter.py +13 -5
  102. supervisely/convert/volume/dicom/dicom_helper.py +30 -18
  103. supervisely/geometry/constants.py +1 -0
  104. supervisely/geometry/geometry.py +4 -0
  105. supervisely/geometry/helpers.py +5 -1
  106. supervisely/geometry/oriented_bbox.py +676 -0
  107. supervisely/geometry/rectangle.py +2 -1
  108. supervisely/io/env.py +76 -1
  109. supervisely/io/fs.py +21 -0
  110. supervisely/nn/benchmark/base_evaluator.py +104 -11
  111. supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
  112. supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
  113. supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
  114. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
  115. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
  116. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
  117. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
  118. supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
  119. supervisely/nn/inference/cache.py +43 -18
  120. supervisely/nn/inference/gui/serving_gui_template.py +5 -2
  121. supervisely/nn/inference/inference.py +795 -199
  122. supervisely/nn/inference/inference_request.py +42 -9
  123. supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
  124. supervisely/nn/inference/predict_app/gui/gui.py +676 -488
  125. supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
  126. supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
  127. supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
  128. supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
  129. supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
  130. supervisely/nn/inference/predict_app/gui/utils.py +236 -119
  131. supervisely/nn/inference/predict_app/predict_app.py +2 -2
  132. supervisely/nn/inference/session.py +43 -35
  133. supervisely/nn/inference/tracking/bbox_tracking.py +113 -34
  134. supervisely/nn/inference/tracking/tracker_interface.py +7 -2
  135. supervisely/nn/inference/uploader.py +139 -12
  136. supervisely/nn/live_training/__init__.py +7 -0
  137. supervisely/nn/live_training/api_server.py +111 -0
  138. supervisely/nn/live_training/artifacts_utils.py +243 -0
  139. supervisely/nn/live_training/checkpoint_utils.py +229 -0
  140. supervisely/nn/live_training/dynamic_sampler.py +44 -0
  141. supervisely/nn/live_training/helpers.py +14 -0
  142. supervisely/nn/live_training/incremental_dataset.py +146 -0
  143. supervisely/nn/live_training/live_training.py +497 -0
  144. supervisely/nn/live_training/loss_plateau_detector.py +111 -0
  145. supervisely/nn/live_training/request_queue.py +52 -0
  146. supervisely/nn/model/model_api.py +9 -0
  147. supervisely/nn/prediction_dto.py +12 -1
  148. supervisely/nn/tracker/base_tracker.py +11 -1
  149. supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
  150. supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
  151. supervisely/nn/tracker/botsort_tracker.py +94 -65
  152. supervisely/nn/tracker/visualize.py +87 -90
  153. supervisely/nn/training/gui/classes_selector.py +16 -1
  154. supervisely/nn/training/train_app.py +28 -29
  155. supervisely/project/data_version.py +115 -51
  156. supervisely/project/download.py +1 -1
  157. supervisely/project/pointcloud_episode_project.py +37 -8
  158. supervisely/project/pointcloud_project.py +30 -2
  159. supervisely/project/project.py +14 -2
  160. supervisely/project/project_meta.py +27 -1
  161. supervisely/project/project_settings.py +32 -18
  162. supervisely/project/versioning/__init__.py +1 -0
  163. supervisely/project/versioning/common.py +20 -0
  164. supervisely/project/versioning/schema_fields.py +35 -0
  165. supervisely/project/versioning/video_schema.py +221 -0
  166. supervisely/project/versioning/volume_schema.py +87 -0
  167. supervisely/project/video_project.py +717 -15
  168. supervisely/project/volume_project.py +623 -5
  169. supervisely/template/experiment/experiment.html.jinja +4 -4
  170. supervisely/template/experiment/experiment_generator.py +14 -21
  171. supervisely/template/live_training/__init__.py +0 -0
  172. supervisely/template/live_training/header.html.jinja +96 -0
  173. supervisely/template/live_training/live_training.html.jinja +51 -0
  174. supervisely/template/live_training/live_training_generator.py +464 -0
  175. supervisely/template/live_training/sly-style.css +402 -0
  176. supervisely/template/live_training/template.html.jinja +18 -0
  177. supervisely/versions.json +28 -26
  178. supervisely/video/sampling.py +39 -20
  179. supervisely/video/video.py +40 -11
  180. supervisely/video_annotation/video_object.py +29 -4
  181. supervisely/volume/stl_converter.py +2 -0
  182. supervisely/worker_api/agent_rpc.py +24 -1
  183. supervisely/worker_api/rpc_servicer.py +31 -7
  184. {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/METADATA +56 -39
  185. {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/RECORD +189 -142
  186. {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
  187. {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
  188. {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
  189. {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
@@ -5,12 +5,14 @@ import asyncio
5
5
  import inspect
6
6
  import json
7
7
  import os
8
+ import queue
8
9
  import re
9
10
  import shutil
10
11
  import subprocess
11
12
  import tempfile
12
13
  import threading
13
14
  import time
15
+ import uuid
14
16
  from collections import OrderedDict, defaultdict
15
17
  from concurrent.futures import ThreadPoolExecutor
16
18
  from dataclasses import asdict, dataclass
@@ -52,6 +54,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
52
54
  from supervisely.api.api import Api, ApiField
53
55
  from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
54
56
  from supervisely.api.image_api import ImageInfo
57
+ from supervisely.api.video.video_api import VideoInfo
55
58
  from supervisely.app.content import get_data_dir
56
59
  from supervisely.app.fastapi.subapp import (
57
60
  Application,
@@ -67,6 +70,7 @@ from supervisely.decorators.inference import (
67
70
  process_images_batch_sliding_window,
68
71
  )
69
72
  from supervisely.geometry.any_geometry import AnyGeometry
73
+ from supervisely.geometry.geometry import Geometry
70
74
  from supervisely.imaging.color import get_predefined_colors
71
75
  from supervisely.io.fs import list_files
72
76
  from supervisely.nn.experiments import ExperimentInfo
@@ -75,7 +79,7 @@ from supervisely.nn.inference.inference_request import (
75
79
  InferenceRequest,
76
80
  InferenceRequestsManager,
77
81
  )
78
- from supervisely.nn.inference.uploader import Uploader
82
+ from supervisely.nn.inference.uploader import Downloader, Uploader
79
83
  from supervisely.nn.model.model_api import ModelAPI, Prediction
80
84
  from supervisely.nn.prediction_dto import Prediction as PredictionDTO
81
85
  from supervisely.nn.utils import (
@@ -94,6 +98,17 @@ from supervisely.project.project_meta import ProjectMeta
94
98
  from supervisely.sly_logger import logger
95
99
  from supervisely.task.progress import Progress
96
100
  from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
101
+ from supervisely.video_annotation.frame import Frame
102
+ from supervisely.video_annotation.frame_collection import FrameCollection
103
+ from supervisely.video_annotation.key_id_map import KeyIdMap
104
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
105
+ from supervisely.video_annotation.video_figure import VideoFigure
106
+ from supervisely.video_annotation.video_object import VideoObject
107
+ from supervisely.video_annotation.video_object_collection import (
108
+ VideoObject,
109
+ VideoObjectCollection,
110
+ )
111
+ from supervisely.video_annotation.video_tag_collection import VideoTagCollection
97
112
 
98
113
  try:
99
114
  from typing import Literal
@@ -140,6 +155,7 @@ class Inference:
140
155
  """Default batch size for inference"""
141
156
  INFERENCE_SETTINGS: str = None
142
157
  """Path to file with custom inference settings"""
158
+ DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
143
159
 
144
160
  def __init__(
145
161
  self,
@@ -193,7 +209,6 @@ class Inference:
193
209
  self._task_id = None
194
210
  self._sliding_window_mode = sliding_window_mode
195
211
  self._autostart_delay_time = 5 * 60 # 5 min
196
- self._tracker = None
197
212
  self._hardware: str = None
198
213
  if custom_inference_settings is None:
199
214
  if self.INFERENCE_SETTINGS is not None:
@@ -427,7 +442,7 @@ class Inference:
427
442
 
428
443
  device = "cuda" if torch.cuda.is_available() else "cpu"
429
444
  except Exception as e:
430
- logger.warn(
445
+ logger.warning(
431
446
  f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
432
447
  )
433
448
  device = "cpu"
@@ -734,15 +749,15 @@ class Inference:
734
749
  for model in self.pretrained_models:
735
750
  model_meta = model.get("meta")
736
751
  if model_meta is not None:
737
- model_name = model_meta.get("model_name")
738
- if model_name is not None:
739
- if model_name.lower() == model_name.lower():
752
+ this_model_name = model_meta.get("model_name")
753
+ if this_model_name is not None:
754
+ if this_model_name.lower() == model_name.lower():
740
755
  selected_model = model
741
756
  break
742
757
  else:
743
- model_name = model.get("model_name")
744
- if model_name is not None:
745
- if model_name.lower() == model_name.lower():
758
+ this_model_name = model.get("model_name")
759
+ if this_model_name is not None:
760
+ if this_model_name.lower() == model_name.lower():
746
761
  selected_model = model
747
762
  break
748
763
 
@@ -1359,6 +1374,7 @@ class Inference:
1359
1374
 
1360
1375
  if tracker == "botsort":
1361
1376
  from supervisely.nn.tracker import BotSortTracker
1377
+
1362
1378
  device = tracker_settings.get("device", self.device)
1363
1379
  logger.debug(f"Initializing BotSort tracker with device: {device}")
1364
1380
  return BotSortTracker(settings=tracker_settings, device=device)
@@ -1375,15 +1391,15 @@ class Inference:
1375
1391
  if classes is not None:
1376
1392
  num_classes = len(classes)
1377
1393
  except NotImplementedError:
1378
- logger.warn(f"get_classes() function not implemented for {type(self)} object.")
1394
+ logger.warning(f"get_classes() function not implemented for {type(self)} object.")
1379
1395
  except AttributeError:
1380
- logger.warn("Probably, get_classes() function not working without model deploy.")
1396
+ logger.warning("Probably, get_classes() function not working without model deploy.")
1381
1397
  except Exception as exc:
1382
- logger.warn("Unknown exception. Please, contact support")
1398
+ logger.warning("Unknown exception. Please, contact support")
1383
1399
  logger.exception(exc)
1384
1400
 
1385
1401
  if num_classes is None:
1386
- logger.warn(f"get_classes() function return {classes}; skip classes processing.")
1402
+ logger.warning(f"get_classes() function return {classes}; skip classes processing.")
1387
1403
 
1388
1404
  return {
1389
1405
  "app_name": get_name_from_env(default="Neural Network Serving"),
@@ -1401,6 +1417,42 @@ class Inference:
1401
1417
 
1402
1418
  # pylint: enable=method-hidden
1403
1419
 
1420
+ def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
1421
+ """
1422
+ Get default parameters for all available tracking algorithms.
1423
+
1424
+ Returns:
1425
+ {"botsort": {"track_high_thresh": 0.6, ...}}
1426
+ Empty dict if tracking not supported.
1427
+ """
1428
+ info = self.get_info()
1429
+ trackers_params = {}
1430
+
1431
+ tracking_support = info.get("tracking_on_videos_support")
1432
+ if not tracking_support:
1433
+ return trackers_params
1434
+
1435
+ tracking_algorithms = info.get("tracking_algorithms", [])
1436
+
1437
+ for tracker_name in tracking_algorithms:
1438
+ try:
1439
+ if tracker_name == "botsort":
1440
+ from supervisely.nn.tracker import BotSortTracker
1441
+
1442
+ trackers_params[tracker_name] = BotSortTracker.get_default_params()
1443
+ # Add other trackers here as elif blocks
1444
+ else:
1445
+ logger.debug(f"Tracker '{tracker_name}' not implemented")
1446
+ except Exception as e:
1447
+ logger.warning(f"Failed to get params for '{tracker_name}': {e}")
1448
+
1449
+ INTERNAL_FIELDS = {"device", "fps"}
1450
+ for tracker_name, params in trackers_params.items():
1451
+ trackers_params[tracker_name] = {
1452
+ k: v for k, v in params.items() if k not in INTERNAL_FIELDS
1453
+ }
1454
+ return trackers_params
1455
+
1404
1456
  def get_human_readable_info(self, replace_none_with: Optional[str] = None):
1405
1457
  hr_info = {}
1406
1458
  info = self.get_info()
@@ -1952,7 +2004,7 @@ class Inference:
1952
2004
  else:
1953
2005
  n_frames = frames_reader.frames_count()
1954
2006
 
1955
- self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2007
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
1956
2008
 
1957
2009
  progress_total = (n_frames + step - 1) // step
1958
2010
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
@@ -1978,8 +2030,8 @@ class Inference:
1978
2030
  settings=inference_settings,
1979
2031
  )
1980
2032
 
1981
- if self._tracker is not None:
1982
- anns = self._apply_tracker_to_anns(frames, anns)
2033
+ if inference_request.tracker is not None:
2034
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
1983
2035
 
1984
2036
  predictions = [
1985
2037
  Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
@@ -1994,10 +2046,9 @@ class Inference:
1994
2046
  inference_request.done(len(batch_results))
1995
2047
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
1996
2048
  video_ann_json = None
1997
- if self._tracker is not None:
2049
+ if inference_request.tracker is not None:
1998
2050
  inference_request.set_stage("Postprocess...", 0, 1)
1999
-
2000
- video_ann_json = self._tracker.video_annotation.to_json()
2051
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2001
2052
  inference_request.done()
2002
2053
  result = {"ann": results, "video_ann": video_ann_json}
2003
2054
  inference_request.final_result = result.copy()
@@ -2029,7 +2080,7 @@ class Inference:
2029
2080
  upload_mode = state.get("upload_mode", None)
2030
2081
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
2031
2082
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
2032
- iou_merge_threshold = 0.7
2083
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
2033
2084
 
2034
2085
  images_infos = api.image.get_info_by_id_batch(image_ids)
2035
2086
  images_infos_dict = {im_info.id: im_info for im_info in images_infos}
@@ -2071,14 +2122,9 @@ class Inference:
2071
2122
  output_dataset_id
2072
2123
  ] = output_dataset_info
2073
2124
 
2074
- # start download to cache in background
2075
- dataset_image_infos: Dict[int, List[ImageInfo]] = defaultdict(list)
2076
- for image_info in images_infos:
2077
- dataset_image_infos[image_info.dataset_id].append(image_info)
2078
- for dataset_id, ds_image_infos in dataset_image_infos.items():
2079
- self.cache.run_cache_task_manually(
2080
- api, [info.id for info in ds_image_infos], dataset_id=dataset_id
2081
- )
2125
+ def download_f(item: int):
2126
+ self.cache.download_image(api, item)
2127
+ return item
2082
2128
 
2083
2129
  _upload_predictions = partial(
2084
2130
  self.upload_predictions,
@@ -2094,7 +2140,9 @@ class Inference:
2094
2140
  )
2095
2141
 
2096
2142
  _add_results_to_request = partial(
2097
- self.add_results_to_request, inference_request=inference_request
2143
+ self.add_results_to_request,
2144
+ inference_request=inference_request,
2145
+ progress_cb=inference_request.done,
2098
2146
  )
2099
2147
 
2100
2148
  if upload_mode is None:
@@ -2103,40 +2151,60 @@ class Inference:
2103
2151
  upload_f = _upload_predictions
2104
2152
 
2105
2153
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(image_ids))
2154
+ download_workers = max(8, min(batch_size, 64))
2106
2155
  with Uploader(upload_f, logger=logger) as uploader:
2107
- for image_ids_batch in batched(image_ids, batch_size=batch_size):
2108
- if uploader.has_exception():
2109
- exception = uploader.exception
2110
- raise exception
2111
- if inference_request.is_stopped():
2112
- logger.debug(
2113
- f"Cancelling inference project...",
2114
- extra={"inference_request_uuid": inference_request.uuid},
2115
- )
2116
- break
2117
-
2118
- images_nps = [self.cache.download_image(api, img_id) for img_id in image_ids_batch]
2119
- anns, slides_data = self._inference_auto(
2120
- source=images_nps,
2121
- settings=inference_settings,
2122
- )
2156
+ with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
2157
+ for image_id in image_ids:
2158
+ downloader.put(image_id)
2159
+ downloader.next(100)
2160
+ for image_ids_batch in batched(image_ids, batch_size=batch_size):
2161
+ if uploader.has_exception():
2162
+ exception = uploader.exception
2163
+ raise exception
2164
+ if inference_request.is_stopped():
2165
+ logger.debug(
2166
+ f"Cancelling inference...",
2167
+ extra={"inference_request_uuid": inference_request.uuid},
2168
+ )
2169
+ break
2170
+ if inference_request.is_paused():
2171
+ logger.info("Inference request is paused. Waiting...")
2172
+ while inference_request.is_paused():
2173
+ if (
2174
+ inference_request.paused_for()
2175
+ > inference_request.PAUSE_SLEEP_MAX_WAIT
2176
+ ):
2177
+ logger.info(
2178
+ "Inference request has been paused for too long. Cancelling..."
2179
+ )
2180
+ raise RuntimeError("Inference request cancelled due to long pause.")
2181
+ time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
2123
2182
 
2124
- batch_predictions = []
2125
- for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
2126
- image_info: ImageInfo = images_infos_dict[image_id]
2127
- dataset_info = dataset_infos_dict[image_info.dataset_id]
2128
- prediction = Prediction(
2129
- ann,
2130
- model_meta=self.model_meta,
2131
- name=image_info.name,
2132
- image_id=image_info.id,
2133
- dataset_id=image_info.dataset_id,
2134
- project_id=dataset_info.project_id,
2183
+ images_nps = [
2184
+ self.cache.download_image(api, img_id) for img_id in image_ids_batch
2185
+ ]
2186
+ downloader.next(len(image_ids_batch))
2187
+ anns, slides_data = self._inference_auto(
2188
+ source=images_nps,
2189
+ settings=inference_settings,
2135
2190
  )
2136
- prediction.extra_data["slides_data"] = this_slides_data
2137
- batch_predictions.append(prediction)
2138
2191
 
2139
- uploader.put(batch_predictions)
2192
+ batch_predictions = []
2193
+ for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
2194
+ image_info: ImageInfo = images_infos_dict[image_id]
2195
+ dataset_info = dataset_infos_dict[image_info.dataset_id]
2196
+ prediction = Prediction(
2197
+ ann,
2198
+ model_meta=self.model_meta,
2199
+ name=image_info.name,
2200
+ image_id=image_info.id,
2201
+ dataset_id=image_info.dataset_id,
2202
+ project_id=dataset_info.project_id,
2203
+ )
2204
+ prediction.extra_data["slides_data"] = this_slides_data
2205
+ batch_predictions.append(prediction)
2206
+
2207
+ uploader.put(batch_predictions)
2140
2208
 
2141
2209
  def _inference_video_id(
2142
2210
  self,
@@ -2181,7 +2249,7 @@ class Inference:
2181
2249
  else:
2182
2250
  n_frames = video_info.frames_count
2183
2251
 
2184
- self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2252
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2185
2253
 
2186
2254
  logger.debug(
2187
2255
  f"Video info:",
@@ -2218,8 +2286,8 @@ class Inference:
2218
2286
  settings=inference_settings,
2219
2287
  )
2220
2288
 
2221
- if self._tracker is not None:
2222
- anns = self._apply_tracker_to_anns(frames, anns)
2289
+ if inference_request.tracker is not None:
2290
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2223
2291
 
2224
2292
  predictions = [
2225
2293
  Prediction(
@@ -2228,8 +2296,8 @@ class Inference:
2228
2296
  frame_index=frame_index,
2229
2297
  video_id=video_info.id,
2230
2298
  dataset_id=video_info.dataset_id,
2231
- project_id=video_info.project_id,
2232
- )
2299
+ project_id=video_info.project_id,
2300
+ )
2233
2301
  for ann, frame_index in zip(anns, batch)
2234
2302
  ]
2235
2303
  for pred, this_slides_data in zip(predictions, slides_data):
@@ -2240,9 +2308,169 @@ class Inference:
2240
2308
  inference_request.done(len(batch_results))
2241
2309
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
2242
2310
  video_ann_json = None
2243
- if self._tracker is not None:
2311
+ if inference_request.tracker is not None:
2312
+ inference_request.set_stage("Postprocess...", 0, progress_total)
2313
+
2314
+ video_ann_json = inference_request.tracker.create_video_annotation(
2315
+ video_info.frames_count,
2316
+ start_frame_index,
2317
+ step=step,
2318
+ progress_cb=inference_request.done,
2319
+ ).to_json()
2320
+ inference_request.final_result = {"video_ann": video_ann_json}
2321
+ return video_ann_json
2322
+
2323
+ def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
2324
+ logger.debug("Inferring video_id...", extra={"state": state})
2325
+ inference_settings = self._get_inference_settings(state)
2326
+ logger.debug(f"Inference settings:", extra=inference_settings)
2327
+ batch_size = self._get_batch_size_from_state(state)
2328
+ video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
2329
+ if video_id is None:
2330
+ raise ValueError("Video id is not provided")
2331
+ video_info = api.video.get_info_by_id(video_id)
2332
+ start_frame_index = get_value_for_keys(
2333
+ state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
2334
+ )
2335
+ if start_frame_index is None:
2336
+ start_frame_index = 0
2337
+ step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
2338
+ if step is None:
2339
+ step = 1
2340
+ end_frame_index = get_value_for_keys(
2341
+ state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
2342
+ )
2343
+ duration = state.get("duration", None)
2344
+ frames_count = get_value_for_keys(
2345
+ state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
2346
+ )
2347
+ tracking = state.get("tracker", None)
2348
+ direction = state.get("direction", "forward")
2349
+ direction = 1 if direction == "forward" else -1
2350
+ track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
2351
+
2352
+ if frames_count is not None:
2353
+ n_frames = frames_count
2354
+ elif end_frame_index is not None:
2355
+ n_frames = end_frame_index - start_frame_index
2356
+ elif duration is not None:
2357
+ fps = video_info.frames_count / video_info.duration
2358
+ n_frames = int(duration * fps)
2359
+ else:
2360
+ n_frames = video_info.frames_count
2361
+
2362
+ inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2363
+
2364
+ logger.debug(
2365
+ f"Video info:",
2366
+ extra=dict(
2367
+ w=video_info.frame_width,
2368
+ h=video_info.frame_height,
2369
+ start_frame_index=start_frame_index,
2370
+ n_frames=n_frames,
2371
+ ),
2372
+ )
2373
+
2374
+ # start downloading video in background
2375
+ self.cache.run_cache_task_manually(api, None, video_id=video_id)
2376
+
2377
+ progress_total = (n_frames + step - 1) // step
2378
+ inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
2379
+
2380
+ _upload_f = partial(
2381
+ self.upload_predictions_to_video,
2382
+ api=api,
2383
+ video_info=video_info,
2384
+ track_id=track_id,
2385
+ context=inference_request.context,
2386
+ progress_cb=inference_request.done,
2387
+ inference_request=inference_request,
2388
+ )
2389
+
2390
+ _range = (start_frame_index, start_frame_index + direction * n_frames)
2391
+ if _range[0] > _range[1]:
2392
+ _range = (_range[1], _range[0])
2393
+
2394
+ def _notify_f(predictions: List[Prediction]):
2395
+ logger.debug(
2396
+ "Notifying tracking progress...",
2397
+ extra={
2398
+ "track_id": track_id,
2399
+ "range": _range,
2400
+ "current": inference_request.progress.current,
2401
+ "total": inference_request.progress.total,
2402
+ },
2403
+ )
2404
+ stopped = self.api.video.notify_progress(
2405
+ track_id=track_id,
2406
+ video_id=video_info.id,
2407
+ frame_start=_range[0],
2408
+ frame_end=_range[1],
2409
+ current=inference_request.progress.current,
2410
+ total=inference_request.progress.total,
2411
+ )
2412
+ if stopped:
2413
+ inference_request.stop()
2414
+ logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
2415
+
2416
+ def _exception_handler(e: Exception):
2417
+ self.api.video.notify_tracking_error(
2418
+ track_id=track_id,
2419
+ error=str(type(e)),
2420
+ message=str(e),
2421
+ )
2422
+ raise e
2423
+
2424
+ with Uploader(
2425
+ upload_f=_upload_f,
2426
+ notify_f=_notify_f,
2427
+ exception_handler=_exception_handler,
2428
+ logger=logger,
2429
+ ) as uploader:
2430
+ for batch in batched(
2431
+ range(
2432
+ start_frame_index, start_frame_index + direction * n_frames, direction * step
2433
+ ),
2434
+ batch_size,
2435
+ ):
2436
+ if inference_request.is_stopped():
2437
+ logger.debug(
2438
+ f"Cancelling inference video...",
2439
+ extra={"inference_request_uuid": inference_request.uuid},
2440
+ )
2441
+ break
2442
+ logger.debug(
2443
+ f"Inferring frames {batch[0]}-{batch[-1]}:",
2444
+ )
2445
+ frames = self.cache.download_frames(
2446
+ api, video_info.id, batch, redownload_video=True
2447
+ )
2448
+ anns, slides_data = self._inference_auto(
2449
+ source=frames,
2450
+ settings=inference_settings,
2451
+ )
2452
+
2453
+ if inference_request.tracker is not None:
2454
+ anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
2455
+
2456
+ predictions = [
2457
+ Prediction(
2458
+ ann,
2459
+ model_meta=self.model_meta,
2460
+ frame_index=frame_index,
2461
+ video_id=video_info.id,
2462
+ dataset_id=video_info.dataset_id,
2463
+ project_id=video_info.project_id,
2464
+ )
2465
+ for ann, frame_index in zip(anns, batch)
2466
+ ]
2467
+ for pred, this_slides_data in zip(predictions, slides_data):
2468
+ pred.extra_data["slides_data"] = this_slides_data
2469
+ uploader.put(predictions)
2470
+ video_ann_json = None
2471
+ if inference_request.tracker is not None:
2244
2472
  inference_request.set_stage("Postprocess...", 0, 1)
2245
- video_ann_json = self._tracker.video_annotation.to_json()
2473
+ video_ann_json = inference_request.tracker.video_annotation.to_json()
2246
2474
  inference_request.done()
2247
2475
  inference_request.final_result = {"video_ann": video_ann_json}
2248
2476
  return video_ann_json
@@ -2268,10 +2496,9 @@ class Inference:
2268
2496
  upload_mode = state.get("upload_mode", None)
2269
2497
  iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
2270
2498
  if upload_mode == "iou_merge" and iou_merge_threshold is None:
2271
- iou_merge_threshold = 0.7
2499
+ iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
2272
2500
  cache_project_on_model = state.get("cache_project_on_model", False)
2273
2501
 
2274
- project_info = api.project.get_info_by_id(project_id)
2275
2502
  inference_request.context.setdefault("project_info", {})[project_id] = project_info
2276
2503
  dataset_ids = state.get("dataset_ids", None)
2277
2504
  if dataset_ids is None:
@@ -2306,7 +2533,11 @@ class Inference:
2306
2533
 
2307
2534
  if cache_project_on_model:
2308
2535
  download_to_cache(
2309
- api, project_info.id, datasets_infos, progress_cb=inference_request.done
2536
+ api,
2537
+ project_info.id,
2538
+ datasets_infos,
2539
+ progress_cb=inference_request.done,
2540
+ skip_create_readme=True,
2310
2541
  )
2311
2542
 
2312
2543
  images_infos_dict = {}
@@ -2315,20 +2546,9 @@ class Inference:
2315
2546
  if not cache_project_on_model:
2316
2547
  inference_request.done(dataset_info.items_count)
2317
2548
 
2318
- def _download_images(datasets_infos: List[DatasetInfo]):
2319
- for dataset_info in datasets_infos:
2320
- image_ids = [image_info.id for image_info in images_infos_dict[dataset_info.id]]
2321
- with ThreadPoolExecutor(max(8, min(batch_size, 64))) as executor:
2322
- for image_id in image_ids:
2323
- executor.submit(
2324
- self.cache.download_image,
2325
- api,
2326
- image_id,
2327
- )
2328
-
2329
- if not cache_project_on_model:
2330
- # start downloading in parallel
2331
- threading.Thread(target=_download_images, args=[datasets_infos], daemon=True).start()
2549
+ def download_f(item: int):
2550
+ self.cache.download_image(api, item)
2551
+ return item
2332
2552
 
2333
2553
  _upload_predictions = partial(
2334
2554
  self.upload_predictions,
@@ -2343,7 +2563,9 @@ class Inference:
2343
2563
  )
2344
2564
 
2345
2565
  _add_results_to_request = partial(
2346
- self.add_results_to_request, inference_request=inference_request
2566
+ self.add_results_to_request,
2567
+ inference_request=inference_request,
2568
+ progress_cb=inference_request.done,
2347
2569
  )
2348
2570
 
2349
2571
  if upload_mode is None:
@@ -2351,57 +2573,78 @@ class Inference:
2351
2573
  else:
2352
2574
  upload_f = _upload_predictions
2353
2575
 
2576
+ download_workers = max(8, min(batch_size, 64))
2354
2577
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, inference_progress_total)
2355
2578
  with Uploader(upload_f, logger=logger) as uploader:
2356
- for dataset_info in datasets_infos:
2357
- for images_infos_batch in batched(
2358
- images_infos_dict[dataset_info.id], batch_size=batch_size
2359
- ):
2360
- if inference_request.is_stopped():
2361
- logger.debug(
2362
- f"Cancelling inference project...",
2363
- extra={"inference_request_uuid": inference_request.uuid},
2364
- )
2365
- return
2366
- if uploader.has_exception():
2367
- exception = uploader.exception
2368
- raise exception
2369
- if cache_project_on_model:
2370
- images_paths, _ = zip(
2371
- *read_from_cached_project(
2372
- project_info.id,
2373
- dataset_info.name,
2374
- [ii.name for ii in images_infos_batch],
2579
+ with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
2580
+ for images in images_infos_dict.values():
2581
+ for image in images:
2582
+ downloader.put(image.id)
2583
+ downloader.next(100)
2584
+ for dataset_info in datasets_infos:
2585
+ for images_infos_batch in batched(
2586
+ images_infos_dict[dataset_info.id], batch_size=batch_size
2587
+ ):
2588
+ if uploader.has_exception():
2589
+ exception = uploader.exception
2590
+ raise exception
2591
+ if inference_request.is_stopped():
2592
+ logger.debug(
2593
+ f"Cancelling inference project...",
2594
+ extra={"inference_request_uuid": inference_request.uuid},
2375
2595
  )
2596
+ return
2597
+ if inference_request.is_paused():
2598
+ logger.info("Inference request is paused. Waiting...")
2599
+ while inference_request.is_paused():
2600
+ if (
2601
+ inference_request.paused_for()
2602
+ > inference_request.PAUSE_SLEEP_MAX_WAIT
2603
+ ):
2604
+ logger.info(
2605
+ "Inference request has been paused for too long. Cancelling..."
2606
+ )
2607
+ raise RuntimeError(
2608
+ "Inference request cancelled due to long pause."
2609
+ )
2610
+ time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
2611
+ if cache_project_on_model:
2612
+ images_paths, _ = zip(
2613
+ *read_from_cached_project(
2614
+ project_info.id,
2615
+ dataset_info.name,
2616
+ [ii.name for ii in images_infos_batch],
2617
+ )
2618
+ )
2619
+ images_nps = [sly_image.read(img_path) for img_path in images_paths]
2620
+ else:
2621
+ images_nps = self.cache.download_images(
2622
+ api,
2623
+ dataset_info.id,
2624
+ [info.id for info in images_infos_batch],
2625
+ return_images=True,
2626
+ )
2627
+ downloader.next(len(images_infos_batch))
2628
+ anns, slides_data = self._inference_auto(
2629
+ source=images_nps,
2630
+ settings=inference_settings,
2376
2631
  )
2377
- images_nps = [sly_image.read(img_path) for img_path in images_paths]
2378
- else:
2379
- images_nps = self.cache.download_images(
2380
- api,
2381
- dataset_info.id,
2382
- [info.id for info in images_infos_batch],
2383
- return_images=True,
2384
- )
2385
- anns, slides_data = self._inference_auto(
2386
- source=images_nps,
2387
- settings=inference_settings,
2388
- )
2389
- predictions = [
2390
- Prediction(
2391
- ann,
2392
- model_meta=self.model_meta,
2393
- image_id=image_info.id,
2394
- name=image_info.name,
2395
- dataset_id=dataset_info.id,
2396
- project_id=dataset_info.project_id,
2397
- image_name=image_info.name,
2398
- )
2399
- for ann, image_info in zip(anns, images_infos_batch)
2400
- ]
2401
- for pred, this_slides_data in zip(predictions, slides_data):
2402
- pred.extra_data["slides_data"] = this_slides_data
2632
+ predictions = [
2633
+ Prediction(
2634
+ ann,
2635
+ model_meta=self.model_meta,
2636
+ image_id=image_info.id,
2637
+ name=image_info.name,
2638
+ dataset_id=dataset_info.id,
2639
+ project_id=dataset_info.project_id,
2640
+ image_name=image_info.name,
2641
+ )
2642
+ for ann, image_info in zip(anns, images_infos_batch)
2643
+ ]
2644
+ for pred, this_slides_data in zip(predictions, slides_data):
2645
+ pred.extra_data["slides_data"] = this_slides_data
2403
2646
 
2404
- uploader.put(predictions)
2647
+ uploader.put(predictions)
2405
2648
 
2406
2649
  def _run_speedtest(
2407
2650
  self,
@@ -2444,7 +2687,13 @@ class Inference:
2444
2687
  inference_request.done()
2445
2688
 
2446
2689
  if cache_project_on_model:
2447
- download_to_cache(api, project_id, datasets_infos, progress_cb=inference_request.done)
2690
+ download_to_cache(
2691
+ api,
2692
+ project_id,
2693
+ datasets_infos,
2694
+ progress_cb=inference_request.done,
2695
+ skip_create_readme=True,
2696
+ )
2448
2697
 
2449
2698
  inference_request.set_stage("warmup", 0, num_warmup)
2450
2699
 
@@ -2565,6 +2814,11 @@ class Inference:
2565
2814
  def _freeze_model(self):
2566
2815
  if self._model_frozen or not self._model_served:
2567
2816
  return
2817
+
2818
+ if not self._deploy_params:
2819
+ logger.warning("Deploy params are not set, cannot freeze the model.")
2820
+ return
2821
+
2568
2822
  logger.debug("Freezing model...")
2569
2823
  runtime = self._deploy_params.get("runtime")
2570
2824
  if runtime and runtime.lower() != RuntimeType.PYTORCH.lower():
@@ -2907,11 +3161,89 @@ class Inference:
2907
3161
  inference_request.add_results(results)
2908
3162
 
2909
3163
  def add_results_to_request(
2910
- self, predictions: List[Prediction], inference_request: InferenceRequest
3164
+ self, predictions: List[Prediction], inference_request: InferenceRequest, progress_cb=None
2911
3165
  ):
2912
3166
  results = self._format_output(predictions)
2913
3167
  inference_request.add_results(results)
2914
- inference_request.done(len(results))
3168
+ if progress_cb:
3169
+ progress_cb(len(results))
3170
+
3171
+ def upload_predictions_to_video(
3172
+ self,
3173
+ predictions: List[Prediction],
3174
+ api: Api,
3175
+ video_info: VideoInfo,
3176
+ track_id: str,
3177
+ context: Dict,
3178
+ progress_cb=None,
3179
+ inference_request: InferenceRequest = None,
3180
+ ):
3181
+ key_id_map = KeyIdMap()
3182
+ project_meta = context.get("project_meta", None)
3183
+ if project_meta is None:
3184
+ project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
3185
+ context["project_meta"] = project_meta
3186
+ meta_changed = False
3187
+ for prediction in predictions:
3188
+ project_meta, ann, meta_changed_ = update_meta_and_ann(
3189
+ project_meta, prediction.annotation, None
3190
+ )
3191
+ prediction.annotation = ann
3192
+ meta_changed = meta_changed or meta_changed_
3193
+ if meta_changed:
3194
+ project_meta = api.project.update_meta(video_info.project_id, project_meta)
3195
+ context["project_meta"] = project_meta
3196
+
3197
+ figure_data_by_object_id = defaultdict(list)
3198
+
3199
+ tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
3200
+ new_tracks: Dict[int, VideoObject] = {}
3201
+ for prediction in predictions:
3202
+ annotation = prediction.annotation
3203
+ tracks = annotation.custom_data
3204
+ for track, label in zip(tracks, annotation.labels):
3205
+ if track not in tracks_to_object_ids and track not in new_tracks:
3206
+ video_object = VideoObject(obj_class=label.obj_class)
3207
+ new_tracks[track] = video_object
3208
+ if new_tracks:
3209
+ tracks, video_objects = zip(*new_tracks.items())
3210
+ added_object_ids = api.video.object.append_bulk(
3211
+ video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
3212
+ )
3213
+ for track, object_id in zip(tracks, added_object_ids):
3214
+ tracks_to_object_ids[track] = object_id
3215
+ for prediction in predictions:
3216
+ annotation = prediction.annotation
3217
+ tracks = annotation.custom_data
3218
+ for track, label in zip(tracks, annotation.labels):
3219
+ object_id = tracks_to_object_ids[track]
3220
+ figure_data_by_object_id[object_id].append(
3221
+ {
3222
+ ApiField.OBJECT_ID: object_id,
3223
+ ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
3224
+ ApiField.GEOMETRY: label.geometry.to_json(),
3225
+ ApiField.META: {ApiField.FRAME: prediction.frame_index},
3226
+ ApiField.TRACK_ID: track_id,
3227
+ }
3228
+ )
3229
+
3230
+ for object_id, figures_data in figure_data_by_object_id.items():
3231
+ figures_keys = [uuid.uuid4() for _ in figures_data]
3232
+ api.video.figure._append_bulk(
3233
+ entity_id=video_info.id,
3234
+ figures_json=figures_data,
3235
+ figures_keys=figures_keys,
3236
+ key_id_map=key_id_map,
3237
+ )
3238
+ logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
3239
+ if progress_cb:
3240
+ progress_cb(len(predictions))
3241
+ if inference_request is not None:
3242
+ results = self._format_output(predictions)
3243
+ for result in results:
3244
+ result["annotation"] = None
3245
+ result["data"] = None
3246
+ inference_request.add_results(results)
2915
3247
 
2916
3248
  def serve(self):
2917
3249
  if not self._use_gui and not self._is_cli_deploy:
@@ -2995,7 +3327,7 @@ class Inference:
2995
3327
 
2996
3328
  if not self._use_gui:
2997
3329
  Progress("Model deployed", 1).iter_done_report()
2998
- else:
3330
+ elif self.api is not None:
2999
3331
  autostart_func()
3000
3332
 
3001
3333
  @server.exception_handler(HTTPException)
@@ -3022,6 +3354,11 @@ class Inference:
3022
3354
  def get_session_info(response: Response):
3023
3355
  return self.get_info()
3024
3356
 
3357
+ @server.post("/get_tracking_settings")
3358
+ @self._check_serve_before_call
3359
+ def get_tracking_settings(response: Response):
3360
+ return self.get_tracking_settings()
3361
+
3025
3362
  @server.post("/get_custom_inference_settings")
3026
3363
  def get_custom_inference_settings():
3027
3364
  return {"settings": self.custom_inference_settings}
@@ -3305,6 +3642,22 @@ class Inference:
3305
3642
  "inference_request_uuid": inference_request.uuid,
3306
3643
  }
3307
3644
 
3645
+ @server.post("/tracking_by_detection")
3646
+ def tracking_by_detection(response: Response, request: Request):
3647
+ state = request.state.state
3648
+ context = request.state.context
3649
+ state.update(context)
3650
+ if state.get("tracker") is None:
3651
+ state["tracker"] = "botsort"
3652
+
3653
+ logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
3654
+ self.validate_inference_state(state)
3655
+ api = self.api_from_request(request)
3656
+ inference_request, future = self.inference_requests_manager.schedule_task(
3657
+ self._tracking_by_detection, api, state
3658
+ )
3659
+ return {"message": "Track task started."}
3660
+
3308
3661
  @server.post("/inference_project_id_async")
3309
3662
  def inference_project_id_async(response: Response, request: Request):
3310
3663
  state = request.state.state
@@ -3368,10 +3721,7 @@ class Inference:
3368
3721
  data = {**inference_request.to_json(), **log_extra}
3369
3722
  if inference_request.stage != InferenceRequest.Stage.INFERENCE:
3370
3723
  data["progress"] = {"current": 0, "total": 1}
3371
- logger.debug(
3372
- f"Sending inference progress with uuid:",
3373
- extra=data,
3374
- )
3724
+ logger.debug(f"Sending inference progress with uuid:", extra=data)
3375
3725
  return data
3376
3726
 
3377
3727
  @server.post(f"/pop_inference_results")
@@ -4228,10 +4578,10 @@ class Inference:
4228
4578
  self._args.draw,
4229
4579
  )
4230
4580
 
4231
- def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
4581
+ def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
4232
4582
  updated_anns = []
4233
4583
  for frame, ann in zip(frames, anns):
4234
- matches = self._tracker.update(frame, ann)
4584
+ matches = tracker.update(frame, ann)
4235
4585
  track_ids = [match["track_id"] for match in matches]
4236
4586
  tracked_labels = [match["label"] for match in matches]
4237
4587
 
@@ -4297,61 +4647,72 @@ class Inference:
4297
4647
  def export_tensorrt(self, deploy_params: dict):
4298
4648
  raise NotImplementedError("Have to be implemented in child class after inheritance")
4299
4649
 
4300
- def _exclude_duplicated_predictions(
4301
- api: Api,
4302
- pred_anns: List[Annotation],
4303
- dataset_id: int,
4304
- gt_image_ids: List[int],
4305
- iou: float = None,
4306
- meta: Optional[ProjectMeta] = None,
4650
+
4651
+ def _filter_duplicated_predictions_from_ann_cpu(
4652
+ gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
4307
4653
  ):
4308
4654
  """
4309
- Filter out predictions that significantly overlap with ground truth (GT) objects.
4655
+ Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
4656
+ Uses Shapely for geometric operations.
4310
4657
 
4311
- This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4312
- - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4313
- - Gets ProjectMeta object if not provided
4314
- - Downloads GT annotations for the specified image IDs
4315
- - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4658
+ Args:
4659
+ pred_ann: Predicted annotation object
4660
+ gt_ann: Ground truth annotation object
4661
+ iou_threshold: IoU threshold for filtering
4316
4662
 
4317
- :param api: Supervisely API object
4318
- :type api: Api
4319
- :param pred_anns: List of Annotation objects containing predictions
4320
- :type pred_anns: List[Annotation]
4321
- :param dataset_id: ID of the dataset containing the images
4322
- :type dataset_id: int
4323
- :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4324
- :type gt_image_ids: List[int]
4325
- :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4326
- ground truth box of the same class will be removed. None if no filtering is needed
4327
- :type iou: Optional[float]
4328
- :param meta: ProjectMeta object
4329
- :type meta: Optional[ProjectMeta]
4330
- :return: List of Annotation objects containing filtered predictions
4331
- :rtype: List[Annotation]
4332
-
4333
- Notes:
4334
- ------
4335
- - Requires PyTorch and torchvision for IoU calculations
4336
- - This method is useful for identifying new objects that aren't already annotated in the ground truth
4663
+ Returns:
4664
+ New annotation with filtered labels
4337
4665
  """
4338
- if isinstance(iou, float) and 0 < iou <= 1:
4339
- if meta is None:
4340
- ds = api.dataset.get_info_by_id(dataset_id)
4341
- meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4342
- gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4343
- gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4344
- for i in range(0, len(pred_anns)):
4345
- before = len(pred_anns[i].labels)
4346
- with Timer() as timer:
4347
- pred_anns[i] = _filter_duplicated_predictions_from_ann(
4348
- gt_anns[i], pred_anns[i], iou
4349
- )
4350
- after = len(pred_anns[i].labels)
4351
- logger.debug(
4352
- f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4353
- )
4354
- return pred_anns
4666
+ if not iou_threshold:
4667
+ return pred_ann
4668
+
4669
+ from shapely.geometry import box
4670
+
4671
+ def calculate_iou(geom1: Geometry, geom2: Geometry):
4672
+ """Calculate IoU between two geometries using Shapely."""
4673
+ bbox1 = geom1.to_bbox()
4674
+ bbox2 = geom2.to_bbox()
4675
+
4676
+ box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
4677
+ box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
4678
+
4679
+ intersection = box1.intersection(box2).area
4680
+ union = box1.union(box2).area
4681
+
4682
+ return intersection / union if union > 0 else 0.0
4683
+
4684
+ new_labels = []
4685
+ pred_cls_bboxes = defaultdict(list)
4686
+ for label in pred_ann.labels:
4687
+ name_shape = (label.obj_class.name, label.geometry.name())
4688
+ pred_cls_bboxes[name_shape].append(label)
4689
+
4690
+ gt_cls_bboxes = defaultdict(list)
4691
+ for label in gt_ann.labels:
4692
+ name_shape = (label.obj_class.name, label.geometry.name())
4693
+ if name_shape not in pred_cls_bboxes:
4694
+ continue
4695
+ gt_cls_bboxes[name_shape].append(label)
4696
+
4697
+ for name_shape, pred in pred_cls_bboxes.items():
4698
+ gt = gt_cls_bboxes[name_shape]
4699
+ if len(gt) == 0:
4700
+ new_labels.extend(pred)
4701
+ continue
4702
+
4703
+ for pred_label in pred:
4704
+ # Check if this prediction has IoU < threshold with ALL GT boxes
4705
+ keep = True
4706
+ for gt_label in gt:
4707
+ iou = calculate_iou(pred_label.geometry, gt_label.geometry)
4708
+ if iou >= iou_threshold:
4709
+ keep = False
4710
+ break
4711
+
4712
+ if keep:
4713
+ new_labels.append(pred_label)
4714
+
4715
+ return pred_ann.clone(labels=new_labels)
4355
4716
 
4356
4717
 
4357
4718
  def _filter_duplicated_predictions_from_ann(
@@ -4382,13 +4743,15 @@ def _filter_duplicated_predictions_from_ann(
4382
4743
  - Predictions with classes not present in ground truth will be kept
4383
4744
  - Requires PyTorch and torchvision for IoU calculations
4384
4745
  """
4746
+ if not iou_threshold:
4747
+ return pred_ann
4385
4748
 
4386
4749
  try:
4387
4750
  import torch
4388
4751
  from torchvision.ops import box_iou
4389
4752
 
4390
4753
  except ImportError:
4391
- raise ImportError("Please install PyTorch and torchvision to use this feature.")
4754
+ return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
4392
4755
 
4393
4756
  def _to_tensor(geom):
4394
4757
  return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
@@ -4396,16 +4759,18 @@ def _filter_duplicated_predictions_from_ann(
4396
4759
  new_labels = []
4397
4760
  pred_cls_bboxes = defaultdict(list)
4398
4761
  for label in pred_ann.labels:
4399
- pred_cls_bboxes[label.obj_class.name].append(label)
4762
+ name_shape = (label.obj_class.name, label.geometry.name())
4763
+ pred_cls_bboxes[name_shape].append(label)
4400
4764
 
4401
4765
  gt_cls_bboxes = defaultdict(list)
4402
4766
  for label in gt_ann.labels:
4403
- if label.obj_class.name not in pred_cls_bboxes:
4767
+ name_shape = (label.obj_class.name, label.geometry.name())
4768
+ if name_shape not in pred_cls_bboxes:
4404
4769
  continue
4405
- gt_cls_bboxes[label.obj_class.name].append(label)
4770
+ gt_cls_bboxes[name_shape].append(label)
4406
4771
 
4407
- for name, pred in pred_cls_bboxes.items():
4408
- gt = gt_cls_bboxes[name]
4772
+ for name_shape, pred in pred_cls_bboxes.items():
4773
+ gt = gt_cls_bboxes[name_shape]
4409
4774
  if len(gt) == 0:
4410
4775
  new_labels.extend(pred)
4411
4776
  continue
@@ -4419,6 +4784,63 @@ def _filter_duplicated_predictions_from_ann(
4419
4784
  return pred_ann.clone(labels=new_labels)
4420
4785
 
4421
4786
 
4787
+ def _exclude_duplicated_predictions(
4788
+ api: Api,
4789
+ pred_anns: List[Annotation],
4790
+ dataset_id: int,
4791
+ gt_image_ids: List[int],
4792
+ iou: float = None,
4793
+ meta: Optional[ProjectMeta] = None,
4794
+ ):
4795
+ """
4796
+ Filter out predictions that significantly overlap with ground truth (GT) objects.
4797
+
4798
+ This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
4799
+ - Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
4800
+ - Gets ProjectMeta object if not provided
4801
+ - Downloads GT annotations for the specified image IDs
4802
+ - Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
4803
+
4804
+ :param api: Supervisely API object
4805
+ :type api: Api
4806
+ :param pred_anns: List of Annotation objects containing predictions
4807
+ :type pred_anns: List[Annotation]
4808
+ :param dataset_id: ID of the dataset containing the images
4809
+ :type dataset_id: int
4810
+ :param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
4811
+ :type gt_image_ids: List[int]
4812
+ :param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
4813
+ ground truth box of the same class will be removed. None if no filtering is needed
4814
+ :type iou: Optional[float]
4815
+ :param meta: ProjectMeta object
4816
+ :type meta: Optional[ProjectMeta]
4817
+ :return: List of Annotation objects containing filtered predictions
4818
+ :rtype: List[Annotation]
4819
+
4820
+ Notes:
4821
+ ------
4822
+ - Requires PyTorch and torchvision for IoU calculations
4823
+ - This method is useful for identifying new objects that aren't already annotated in the ground truth
4824
+ """
4825
+ if isinstance(iou, float) and 0 < iou <= 1:
4826
+ if meta is None:
4827
+ ds = api.dataset.get_info_by_id(dataset_id)
4828
+ meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
4829
+ gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
4830
+ gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
4831
+ for i in range(0, len(pred_anns)):
4832
+ before = len(pred_anns[i].labels)
4833
+ with Timer() as timer:
4834
+ pred_anns[i] = _filter_duplicated_predictions_from_ann(
4835
+ gt_anns[i], pred_anns[i], iou
4836
+ )
4837
+ after = len(pred_anns[i].labels)
4838
+ logger.debug(
4839
+ f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
4840
+ )
4841
+ return pred_anns
4842
+
4843
+
4422
4844
  def _get_log_extra_for_inference_request(
4423
4845
  inference_request_uuid, inference_request: Union[InferenceRequest, dict]
4424
4846
  ):
@@ -4526,7 +4948,7 @@ def get_gpu_count():
4526
4948
  gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
4527
4949
  return gpu_count
4528
4950
  except (subprocess.CalledProcessError, FileNotFoundError) as exc:
4529
- logger.warn("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4951
+ logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
4530
4952
  return 0
4531
4953
 
4532
4954
 
@@ -4706,7 +5128,180 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suf
4706
5128
  img_tags = None
4707
5129
  if not any_label_updated:
4708
5130
  labels = None
4709
- ann = ann.clone(img_tags=TagCollection(img_tags))
5131
+ ann = ann.clone(img_tags=img_tags)
5132
+ return meta, ann, meta_changed
5133
+
5134
+
5135
+ def update_meta_and_ann_for_video_annotation(
5136
+ meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
5137
+ ):
5138
+ """Update project meta and annotation to match each other
5139
+ If obj class or tag meta from annotation conflicts with project meta
5140
+ add suffix to obj class or tag meta.
5141
+ Return tuple of updated project meta, annotation and boolean flag if meta was changed.
5142
+ """
5143
+ obj_classes_suffixes = ["_nn"]
5144
+ tag_meta_suffixes = ["_nn"]
5145
+ if model_prediction_suffix is not None:
5146
+ obj_classes_suffixes = [model_prediction_suffix]
5147
+ tag_meta_suffixes = [model_prediction_suffix]
5148
+ logger.debug(
5149
+ f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
5150
+ )
5151
+ logger.debug("source meta", extra={"meta": meta.to_json()})
5152
+ meta_changed = False
5153
+
5154
+ # meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
5155
+ # if replaced_classes_in_meta:
5156
+ # meta_changed = True
5157
+ # logger.warning(
5158
+ # "Some classes names were fixed in project meta",
5159
+ # extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
5160
+ # )
5161
+
5162
+ new_objects: List[VideoObject] = []
5163
+ new_figures: List[VideoFigure] = []
5164
+ any_object_updated = False
5165
+ for video_object in ann.objects:
5166
+ this_object_figures = [
5167
+ figure for figure in ann.figures if figure.video_object.key() == video_object.key()
5168
+ ]
5169
+ this_object_changed = False
5170
+ original_obj_class_name = video_object.obj_class.name
5171
+ suffix_found = False
5172
+ for suffix in ["", *obj_classes_suffixes]:
5173
+ obj_class = video_object.obj_class
5174
+ obj_class_name = obj_class.name + suffix
5175
+ if suffix:
5176
+ obj_class = obj_class.clone(name=obj_class_name)
5177
+ video_object = video_object.clone(obj_class=obj_class)
5178
+ any_object_updated = True
5179
+ this_object_changed = True
5180
+ meta_obj_class = meta.get_obj_class(obj_class_name)
5181
+ if meta_obj_class is None:
5182
+ # obj class is not in meta, add it with suffix
5183
+ meta = meta.add_obj_class(obj_class)
5184
+ new_objects.append(video_object)
5185
+ meta_changed = True
5186
+ suffix_found = True
5187
+ break
5188
+ elif (
5189
+ meta_obj_class.geometry_type.geometry_name()
5190
+ == video_object.obj_class.geometry_type.geometry_name()
5191
+ ):
5192
+ # if object geometry is the same as in meta, use meta obj class
5193
+ video_object = video_object.clone(obj_class=meta_obj_class)
5194
+ new_objects.append(video_object)
5195
+ suffix_found = True
5196
+ any_object_updated = True
5197
+ this_object_changed = True
5198
+ break
5199
+ elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
5200
+ # if meta obj class is AnyGeometry, use it in object
5201
+ video_object = video_object.clone(obj_class=meta_obj_class)
5202
+ new_objects.append(video_object)
5203
+ suffix_found = True
5204
+ any_object_updated = True
5205
+ this_object_changed = True
5206
+ break
5207
+ if not suffix_found:
5208
+ # if no suffix found, raise error
5209
+ raise ValueError(
5210
+ f"Can't add obj class {original_obj_class_name} to project meta. "
5211
+ "Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
5212
+ "Please check if model geometry type is compatible with existing obj classes."
5213
+ )
5214
+ elif this_object_changed:
5215
+ this_object_figures = [
5216
+ figure.clone(video_object=video_object) for figure in this_object_figures
5217
+ ]
5218
+ new_figures.extend(this_object_figures)
5219
+ if any_object_updated:
5220
+ frames_figures = {}
5221
+ for figure in new_figures:
5222
+ frames_figures.setdefault(figure.frame_index, []).append(figure)
5223
+ new_frames = FrameCollection(
5224
+ [
5225
+ Frame(index=frame_index, figures=figures)
5226
+ for frame_index, figures in frames_figures.items()
5227
+ ]
5228
+ )
5229
+ ann = ann.clone(objects=new_objects, frames=new_frames)
5230
+
5231
+ # check if tag metas are in project meta
5232
+ # if not, add them with suffix
5233
+ ann_tag_metas: Dict[str, TagMeta] = {}
5234
+ for video_object in ann.objects:
5235
+ for tag in video_object.tags:
5236
+ tag_name = tag.meta.name
5237
+ if tag_name not in ann_tag_metas:
5238
+ ann_tag_metas[tag_name] = tag.meta
5239
+ for tag in ann.tags:
5240
+ tag_name = tag.meta.name
5241
+ if tag_name not in ann_tag_metas:
5242
+ ann_tag_metas[tag_name] = tag.meta
5243
+
5244
+ changed_tag_metas = {}
5245
+ for ann_tag_meta in ann_tag_metas.values():
5246
+ meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
5247
+ if meta_tag_meta is None:
5248
+ meta = meta.add_tag_meta(ann_tag_meta)
5249
+ meta_changed = True
5250
+ elif not meta_tag_meta.is_compatible(ann_tag_meta):
5251
+ suffix_found = False
5252
+ for suffix in tag_meta_suffixes:
5253
+ new_tag_meta_name = ann_tag_meta.name + suffix
5254
+ meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
5255
+ if meta_tag_meta is None:
5256
+ new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
5257
+ meta = meta.add_tag_meta(new_tag_meta)
5258
+ changed_tag_metas[ann_tag_meta.name] = new_tag_meta
5259
+ meta_changed = True
5260
+ suffix_found = True
5261
+ break
5262
+ if meta_tag_meta.is_compatible(ann_tag_meta):
5263
+ changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
5264
+ suffix_found = True
5265
+ break
5266
+ if not suffix_found:
5267
+ raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
5268
+
5269
+ if changed_tag_metas:
5270
+ objects = []
5271
+ any_object_updated = False
5272
+ for video_object in ann.objects:
5273
+ any_tag_updated = False
5274
+ object_tags = []
5275
+ for tag in video_object.tags:
5276
+ if tag.meta.name in changed_tag_metas:
5277
+ object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5278
+ any_tag_updated = True
5279
+ else:
5280
+ object_tags.append(tag)
5281
+ if any_tag_updated:
5282
+ video_object = video_object.clone(tags=TagCollection(object_tags))
5283
+ any_object_updated = True
5284
+ objects.append(video_object)
5285
+
5286
+ video_tags = []
5287
+ any_tag_updated = False
5288
+ for tag in ann.tags:
5289
+ if tag.meta.name in changed_tag_metas:
5290
+ video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
5291
+ any_tag_updated = True
5292
+ else:
5293
+ video_tags.append(tag)
5294
+ if any_tag_updated or any_object_updated:
5295
+ if any_tag_updated:
5296
+ video_tags = VideoTagCollection(video_tags)
5297
+ else:
5298
+ video_tags = None
5299
+ if any_object_updated:
5300
+ objects = VideoObjectCollection(objects)
5301
+ else:
5302
+ objects = None
5303
+ ann = ann.clone(tags=video_tags, objects=objects)
5304
+
4710
5305
  return meta, ann, meta_changed
4711
5306
 
4712
5307
 
@@ -4820,7 +5415,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
4820
5415
  return data[key]
4821
5416
  return None
4822
5417
 
4823
- def torch_load_safe(checkpoint_path: str, device:str = "cpu"):
5418
+
5419
+ def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
4824
5420
  import torch # pylint: disable=import-error
4825
5421
 
4826
5422
  # TODO: handle torch.load(weights_only=True) - change in torch 2.6.0