supervisely 6.73.444__py3-none-any.whl → 6.73.468__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +24 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/json_geometries_map.py +2 -0
- supervisely/api/dataset_api.py +74 -12
- supervisely/api/entity_annotation/figure_api.py +8 -5
- supervisely/api/image_api.py +4 -0
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +41 -1
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +14 -6
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +88 -0
- supervisely/app/fastapi/subapp.py +88 -42
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +121 -31
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/video/video_converter.py +2 -2
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/io/env.py +76 -1
- supervisely/nn/inference/cache.py +37 -17
- supervisely/nn/inference/inference.py +667 -114
- supervisely/nn/inference/inference_request.py +15 -8
- supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/inference/session.py +43 -35
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/model/prediction_session.py +8 -7
- supervisely/nn/prediction_dto.py +7 -0
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort_tracker.py +14 -7
- supervisely/nn/tracker/visualize.py +70 -72
- supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
- supervisely/nn/training/train_app.py +10 -5
- supervisely/project/project.py +9 -1
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +41 -12
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/METADATA +14 -11
- {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/RECORD +68 -66
- {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/LICENSE +0 -0
- {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/WHEEL +0 -0
- {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.444.dist-info → supervisely-6.73.468.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ import subprocess
|
|
|
11
11
|
import tempfile
|
|
12
12
|
import threading
|
|
13
13
|
import time
|
|
14
|
+
import uuid
|
|
14
15
|
from collections import OrderedDict, defaultdict
|
|
15
16
|
from concurrent.futures import ThreadPoolExecutor
|
|
16
17
|
from dataclasses import asdict, dataclass
|
|
@@ -52,6 +53,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
|
|
|
52
53
|
from supervisely.api.api import Api, ApiField
|
|
53
54
|
from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
|
|
54
55
|
from supervisely.api.image_api import ImageInfo
|
|
56
|
+
from supervisely.api.video.video_api import VideoInfo
|
|
55
57
|
from supervisely.app.content import get_data_dir
|
|
56
58
|
from supervisely.app.fastapi.subapp import (
|
|
57
59
|
Application,
|
|
@@ -67,6 +69,7 @@ from supervisely.decorators.inference import (
|
|
|
67
69
|
process_images_batch_sliding_window,
|
|
68
70
|
)
|
|
69
71
|
from supervisely.geometry.any_geometry import AnyGeometry
|
|
72
|
+
from supervisely.geometry.geometry import Geometry
|
|
70
73
|
from supervisely.imaging.color import get_predefined_colors
|
|
71
74
|
from supervisely.io.fs import list_files
|
|
72
75
|
from supervisely.nn.experiments import ExperimentInfo
|
|
@@ -94,6 +97,18 @@ from supervisely.project.project_meta import ProjectMeta
|
|
|
94
97
|
from supervisely.sly_logger import logger
|
|
95
98
|
from supervisely.task.progress import Progress
|
|
96
99
|
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
|
|
100
|
+
from supervisely.video_annotation.frame import Frame
|
|
101
|
+
from supervisely.video_annotation.frame_collection import FrameCollection
|
|
102
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
|
103
|
+
from supervisely.video_annotation.video_figure import VideoFigure
|
|
104
|
+
from supervisely.video_annotation.video_object import VideoObject
|
|
105
|
+
from supervisely.video_annotation.video_object_collection import VideoObjectCollection
|
|
106
|
+
from supervisely.video_annotation.video_tag_collection import VideoTagCollection
|
|
107
|
+
from supervisely.video_annotation.key_id_map import KeyIdMap
|
|
108
|
+
from supervisely.video_annotation.video_object_collection import (
|
|
109
|
+
VideoObject,
|
|
110
|
+
VideoObjectCollection,
|
|
111
|
+
)
|
|
97
112
|
|
|
98
113
|
try:
|
|
99
114
|
from typing import Literal
|
|
@@ -140,6 +155,7 @@ class Inference:
|
|
|
140
155
|
"""Default batch size for inference"""
|
|
141
156
|
INFERENCE_SETTINGS: str = None
|
|
142
157
|
"""Path to file with custom inference settings"""
|
|
158
|
+
DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
|
|
143
159
|
|
|
144
160
|
def __init__(
|
|
145
161
|
self,
|
|
@@ -193,7 +209,6 @@ class Inference:
|
|
|
193
209
|
self._task_id = None
|
|
194
210
|
self._sliding_window_mode = sliding_window_mode
|
|
195
211
|
self._autostart_delay_time = 5 * 60 # 5 min
|
|
196
|
-
self._tracker = None
|
|
197
212
|
self._hardware: str = None
|
|
198
213
|
if custom_inference_settings is None:
|
|
199
214
|
if self.INFERENCE_SETTINGS is not None:
|
|
@@ -427,7 +442,7 @@ class Inference:
|
|
|
427
442
|
|
|
428
443
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
429
444
|
except Exception as e:
|
|
430
|
-
logger.
|
|
445
|
+
logger.warning(
|
|
431
446
|
f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
|
|
432
447
|
)
|
|
433
448
|
device = "cpu"
|
|
@@ -1105,31 +1120,37 @@ class Inference:
|
|
|
1105
1120
|
self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
|
|
1106
1121
|
self._hardware = get_hardware_info(self.device)
|
|
1107
1122
|
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1123
|
+
model_files = deploy_params.get("model_files", None)
|
|
1124
|
+
if model_files is not None:
|
|
1125
|
+
checkpoint_path = deploy_params["model_files"]["checkpoint"]
|
|
1126
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
1127
|
+
if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
|
|
1128
|
+
try:
|
|
1129
|
+
self.load_model(**deploy_params)
|
|
1130
|
+
except Exception as e:
|
|
1131
|
+
logger.warning(
|
|
1132
|
+
f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}"
|
|
1133
|
+
)
|
|
1134
|
+
checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
|
|
1135
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1136
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1137
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1138
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1139
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1140
|
+
self.load_model(**deploy_params)
|
|
1141
|
+
if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
|
|
1142
|
+
if self.runtime == RuntimeType.ONNXRUNTIME:
|
|
1143
|
+
logger.info("Exporting PyTorch model to ONNX...")
|
|
1144
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1145
|
+
checkpoint_path = self.export_onnx(deploy_params)
|
|
1146
|
+
elif self.runtime == RuntimeType.TENSORRT:
|
|
1147
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1148
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1149
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1120
1150
|
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1121
1151
|
self.load_model(**deploy_params)
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
logger.info("Exporting PyTorch model to ONNX...")
|
|
1125
|
-
self._remove_exported_checkpoints(checkpoint_path)
|
|
1126
|
-
checkpoint_path = self.export_onnx(deploy_params)
|
|
1127
|
-
elif self.runtime == RuntimeType.TENSORRT:
|
|
1128
|
-
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1129
|
-
self._remove_exported_checkpoints(checkpoint_path)
|
|
1130
|
-
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1131
|
-
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1132
|
-
self.load_model(**deploy_params)
|
|
1152
|
+
else:
|
|
1153
|
+
self.load_model(**deploy_params)
|
|
1133
1154
|
else:
|
|
1134
1155
|
self.load_model(**deploy_params)
|
|
1135
1156
|
|
|
@@ -1253,7 +1274,6 @@ class Inference:
|
|
|
1253
1274
|
if self._model_meta is None:
|
|
1254
1275
|
self._set_model_meta_from_classes()
|
|
1255
1276
|
|
|
1256
|
-
|
|
1257
1277
|
def _set_model_meta_custom_model(self, model_info: dict):
|
|
1258
1278
|
model_meta = model_info.get("model_meta")
|
|
1259
1279
|
if model_meta is None:
|
|
@@ -1354,6 +1374,7 @@ class Inference:
|
|
|
1354
1374
|
|
|
1355
1375
|
if tracker == "botsort":
|
|
1356
1376
|
from supervisely.nn.tracker import BotSortTracker
|
|
1377
|
+
|
|
1357
1378
|
device = tracker_settings.get("device", self.device)
|
|
1358
1379
|
logger.debug(f"Initializing BotSort tracker with device: {device}")
|
|
1359
1380
|
return BotSortTracker(settings=tracker_settings, device=device)
|
|
@@ -1370,15 +1391,15 @@ class Inference:
|
|
|
1370
1391
|
if classes is not None:
|
|
1371
1392
|
num_classes = len(classes)
|
|
1372
1393
|
except NotImplementedError:
|
|
1373
|
-
logger.
|
|
1394
|
+
logger.warning(f"get_classes() function not implemented for {type(self)} object.")
|
|
1374
1395
|
except AttributeError:
|
|
1375
|
-
logger.
|
|
1396
|
+
logger.warning("Probably, get_classes() function not working without model deploy.")
|
|
1376
1397
|
except Exception as exc:
|
|
1377
|
-
logger.
|
|
1398
|
+
logger.warning("Unknown exception. Please, contact support")
|
|
1378
1399
|
logger.exception(exc)
|
|
1379
1400
|
|
|
1380
1401
|
if num_classes is None:
|
|
1381
|
-
logger.
|
|
1402
|
+
logger.warning(f"get_classes() function return {classes}; skip classes processing.")
|
|
1382
1403
|
|
|
1383
1404
|
return {
|
|
1384
1405
|
"app_name": get_name_from_env(default="Neural Network Serving"),
|
|
@@ -1396,6 +1417,42 @@ class Inference:
|
|
|
1396
1417
|
|
|
1397
1418
|
# pylint: enable=method-hidden
|
|
1398
1419
|
|
|
1420
|
+
def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
|
|
1421
|
+
"""
|
|
1422
|
+
Get default parameters for all available tracking algorithms.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
{"botsort": {"track_high_thresh": 0.6, ...}}
|
|
1426
|
+
Empty dict if tracking not supported.
|
|
1427
|
+
"""
|
|
1428
|
+
info = self.get_info()
|
|
1429
|
+
trackers_params = {}
|
|
1430
|
+
|
|
1431
|
+
tracking_support = info.get("tracking_on_videos_support")
|
|
1432
|
+
if not tracking_support:
|
|
1433
|
+
return trackers_params
|
|
1434
|
+
|
|
1435
|
+
tracking_algorithms = info.get("tracking_algorithms", [])
|
|
1436
|
+
|
|
1437
|
+
for tracker_name in tracking_algorithms:
|
|
1438
|
+
try:
|
|
1439
|
+
if tracker_name == "botsort":
|
|
1440
|
+
from supervisely.nn.tracker import BotSortTracker
|
|
1441
|
+
|
|
1442
|
+
trackers_params[tracker_name] = BotSortTracker.get_default_params()
|
|
1443
|
+
# Add other trackers here as elif blocks
|
|
1444
|
+
else:
|
|
1445
|
+
logger.debug(f"Tracker '{tracker_name}' not implemented")
|
|
1446
|
+
except Exception as e:
|
|
1447
|
+
logger.warning(f"Failed to get params for '{tracker_name}': {e}")
|
|
1448
|
+
|
|
1449
|
+
INTERNAL_FIELDS = {"device", "fps"}
|
|
1450
|
+
for tracker_name, params in trackers_params.items():
|
|
1451
|
+
trackers_params[tracker_name] = {
|
|
1452
|
+
k: v for k, v in params.items() if k not in INTERNAL_FIELDS
|
|
1453
|
+
}
|
|
1454
|
+
return trackers_params
|
|
1455
|
+
|
|
1399
1456
|
def get_human_readable_info(self, replace_none_with: Optional[str] = None):
|
|
1400
1457
|
hr_info = {}
|
|
1401
1458
|
info = self.get_info()
|
|
@@ -1947,7 +2004,7 @@ class Inference:
|
|
|
1947
2004
|
else:
|
|
1948
2005
|
n_frames = frames_reader.frames_count()
|
|
1949
2006
|
|
|
1950
|
-
|
|
2007
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
1951
2008
|
|
|
1952
2009
|
progress_total = (n_frames + step - 1) // step
|
|
1953
2010
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
@@ -1973,8 +2030,8 @@ class Inference:
|
|
|
1973
2030
|
settings=inference_settings,
|
|
1974
2031
|
)
|
|
1975
2032
|
|
|
1976
|
-
if
|
|
1977
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
|
2033
|
+
if inference_request.tracker is not None:
|
|
2034
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
1978
2035
|
|
|
1979
2036
|
predictions = [
|
|
1980
2037
|
Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
|
|
@@ -1989,10 +2046,9 @@ class Inference:
|
|
|
1989
2046
|
inference_request.done(len(batch_results))
|
|
1990
2047
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
1991
2048
|
video_ann_json = None
|
|
1992
|
-
if
|
|
2049
|
+
if inference_request.tracker is not None:
|
|
1993
2050
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
1994
|
-
|
|
1995
|
-
video_ann_json = self._tracker.video_annotation.to_json()
|
|
2051
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
1996
2052
|
inference_request.done()
|
|
1997
2053
|
result = {"ann": results, "video_ann": video_ann_json}
|
|
1998
2054
|
inference_request.final_result = result.copy()
|
|
@@ -2024,7 +2080,7 @@ class Inference:
|
|
|
2024
2080
|
upload_mode = state.get("upload_mode", None)
|
|
2025
2081
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
2026
2082
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
2027
|
-
iou_merge_threshold = 0.
|
|
2083
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
|
|
2028
2084
|
|
|
2029
2085
|
images_infos = api.image.get_info_by_id_batch(image_ids)
|
|
2030
2086
|
images_infos_dict = {im_info.id: im_info for im_info in images_infos}
|
|
@@ -2146,7 +2202,7 @@ class Inference:
|
|
|
2146
2202
|
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2147
2203
|
if video_id is None:
|
|
2148
2204
|
raise ValueError("Video id is not provided")
|
|
2149
|
-
video_info = api.video.get_info_by_id(video_id)
|
|
2205
|
+
video_info = api.video.get_info_by_id(video_id, force_metadata_for_links=True)
|
|
2150
2206
|
start_frame_index = get_value_for_keys(
|
|
2151
2207
|
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2152
2208
|
)
|
|
@@ -2176,7 +2232,7 @@ class Inference:
|
|
|
2176
2232
|
else:
|
|
2177
2233
|
n_frames = video_info.frames_count
|
|
2178
2234
|
|
|
2179
|
-
|
|
2235
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2180
2236
|
|
|
2181
2237
|
logger.debug(
|
|
2182
2238
|
f"Video info:",
|
|
@@ -2213,8 +2269,8 @@ class Inference:
|
|
|
2213
2269
|
settings=inference_settings,
|
|
2214
2270
|
)
|
|
2215
2271
|
|
|
2216
|
-
if
|
|
2217
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
|
2272
|
+
if inference_request.tracker is not None:
|
|
2273
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2218
2274
|
|
|
2219
2275
|
predictions = [
|
|
2220
2276
|
Prediction(
|
|
@@ -2223,8 +2279,8 @@ class Inference:
|
|
|
2223
2279
|
frame_index=frame_index,
|
|
2224
2280
|
video_id=video_info.id,
|
|
2225
2281
|
dataset_id=video_info.dataset_id,
|
|
2226
|
-
|
|
2227
|
-
|
|
2282
|
+
project_id=video_info.project_id,
|
|
2283
|
+
)
|
|
2228
2284
|
for ann, frame_index in zip(anns, batch)
|
|
2229
2285
|
]
|
|
2230
2286
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
@@ -2235,13 +2291,169 @@ class Inference:
|
|
|
2235
2291
|
inference_request.done(len(batch_results))
|
|
2236
2292
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
2237
2293
|
video_ann_json = None
|
|
2238
|
-
if
|
|
2294
|
+
if inference_request.tracker is not None:
|
|
2295
|
+
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2296
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2297
|
+
inference_request.done()
|
|
2298
|
+
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2299
|
+
return video_ann_json
|
|
2300
|
+
|
|
2301
|
+
def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2302
|
+
logger.debug("Inferring video_id...", extra={"state": state})
|
|
2303
|
+
inference_settings = self._get_inference_settings(state)
|
|
2304
|
+
logger.debug(f"Inference settings:", extra=inference_settings)
|
|
2305
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
2306
|
+
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2307
|
+
if video_id is None:
|
|
2308
|
+
raise ValueError("Video id is not provided")
|
|
2309
|
+
video_info = api.video.get_info_by_id(video_id)
|
|
2310
|
+
start_frame_index = get_value_for_keys(
|
|
2311
|
+
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2312
|
+
)
|
|
2313
|
+
if start_frame_index is None:
|
|
2314
|
+
start_frame_index = 0
|
|
2315
|
+
step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
|
|
2316
|
+
if step is None:
|
|
2317
|
+
step = 1
|
|
2318
|
+
end_frame_index = get_value_for_keys(
|
|
2319
|
+
state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
|
|
2320
|
+
)
|
|
2321
|
+
duration = state.get("duration", None)
|
|
2322
|
+
frames_count = get_value_for_keys(
|
|
2323
|
+
state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
|
|
2324
|
+
)
|
|
2325
|
+
tracking = state.get("tracker", None)
|
|
2326
|
+
direction = state.get("direction", "forward")
|
|
2327
|
+
direction = 1 if direction == "forward" else -1
|
|
2328
|
+
track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
|
|
2329
|
+
|
|
2330
|
+
if frames_count is not None:
|
|
2331
|
+
n_frames = frames_count
|
|
2332
|
+
elif end_frame_index is not None:
|
|
2333
|
+
n_frames = end_frame_index - start_frame_index
|
|
2334
|
+
elif duration is not None:
|
|
2335
|
+
fps = video_info.frames_count / video_info.duration
|
|
2336
|
+
n_frames = int(duration * fps)
|
|
2337
|
+
else:
|
|
2338
|
+
n_frames = video_info.frames_count
|
|
2339
|
+
|
|
2340
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2341
|
+
|
|
2342
|
+
logger.debug(
|
|
2343
|
+
f"Video info:",
|
|
2344
|
+
extra=dict(
|
|
2345
|
+
w=video_info.frame_width,
|
|
2346
|
+
h=video_info.frame_height,
|
|
2347
|
+
start_frame_index=start_frame_index,
|
|
2348
|
+
n_frames=n_frames,
|
|
2349
|
+
),
|
|
2350
|
+
)
|
|
2351
|
+
|
|
2352
|
+
# start downloading video in background
|
|
2353
|
+
self.cache.run_cache_task_manually(api, None, video_id=video_id)
|
|
2354
|
+
|
|
2355
|
+
progress_total = (n_frames + step - 1) // step
|
|
2356
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
2357
|
+
|
|
2358
|
+
_upload_f = partial(
|
|
2359
|
+
self.upload_predictions_to_video,
|
|
2360
|
+
api=api,
|
|
2361
|
+
video_info=video_info,
|
|
2362
|
+
track_id=track_id,
|
|
2363
|
+
context=inference_request.context,
|
|
2364
|
+
progress_cb=inference_request.done,
|
|
2365
|
+
inference_request=inference_request,
|
|
2366
|
+
)
|
|
2367
|
+
|
|
2368
|
+
_range = (start_frame_index, start_frame_index + direction * n_frames)
|
|
2369
|
+
if _range[0] > _range[1]:
|
|
2370
|
+
_range = (_range[1], _range[0])
|
|
2371
|
+
|
|
2372
|
+
def _notify_f(predictions: List[Prediction]):
|
|
2373
|
+
logger.debug(
|
|
2374
|
+
"Notifying tracking progress...",
|
|
2375
|
+
extra={
|
|
2376
|
+
"track_id": track_id,
|
|
2377
|
+
"range": _range,
|
|
2378
|
+
"current": inference_request.progress.current,
|
|
2379
|
+
"total": inference_request.progress.total,
|
|
2380
|
+
},
|
|
2381
|
+
)
|
|
2382
|
+
stopped = self.api.video.notify_progress(
|
|
2383
|
+
track_id=track_id,
|
|
2384
|
+
video_id=video_info.id,
|
|
2385
|
+
frame_start=_range[0],
|
|
2386
|
+
frame_end=_range[1],
|
|
2387
|
+
current=inference_request.progress.current,
|
|
2388
|
+
total=inference_request.progress.total,
|
|
2389
|
+
)
|
|
2390
|
+
if stopped:
|
|
2391
|
+
inference_request.stop()
|
|
2392
|
+
logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
|
|
2393
|
+
|
|
2394
|
+
def _exception_handler(e: Exception):
|
|
2395
|
+
self.api.video.notify_tracking_error(
|
|
2396
|
+
track_id=track_id,
|
|
2397
|
+
error=str(type(e)),
|
|
2398
|
+
message=str(e),
|
|
2399
|
+
)
|
|
2400
|
+
raise e
|
|
2401
|
+
|
|
2402
|
+
with Uploader(
|
|
2403
|
+
upload_f=_upload_f,
|
|
2404
|
+
notify_f=_notify_f,
|
|
2405
|
+
exception_handler=_exception_handler,
|
|
2406
|
+
logger=logger,
|
|
2407
|
+
) as uploader:
|
|
2408
|
+
for batch in batched(
|
|
2409
|
+
range(
|
|
2410
|
+
start_frame_index, start_frame_index + direction * n_frames, direction * step
|
|
2411
|
+
),
|
|
2412
|
+
batch_size,
|
|
2413
|
+
):
|
|
2414
|
+
if inference_request.is_stopped():
|
|
2415
|
+
logger.debug(
|
|
2416
|
+
f"Cancelling inference video...",
|
|
2417
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2418
|
+
)
|
|
2419
|
+
break
|
|
2420
|
+
logger.debug(
|
|
2421
|
+
f"Inferring frames {batch[0]}-{batch[-1]}:",
|
|
2422
|
+
)
|
|
2423
|
+
frames = self.cache.download_frames(
|
|
2424
|
+
api, video_info.id, batch, redownload_video=True
|
|
2425
|
+
)
|
|
2426
|
+
anns, slides_data = self._inference_auto(
|
|
2427
|
+
source=frames,
|
|
2428
|
+
settings=inference_settings,
|
|
2429
|
+
)
|
|
2430
|
+
|
|
2431
|
+
if inference_request.tracker is not None:
|
|
2432
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2433
|
+
|
|
2434
|
+
predictions = [
|
|
2435
|
+
Prediction(
|
|
2436
|
+
ann,
|
|
2437
|
+
model_meta=self.model_meta,
|
|
2438
|
+
frame_index=frame_index,
|
|
2439
|
+
video_id=video_info.id,
|
|
2440
|
+
dataset_id=video_info.dataset_id,
|
|
2441
|
+
project_id=video_info.project_id,
|
|
2442
|
+
)
|
|
2443
|
+
for ann, frame_index in zip(anns, batch)
|
|
2444
|
+
]
|
|
2445
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2446
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
2447
|
+
uploader.put(predictions)
|
|
2448
|
+
video_ann_json = None
|
|
2449
|
+
if inference_request.tracker is not None:
|
|
2239
2450
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2240
|
-
video_ann_json =
|
|
2451
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2241
2452
|
inference_request.done()
|
|
2242
2453
|
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2243
2454
|
return video_ann_json
|
|
2244
2455
|
|
|
2456
|
+
|
|
2245
2457
|
def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2246
2458
|
"""Inference project images.
|
|
2247
2459
|
If "output_project_id" in state, upload images and annotations to the output project.
|
|
@@ -2263,7 +2475,7 @@ class Inference:
|
|
|
2263
2475
|
upload_mode = state.get("upload_mode", None)
|
|
2264
2476
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
2265
2477
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
2266
|
-
iou_merge_threshold =
|
|
2478
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
|
|
2267
2479
|
cache_project_on_model = state.get("cache_project_on_model", False)
|
|
2268
2480
|
|
|
2269
2481
|
project_info = api.project.get_info_by_id(project_id)
|
|
@@ -2747,10 +2959,10 @@ class Inference:
|
|
|
2747
2959
|
context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
|
|
2748
2960
|
return created_dataset.id
|
|
2749
2961
|
|
|
2750
|
-
created_names = []
|
|
2751
2962
|
if context is None:
|
|
2752
2963
|
context = {}
|
|
2753
2964
|
for dataset_id, preds in ds_predictions.items():
|
|
2965
|
+
created_names = set()
|
|
2754
2966
|
if dst_project_id is not None:
|
|
2755
2967
|
# upload to the destination project
|
|
2756
2968
|
dst_dataset_id = _get_or_create_dataset(
|
|
@@ -2826,7 +3038,7 @@ class Inference:
|
|
|
2826
3038
|
with_annotations=False,
|
|
2827
3039
|
save_source_date=False,
|
|
2828
3040
|
)
|
|
2829
|
-
created_names.
|
|
3041
|
+
created_names.update([image_info.name for image_info in dst_image_infos])
|
|
2830
3042
|
api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
|
|
2831
3043
|
else:
|
|
2832
3044
|
# upload to the source dataset
|
|
@@ -2908,6 +3120,83 @@ class Inference:
|
|
|
2908
3120
|
inference_request.add_results(results)
|
|
2909
3121
|
inference_request.done(len(results))
|
|
2910
3122
|
|
|
3123
|
+
def upload_predictions_to_video(
|
|
3124
|
+
self,
|
|
3125
|
+
predictions: List[Prediction],
|
|
3126
|
+
api: Api,
|
|
3127
|
+
video_info: VideoInfo,
|
|
3128
|
+
track_id: str,
|
|
3129
|
+
context: Dict,
|
|
3130
|
+
progress_cb=None,
|
|
3131
|
+
inference_request: InferenceRequest = None,
|
|
3132
|
+
):
|
|
3133
|
+
key_id_map = KeyIdMap()
|
|
3134
|
+
project_meta = context.get("project_meta", None)
|
|
3135
|
+
if project_meta is None:
|
|
3136
|
+
project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
|
|
3137
|
+
context["project_meta"] = project_meta
|
|
3138
|
+
meta_changed = False
|
|
3139
|
+
for prediction in predictions:
|
|
3140
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
3141
|
+
project_meta, prediction.annotation, None
|
|
3142
|
+
)
|
|
3143
|
+
prediction.annotation = ann
|
|
3144
|
+
meta_changed = meta_changed or meta_changed_
|
|
3145
|
+
if meta_changed:
|
|
3146
|
+
project_meta = api.project.update_meta(video_info.project_id, project_meta)
|
|
3147
|
+
context["project_meta"] = project_meta
|
|
3148
|
+
|
|
3149
|
+
figure_data_by_object_id = defaultdict(list)
|
|
3150
|
+
|
|
3151
|
+
tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
|
|
3152
|
+
new_tracks: Dict[int, VideoObject] = {}
|
|
3153
|
+
for prediction in predictions:
|
|
3154
|
+
annotation = prediction.annotation
|
|
3155
|
+
tracks = annotation.custom_data
|
|
3156
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3157
|
+
if track not in tracks_to_object_ids and track not in new_tracks:
|
|
3158
|
+
video_object = VideoObject(obj_class=label.obj_class)
|
|
3159
|
+
new_tracks[track] = video_object
|
|
3160
|
+
if new_tracks:
|
|
3161
|
+
tracks, video_objects = zip(*new_tracks.items())
|
|
3162
|
+
added_object_ids = api.video.object.append_bulk(
|
|
3163
|
+
video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
|
|
3164
|
+
)
|
|
3165
|
+
for track, object_id in zip(tracks, added_object_ids):
|
|
3166
|
+
tracks_to_object_ids[track] = object_id
|
|
3167
|
+
for prediction in predictions:
|
|
3168
|
+
annotation = prediction.annotation
|
|
3169
|
+
tracks = annotation.custom_data
|
|
3170
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3171
|
+
object_id = tracks_to_object_ids[track]
|
|
3172
|
+
figure_data_by_object_id[object_id].append(
|
|
3173
|
+
{
|
|
3174
|
+
ApiField.OBJECT_ID: object_id,
|
|
3175
|
+
ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
|
|
3176
|
+
ApiField.GEOMETRY: label.geometry.to_json(),
|
|
3177
|
+
ApiField.META: {ApiField.FRAME: prediction.frame_index},
|
|
3178
|
+
ApiField.TRACK_ID: track_id,
|
|
3179
|
+
}
|
|
3180
|
+
)
|
|
3181
|
+
|
|
3182
|
+
for object_id, figures_data in figure_data_by_object_id.items():
|
|
3183
|
+
figures_keys = [uuid.uuid4() for _ in figures_data]
|
|
3184
|
+
api.video.figure._append_bulk(
|
|
3185
|
+
entity_id=video_info.id,
|
|
3186
|
+
figures_json=figures_data,
|
|
3187
|
+
figures_keys=figures_keys,
|
|
3188
|
+
key_id_map=key_id_map,
|
|
3189
|
+
)
|
|
3190
|
+
logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
|
|
3191
|
+
if progress_cb:
|
|
3192
|
+
progress_cb(len(predictions))
|
|
3193
|
+
if inference_request is not None:
|
|
3194
|
+
results = self._format_output(predictions)
|
|
3195
|
+
for result in results:
|
|
3196
|
+
result["annotation"] = None
|
|
3197
|
+
result["data"] = None
|
|
3198
|
+
inference_request.add_results(results)
|
|
3199
|
+
|
|
2911
3200
|
def serve(self):
|
|
2912
3201
|
if not self._use_gui and not self._is_cli_deploy:
|
|
2913
3202
|
Progress("Deploying model ...", 1)
|
|
@@ -3017,6 +3306,11 @@ class Inference:
|
|
|
3017
3306
|
def get_session_info(response: Response):
|
|
3018
3307
|
return self.get_info()
|
|
3019
3308
|
|
|
3309
|
+
@server.post("/get_tracking_settings")
|
|
3310
|
+
@self._check_serve_before_call
|
|
3311
|
+
def get_tracking_settings(response: Response):
|
|
3312
|
+
return self.get_tracking_settings()
|
|
3313
|
+
|
|
3020
3314
|
@server.post("/get_custom_inference_settings")
|
|
3021
3315
|
def get_custom_inference_settings():
|
|
3022
3316
|
return {"settings": self.custom_inference_settings}
|
|
@@ -3300,6 +3594,22 @@ class Inference:
|
|
|
3300
3594
|
"inference_request_uuid": inference_request.uuid,
|
|
3301
3595
|
}
|
|
3302
3596
|
|
|
3597
|
+
@server.post("/tracking_by_detection")
|
|
3598
|
+
def tracking_by_detection(response: Response, request: Request):
|
|
3599
|
+
state = request.state.state
|
|
3600
|
+
context = request.state.context
|
|
3601
|
+
state.update(context)
|
|
3602
|
+
if state.get("tracker") is None:
|
|
3603
|
+
state["tracker"] = "botsort"
|
|
3604
|
+
|
|
3605
|
+
logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
|
|
3606
|
+
self.validate_inference_state(state)
|
|
3607
|
+
api = self.api_from_request(request)
|
|
3608
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
3609
|
+
self._tracking_by_detection, api, state
|
|
3610
|
+
)
|
|
3611
|
+
return {"message": "Track task started."}
|
|
3612
|
+
|
|
3303
3613
|
@server.post("/inference_project_id_async")
|
|
3304
3614
|
def inference_project_id_async(response: Response, request: Request):
|
|
3305
3615
|
state = request.state.state
|
|
@@ -3363,10 +3673,7 @@ class Inference:
|
|
|
3363
3673
|
data = {**inference_request.to_json(), **log_extra}
|
|
3364
3674
|
if inference_request.stage != InferenceRequest.Stage.INFERENCE:
|
|
3365
3675
|
data["progress"] = {"current": 0, "total": 1}
|
|
3366
|
-
logger.debug(
|
|
3367
|
-
f"Sending inference progress with uuid:",
|
|
3368
|
-
extra=data,
|
|
3369
|
-
)
|
|
3676
|
+
logger.debug(f"Sending inference progress with uuid:", extra=data)
|
|
3370
3677
|
return data
|
|
3371
3678
|
|
|
3372
3679
|
@server.post(f"/pop_inference_results")
|
|
@@ -4223,10 +4530,10 @@ class Inference:
|
|
|
4223
4530
|
self._args.draw,
|
|
4224
4531
|
)
|
|
4225
4532
|
|
|
4226
|
-
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
|
|
4533
|
+
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
|
|
4227
4534
|
updated_anns = []
|
|
4228
4535
|
for frame, ann in zip(frames, anns):
|
|
4229
|
-
matches =
|
|
4536
|
+
matches = tracker.update(frame, ann)
|
|
4230
4537
|
track_ids = [match["track_id"] for match in matches]
|
|
4231
4538
|
tracked_labels = [match["label"] for match in matches]
|
|
4232
4539
|
|
|
@@ -4292,61 +4599,72 @@ class Inference:
|
|
|
4292
4599
|
def export_tensorrt(self, deploy_params: dict):
|
|
4293
4600
|
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4294
4601
|
|
|
4295
|
-
|
|
4296
|
-
|
|
4297
|
-
|
|
4298
|
-
dataset_id: int,
|
|
4299
|
-
gt_image_ids: List[int],
|
|
4300
|
-
iou: float = None,
|
|
4301
|
-
meta: Optional[ProjectMeta] = None,
|
|
4602
|
+
|
|
4603
|
+
def _filter_duplicated_predictions_from_ann_cpu(
|
|
4604
|
+
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
|
4302
4605
|
):
|
|
4303
4606
|
"""
|
|
4304
|
-
Filter out
|
|
4305
|
-
|
|
4306
|
-
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
4307
|
-
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
4308
|
-
- Gets ProjectMeta object if not provided
|
|
4309
|
-
- Downloads GT annotations for the specified image IDs
|
|
4310
|
-
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4607
|
+
Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
|
|
4608
|
+
Uses Shapely for geometric operations.
|
|
4311
4609
|
|
|
4312
|
-
:
|
|
4313
|
-
|
|
4314
|
-
|
|
4315
|
-
|
|
4316
|
-
:param dataset_id: ID of the dataset containing the images
|
|
4317
|
-
:type dataset_id: int
|
|
4318
|
-
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4319
|
-
:type gt_image_ids: List[int]
|
|
4320
|
-
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4321
|
-
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4322
|
-
:type iou: Optional[float]
|
|
4323
|
-
:param meta: ProjectMeta object
|
|
4324
|
-
:type meta: Optional[ProjectMeta]
|
|
4325
|
-
:return: List of Annotation objects containing filtered predictions
|
|
4326
|
-
:rtype: List[Annotation]
|
|
4610
|
+
Args:
|
|
4611
|
+
pred_ann: Predicted annotation object
|
|
4612
|
+
gt_ann: Ground truth annotation object
|
|
4613
|
+
iou_threshold: IoU threshold for filtering
|
|
4327
4614
|
|
|
4328
|
-
|
|
4329
|
-
|
|
4330
|
-
- Requires PyTorch and torchvision for IoU calculations
|
|
4331
|
-
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4615
|
+
Returns:
|
|
4616
|
+
New annotation with filtered labels
|
|
4332
4617
|
"""
|
|
4333
|
-
if
|
|
4334
|
-
|
|
4335
|
-
|
|
4336
|
-
|
|
4337
|
-
|
|
4338
|
-
|
|
4339
|
-
|
|
4340
|
-
|
|
4341
|
-
|
|
4342
|
-
|
|
4343
|
-
|
|
4344
|
-
|
|
4345
|
-
|
|
4346
|
-
|
|
4347
|
-
|
|
4348
|
-
|
|
4349
|
-
|
|
4618
|
+
if not iou_threshold:
|
|
4619
|
+
return pred_ann
|
|
4620
|
+
|
|
4621
|
+
from shapely.geometry import box
|
|
4622
|
+
|
|
4623
|
+
def calculate_iou(geom1: Geometry, geom2: Geometry):
|
|
4624
|
+
"""Calculate IoU between two geometries using Shapely."""
|
|
4625
|
+
bbox1 = geom1.to_bbox()
|
|
4626
|
+
bbox2 = geom2.to_bbox()
|
|
4627
|
+
|
|
4628
|
+
box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
|
|
4629
|
+
box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
|
|
4630
|
+
|
|
4631
|
+
intersection = box1.intersection(box2).area
|
|
4632
|
+
union = box1.union(box2).area
|
|
4633
|
+
|
|
4634
|
+
return intersection / union if union > 0 else 0.0
|
|
4635
|
+
|
|
4636
|
+
new_labels = []
|
|
4637
|
+
pred_cls_bboxes = defaultdict(list)
|
|
4638
|
+
for label in pred_ann.labels:
|
|
4639
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4640
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4641
|
+
|
|
4642
|
+
gt_cls_bboxes = defaultdict(list)
|
|
4643
|
+
for label in gt_ann.labels:
|
|
4644
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4645
|
+
if name_shape not in pred_cls_bboxes:
|
|
4646
|
+
continue
|
|
4647
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4648
|
+
|
|
4649
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4650
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4651
|
+
if len(gt) == 0:
|
|
4652
|
+
new_labels.extend(pred)
|
|
4653
|
+
continue
|
|
4654
|
+
|
|
4655
|
+
for pred_label in pred:
|
|
4656
|
+
# Check if this prediction has IoU < threshold with ALL GT boxes
|
|
4657
|
+
keep = True
|
|
4658
|
+
for gt_label in gt:
|
|
4659
|
+
iou = calculate_iou(pred_label.geometry, gt_label.geometry)
|
|
4660
|
+
if iou >= iou_threshold:
|
|
4661
|
+
keep = False
|
|
4662
|
+
break
|
|
4663
|
+
|
|
4664
|
+
if keep:
|
|
4665
|
+
new_labels.append(pred_label)
|
|
4666
|
+
|
|
4667
|
+
return pred_ann.clone(labels=new_labels)
|
|
4350
4668
|
|
|
4351
4669
|
|
|
4352
4670
|
def _filter_duplicated_predictions_from_ann(
|
|
@@ -4377,13 +4695,15 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4377
4695
|
- Predictions with classes not present in ground truth will be kept
|
|
4378
4696
|
- Requires PyTorch and torchvision for IoU calculations
|
|
4379
4697
|
"""
|
|
4698
|
+
if not iou_threshold:
|
|
4699
|
+
return pred_ann
|
|
4380
4700
|
|
|
4381
4701
|
try:
|
|
4382
4702
|
import torch
|
|
4383
4703
|
from torchvision.ops import box_iou
|
|
4384
4704
|
|
|
4385
4705
|
except ImportError:
|
|
4386
|
-
|
|
4706
|
+
return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
|
|
4387
4707
|
|
|
4388
4708
|
def _to_tensor(geom):
|
|
4389
4709
|
return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
|
|
@@ -4391,16 +4711,18 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4391
4711
|
new_labels = []
|
|
4392
4712
|
pred_cls_bboxes = defaultdict(list)
|
|
4393
4713
|
for label in pred_ann.labels:
|
|
4394
|
-
|
|
4714
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4715
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4395
4716
|
|
|
4396
4717
|
gt_cls_bboxes = defaultdict(list)
|
|
4397
4718
|
for label in gt_ann.labels:
|
|
4398
|
-
|
|
4719
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4720
|
+
if name_shape not in pred_cls_bboxes:
|
|
4399
4721
|
continue
|
|
4400
|
-
gt_cls_bboxes[
|
|
4722
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4401
4723
|
|
|
4402
|
-
for
|
|
4403
|
-
gt = gt_cls_bboxes[
|
|
4724
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4725
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4404
4726
|
if len(gt) == 0:
|
|
4405
4727
|
new_labels.extend(pred)
|
|
4406
4728
|
continue
|
|
@@ -4414,6 +4736,63 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4414
4736
|
return pred_ann.clone(labels=new_labels)
|
|
4415
4737
|
|
|
4416
4738
|
|
|
4739
|
+
def _exclude_duplicated_predictions(
|
|
4740
|
+
api: Api,
|
|
4741
|
+
pred_anns: List[Annotation],
|
|
4742
|
+
dataset_id: int,
|
|
4743
|
+
gt_image_ids: List[int],
|
|
4744
|
+
iou: float = None,
|
|
4745
|
+
meta: Optional[ProjectMeta] = None,
|
|
4746
|
+
):
|
|
4747
|
+
"""
|
|
4748
|
+
Filter out predictions that significantly overlap with ground truth (GT) objects.
|
|
4749
|
+
|
|
4750
|
+
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
4751
|
+
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
4752
|
+
- Gets ProjectMeta object if not provided
|
|
4753
|
+
- Downloads GT annotations for the specified image IDs
|
|
4754
|
+
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4755
|
+
|
|
4756
|
+
:param api: Supervisely API object
|
|
4757
|
+
:type api: Api
|
|
4758
|
+
:param pred_anns: List of Annotation objects containing predictions
|
|
4759
|
+
:type pred_anns: List[Annotation]
|
|
4760
|
+
:param dataset_id: ID of the dataset containing the images
|
|
4761
|
+
:type dataset_id: int
|
|
4762
|
+
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4763
|
+
:type gt_image_ids: List[int]
|
|
4764
|
+
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4765
|
+
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4766
|
+
:type iou: Optional[float]
|
|
4767
|
+
:param meta: ProjectMeta object
|
|
4768
|
+
:type meta: Optional[ProjectMeta]
|
|
4769
|
+
:return: List of Annotation objects containing filtered predictions
|
|
4770
|
+
:rtype: List[Annotation]
|
|
4771
|
+
|
|
4772
|
+
Notes:
|
|
4773
|
+
------
|
|
4774
|
+
- Requires PyTorch and torchvision for IoU calculations
|
|
4775
|
+
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4776
|
+
"""
|
|
4777
|
+
if isinstance(iou, float) and 0 < iou <= 1:
|
|
4778
|
+
if meta is None:
|
|
4779
|
+
ds = api.dataset.get_info_by_id(dataset_id)
|
|
4780
|
+
meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
|
|
4781
|
+
gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
|
|
4782
|
+
gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
|
|
4783
|
+
for i in range(0, len(pred_anns)):
|
|
4784
|
+
before = len(pred_anns[i].labels)
|
|
4785
|
+
with Timer() as timer:
|
|
4786
|
+
pred_anns[i] = _filter_duplicated_predictions_from_ann(
|
|
4787
|
+
gt_anns[i], pred_anns[i], iou
|
|
4788
|
+
)
|
|
4789
|
+
after = len(pred_anns[i].labels)
|
|
4790
|
+
logger.debug(
|
|
4791
|
+
f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
|
|
4792
|
+
)
|
|
4793
|
+
return pred_anns
|
|
4794
|
+
|
|
4795
|
+
|
|
4417
4796
|
def _get_log_extra_for_inference_request(
|
|
4418
4797
|
inference_request_uuid, inference_request: Union[InferenceRequest, dict]
|
|
4419
4798
|
):
|
|
@@ -4440,8 +4819,8 @@ def _get_log_extra_for_inference_request(
|
|
|
4440
4819
|
"has_result": inference_request.final_result is not None,
|
|
4441
4820
|
"pending_results": inference_request.pending_num(),
|
|
4442
4821
|
"exception": inference_request.exception_json(),
|
|
4443
|
-
"result": inference_request._final_result,
|
|
4444
4822
|
"preparing_progress": progress,
|
|
4823
|
+
"result": inference_request.final_result is not None, # for backward compatibility
|
|
4445
4824
|
}
|
|
4446
4825
|
return log_extra
|
|
4447
4826
|
|
|
@@ -4521,7 +4900,7 @@ def get_gpu_count():
|
|
|
4521
4900
|
gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
|
|
4522
4901
|
return gpu_count
|
|
4523
4902
|
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
|
4524
|
-
logger.
|
|
4903
|
+
logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
|
|
4525
4904
|
return 0
|
|
4526
4905
|
|
|
4527
4906
|
|
|
@@ -4701,7 +5080,180 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suf
|
|
|
4701
5080
|
img_tags = None
|
|
4702
5081
|
if not any_label_updated:
|
|
4703
5082
|
labels = None
|
|
4704
|
-
ann = ann.clone(img_tags=
|
|
5083
|
+
ann = ann.clone(img_tags=img_tags)
|
|
5084
|
+
return meta, ann, meta_changed
|
|
5085
|
+
|
|
5086
|
+
|
|
5087
|
+
def update_meta_and_ann_for_video_annotation(
|
|
5088
|
+
meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
|
|
5089
|
+
):
|
|
5090
|
+
"""Update project meta and annotation to match each other
|
|
5091
|
+
If obj class or tag meta from annotation conflicts with project meta
|
|
5092
|
+
add suffix to obj class or tag meta.
|
|
5093
|
+
Return tuple of updated project meta, annotation and boolean flag if meta was changed.
|
|
5094
|
+
"""
|
|
5095
|
+
obj_classes_suffixes = ["_nn"]
|
|
5096
|
+
tag_meta_suffixes = ["_nn"]
|
|
5097
|
+
if model_prediction_suffix is not None:
|
|
5098
|
+
obj_classes_suffixes = [model_prediction_suffix]
|
|
5099
|
+
tag_meta_suffixes = [model_prediction_suffix]
|
|
5100
|
+
logger.debug(
|
|
5101
|
+
f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
|
|
5102
|
+
)
|
|
5103
|
+
logger.debug("source meta", extra={"meta": meta.to_json()})
|
|
5104
|
+
meta_changed = False
|
|
5105
|
+
|
|
5106
|
+
# meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
|
5107
|
+
# if replaced_classes_in_meta:
|
|
5108
|
+
# meta_changed = True
|
|
5109
|
+
# logger.warning(
|
|
5110
|
+
# "Some classes names were fixed in project meta",
|
|
5111
|
+
# extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
|
5112
|
+
# )
|
|
5113
|
+
|
|
5114
|
+
new_objects: List[VideoObject] = []
|
|
5115
|
+
new_figures: List[VideoFigure] = []
|
|
5116
|
+
any_object_updated = False
|
|
5117
|
+
for video_object in ann.objects:
|
|
5118
|
+
this_object_figures = [
|
|
5119
|
+
figure for figure in ann.figures if figure.video_object.key() == video_object.key()
|
|
5120
|
+
]
|
|
5121
|
+
this_object_changed = False
|
|
5122
|
+
original_obj_class_name = video_object.obj_class.name
|
|
5123
|
+
suffix_found = False
|
|
5124
|
+
for suffix in ["", *obj_classes_suffixes]:
|
|
5125
|
+
obj_class = video_object.obj_class
|
|
5126
|
+
obj_class_name = obj_class.name + suffix
|
|
5127
|
+
if suffix:
|
|
5128
|
+
obj_class = obj_class.clone(name=obj_class_name)
|
|
5129
|
+
video_object = video_object.clone(obj_class=obj_class)
|
|
5130
|
+
any_object_updated = True
|
|
5131
|
+
this_object_changed = True
|
|
5132
|
+
meta_obj_class = meta.get_obj_class(obj_class_name)
|
|
5133
|
+
if meta_obj_class is None:
|
|
5134
|
+
# obj class is not in meta, add it with suffix
|
|
5135
|
+
meta = meta.add_obj_class(obj_class)
|
|
5136
|
+
new_objects.append(video_object)
|
|
5137
|
+
meta_changed = True
|
|
5138
|
+
suffix_found = True
|
|
5139
|
+
break
|
|
5140
|
+
elif (
|
|
5141
|
+
meta_obj_class.geometry_type.geometry_name()
|
|
5142
|
+
== video_object.obj_class.geometry_type.geometry_name()
|
|
5143
|
+
):
|
|
5144
|
+
# if object geometry is the same as in meta, use meta obj class
|
|
5145
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5146
|
+
new_objects.append(video_object)
|
|
5147
|
+
suffix_found = True
|
|
5148
|
+
any_object_updated = True
|
|
5149
|
+
this_object_changed = True
|
|
5150
|
+
break
|
|
5151
|
+
elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
|
|
5152
|
+
# if meta obj class is AnyGeometry, use it in object
|
|
5153
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5154
|
+
new_objects.append(video_object)
|
|
5155
|
+
suffix_found = True
|
|
5156
|
+
any_object_updated = True
|
|
5157
|
+
this_object_changed = True
|
|
5158
|
+
break
|
|
5159
|
+
if not suffix_found:
|
|
5160
|
+
# if no suffix found, raise error
|
|
5161
|
+
raise ValueError(
|
|
5162
|
+
f"Can't add obj class {original_obj_class_name} to project meta. "
|
|
5163
|
+
"Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
|
|
5164
|
+
"Please check if model geometry type is compatible with existing obj classes."
|
|
5165
|
+
)
|
|
5166
|
+
elif this_object_changed:
|
|
5167
|
+
this_object_figures = [
|
|
5168
|
+
figure.clone(video_object=video_object) for figure in this_object_figures
|
|
5169
|
+
]
|
|
5170
|
+
new_figures.extend(this_object_figures)
|
|
5171
|
+
if any_object_updated:
|
|
5172
|
+
frames_figures = {}
|
|
5173
|
+
for figure in new_figures:
|
|
5174
|
+
frames_figures.setdefault(figure.frame_index, []).append(figure)
|
|
5175
|
+
new_frames = FrameCollection(
|
|
5176
|
+
[
|
|
5177
|
+
Frame(index=frame_index, figures=figures)
|
|
5178
|
+
for frame_index, figures in frames_figures.items()
|
|
5179
|
+
]
|
|
5180
|
+
)
|
|
5181
|
+
ann = ann.clone(objects=new_objects, frames=new_frames)
|
|
5182
|
+
|
|
5183
|
+
# check if tag metas are in project meta
|
|
5184
|
+
# if not, add them with suffix
|
|
5185
|
+
ann_tag_metas: Dict[str, TagMeta] = {}
|
|
5186
|
+
for video_object in ann.objects:
|
|
5187
|
+
for tag in video_object.tags:
|
|
5188
|
+
tag_name = tag.meta.name
|
|
5189
|
+
if tag_name not in ann_tag_metas:
|
|
5190
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5191
|
+
for tag in ann.tags:
|
|
5192
|
+
tag_name = tag.meta.name
|
|
5193
|
+
if tag_name not in ann_tag_metas:
|
|
5194
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5195
|
+
|
|
5196
|
+
changed_tag_metas = {}
|
|
5197
|
+
for ann_tag_meta in ann_tag_metas.values():
|
|
5198
|
+
meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
|
|
5199
|
+
if meta_tag_meta is None:
|
|
5200
|
+
meta = meta.add_tag_meta(ann_tag_meta)
|
|
5201
|
+
meta_changed = True
|
|
5202
|
+
elif not meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5203
|
+
suffix_found = False
|
|
5204
|
+
for suffix in tag_meta_suffixes:
|
|
5205
|
+
new_tag_meta_name = ann_tag_meta.name + suffix
|
|
5206
|
+
meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
|
|
5207
|
+
if meta_tag_meta is None:
|
|
5208
|
+
new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
|
|
5209
|
+
meta = meta.add_tag_meta(new_tag_meta)
|
|
5210
|
+
changed_tag_metas[ann_tag_meta.name] = new_tag_meta
|
|
5211
|
+
meta_changed = True
|
|
5212
|
+
suffix_found = True
|
|
5213
|
+
break
|
|
5214
|
+
if meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5215
|
+
changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
|
|
5216
|
+
suffix_found = True
|
|
5217
|
+
break
|
|
5218
|
+
if not suffix_found:
|
|
5219
|
+
raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
|
|
5220
|
+
|
|
5221
|
+
if changed_tag_metas:
|
|
5222
|
+
objects = []
|
|
5223
|
+
any_object_updated = False
|
|
5224
|
+
for video_object in ann.objects:
|
|
5225
|
+
any_tag_updated = False
|
|
5226
|
+
object_tags = []
|
|
5227
|
+
for tag in video_object.tags:
|
|
5228
|
+
if tag.meta.name in changed_tag_metas:
|
|
5229
|
+
object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5230
|
+
any_tag_updated = True
|
|
5231
|
+
else:
|
|
5232
|
+
object_tags.append(tag)
|
|
5233
|
+
if any_tag_updated:
|
|
5234
|
+
video_object = video_object.clone(tags=TagCollection(object_tags))
|
|
5235
|
+
any_object_updated = True
|
|
5236
|
+
objects.append(video_object)
|
|
5237
|
+
|
|
5238
|
+
video_tags = []
|
|
5239
|
+
any_tag_updated = False
|
|
5240
|
+
for tag in ann.tags:
|
|
5241
|
+
if tag.meta.name in changed_tag_metas:
|
|
5242
|
+
video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5243
|
+
any_tag_updated = True
|
|
5244
|
+
else:
|
|
5245
|
+
video_tags.append(tag)
|
|
5246
|
+
if any_tag_updated or any_object_updated:
|
|
5247
|
+
if any_tag_updated:
|
|
5248
|
+
video_tags = VideoTagCollection(video_tags)
|
|
5249
|
+
else:
|
|
5250
|
+
video_tags = None
|
|
5251
|
+
if any_object_updated:
|
|
5252
|
+
objects = VideoObjectCollection(objects)
|
|
5253
|
+
else:
|
|
5254
|
+
objects = None
|
|
5255
|
+
ann = ann.clone(tags=video_tags, objects=objects)
|
|
5256
|
+
|
|
4705
5257
|
return meta, ann, meta_changed
|
|
4706
5258
|
|
|
4707
5259
|
|
|
@@ -4815,7 +5367,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
|
|
|
4815
5367
|
return data[key]
|
|
4816
5368
|
return None
|
|
4817
5369
|
|
|
4818
|
-
|
|
5370
|
+
|
|
5371
|
+
def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
|
|
4819
5372
|
import torch # pylint: disable=import-error
|
|
4820
5373
|
|
|
4821
5374
|
# TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
|