supervisely 6.73.418__py3-none-any.whl → 6.73.419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. supervisely/api/entity_annotation/figure_api.py +89 -45
  2. supervisely/nn/inference/inference.py +61 -45
  3. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  4. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  5. supervisely/nn/inference/session.py +4 -4
  6. supervisely/nn/model/model_api.py +31 -20
  7. supervisely/nn/model/prediction.py +11 -0
  8. supervisely/nn/model/prediction_session.py +33 -6
  9. supervisely/nn/tracker/__init__.py +1 -2
  10. supervisely/nn/tracker/base_tracker.py +44 -0
  11. supervisely/nn/tracker/botsort/__init__.py +1 -0
  12. supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
  13. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  14. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  15. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  16. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  17. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  18. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  19. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  20. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  21. supervisely/nn/tracker/botsort_tracker.py +259 -0
  22. supervisely/project/project.py +1 -1
  23. {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/METADATA +3 -1
  24. {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/RECORD +29 -42
  25. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  26. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  27. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  28. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  29. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  30. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  31. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  32. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  33. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  34. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  35. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  36. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  37. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  38. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  39. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  40. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  41. supervisely/nn/tracker/tracker.py +0 -285
  42. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  43. supervisely/nn/tracking/__init__.py +0 -1
  44. supervisely/nn/tracking/boxmot.py +0 -114
  45. supervisely/nn/tracking/tracking.py +0 -24
  46. /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
  47. {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/LICENSE +0 -0
  48. {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/WHEEL +0 -0
  49. {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/entry_points.txt +0 -0
  50. {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/top_level.txt +0 -0
@@ -800,6 +800,7 @@ class FigureApi(RemoveableBulkModuleApi):
800
800
  skip_geometry: bool = False,
801
801
  semaphore: Optional[asyncio.Semaphore] = None,
802
802
  log_progress: bool = True,
803
+ batch_size: int = 300,
803
804
  ) -> Dict[int, List[FigureInfo]]:
804
805
  """
805
806
  Asynchronously download figures for the given dataset ID. Can be filtered by image IDs.
@@ -815,6 +816,10 @@ class FigureApi(RemoveableBulkModuleApi):
815
816
  :type semaphore: Optional[asyncio.Semaphore], optional
816
817
  :param log_progress: If True, log the progress of the download.
817
818
  :type log_progress: bool, optional
819
+ :param batch_size: Size of the batch for downloading figures per 1 request. Default is 300.
820
+ Used for batching image_ids when filtering by specific images.
821
+ Adjust this value for optimal performance, value cannot exceed 500.
822
+ :type batch_size: int, optional
818
823
  :return: A dictionary where keys are image IDs and values are lists of figures.
819
824
  :rtype: Dict[int, List[FigureInfo]]
820
825
 
@@ -853,71 +858,104 @@ class FigureApi(RemoveableBulkModuleApi):
853
858
  if skip_geometry is True:
854
859
  fields = [x for x in fields if x != ApiField.GEOMETRY]
855
860
 
856
- if image_ids is None:
857
- filters = []
858
- else:
859
- filters = [
860
- {
861
- ApiField.FIELD: ApiField.ENTITY_ID,
862
- ApiField.OPERATOR: "in",
863
- ApiField.VALUE: image_ids,
864
- }
865
- ]
866
-
867
- data = {
861
+ # Base data setup
862
+ base_data = {
868
863
  ApiField.DATASET_ID: dataset_id,
869
864
  ApiField.FIELDS: fields,
870
- ApiField.FILTER: filters,
871
865
  }
872
866
 
873
- # Get first page to determine total pages
874
867
  if semaphore is None:
875
868
  semaphore = self._api.get_default_semaphore()
876
- images_figures = defaultdict(list)
877
- pages_count = None
878
- total = 0
879
- tasks = []
880
869
 
881
- async def _get_page(page_data, page_num):
870
+ async def _get_page_figures(page_data, semaphore, progress_cb: tqdm = None):
871
+ """Helper function to get figures from a single page"""
882
872
  async with semaphore:
883
873
  response = await self._api.post_async("figures.list", page_data)
884
874
  response_json = response.json()
885
- nonlocal pages_count, total
886
- pages_count = response_json["pagesCount"]
887
- if page_num == 1:
888
- total = response_json["total"]
889
875
 
890
876
  page_figures = []
891
877
  for info in response_json["entities"]:
892
878
  figure_info = self._convert_json_info(info, True)
893
879
  page_figures.append(figure_info)
880
+ if progress_cb is not None:
881
+ progress_cb.update(len(response_json["entities"]))
894
882
  return page_figures
895
883
 
896
- # Get first page
897
- data[ApiField.PAGE] = 1
898
- first_page_figures = await _get_page(data, 1)
884
+ async def _get_all_pages(ids_filter, progress_cb: tqdm = None):
885
+ """Internal function to process all pages for given filter"""
886
+ data = base_data.copy()
887
+ data[ApiField.FILTER] = ids_filter
888
+
889
+ # Get first page to determine pagination
890
+ data[ApiField.PAGE] = 1
891
+ async with semaphore:
892
+ response = await self._api.post_async("figures.list", data)
893
+ response_json = response.json()
894
+
895
+ pages_count = response_json["pagesCount"]
896
+ all_figures = []
897
+
898
+ # Process first page
899
+ for info in response_json["entities"]:
900
+ figure_info = self._convert_json_info(info, True)
901
+ all_figures.append(figure_info)
902
+ if progress_cb is not None:
903
+ progress_cb.update(len(response_json["entities"]))
904
+
905
+ # Process remaining pages in parallel if needed
906
+ if pages_count > 1:
907
+ tasks = []
908
+ for page in range(2, pages_count + 1):
909
+ page_data = data.copy()
910
+ page_data[ApiField.PAGE] = page
911
+ tasks.append(
912
+ asyncio.create_task(
913
+ _get_page_figures(page_data, semaphore, progress_cb=progress_cb)
914
+ )
915
+ )
916
+
917
+ if tasks:
918
+ page_results = await asyncio.gather(*tasks)
919
+ for page_figures in page_results:
920
+ all_figures.extend(page_figures)
921
+
922
+ return all_figures
899
923
 
900
924
  if log_progress:
901
- progress_cb = tqdm(total=total, desc="Downloading figures")
925
+ progress_cb = tqdm(desc="Downloading figures", unit="figure", total=0)
926
+ else:
927
+ progress_cb = None
902
928
 
903
- for figure in first_page_figures:
904
- images_figures[figure.entity_id].append(figure)
905
- if log_progress:
906
- progress_cb.update(1)
907
-
908
- # Get rest of the pages in parallel
909
- if pages_count > 1:
910
- for page in range(2, pages_count + 1):
911
- page_data = data.copy()
912
- page_data[ApiField.PAGE] = page
913
- tasks.append(asyncio.create_task(_get_page(page_data, page)))
914
-
915
- for task in asyncio.as_completed(tasks):
916
- page_figures = await task
917
- for figure in page_figures:
918
- images_figures[figure.entity_id].append(figure)
919
- if log_progress:
920
- progress_cb.update(1)
929
+ # Strategy: batch processing based on image_ids
930
+ tasks = []
931
+
932
+ if image_ids is None:
933
+ # Single task for all figures in dataset
934
+ filters = []
935
+ tasks.append(_get_all_pages(filters, progress_cb=progress_cb))
936
+ else:
937
+ # Batch image_ids and create tasks for each batch
938
+ for batch_ids in batched(image_ids, batch_size):
939
+ filters = [
940
+ {
941
+ ApiField.FIELD: ApiField.ENTITY_ID,
942
+ ApiField.OPERATOR: "in",
943
+ ApiField.VALUE: list(batch_ids),
944
+ }
945
+ ]
946
+ tasks.append(_get_all_pages(filters, progress_cb=progress_cb))
947
+ # Small delay between batches to reduce server load
948
+ await asyncio.sleep(0.02)
949
+
950
+ # Execute all tasks in parallel and collect results
951
+ all_results = await asyncio.gather(*tasks)
952
+
953
+ # Combine results from all batches
954
+ images_figures = defaultdict(list)
955
+
956
+ for batch_figures in all_results:
957
+ for figure in batch_figures:
958
+ images_figures[figure.entity_id].append(figure)
921
959
 
922
960
  return dict(images_figures)
923
961
 
@@ -928,6 +966,7 @@ class FigureApi(RemoveableBulkModuleApi):
928
966
  skip_geometry: bool = False,
929
967
  semaphore: Optional[asyncio.Semaphore] = None,
930
968
  log_progress: bool = True,
969
+ batch_size: int = 300,
931
970
  ) -> Dict[int, List[FigureInfo]]:
932
971
  """
933
972
  Download figures for the given dataset ID. Can be filtered by image IDs.
@@ -945,6 +984,10 @@ class FigureApi(RemoveableBulkModuleApi):
945
984
  :type semaphore: Optional[asyncio.Semaphore], optional
946
985
  :param log_progress: If True, log the progress of the download.
947
986
  :type log_progress: bool, optional
987
+ :param batch_size: Size of the batch for downloading figures per 1 request. Default is 300.
988
+ Used for batching image_ids when filtering by specific images.
989
+ Adjust this value for optimal performance, value cannot exceed 500.
990
+ :type batch_size: int, optional
948
991
 
949
992
  :return: A dictionary where keys are image IDs and values are lists of figures.
950
993
  :rtype: Dict[int, List[FigureInfo]]
@@ -970,6 +1013,7 @@ class FigureApi(RemoveableBulkModuleApi):
970
1013
  skip_geometry=skip_geometry,
971
1014
  semaphore=semaphore,
972
1015
  log_progress=log_progress,
1016
+ batch_size=batch_size,
973
1017
  )
974
1018
  )
975
1019
  except Exception:
@@ -1265,6 +1265,26 @@ class Inference:
1265
1265
 
1266
1266
  def get_classes(self) -> List[str]:
1267
1267
  return self.classes
1268
+
1269
+ def _tracker_init(self, tracker: str, tracker_settings: dict):
1270
+ # Check if tracking is supported for this model
1271
+ info = self.get_info()
1272
+ tracking_support = info.get("tracking_on_videos_support", False)
1273
+
1274
+ if not tracking_support:
1275
+ logger.debug("Tracking is not supported for this model")
1276
+ return None
1277
+
1278
+ if tracker == "botsort":
1279
+ from supervisely.nn.tracker import BotSortTracker
1280
+ device = tracker_settings.get("device", self.device)
1281
+ logger.debug(f"Initializing BotSort tracker with device: {device}")
1282
+ return BotSortTracker(settings=tracker_settings, device=device)
1283
+ else:
1284
+ if tracker is not None:
1285
+ logger.warning(f"Unknown tracking type: {tracker}. Tracking is disabled.")
1286
+ return None
1287
+
1268
1288
 
1269
1289
  def get_info(self) -> Dict[str, Any]:
1270
1290
  num_classes = None
@@ -1291,9 +1311,9 @@ class Inference:
1291
1311
  "sliding_window_support": self.sliding_window_mode,
1292
1312
  "videos_support": True,
1293
1313
  "async_video_inference_support": True,
1294
- "tracking_on_videos_support": True,
1314
+ "tracking_on_videos_support": False,
1295
1315
  "async_image_inference_support": True,
1296
- "tracking_algorithms": ["bot", "deepsort"],
1316
+ "tracking_algorithms": ["botsort"],
1297
1317
  "batch_inference_support": self.is_batch_inference_supported(),
1298
1318
  "max_batch_size": self.max_batch_size,
1299
1319
  }
@@ -1847,24 +1867,12 @@ class Inference:
1847
1867
  else:
1848
1868
  n_frames = frames_reader.frames_count()
1849
1869
 
1850
- if tracking == "bot":
1851
- from supervisely.nn.tracker import BoTTracker
1852
-
1853
- tracker = BoTTracker(state)
1854
- elif tracking == "deepsort":
1855
- from supervisely.nn.tracker import DeepSortTracker
1856
-
1857
- tracker = DeepSortTracker(state)
1858
- else:
1859
- if tracking is not None:
1860
- logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
1861
- tracker = None
1862
-
1870
+ self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
1871
+
1863
1872
  progress_total = (n_frames + step - 1) // step
1864
1873
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
1865
1874
 
1866
1875
  results = []
1867
- tracks_data = {}
1868
1876
  for batch in batched(
1869
1877
  range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
1870
1878
  batch_size,
@@ -1884,28 +1892,32 @@ class Inference:
1884
1892
  source=frames,
1885
1893
  settings=inference_settings,
1886
1894
  )
1895
+
1896
+ if self._tracker is not None:
1897
+ anns = self._apply_tracker_to_anns(frames, anns)
1898
+
1887
1899
  predictions = [
1888
1900
  Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
1889
1901
  for ann, frame_index in zip(anns, batch)
1890
1902
  ]
1903
+
1891
1904
  for pred, this_slides_data in zip(predictions, slides_data):
1892
1905
  pred.extra_data["slides_data"] = this_slides_data
1893
1906
  batch_results = self._format_output(predictions)
1894
- if tracker is not None:
1895
- for frame_index, frame, ann in zip(batch, frames, anns):
1896
- tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
1907
+
1897
1908
  inference_request.add_results(batch_results)
1898
1909
  inference_request.done(len(batch_results))
1899
1910
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
1900
1911
  video_ann_json = None
1901
- if tracker is not None:
1912
+ if self._tracker is not None:
1902
1913
  inference_request.set_stage("Postprocess...", 0, 1)
1903
- video_ann_json = tracker.get_annotation(
1904
- tracks_data, (video_height, video_witdth), n_frames
1905
- ).to_json()
1914
+
1915
+ video_ann_json = self._tracker.video_annotation.to_json()
1906
1916
  inference_request.done()
1907
1917
  result = {"ann": results, "video_ann": video_ann_json}
1908
1918
  inference_request.final_result = result.copy()
1919
+ return video_ann_json
1920
+
1909
1921
 
1910
1922
  def _inference_image_ids(
1911
1923
  self,
@@ -2083,18 +2095,8 @@ class Inference:
2083
2095
  else:
2084
2096
  n_frames = video_info.frames_count
2085
2097
 
2086
- if tracking == "bot":
2087
- from supervisely.nn.tracker import BoTTracker
2088
-
2089
- tracker = BoTTracker(state)
2090
- elif tracking == "deepsort":
2091
- from supervisely.nn.tracker import DeepSortTracker
2092
-
2093
- tracker = DeepSortTracker(state)
2094
- else:
2095
- if tracking is not None:
2096
- logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
2097
- tracker = None
2098
+ self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
2099
+
2098
2100
  logger.debug(
2099
2101
  f"Video info:",
2100
2102
  extra=dict(
@@ -2111,7 +2113,6 @@ class Inference:
2111
2113
  progress_total = (n_frames + step - 1) // step
2112
2114
  inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
2113
2115
 
2114
- tracks_data = {}
2115
2116
  for batch in batched(
2116
2117
  range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
2117
2118
  batch_size,
@@ -2130,6 +2131,10 @@ class Inference:
2130
2131
  source=frames,
2131
2132
  settings=inference_settings,
2132
2133
  )
2134
+
2135
+ if self._tracker is not None:
2136
+ anns = self._apply_tracker_to_anns(frames, anns)
2137
+
2133
2138
  predictions = [
2134
2139
  Prediction(
2135
2140
  ann,
@@ -2137,27 +2142,24 @@ class Inference:
2137
2142
  frame_index=frame_index,
2138
2143
  video_id=video_info.id,
2139
2144
  dataset_id=video_info.dataset_id,
2140
- project_id=video_info.project_id,
2141
- )
2145
+ project_id=video_info.project_id,
2146
+ )
2142
2147
  for ann, frame_index in zip(anns, batch)
2143
2148
  ]
2144
2149
  for pred, this_slides_data in zip(predictions, slides_data):
2145
2150
  pred.extra_data["slides_data"] = this_slides_data
2146
2151
  batch_results = self._format_output(predictions)
2147
- if tracker is not None:
2148
- for frame_index, frame, ann in zip(batch, frames, anns):
2149
- tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
2152
+
2150
2153
  inference_request.add_results(batch_results)
2151
2154
  inference_request.done(len(batch_results))
2152
2155
  logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
2153
2156
  video_ann_json = None
2154
- if tracker is not None:
2157
+ if self._tracker is not None:
2155
2158
  inference_request.set_stage("Postprocess...", 0, 1)
2156
- video_ann_json = tracker.get_annotation(
2157
- tracks_data, (video_info.frame_height, video_info.frame_width), n_frames
2158
- ).to_json()
2159
+ video_ann_json = self._tracker.video_annotation.to_json()
2159
2160
  inference_request.done()
2160
2161
  inference_request.final_result = {"video_ann": video_ann_json}
2162
+ return video_ann_json
2161
2163
 
2162
2164
  def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
2163
2165
  """Inference project images.
@@ -4117,6 +4119,20 @@ class Inference:
4117
4119
  self._args.draw,
4118
4120
  )
4119
4121
 
4122
+ def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
4123
+ updated_anns = []
4124
+ for frame, ann in zip(frames, anns):
4125
+ matches = self._tracker.update(frame, ann)
4126
+ track_ids = [match["track_id"] for match in matches]
4127
+ tracked_labels = [match["label"] for match in matches]
4128
+
4129
+ filtered_annotation = ann.clone(
4130
+ labels=tracked_labels,
4131
+ custom_data=track_ids
4132
+ )
4133
+ updated_anns.append(filtered_annotation)
4134
+ return updated_anns
4135
+
4120
4136
  def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
4121
4137
  if model_source == ModelSource.PRETRAINED:
4122
4138
  checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
@@ -12,6 +12,7 @@ class InstanceSegmentation(Inference):
12
12
  def get_info(self) -> dict:
13
13
  info = super().get_info()
14
14
  info["task type"] = "instance segmentation"
15
+ info["tracking_on_videos_support"] = True
15
16
  # recommended parameters:
16
17
  # info["model_name"] = ""
17
18
  # info["checkpoint_name"] = ""
@@ -17,6 +17,7 @@ class ObjectDetection(Inference):
17
17
  def get_info(self) -> dict:
18
18
  info = super().get_info()
19
19
  info["task type"] = "object detection"
20
+ info["tracking_on_videos_support"] = True
20
21
  # recommended parameters:
21
22
  # info["model_name"] = ""
22
23
  # info["checkpoint_name"] = ""
@@ -271,7 +271,7 @@ class SessionJSON:
271
271
  start_frame_index: int = None,
272
272
  frames_count: int = None,
273
273
  frames_direction: Literal["forward", "backward"] = None,
274
- tracker: Literal["bot", "deepsort"] = None,
274
+ tracker: Literal["botsort"] = None,
275
275
  batch_size: int = None,
276
276
  ) -> Dict[str, Any]:
277
277
  endpoint = "inference_video_id"
@@ -295,7 +295,7 @@ class SessionJSON:
295
295
  frames_direction: Literal["forward", "backward"] = None,
296
296
  process_fn=None,
297
297
  preparing_cb=None,
298
- tracker: Literal["bot", "deepsort"] = None,
298
+ tracker: Literal["botsort"] = None,
299
299
  batch_size: int = None,
300
300
  ) -> Iterator:
301
301
  if self._async_inference_uuid:
@@ -795,7 +795,7 @@ class Session(SessionJSON):
795
795
  start_frame_index: int = None,
796
796
  frames_count: int = None,
797
797
  frames_direction: Literal["forward", "backward"] = None,
798
- tracker: Literal["bot", "deepsort"] = None,
798
+ tracker: Literal["botsort"] = None,
799
799
  batch_size: int = None,
800
800
  ) -> List[sly.Annotation]:
801
801
  pred_list_raw = super().inference_video_id(
@@ -811,7 +811,7 @@ class Session(SessionJSON):
811
811
  start_frame_index: int = None,
812
812
  frames_count: int = None,
813
813
  frames_direction: Literal["forward", "backward"] = None,
814
- tracker: Literal["bot", "deepsort"] = None,
814
+ tracker: Literal["botsort"] = None,
815
815
  batch_size: int = None,
816
816
  preparing_cb=None,
817
817
  ) -> AsyncInferenceIterator:
@@ -211,12 +211,15 @@ class ModelAPI:
211
211
  project_id: int = None,
212
212
  batch_size: int = None,
213
213
  conf: float = None,
214
+ img_size: int = None,
214
215
  classes: List[str] = None,
215
216
  upload_mode: str = None,
217
+ recursive: bool = False,
218
+ tracking: bool = None,
219
+ tracking_config: dict = None,
216
220
  **kwargs,
217
221
  ) -> PredictionSession:
218
- if upload_mode is not None:
219
- kwargs["upload_mode"] = upload_mode
222
+
220
223
  return PredictionSession(
221
224
  self.url,
222
225
  input=input,
@@ -227,7 +230,12 @@ class ModelAPI:
227
230
  api=self.api,
228
231
  batch_size=batch_size,
229
232
  conf=conf,
233
+ img_size=img_size,
230
234
  classes=classes,
235
+ upload_mode=upload_mode,
236
+ recursive=recursive,
237
+ tracking=tracking,
238
+ tracking_config=tracking_config,
231
239
  **kwargs,
232
240
  )
233
241
 
@@ -243,28 +251,31 @@ class ModelAPI:
243
251
  img_size: int = None,
244
252
  classes: List[str] = None,
245
253
  upload_mode: str = None,
246
- recursive: bool = None,
254
+ recursive: bool = False,
255
+ tracking: bool = None,
256
+ tracking_config: dict = None,
247
257
  **kwargs,
248
258
  ) -> List[Prediction]:
249
259
  if "show_progress" not in kwargs:
250
260
  kwargs["show_progress"] = True
251
- if recursive is not None:
252
- kwargs["recursive"] = recursive
253
- if img_size is not None:
254
- kwargs["img_size"] = img_size
255
- return list(
256
- self.predict_detached(
257
- input,
258
- image_id,
259
- video_id,
260
- dataset_id,
261
- project_id,
262
- batch_size,
263
- conf,
264
- classes,
265
- upload_mode,
266
- **kwargs,
267
- )
261
+ session = PredictionSession(
262
+ self.url,
263
+ input=input,
264
+ image_id=image_id,
265
+ video_id=video_id,
266
+ dataset_id=dataset_id,
267
+ project_id=project_id,
268
+ api=self.api,
269
+ batch_size=batch_size,
270
+ conf=conf,
271
+ img_size=img_size,
272
+ classes=classes,
273
+ upload_mode=upload_mode,
274
+ recursive=recursive,
275
+ tracking=tracking,
276
+ tracking_config=tracking_config,
277
+ **kwargs,
268
278
  )
279
+ return list(session)
269
280
 
270
281
  # ------------------------------------ #
@@ -82,6 +82,7 @@ class Prediction:
82
82
  self._masks = None
83
83
  self._classes = None
84
84
  self._scores = None
85
+ self._track_ids = None
85
86
 
86
87
  if self.path is None and isinstance(self.source, (str, PathLike)):
87
88
  self.path = str(self.source)
@@ -125,6 +126,10 @@ class Prediction:
125
126
  )
126
127
  self._boxes = np.array(self._boxes)
127
128
  self._masks = np.array(self._masks)
129
+
130
+ custom_data = self.annotation.custom_data
131
+ if custom_data and isinstance(custom_data, list) and len(custom_data) == len(self.annotation.labels):
132
+ self._track_ids = np.array(custom_data)
128
133
 
129
134
  @property
130
135
  def boxes(self):
@@ -178,6 +183,12 @@ class Prediction:
178
183
  obj_class.name: i for i, obj_class in enumerate(self.model_meta.obj_classes)
179
184
  }
180
185
  return np.array([cls_name_to_idx[class_name] for class_name in self.classes])
186
+ @property
187
+ def track_ids(self):
188
+ """Get track IDs for each detection. Returns None for detections without tracking."""
189
+ if self._track_ids is None:
190
+ self._init_geometries()
191
+ return self._track_ids
181
192
 
182
193
  @classmethod
183
194
  def from_json(cls, json_data: Dict, **kwargs) -> "Prediction":
@@ -67,8 +67,11 @@ class PredictionSession:
67
67
  dataset_id: Union[List[int], int] = None,
68
68
  project_id: Union[List[int], int] = None,
69
69
  api: "Api" = None,
70
+ tracking: bool = None,
71
+ tracking_config: dict = None,
70
72
  **kwargs: dict,
71
- ):
73
+ ):
74
+
72
75
  extra_input_args = ["image_ids", "video_ids", "dataset_ids", "project_ids"]
73
76
  assert (
74
77
  sum(
@@ -87,6 +90,7 @@ class PredictionSession:
87
90
  == 1
88
91
  ), "Exactly one of input, image_ids, video_id, dataset_id, project_id or image_id must be provided."
89
92
 
93
+
90
94
  self._iterator = None
91
95
  self._base_url = url
92
96
  self.inference_request_uuid = None
@@ -111,6 +115,22 @@ class PredictionSession:
111
115
  self.inference_settings = {
112
116
  k: v for k, v in kwargs.items() if isinstance(v, (str, int, float))
113
117
  }
118
+
119
+ if tracking is True:
120
+ model_info = self._get_session_info()
121
+ if not model_info.get("tracking_on_videos_support", False):
122
+ raise ValueError("Tracking is not supported by this model")
123
+
124
+ if tracking_config is None:
125
+ self.tracker = "botsort"
126
+ self.tracker_settings = {}
127
+ else:
128
+ cfg = dict(tracking_config)
129
+ self.tracker = cfg.pop("tracker", "botsort")
130
+ self.tracker_settings = cfg
131
+ else:
132
+ self.tracker = None
133
+ self.tracker_settings = None
114
134
 
115
135
  # extra input args
116
136
  image_ids = self._set_var_from_kwargs("image_ids", kwargs, image_id)
@@ -180,7 +200,7 @@ class PredictionSession:
180
200
  self._iterator = self._predict_images(input, **kwargs)
181
201
  elif ext.lower() in ALLOWED_VIDEO_EXTENSIONS:
182
202
  kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
183
- self._iterator = self._predict_videos(input, **kwargs)
203
+ self._iterator = self._predict_videos(input, tracker=self.tracker, tracker_settings=self.tracker_settings, **kwargs)
184
204
  else:
185
205
  raise ValueError(
186
206
  f"Unsupported file extension: {ext}. Supported extensions are: {SUPPORTED_IMG_EXTS + ALLOWED_VIDEO_EXTENSIONS}"
@@ -193,7 +213,7 @@ class PredictionSession:
193
213
  if len(video_ids) > 1:
194
214
  raise ValueError("Only one video id can be provided.")
195
215
  kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
196
- self._iterator = self._predict_videos(video_ids, **kwargs)
216
+ self._iterator = self._predict_videos(video_ids, tracker=self.tracker, tracker_settings=self.tracker_settings, **kwargs)
197
217
  elif dataset_ids is not None:
198
218
  kwargs = get_valid_kwargs(
199
219
  kwargs,
@@ -259,7 +279,7 @@ class PredictionSession:
259
279
  if self.api is not None:
260
280
  return self.api.token
261
281
  return env.api_token(raise_not_found=False)
262
-
282
+
263
283
  def _get_json_body(self):
264
284
  body = {"state": {}, "context": {}}
265
285
  if self.inference_request_uuid is not None:
@@ -269,7 +289,7 @@ class PredictionSession:
269
289
  if self.api_token is not None:
270
290
  body["api_token"] = self.api_token
271
291
  return body
272
-
292
+
273
293
  def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
274
294
  if kwargs.get("headers") is None:
275
295
  kwargs["headers"] = {}
@@ -303,6 +323,11 @@ class PredictionSession:
303
323
  if retry_idx + 1 == retries:
304
324
  raise exc
305
325
 
326
+ def _get_session_info(self) -> Dict[str, Any]:
327
+ method = "get_session_info"
328
+ r = self._post(method, json=self._get_json_body())
329
+ return r.json()
330
+
306
331
  def _get_inference_progress(self):
307
332
  method = "get_inference_progress"
308
333
  r = self._post(method, json=self._get_json_body())
@@ -558,7 +583,8 @@ class PredictionSession:
558
583
  end_frame=None,
559
584
  duration=None,
560
585
  direction: Literal["forward", "backward"] = None,
561
- tracker: Literal["bot", "deepsort"] = None,
586
+ tracker: Literal["botsort"] = None,
587
+ tracker_settings: dict = None,
562
588
  batch_size: int = None,
563
589
  ):
564
590
  if len(videos) != 1:
@@ -573,6 +599,7 @@ class PredictionSession:
573
599
  ("duration", duration),
574
600
  ("direction", direction),
575
601
  ("tracker", tracker),
602
+ ("tracker_settings", tracker_settings),
576
603
  ("batch_size", batch_size),
577
604
  ):
578
605
  if value is not None: