supervisely 6.73.417__py3-none-any.whl → 6.73.419__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/entity_annotation/figure_api.py +89 -45
- supervisely/nn/inference/inference.py +61 -45
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/session.py +4 -4
- supervisely/nn/model/model_api.py +31 -20
- supervisely/nn/model/prediction.py +11 -0
- supervisely/nn/model/prediction_session.py +33 -6
- supervisely/nn/tracker/__init__.py +1 -2
- supervisely/nn/tracker/base_tracker.py +44 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +259 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/METADATA +5 -3
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/RECORD +29 -42
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/LICENSE +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/WHEEL +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/top_level.txt +0 -0
|
@@ -800,6 +800,7 @@ class FigureApi(RemoveableBulkModuleApi):
|
|
|
800
800
|
skip_geometry: bool = False,
|
|
801
801
|
semaphore: Optional[asyncio.Semaphore] = None,
|
|
802
802
|
log_progress: bool = True,
|
|
803
|
+
batch_size: int = 300,
|
|
803
804
|
) -> Dict[int, List[FigureInfo]]:
|
|
804
805
|
"""
|
|
805
806
|
Asynchronously download figures for the given dataset ID. Can be filtered by image IDs.
|
|
@@ -815,6 +816,10 @@ class FigureApi(RemoveableBulkModuleApi):
|
|
|
815
816
|
:type semaphore: Optional[asyncio.Semaphore], optional
|
|
816
817
|
:param log_progress: If True, log the progress of the download.
|
|
817
818
|
:type log_progress: bool, optional
|
|
819
|
+
:param batch_size: Size of the batch for downloading figures per 1 request. Default is 300.
|
|
820
|
+
Used for batching image_ids when filtering by specific images.
|
|
821
|
+
Adjust this value for optimal performance, value cannot exceed 500.
|
|
822
|
+
:type batch_size: int, optional
|
|
818
823
|
:return: A dictionary where keys are image IDs and values are lists of figures.
|
|
819
824
|
:rtype: Dict[int, List[FigureInfo]]
|
|
820
825
|
|
|
@@ -853,71 +858,104 @@ class FigureApi(RemoveableBulkModuleApi):
|
|
|
853
858
|
if skip_geometry is True:
|
|
854
859
|
fields = [x for x in fields if x != ApiField.GEOMETRY]
|
|
855
860
|
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
else:
|
|
859
|
-
filters = [
|
|
860
|
-
{
|
|
861
|
-
ApiField.FIELD: ApiField.ENTITY_ID,
|
|
862
|
-
ApiField.OPERATOR: "in",
|
|
863
|
-
ApiField.VALUE: image_ids,
|
|
864
|
-
}
|
|
865
|
-
]
|
|
866
|
-
|
|
867
|
-
data = {
|
|
861
|
+
# Base data setup
|
|
862
|
+
base_data = {
|
|
868
863
|
ApiField.DATASET_ID: dataset_id,
|
|
869
864
|
ApiField.FIELDS: fields,
|
|
870
|
-
ApiField.FILTER: filters,
|
|
871
865
|
}
|
|
872
866
|
|
|
873
|
-
# Get first page to determine total pages
|
|
874
867
|
if semaphore is None:
|
|
875
868
|
semaphore = self._api.get_default_semaphore()
|
|
876
|
-
images_figures = defaultdict(list)
|
|
877
|
-
pages_count = None
|
|
878
|
-
total = 0
|
|
879
|
-
tasks = []
|
|
880
869
|
|
|
881
|
-
async def
|
|
870
|
+
async def _get_page_figures(page_data, semaphore, progress_cb: tqdm = None):
|
|
871
|
+
"""Helper function to get figures from a single page"""
|
|
882
872
|
async with semaphore:
|
|
883
873
|
response = await self._api.post_async("figures.list", page_data)
|
|
884
874
|
response_json = response.json()
|
|
885
|
-
nonlocal pages_count, total
|
|
886
|
-
pages_count = response_json["pagesCount"]
|
|
887
|
-
if page_num == 1:
|
|
888
|
-
total = response_json["total"]
|
|
889
875
|
|
|
890
876
|
page_figures = []
|
|
891
877
|
for info in response_json["entities"]:
|
|
892
878
|
figure_info = self._convert_json_info(info, True)
|
|
893
879
|
page_figures.append(figure_info)
|
|
880
|
+
if progress_cb is not None:
|
|
881
|
+
progress_cb.update(len(response_json["entities"]))
|
|
894
882
|
return page_figures
|
|
895
883
|
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
884
|
+
async def _get_all_pages(ids_filter, progress_cb: tqdm = None):
|
|
885
|
+
"""Internal function to process all pages for given filter"""
|
|
886
|
+
data = base_data.copy()
|
|
887
|
+
data[ApiField.FILTER] = ids_filter
|
|
888
|
+
|
|
889
|
+
# Get first page to determine pagination
|
|
890
|
+
data[ApiField.PAGE] = 1
|
|
891
|
+
async with semaphore:
|
|
892
|
+
response = await self._api.post_async("figures.list", data)
|
|
893
|
+
response_json = response.json()
|
|
894
|
+
|
|
895
|
+
pages_count = response_json["pagesCount"]
|
|
896
|
+
all_figures = []
|
|
897
|
+
|
|
898
|
+
# Process first page
|
|
899
|
+
for info in response_json["entities"]:
|
|
900
|
+
figure_info = self._convert_json_info(info, True)
|
|
901
|
+
all_figures.append(figure_info)
|
|
902
|
+
if progress_cb is not None:
|
|
903
|
+
progress_cb.update(len(response_json["entities"]))
|
|
904
|
+
|
|
905
|
+
# Process remaining pages in parallel if needed
|
|
906
|
+
if pages_count > 1:
|
|
907
|
+
tasks = []
|
|
908
|
+
for page in range(2, pages_count + 1):
|
|
909
|
+
page_data = data.copy()
|
|
910
|
+
page_data[ApiField.PAGE] = page
|
|
911
|
+
tasks.append(
|
|
912
|
+
asyncio.create_task(
|
|
913
|
+
_get_page_figures(page_data, semaphore, progress_cb=progress_cb)
|
|
914
|
+
)
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
if tasks:
|
|
918
|
+
page_results = await asyncio.gather(*tasks)
|
|
919
|
+
for page_figures in page_results:
|
|
920
|
+
all_figures.extend(page_figures)
|
|
921
|
+
|
|
922
|
+
return all_figures
|
|
899
923
|
|
|
900
924
|
if log_progress:
|
|
901
|
-
progress_cb = tqdm(
|
|
925
|
+
progress_cb = tqdm(desc="Downloading figures", unit="figure", total=0)
|
|
926
|
+
else:
|
|
927
|
+
progress_cb = None
|
|
902
928
|
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
929
|
+
# Strategy: batch processing based on image_ids
|
|
930
|
+
tasks = []
|
|
931
|
+
|
|
932
|
+
if image_ids is None:
|
|
933
|
+
# Single task for all figures in dataset
|
|
934
|
+
filters = []
|
|
935
|
+
tasks.append(_get_all_pages(filters, progress_cb=progress_cb))
|
|
936
|
+
else:
|
|
937
|
+
# Batch image_ids and create tasks for each batch
|
|
938
|
+
for batch_ids in batched(image_ids, batch_size):
|
|
939
|
+
filters = [
|
|
940
|
+
{
|
|
941
|
+
ApiField.FIELD: ApiField.ENTITY_ID,
|
|
942
|
+
ApiField.OPERATOR: "in",
|
|
943
|
+
ApiField.VALUE: list(batch_ids),
|
|
944
|
+
}
|
|
945
|
+
]
|
|
946
|
+
tasks.append(_get_all_pages(filters, progress_cb=progress_cb))
|
|
947
|
+
# Small delay between batches to reduce server load
|
|
948
|
+
await asyncio.sleep(0.02)
|
|
949
|
+
|
|
950
|
+
# Execute all tasks in parallel and collect results
|
|
951
|
+
all_results = await asyncio.gather(*tasks)
|
|
952
|
+
|
|
953
|
+
# Combine results from all batches
|
|
954
|
+
images_figures = defaultdict(list)
|
|
955
|
+
|
|
956
|
+
for batch_figures in all_results:
|
|
957
|
+
for figure in batch_figures:
|
|
958
|
+
images_figures[figure.entity_id].append(figure)
|
|
921
959
|
|
|
922
960
|
return dict(images_figures)
|
|
923
961
|
|
|
@@ -928,6 +966,7 @@ class FigureApi(RemoveableBulkModuleApi):
|
|
|
928
966
|
skip_geometry: bool = False,
|
|
929
967
|
semaphore: Optional[asyncio.Semaphore] = None,
|
|
930
968
|
log_progress: bool = True,
|
|
969
|
+
batch_size: int = 300,
|
|
931
970
|
) -> Dict[int, List[FigureInfo]]:
|
|
932
971
|
"""
|
|
933
972
|
Download figures for the given dataset ID. Can be filtered by image IDs.
|
|
@@ -945,6 +984,10 @@ class FigureApi(RemoveableBulkModuleApi):
|
|
|
945
984
|
:type semaphore: Optional[asyncio.Semaphore], optional
|
|
946
985
|
:param log_progress: If True, log the progress of the download.
|
|
947
986
|
:type log_progress: bool, optional
|
|
987
|
+
:param batch_size: Size of the batch for downloading figures per 1 request. Default is 300.
|
|
988
|
+
Used for batching image_ids when filtering by specific images.
|
|
989
|
+
Adjust this value for optimal performance, value cannot exceed 500.
|
|
990
|
+
:type batch_size: int, optional
|
|
948
991
|
|
|
949
992
|
:return: A dictionary where keys are image IDs and values are lists of figures.
|
|
950
993
|
:rtype: Dict[int, List[FigureInfo]]
|
|
@@ -970,6 +1013,7 @@ class FigureApi(RemoveableBulkModuleApi):
|
|
|
970
1013
|
skip_geometry=skip_geometry,
|
|
971
1014
|
semaphore=semaphore,
|
|
972
1015
|
log_progress=log_progress,
|
|
1016
|
+
batch_size=batch_size,
|
|
973
1017
|
)
|
|
974
1018
|
)
|
|
975
1019
|
except Exception:
|
|
@@ -1265,6 +1265,26 @@ class Inference:
|
|
|
1265
1265
|
|
|
1266
1266
|
def get_classes(self) -> List[str]:
|
|
1267
1267
|
return self.classes
|
|
1268
|
+
|
|
1269
|
+
def _tracker_init(self, tracker: str, tracker_settings: dict):
|
|
1270
|
+
# Check if tracking is supported for this model
|
|
1271
|
+
info = self.get_info()
|
|
1272
|
+
tracking_support = info.get("tracking_on_videos_support", False)
|
|
1273
|
+
|
|
1274
|
+
if not tracking_support:
|
|
1275
|
+
logger.debug("Tracking is not supported for this model")
|
|
1276
|
+
return None
|
|
1277
|
+
|
|
1278
|
+
if tracker == "botsort":
|
|
1279
|
+
from supervisely.nn.tracker import BotSortTracker
|
|
1280
|
+
device = tracker_settings.get("device", self.device)
|
|
1281
|
+
logger.debug(f"Initializing BotSort tracker with device: {device}")
|
|
1282
|
+
return BotSortTracker(settings=tracker_settings, device=device)
|
|
1283
|
+
else:
|
|
1284
|
+
if tracker is not None:
|
|
1285
|
+
logger.warning(f"Unknown tracking type: {tracker}. Tracking is disabled.")
|
|
1286
|
+
return None
|
|
1287
|
+
|
|
1268
1288
|
|
|
1269
1289
|
def get_info(self) -> Dict[str, Any]:
|
|
1270
1290
|
num_classes = None
|
|
@@ -1291,9 +1311,9 @@ class Inference:
|
|
|
1291
1311
|
"sliding_window_support": self.sliding_window_mode,
|
|
1292
1312
|
"videos_support": True,
|
|
1293
1313
|
"async_video_inference_support": True,
|
|
1294
|
-
"tracking_on_videos_support":
|
|
1314
|
+
"tracking_on_videos_support": False,
|
|
1295
1315
|
"async_image_inference_support": True,
|
|
1296
|
-
"tracking_algorithms": ["
|
|
1316
|
+
"tracking_algorithms": ["botsort"],
|
|
1297
1317
|
"batch_inference_support": self.is_batch_inference_supported(),
|
|
1298
1318
|
"max_batch_size": self.max_batch_size,
|
|
1299
1319
|
}
|
|
@@ -1847,24 +1867,12 @@ class Inference:
|
|
|
1847
1867
|
else:
|
|
1848
1868
|
n_frames = frames_reader.frames_count()
|
|
1849
1869
|
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
tracker = BoTTracker(state)
|
|
1854
|
-
elif tracking == "deepsort":
|
|
1855
|
-
from supervisely.nn.tracker import DeepSortTracker
|
|
1856
|
-
|
|
1857
|
-
tracker = DeepSortTracker(state)
|
|
1858
|
-
else:
|
|
1859
|
-
if tracking is not None:
|
|
1860
|
-
logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
|
|
1861
|
-
tracker = None
|
|
1862
|
-
|
|
1870
|
+
self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
1871
|
+
|
|
1863
1872
|
progress_total = (n_frames + step - 1) // step
|
|
1864
1873
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
1865
1874
|
|
|
1866
1875
|
results = []
|
|
1867
|
-
tracks_data = {}
|
|
1868
1876
|
for batch in batched(
|
|
1869
1877
|
range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
|
|
1870
1878
|
batch_size,
|
|
@@ -1884,28 +1892,32 @@ class Inference:
|
|
|
1884
1892
|
source=frames,
|
|
1885
1893
|
settings=inference_settings,
|
|
1886
1894
|
)
|
|
1895
|
+
|
|
1896
|
+
if self._tracker is not None:
|
|
1897
|
+
anns = self._apply_tracker_to_anns(frames, anns)
|
|
1898
|
+
|
|
1887
1899
|
predictions = [
|
|
1888
1900
|
Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
|
|
1889
1901
|
for ann, frame_index in zip(anns, batch)
|
|
1890
1902
|
]
|
|
1903
|
+
|
|
1891
1904
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
1892
1905
|
pred.extra_data["slides_data"] = this_slides_data
|
|
1893
1906
|
batch_results = self._format_output(predictions)
|
|
1894
|
-
|
|
1895
|
-
for frame_index, frame, ann in zip(batch, frames, anns):
|
|
1896
|
-
tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
|
|
1907
|
+
|
|
1897
1908
|
inference_request.add_results(batch_results)
|
|
1898
1909
|
inference_request.done(len(batch_results))
|
|
1899
1910
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
1900
1911
|
video_ann_json = None
|
|
1901
|
-
if
|
|
1912
|
+
if self._tracker is not None:
|
|
1902
1913
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
).to_json()
|
|
1914
|
+
|
|
1915
|
+
video_ann_json = self._tracker.video_annotation.to_json()
|
|
1906
1916
|
inference_request.done()
|
|
1907
1917
|
result = {"ann": results, "video_ann": video_ann_json}
|
|
1908
1918
|
inference_request.final_result = result.copy()
|
|
1919
|
+
return video_ann_json
|
|
1920
|
+
|
|
1909
1921
|
|
|
1910
1922
|
def _inference_image_ids(
|
|
1911
1923
|
self,
|
|
@@ -2083,18 +2095,8 @@ class Inference:
|
|
|
2083
2095
|
else:
|
|
2084
2096
|
n_frames = video_info.frames_count
|
|
2085
2097
|
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
tracker = BoTTracker(state)
|
|
2090
|
-
elif tracking == "deepsort":
|
|
2091
|
-
from supervisely.nn.tracker import DeepSortTracker
|
|
2092
|
-
|
|
2093
|
-
tracker = DeepSortTracker(state)
|
|
2094
|
-
else:
|
|
2095
|
-
if tracking is not None:
|
|
2096
|
-
logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
|
|
2097
|
-
tracker = None
|
|
2098
|
+
self._tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2099
|
+
|
|
2098
2100
|
logger.debug(
|
|
2099
2101
|
f"Video info:",
|
|
2100
2102
|
extra=dict(
|
|
@@ -2111,7 +2113,6 @@ class Inference:
|
|
|
2111
2113
|
progress_total = (n_frames + step - 1) // step
|
|
2112
2114
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
2113
2115
|
|
|
2114
|
-
tracks_data = {}
|
|
2115
2116
|
for batch in batched(
|
|
2116
2117
|
range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
|
|
2117
2118
|
batch_size,
|
|
@@ -2130,6 +2131,10 @@ class Inference:
|
|
|
2130
2131
|
source=frames,
|
|
2131
2132
|
settings=inference_settings,
|
|
2132
2133
|
)
|
|
2134
|
+
|
|
2135
|
+
if self._tracker is not None:
|
|
2136
|
+
anns = self._apply_tracker_to_anns(frames, anns)
|
|
2137
|
+
|
|
2133
2138
|
predictions = [
|
|
2134
2139
|
Prediction(
|
|
2135
2140
|
ann,
|
|
@@ -2137,27 +2142,24 @@ class Inference:
|
|
|
2137
2142
|
frame_index=frame_index,
|
|
2138
2143
|
video_id=video_info.id,
|
|
2139
2144
|
dataset_id=video_info.dataset_id,
|
|
2140
|
-
|
|
2141
|
-
|
|
2145
|
+
project_id=video_info.project_id,
|
|
2146
|
+
)
|
|
2142
2147
|
for ann, frame_index in zip(anns, batch)
|
|
2143
2148
|
]
|
|
2144
2149
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2145
2150
|
pred.extra_data["slides_data"] = this_slides_data
|
|
2146
2151
|
batch_results = self._format_output(predictions)
|
|
2147
|
-
|
|
2148
|
-
for frame_index, frame, ann in zip(batch, frames, anns):
|
|
2149
|
-
tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
|
|
2152
|
+
|
|
2150
2153
|
inference_request.add_results(batch_results)
|
|
2151
2154
|
inference_request.done(len(batch_results))
|
|
2152
2155
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
2153
2156
|
video_ann_json = None
|
|
2154
|
-
if
|
|
2157
|
+
if self._tracker is not None:
|
|
2155
2158
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2156
|
-
video_ann_json =
|
|
2157
|
-
tracks_data, (video_info.frame_height, video_info.frame_width), n_frames
|
|
2158
|
-
).to_json()
|
|
2159
|
+
video_ann_json = self._tracker.video_annotation.to_json()
|
|
2159
2160
|
inference_request.done()
|
|
2160
2161
|
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2162
|
+
return video_ann_json
|
|
2161
2163
|
|
|
2162
2164
|
def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2163
2165
|
"""Inference project images.
|
|
@@ -4117,6 +4119,20 @@ class Inference:
|
|
|
4117
4119
|
self._args.draw,
|
|
4118
4120
|
)
|
|
4119
4121
|
|
|
4122
|
+
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
|
|
4123
|
+
updated_anns = []
|
|
4124
|
+
for frame, ann in zip(frames, anns):
|
|
4125
|
+
matches = self._tracker.update(frame, ann)
|
|
4126
|
+
track_ids = [match["track_id"] for match in matches]
|
|
4127
|
+
tracked_labels = [match["label"] for match in matches]
|
|
4128
|
+
|
|
4129
|
+
filtered_annotation = ann.clone(
|
|
4130
|
+
labels=tracked_labels,
|
|
4131
|
+
custom_data=track_ids
|
|
4132
|
+
)
|
|
4133
|
+
updated_anns.append(filtered_annotation)
|
|
4134
|
+
return updated_anns
|
|
4135
|
+
|
|
4120
4136
|
def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
|
|
4121
4137
|
if model_source == ModelSource.PRETRAINED:
|
|
4122
4138
|
checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
|
|
@@ -12,6 +12,7 @@ class InstanceSegmentation(Inference):
|
|
|
12
12
|
def get_info(self) -> dict:
|
|
13
13
|
info = super().get_info()
|
|
14
14
|
info["task type"] = "instance segmentation"
|
|
15
|
+
info["tracking_on_videos_support"] = True
|
|
15
16
|
# recommended parameters:
|
|
16
17
|
# info["model_name"] = ""
|
|
17
18
|
# info["checkpoint_name"] = ""
|
|
@@ -17,6 +17,7 @@ class ObjectDetection(Inference):
|
|
|
17
17
|
def get_info(self) -> dict:
|
|
18
18
|
info = super().get_info()
|
|
19
19
|
info["task type"] = "object detection"
|
|
20
|
+
info["tracking_on_videos_support"] = True
|
|
20
21
|
# recommended parameters:
|
|
21
22
|
# info["model_name"] = ""
|
|
22
23
|
# info["checkpoint_name"] = ""
|
|
@@ -271,7 +271,7 @@ class SessionJSON:
|
|
|
271
271
|
start_frame_index: int = None,
|
|
272
272
|
frames_count: int = None,
|
|
273
273
|
frames_direction: Literal["forward", "backward"] = None,
|
|
274
|
-
tracker: Literal["
|
|
274
|
+
tracker: Literal["botsort"] = None,
|
|
275
275
|
batch_size: int = None,
|
|
276
276
|
) -> Dict[str, Any]:
|
|
277
277
|
endpoint = "inference_video_id"
|
|
@@ -295,7 +295,7 @@ class SessionJSON:
|
|
|
295
295
|
frames_direction: Literal["forward", "backward"] = None,
|
|
296
296
|
process_fn=None,
|
|
297
297
|
preparing_cb=None,
|
|
298
|
-
tracker: Literal["
|
|
298
|
+
tracker: Literal["botsort"] = None,
|
|
299
299
|
batch_size: int = None,
|
|
300
300
|
) -> Iterator:
|
|
301
301
|
if self._async_inference_uuid:
|
|
@@ -795,7 +795,7 @@ class Session(SessionJSON):
|
|
|
795
795
|
start_frame_index: int = None,
|
|
796
796
|
frames_count: int = None,
|
|
797
797
|
frames_direction: Literal["forward", "backward"] = None,
|
|
798
|
-
tracker: Literal["
|
|
798
|
+
tracker: Literal["botsort"] = None,
|
|
799
799
|
batch_size: int = None,
|
|
800
800
|
) -> List[sly.Annotation]:
|
|
801
801
|
pred_list_raw = super().inference_video_id(
|
|
@@ -811,7 +811,7 @@ class Session(SessionJSON):
|
|
|
811
811
|
start_frame_index: int = None,
|
|
812
812
|
frames_count: int = None,
|
|
813
813
|
frames_direction: Literal["forward", "backward"] = None,
|
|
814
|
-
tracker: Literal["
|
|
814
|
+
tracker: Literal["botsort"] = None,
|
|
815
815
|
batch_size: int = None,
|
|
816
816
|
preparing_cb=None,
|
|
817
817
|
) -> AsyncInferenceIterator:
|
|
@@ -211,12 +211,15 @@ class ModelAPI:
|
|
|
211
211
|
project_id: int = None,
|
|
212
212
|
batch_size: int = None,
|
|
213
213
|
conf: float = None,
|
|
214
|
+
img_size: int = None,
|
|
214
215
|
classes: List[str] = None,
|
|
215
216
|
upload_mode: str = None,
|
|
217
|
+
recursive: bool = False,
|
|
218
|
+
tracking: bool = None,
|
|
219
|
+
tracking_config: dict = None,
|
|
216
220
|
**kwargs,
|
|
217
221
|
) -> PredictionSession:
|
|
218
|
-
|
|
219
|
-
kwargs["upload_mode"] = upload_mode
|
|
222
|
+
|
|
220
223
|
return PredictionSession(
|
|
221
224
|
self.url,
|
|
222
225
|
input=input,
|
|
@@ -227,7 +230,12 @@ class ModelAPI:
|
|
|
227
230
|
api=self.api,
|
|
228
231
|
batch_size=batch_size,
|
|
229
232
|
conf=conf,
|
|
233
|
+
img_size=img_size,
|
|
230
234
|
classes=classes,
|
|
235
|
+
upload_mode=upload_mode,
|
|
236
|
+
recursive=recursive,
|
|
237
|
+
tracking=tracking,
|
|
238
|
+
tracking_config=tracking_config,
|
|
231
239
|
**kwargs,
|
|
232
240
|
)
|
|
233
241
|
|
|
@@ -243,28 +251,31 @@ class ModelAPI:
|
|
|
243
251
|
img_size: int = None,
|
|
244
252
|
classes: List[str] = None,
|
|
245
253
|
upload_mode: str = None,
|
|
246
|
-
recursive: bool =
|
|
254
|
+
recursive: bool = False,
|
|
255
|
+
tracking: bool = None,
|
|
256
|
+
tracking_config: dict = None,
|
|
247
257
|
**kwargs,
|
|
248
258
|
) -> List[Prediction]:
|
|
249
259
|
if "show_progress" not in kwargs:
|
|
250
260
|
kwargs["show_progress"] = True
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
261
|
+
session = PredictionSession(
|
|
262
|
+
self.url,
|
|
263
|
+
input=input,
|
|
264
|
+
image_id=image_id,
|
|
265
|
+
video_id=video_id,
|
|
266
|
+
dataset_id=dataset_id,
|
|
267
|
+
project_id=project_id,
|
|
268
|
+
api=self.api,
|
|
269
|
+
batch_size=batch_size,
|
|
270
|
+
conf=conf,
|
|
271
|
+
img_size=img_size,
|
|
272
|
+
classes=classes,
|
|
273
|
+
upload_mode=upload_mode,
|
|
274
|
+
recursive=recursive,
|
|
275
|
+
tracking=tracking,
|
|
276
|
+
tracking_config=tracking_config,
|
|
277
|
+
**kwargs,
|
|
268
278
|
)
|
|
279
|
+
return list(session)
|
|
269
280
|
|
|
270
281
|
# ------------------------------------ #
|
|
@@ -82,6 +82,7 @@ class Prediction:
|
|
|
82
82
|
self._masks = None
|
|
83
83
|
self._classes = None
|
|
84
84
|
self._scores = None
|
|
85
|
+
self._track_ids = None
|
|
85
86
|
|
|
86
87
|
if self.path is None and isinstance(self.source, (str, PathLike)):
|
|
87
88
|
self.path = str(self.source)
|
|
@@ -125,6 +126,10 @@ class Prediction:
|
|
|
125
126
|
)
|
|
126
127
|
self._boxes = np.array(self._boxes)
|
|
127
128
|
self._masks = np.array(self._masks)
|
|
129
|
+
|
|
130
|
+
custom_data = self.annotation.custom_data
|
|
131
|
+
if custom_data and isinstance(custom_data, list) and len(custom_data) == len(self.annotation.labels):
|
|
132
|
+
self._track_ids = np.array(custom_data)
|
|
128
133
|
|
|
129
134
|
@property
|
|
130
135
|
def boxes(self):
|
|
@@ -178,6 +183,12 @@ class Prediction:
|
|
|
178
183
|
obj_class.name: i for i, obj_class in enumerate(self.model_meta.obj_classes)
|
|
179
184
|
}
|
|
180
185
|
return np.array([cls_name_to_idx[class_name] for class_name in self.classes])
|
|
186
|
+
@property
|
|
187
|
+
def track_ids(self):
|
|
188
|
+
"""Get track IDs for each detection. Returns None for detections without tracking."""
|
|
189
|
+
if self._track_ids is None:
|
|
190
|
+
self._init_geometries()
|
|
191
|
+
return self._track_ids
|
|
181
192
|
|
|
182
193
|
@classmethod
|
|
183
194
|
def from_json(cls, json_data: Dict, **kwargs) -> "Prediction":
|
|
@@ -67,8 +67,11 @@ class PredictionSession:
|
|
|
67
67
|
dataset_id: Union[List[int], int] = None,
|
|
68
68
|
project_id: Union[List[int], int] = None,
|
|
69
69
|
api: "Api" = None,
|
|
70
|
+
tracking: bool = None,
|
|
71
|
+
tracking_config: dict = None,
|
|
70
72
|
**kwargs: dict,
|
|
71
|
-
):
|
|
73
|
+
):
|
|
74
|
+
|
|
72
75
|
extra_input_args = ["image_ids", "video_ids", "dataset_ids", "project_ids"]
|
|
73
76
|
assert (
|
|
74
77
|
sum(
|
|
@@ -87,6 +90,7 @@ class PredictionSession:
|
|
|
87
90
|
== 1
|
|
88
91
|
), "Exactly one of input, image_ids, video_id, dataset_id, project_id or image_id must be provided."
|
|
89
92
|
|
|
93
|
+
|
|
90
94
|
self._iterator = None
|
|
91
95
|
self._base_url = url
|
|
92
96
|
self.inference_request_uuid = None
|
|
@@ -111,6 +115,22 @@ class PredictionSession:
|
|
|
111
115
|
self.inference_settings = {
|
|
112
116
|
k: v for k, v in kwargs.items() if isinstance(v, (str, int, float))
|
|
113
117
|
}
|
|
118
|
+
|
|
119
|
+
if tracking is True:
|
|
120
|
+
model_info = self._get_session_info()
|
|
121
|
+
if not model_info.get("tracking_on_videos_support", False):
|
|
122
|
+
raise ValueError("Tracking is not supported by this model")
|
|
123
|
+
|
|
124
|
+
if tracking_config is None:
|
|
125
|
+
self.tracker = "botsort"
|
|
126
|
+
self.tracker_settings = {}
|
|
127
|
+
else:
|
|
128
|
+
cfg = dict(tracking_config)
|
|
129
|
+
self.tracker = cfg.pop("tracker", "botsort")
|
|
130
|
+
self.tracker_settings = cfg
|
|
131
|
+
else:
|
|
132
|
+
self.tracker = None
|
|
133
|
+
self.tracker_settings = None
|
|
114
134
|
|
|
115
135
|
# extra input args
|
|
116
136
|
image_ids = self._set_var_from_kwargs("image_ids", kwargs, image_id)
|
|
@@ -180,7 +200,7 @@ class PredictionSession:
|
|
|
180
200
|
self._iterator = self._predict_images(input, **kwargs)
|
|
181
201
|
elif ext.lower() in ALLOWED_VIDEO_EXTENSIONS:
|
|
182
202
|
kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
|
|
183
|
-
self._iterator = self._predict_videos(input, **kwargs)
|
|
203
|
+
self._iterator = self._predict_videos(input, tracker=self.tracker, tracker_settings=self.tracker_settings, **kwargs)
|
|
184
204
|
else:
|
|
185
205
|
raise ValueError(
|
|
186
206
|
f"Unsupported file extension: {ext}. Supported extensions are: {SUPPORTED_IMG_EXTS + ALLOWED_VIDEO_EXTENSIONS}"
|
|
@@ -193,7 +213,7 @@ class PredictionSession:
|
|
|
193
213
|
if len(video_ids) > 1:
|
|
194
214
|
raise ValueError("Only one video id can be provided.")
|
|
195
215
|
kwargs = get_valid_kwargs(kwargs, self._predict_videos, exclude=["videos"])
|
|
196
|
-
self._iterator = self._predict_videos(video_ids, **kwargs)
|
|
216
|
+
self._iterator = self._predict_videos(video_ids, tracker=self.tracker, tracker_settings=self.tracker_settings, **kwargs)
|
|
197
217
|
elif dataset_ids is not None:
|
|
198
218
|
kwargs = get_valid_kwargs(
|
|
199
219
|
kwargs,
|
|
@@ -259,7 +279,7 @@ class PredictionSession:
|
|
|
259
279
|
if self.api is not None:
|
|
260
280
|
return self.api.token
|
|
261
281
|
return env.api_token(raise_not_found=False)
|
|
262
|
-
|
|
282
|
+
|
|
263
283
|
def _get_json_body(self):
|
|
264
284
|
body = {"state": {}, "context": {}}
|
|
265
285
|
if self.inference_request_uuid is not None:
|
|
@@ -269,7 +289,7 @@ class PredictionSession:
|
|
|
269
289
|
if self.api_token is not None:
|
|
270
290
|
body["api_token"] = self.api_token
|
|
271
291
|
return body
|
|
272
|
-
|
|
292
|
+
|
|
273
293
|
def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
|
|
274
294
|
if kwargs.get("headers") is None:
|
|
275
295
|
kwargs["headers"] = {}
|
|
@@ -303,6 +323,11 @@ class PredictionSession:
|
|
|
303
323
|
if retry_idx + 1 == retries:
|
|
304
324
|
raise exc
|
|
305
325
|
|
|
326
|
+
def _get_session_info(self) -> Dict[str, Any]:
|
|
327
|
+
method = "get_session_info"
|
|
328
|
+
r = self._post(method, json=self._get_json_body())
|
|
329
|
+
return r.json()
|
|
330
|
+
|
|
306
331
|
def _get_inference_progress(self):
|
|
307
332
|
method = "get_inference_progress"
|
|
308
333
|
r = self._post(method, json=self._get_json_body())
|
|
@@ -558,7 +583,8 @@ class PredictionSession:
|
|
|
558
583
|
end_frame=None,
|
|
559
584
|
duration=None,
|
|
560
585
|
direction: Literal["forward", "backward"] = None,
|
|
561
|
-
tracker: Literal["
|
|
586
|
+
tracker: Literal["botsort"] = None,
|
|
587
|
+
tracker_settings: dict = None,
|
|
562
588
|
batch_size: int = None,
|
|
563
589
|
):
|
|
564
590
|
if len(videos) != 1:
|
|
@@ -573,6 +599,7 @@ class PredictionSession:
|
|
|
573
599
|
("duration", duration),
|
|
574
600
|
("direction", direction),
|
|
575
601
|
("tracker", tracker),
|
|
602
|
+
("tracker_settings", tracker_settings),
|
|
576
603
|
("batch_size", batch_size),
|
|
577
604
|
):
|
|
578
605
|
if value is not None:
|