supervisely 6.73.456__py3-none-any.whl → 6.73.458__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +24 -1
- supervisely/api/image_api.py +4 -0
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +41 -1
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +45 -11
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +63 -7
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/nn/inference/cache.py +2 -2
- supervisely/nn/inference/inference.py +364 -73
- supervisely/nn/inference/inference_request.py +3 -2
- supervisely/nn/inference/predict_app/gui/classes_selector.py +81 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +178 -25
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort_tracker.py +14 -7
- supervisely/nn/tracker/visualize.py +70 -72
- supervisely/video/video.py +15 -1
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/METADATA +3 -2
- {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/RECORD +41 -41
- {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/LICENSE +0 -0
- {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/WHEEL +0 -0
- {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.456.dist-info → supervisely-6.73.458.dist-info}/top_level.txt +0 -0
@@ -67,6 +67,7 @@ from supervisely.decorators.inference import (
|
|
67
67
|
process_images_batch_sliding_window,
|
68
68
|
)
|
69
69
|
from supervisely.geometry.any_geometry import AnyGeometry
|
70
|
+
from supervisely.geometry.geometry import Geometry
|
70
71
|
from supervisely.imaging.color import get_predefined_colors
|
71
72
|
from supervisely.io.fs import list_files
|
72
73
|
from supervisely.nn.experiments import ExperimentInfo
|
@@ -94,6 +95,13 @@ from supervisely.project.project_meta import ProjectMeta
|
|
94
95
|
from supervisely.sly_logger import logger
|
95
96
|
from supervisely.task.progress import Progress
|
96
97
|
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
|
98
|
+
from supervisely.video_annotation.frame import Frame
|
99
|
+
from supervisely.video_annotation.frame_collection import FrameCollection
|
100
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
101
|
+
from supervisely.video_annotation.video_figure import VideoFigure
|
102
|
+
from supervisely.video_annotation.video_object import VideoObject
|
103
|
+
from supervisely.video_annotation.video_object_collection import VideoObjectCollection
|
104
|
+
from supervisely.video_annotation.video_tag_collection import VideoTagCollection
|
97
105
|
|
98
106
|
try:
|
99
107
|
from typing import Literal
|
@@ -140,6 +148,7 @@ class Inference:
|
|
140
148
|
"""Default batch size for inference"""
|
141
149
|
INFERENCE_SETTINGS: str = None
|
142
150
|
"""Path to file with custom inference settings"""
|
151
|
+
DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
|
143
152
|
|
144
153
|
def __init__(
|
145
154
|
self,
|
@@ -193,7 +202,6 @@ class Inference:
|
|
193
202
|
self._task_id = None
|
194
203
|
self._sliding_window_mode = sliding_window_mode
|
195
204
|
self._autostart_delay_time = 5 * 60 # 5 min
|
196
|
-
self._tracker = None
|
197
205
|
self._hardware: str = None
|
198
206
|
if custom_inference_settings is None:
|
199
207
|
if self.INFERENCE_SETTINGS is not None:
|
@@ -1401,6 +1409,41 @@ class Inference:
|
|
1401
1409
|
|
1402
1410
|
# pylint: enable=method-hidden
|
1403
1411
|
|
1412
|
+
def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
|
1413
|
+
"""
|
1414
|
+
Get default parameters for all available tracking algorithms.
|
1415
|
+
|
1416
|
+
Returns:
|
1417
|
+
{"botsort": {"track_high_thresh": 0.6, ...}}
|
1418
|
+
Empty dict if tracking not supported.
|
1419
|
+
"""
|
1420
|
+
info = self.get_info()
|
1421
|
+
trackers_params = {}
|
1422
|
+
|
1423
|
+
tracking_support = info.get("tracking_on_videos_support")
|
1424
|
+
if not tracking_support:
|
1425
|
+
return trackers_params
|
1426
|
+
|
1427
|
+
tracking_algorithms = info.get("tracking_algorithms", [])
|
1428
|
+
|
1429
|
+
for tracker_name in tracking_algorithms:
|
1430
|
+
try:
|
1431
|
+
if tracker_name == "botsort":
|
1432
|
+
from supervisely.nn.tracker import BotSortTracker
|
1433
|
+
trackers_params[tracker_name] = BotSortTracker.get_default_params()
|
1434
|
+
# Add other trackers here as elif blocks
|
1435
|
+
else:
|
1436
|
+
logger.debug(f"Tracker '{tracker_name}' not implemented")
|
1437
|
+
except Exception as e:
|
1438
|
+
logger.warning(f"Failed to get params for '{tracker_name}': {e}")
|
1439
|
+
|
1440
|
+
INTERNAL_FIELDS = {"device", "fps"}
|
1441
|
+
for tracker_name, params in trackers_params.items():
|
1442
|
+
trackers_params[tracker_name] = {
|
1443
|
+
k: v for k, v in params.items() if k not in INTERNAL_FIELDS
|
1444
|
+
}
|
1445
|
+
return trackers_params
|
1446
|
+
|
1404
1447
|
def get_human_readable_info(self, replace_none_with: Optional[str] = None):
|
1405
1448
|
hr_info = {}
|
1406
1449
|
info = self.get_info()
|
@@ -1952,7 +1995,7 @@ class Inference:
|
|
1952
1995
|
else:
|
1953
1996
|
n_frames = frames_reader.frames_count()
|
1954
1997
|
|
1955
|
-
|
1998
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
1956
1999
|
|
1957
2000
|
progress_total = (n_frames + step - 1) // step
|
1958
2001
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
@@ -1978,8 +2021,8 @@ class Inference:
|
|
1978
2021
|
settings=inference_settings,
|
1979
2022
|
)
|
1980
2023
|
|
1981
|
-
if
|
1982
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
2024
|
+
if inference_request.tracker is not None:
|
2025
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
1983
2026
|
|
1984
2027
|
predictions = [
|
1985
2028
|
Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
|
@@ -1994,10 +2037,9 @@ class Inference:
|
|
1994
2037
|
inference_request.done(len(batch_results))
|
1995
2038
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
1996
2039
|
video_ann_json = None
|
1997
|
-
if
|
2040
|
+
if inference_request.tracker is not None:
|
1998
2041
|
inference_request.set_stage("Postprocess...", 0, 1)
|
1999
|
-
|
2000
|
-
video_ann_json = self._tracker.video_annotation.to_json()
|
2042
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
2001
2043
|
inference_request.done()
|
2002
2044
|
result = {"ann": results, "video_ann": video_ann_json}
|
2003
2045
|
inference_request.final_result = result.copy()
|
@@ -2029,7 +2071,7 @@ class Inference:
|
|
2029
2071
|
upload_mode = state.get("upload_mode", None)
|
2030
2072
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
2031
2073
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
2032
|
-
iou_merge_threshold = 0.
|
2074
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
|
2033
2075
|
|
2034
2076
|
images_infos = api.image.get_info_by_id_batch(image_ids)
|
2035
2077
|
images_infos_dict = {im_info.id: im_info for im_info in images_infos}
|
@@ -2181,7 +2223,7 @@ class Inference:
|
|
2181
2223
|
else:
|
2182
2224
|
n_frames = video_info.frames_count
|
2183
2225
|
|
2184
|
-
|
2226
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
2185
2227
|
|
2186
2228
|
logger.debug(
|
2187
2229
|
f"Video info:",
|
@@ -2218,8 +2260,8 @@ class Inference:
|
|
2218
2260
|
settings=inference_settings,
|
2219
2261
|
)
|
2220
2262
|
|
2221
|
-
if
|
2222
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
2263
|
+
if inference_request.tracker is not None:
|
2264
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
2223
2265
|
|
2224
2266
|
predictions = [
|
2225
2267
|
Prediction(
|
@@ -2240,9 +2282,9 @@ class Inference:
|
|
2240
2282
|
inference_request.done(len(batch_results))
|
2241
2283
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
2242
2284
|
video_ann_json = None
|
2243
|
-
if
|
2285
|
+
if inference_request.tracker is not None:
|
2244
2286
|
inference_request.set_stage("Postprocess...", 0, 1)
|
2245
|
-
video_ann_json =
|
2287
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
2246
2288
|
inference_request.done()
|
2247
2289
|
inference_request.final_result = {"video_ann": video_ann_json}
|
2248
2290
|
return video_ann_json
|
@@ -2268,7 +2310,7 @@ class Inference:
|
|
2268
2310
|
upload_mode = state.get("upload_mode", None)
|
2269
2311
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
2270
2312
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
2271
|
-
iou_merge_threshold =
|
2313
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
|
2272
2314
|
cache_project_on_model = state.get("cache_project_on_model", False)
|
2273
2315
|
|
2274
2316
|
project_info = api.project.get_info_by_id(project_id)
|
@@ -3022,6 +3064,11 @@ class Inference:
|
|
3022
3064
|
def get_session_info(response: Response):
|
3023
3065
|
return self.get_info()
|
3024
3066
|
|
3067
|
+
@server.post("/get_tracking_settings")
|
3068
|
+
@self._check_serve_before_call
|
3069
|
+
def get_tracking_settings(response: Response):
|
3070
|
+
return self.get_tracking_settings()
|
3071
|
+
|
3025
3072
|
@server.post("/get_custom_inference_settings")
|
3026
3073
|
def get_custom_inference_settings():
|
3027
3074
|
return {"settings": self.custom_inference_settings}
|
@@ -4228,10 +4275,10 @@ class Inference:
|
|
4228
4275
|
self._args.draw,
|
4229
4276
|
)
|
4230
4277
|
|
4231
|
-
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
|
4278
|
+
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
|
4232
4279
|
updated_anns = []
|
4233
4280
|
for frame, ann in zip(frames, anns):
|
4234
|
-
matches =
|
4281
|
+
matches = tracker.update(frame, ann)
|
4235
4282
|
track_ids = [match["track_id"] for match in matches]
|
4236
4283
|
tracked_labels = [match["label"] for match in matches]
|
4237
4284
|
|
@@ -4297,62 +4344,72 @@ class Inference:
|
|
4297
4344
|
def export_tensorrt(self, deploy_params: dict):
|
4298
4345
|
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
4299
4346
|
|
4300
|
-
|
4301
|
-
|
4302
|
-
|
4303
|
-
dataset_id: int,
|
4304
|
-
gt_image_ids: List[int],
|
4305
|
-
iou: float = None,
|
4306
|
-
meta: Optional[ProjectMeta] = None,
|
4347
|
+
|
4348
|
+
def _filter_duplicated_predictions_from_ann_cpu(
|
4349
|
+
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
4307
4350
|
):
|
4308
4351
|
"""
|
4309
|
-
Filter out
|
4352
|
+
Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
|
4353
|
+
Uses Shapely for geometric operations.
|
4310
4354
|
|
4311
|
-
|
4312
|
-
|
4313
|
-
|
4314
|
-
|
4315
|
-
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
4355
|
+
Args:
|
4356
|
+
pred_ann: Predicted annotation object
|
4357
|
+
gt_ann: Ground truth annotation object
|
4358
|
+
iou_threshold: IoU threshold for filtering
|
4316
4359
|
|
4317
|
-
:
|
4318
|
-
|
4319
|
-
:param pred_anns: List of Annotation objects containing predictions
|
4320
|
-
:type pred_anns: List[Annotation]
|
4321
|
-
:param dataset_id: ID of the dataset containing the images
|
4322
|
-
:type dataset_id: int
|
4323
|
-
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
4324
|
-
:type gt_image_ids: List[int]
|
4325
|
-
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
4326
|
-
ground truth box of the same class will be removed. None if no filtering is needed
|
4327
|
-
:type iou: Optional[float]
|
4328
|
-
:param meta: ProjectMeta object
|
4329
|
-
:type meta: Optional[ProjectMeta]
|
4330
|
-
:return: List of Annotation objects containing filtered predictions
|
4331
|
-
:rtype: List[Annotation]
|
4332
|
-
|
4333
|
-
Notes:
|
4334
|
-
------
|
4335
|
-
- Requires PyTorch and torchvision for IoU calculations
|
4336
|
-
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
4360
|
+
Returns:
|
4361
|
+
New annotation with filtered labels
|
4337
4362
|
"""
|
4338
|
-
if
|
4339
|
-
|
4340
|
-
|
4341
|
-
|
4342
|
-
|
4343
|
-
|
4344
|
-
|
4345
|
-
|
4346
|
-
|
4347
|
-
|
4348
|
-
|
4349
|
-
|
4350
|
-
|
4351
|
-
|
4352
|
-
|
4353
|
-
|
4354
|
-
|
4363
|
+
if not iou_threshold:
|
4364
|
+
return pred_ann
|
4365
|
+
|
4366
|
+
from shapely.geometry import box
|
4367
|
+
|
4368
|
+
def calculate_iou(geom1: Geometry, geom2: Geometry):
|
4369
|
+
"""Calculate IoU between two geometries using Shapely."""
|
4370
|
+
bbox1 = geom1.to_bbox()
|
4371
|
+
bbox2 = geom2.to_bbox()
|
4372
|
+
|
4373
|
+
box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
|
4374
|
+
box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
|
4375
|
+
|
4376
|
+
intersection = box1.intersection(box2).area
|
4377
|
+
union = box1.union(box2).area
|
4378
|
+
|
4379
|
+
return intersection / union if union > 0 else 0.0
|
4380
|
+
|
4381
|
+
new_labels = []
|
4382
|
+
pred_cls_bboxes = defaultdict(list)
|
4383
|
+
for label in pred_ann.labels:
|
4384
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
4385
|
+
pred_cls_bboxes[name_shape].append(label)
|
4386
|
+
|
4387
|
+
gt_cls_bboxes = defaultdict(list)
|
4388
|
+
for label in gt_ann.labels:
|
4389
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
4390
|
+
if name_shape not in pred_cls_bboxes:
|
4391
|
+
continue
|
4392
|
+
gt_cls_bboxes[name_shape].append(label)
|
4393
|
+
|
4394
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
4395
|
+
gt = gt_cls_bboxes[name_shape]
|
4396
|
+
if len(gt) == 0:
|
4397
|
+
new_labels.extend(pred)
|
4398
|
+
continue
|
4399
|
+
|
4400
|
+
for pred_label in pred:
|
4401
|
+
# Check if this prediction has IoU < threshold with ALL GT boxes
|
4402
|
+
keep = True
|
4403
|
+
for gt_label in gt:
|
4404
|
+
iou = calculate_iou(pred_label.geometry, gt_label.geometry)
|
4405
|
+
if iou >= iou_threshold:
|
4406
|
+
keep = False
|
4407
|
+
break
|
4408
|
+
|
4409
|
+
if keep:
|
4410
|
+
new_labels.append(pred_label)
|
4355
4411
|
|
4412
|
+
return pred_ann.clone(labels=new_labels)
|
4356
4413
|
|
4357
4414
|
def _filter_duplicated_predictions_from_ann(
|
4358
4415
|
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
@@ -4382,13 +4439,15 @@ def _filter_duplicated_predictions_from_ann(
|
|
4382
4439
|
- Predictions with classes not present in ground truth will be kept
|
4383
4440
|
- Requires PyTorch and torchvision for IoU calculations
|
4384
4441
|
"""
|
4442
|
+
if not iou_threshold:
|
4443
|
+
return pred_ann
|
4385
4444
|
|
4386
4445
|
try:
|
4387
4446
|
import torch
|
4388
4447
|
from torchvision.ops import box_iou
|
4389
4448
|
|
4390
4449
|
except ImportError:
|
4391
|
-
|
4450
|
+
return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
|
4392
4451
|
|
4393
4452
|
def _to_tensor(geom):
|
4394
4453
|
return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
|
@@ -4396,16 +4455,18 @@ def _filter_duplicated_predictions_from_ann(
|
|
4396
4455
|
new_labels = []
|
4397
4456
|
pred_cls_bboxes = defaultdict(list)
|
4398
4457
|
for label in pred_ann.labels:
|
4399
|
-
|
4458
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
4459
|
+
pred_cls_bboxes[name_shape].append(label)
|
4400
4460
|
|
4401
4461
|
gt_cls_bboxes = defaultdict(list)
|
4402
4462
|
for label in gt_ann.labels:
|
4403
|
-
|
4463
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
4464
|
+
if name_shape not in pred_cls_bboxes:
|
4404
4465
|
continue
|
4405
|
-
gt_cls_bboxes[
|
4466
|
+
gt_cls_bboxes[name_shape].append(label)
|
4406
4467
|
|
4407
|
-
for
|
4408
|
-
gt = gt_cls_bboxes[
|
4468
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
4469
|
+
gt = gt_cls_bboxes[name_shape]
|
4409
4470
|
if len(gt) == 0:
|
4410
4471
|
new_labels.extend(pred)
|
4411
4472
|
continue
|
@@ -4419,6 +4480,63 @@ def _filter_duplicated_predictions_from_ann(
|
|
4419
4480
|
return pred_ann.clone(labels=new_labels)
|
4420
4481
|
|
4421
4482
|
|
4483
|
+
def _exclude_duplicated_predictions(
|
4484
|
+
api: Api,
|
4485
|
+
pred_anns: List[Annotation],
|
4486
|
+
dataset_id: int,
|
4487
|
+
gt_image_ids: List[int],
|
4488
|
+
iou: float = None,
|
4489
|
+
meta: Optional[ProjectMeta] = None,
|
4490
|
+
):
|
4491
|
+
"""
|
4492
|
+
Filter out predictions that significantly overlap with ground truth (GT) objects.
|
4493
|
+
|
4494
|
+
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
4495
|
+
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
4496
|
+
- Gets ProjectMeta object if not provided
|
4497
|
+
- Downloads GT annotations for the specified image IDs
|
4498
|
+
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
4499
|
+
|
4500
|
+
:param api: Supervisely API object
|
4501
|
+
:type api: Api
|
4502
|
+
:param pred_anns: List of Annotation objects containing predictions
|
4503
|
+
:type pred_anns: List[Annotation]
|
4504
|
+
:param dataset_id: ID of the dataset containing the images
|
4505
|
+
:type dataset_id: int
|
4506
|
+
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
4507
|
+
:type gt_image_ids: List[int]
|
4508
|
+
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
4509
|
+
ground truth box of the same class will be removed. None if no filtering is needed
|
4510
|
+
:type iou: Optional[float]
|
4511
|
+
:param meta: ProjectMeta object
|
4512
|
+
:type meta: Optional[ProjectMeta]
|
4513
|
+
:return: List of Annotation objects containing filtered predictions
|
4514
|
+
:rtype: List[Annotation]
|
4515
|
+
|
4516
|
+
Notes:
|
4517
|
+
------
|
4518
|
+
- Requires PyTorch and torchvision for IoU calculations
|
4519
|
+
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
4520
|
+
"""
|
4521
|
+
if isinstance(iou, float) and 0 < iou <= 1:
|
4522
|
+
if meta is None:
|
4523
|
+
ds = api.dataset.get_info_by_id(dataset_id)
|
4524
|
+
meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
|
4525
|
+
gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
|
4526
|
+
gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
|
4527
|
+
for i in range(0, len(pred_anns)):
|
4528
|
+
before = len(pred_anns[i].labels)
|
4529
|
+
with Timer() as timer:
|
4530
|
+
pred_anns[i] = _filter_duplicated_predictions_from_ann(
|
4531
|
+
gt_anns[i], pred_anns[i], iou
|
4532
|
+
)
|
4533
|
+
after = len(pred_anns[i].labels)
|
4534
|
+
logger.debug(
|
4535
|
+
f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
|
4536
|
+
)
|
4537
|
+
return pred_anns
|
4538
|
+
|
4539
|
+
|
4422
4540
|
def _get_log_extra_for_inference_request(
|
4423
4541
|
inference_request_uuid, inference_request: Union[InferenceRequest, dict]
|
4424
4542
|
):
|
@@ -4706,7 +4824,180 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suf
|
|
4706
4824
|
img_tags = None
|
4707
4825
|
if not any_label_updated:
|
4708
4826
|
labels = None
|
4709
|
-
ann = ann.clone(img_tags=
|
4827
|
+
ann = ann.clone(img_tags=img_tags)
|
4828
|
+
return meta, ann, meta_changed
|
4829
|
+
|
4830
|
+
|
4831
|
+
def update_meta_and_ann_for_video_annotation(
|
4832
|
+
meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
|
4833
|
+
):
|
4834
|
+
"""Update project meta and annotation to match each other
|
4835
|
+
If obj class or tag meta from annotation conflicts with project meta
|
4836
|
+
add suffix to obj class or tag meta.
|
4837
|
+
Return tuple of updated project meta, annotation and boolean flag if meta was changed.
|
4838
|
+
"""
|
4839
|
+
obj_classes_suffixes = ["_nn"]
|
4840
|
+
tag_meta_suffixes = ["_nn"]
|
4841
|
+
if model_prediction_suffix is not None:
|
4842
|
+
obj_classes_suffixes = [model_prediction_suffix]
|
4843
|
+
tag_meta_suffixes = [model_prediction_suffix]
|
4844
|
+
logger.debug(
|
4845
|
+
f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
|
4846
|
+
)
|
4847
|
+
logger.debug("source meta", extra={"meta": meta.to_json()})
|
4848
|
+
meta_changed = False
|
4849
|
+
|
4850
|
+
# meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
4851
|
+
# if replaced_classes_in_meta:
|
4852
|
+
# meta_changed = True
|
4853
|
+
# logger.warning(
|
4854
|
+
# "Some classes names were fixed in project meta",
|
4855
|
+
# extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
4856
|
+
# )
|
4857
|
+
|
4858
|
+
new_objects: List[VideoObject] = []
|
4859
|
+
new_figures: List[VideoFigure] = []
|
4860
|
+
any_object_updated = False
|
4861
|
+
for video_object in ann.objects:
|
4862
|
+
this_object_figures = [
|
4863
|
+
figure for figure in ann.figures if figure.video_object.key() == video_object.key()
|
4864
|
+
]
|
4865
|
+
this_object_changed = False
|
4866
|
+
original_obj_class_name = video_object.obj_class.name
|
4867
|
+
suffix_found = False
|
4868
|
+
for suffix in ["", *obj_classes_suffixes]:
|
4869
|
+
obj_class = video_object.obj_class
|
4870
|
+
obj_class_name = obj_class.name + suffix
|
4871
|
+
if suffix:
|
4872
|
+
obj_class = obj_class.clone(name=obj_class_name)
|
4873
|
+
video_object = video_object.clone(obj_class=obj_class)
|
4874
|
+
any_object_updated = True
|
4875
|
+
this_object_changed = True
|
4876
|
+
meta_obj_class = meta.get_obj_class(obj_class_name)
|
4877
|
+
if meta_obj_class is None:
|
4878
|
+
# obj class is not in meta, add it with suffix
|
4879
|
+
meta = meta.add_obj_class(obj_class)
|
4880
|
+
new_objects.append(video_object)
|
4881
|
+
meta_changed = True
|
4882
|
+
suffix_found = True
|
4883
|
+
break
|
4884
|
+
elif (
|
4885
|
+
meta_obj_class.geometry_type.geometry_name()
|
4886
|
+
== video_object.obj_class.geometry_type.geometry_name()
|
4887
|
+
):
|
4888
|
+
# if object geometry is the same as in meta, use meta obj class
|
4889
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
4890
|
+
new_objects.append(video_object)
|
4891
|
+
suffix_found = True
|
4892
|
+
any_object_updated = True
|
4893
|
+
this_object_changed = True
|
4894
|
+
break
|
4895
|
+
elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
|
4896
|
+
# if meta obj class is AnyGeometry, use it in object
|
4897
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
4898
|
+
new_objects.append(video_object)
|
4899
|
+
suffix_found = True
|
4900
|
+
any_object_updated = True
|
4901
|
+
this_object_changed = True
|
4902
|
+
break
|
4903
|
+
if not suffix_found:
|
4904
|
+
# if no suffix found, raise error
|
4905
|
+
raise ValueError(
|
4906
|
+
f"Can't add obj class {original_obj_class_name} to project meta. "
|
4907
|
+
"Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
|
4908
|
+
"Please check if model geometry type is compatible with existing obj classes."
|
4909
|
+
)
|
4910
|
+
elif this_object_changed:
|
4911
|
+
this_object_figures = [
|
4912
|
+
figure.clone(video_object=video_object) for figure in this_object_figures
|
4913
|
+
]
|
4914
|
+
new_figures.extend(this_object_figures)
|
4915
|
+
if any_object_updated:
|
4916
|
+
frames_figures = {}
|
4917
|
+
for figure in new_figures:
|
4918
|
+
frames_figures.setdefault(figure.frame_index, []).append(figure)
|
4919
|
+
new_frames = FrameCollection(
|
4920
|
+
[
|
4921
|
+
Frame(index=frame_index, figures=figures)
|
4922
|
+
for frame_index, figures in frames_figures.items()
|
4923
|
+
]
|
4924
|
+
)
|
4925
|
+
ann = ann.clone(objects=new_objects, frames=new_frames)
|
4926
|
+
|
4927
|
+
# check if tag metas are in project meta
|
4928
|
+
# if not, add them with suffix
|
4929
|
+
ann_tag_metas: Dict[str, TagMeta] = {}
|
4930
|
+
for video_object in ann.objects:
|
4931
|
+
for tag in video_object.tags:
|
4932
|
+
tag_name = tag.meta.name
|
4933
|
+
if tag_name not in ann_tag_metas:
|
4934
|
+
ann_tag_metas[tag_name] = tag.meta
|
4935
|
+
for tag in ann.tags:
|
4936
|
+
tag_name = tag.meta.name
|
4937
|
+
if tag_name not in ann_tag_metas:
|
4938
|
+
ann_tag_metas[tag_name] = tag.meta
|
4939
|
+
|
4940
|
+
changed_tag_metas = {}
|
4941
|
+
for ann_tag_meta in ann_tag_metas.values():
|
4942
|
+
meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
|
4943
|
+
if meta_tag_meta is None:
|
4944
|
+
meta = meta.add_tag_meta(ann_tag_meta)
|
4945
|
+
meta_changed = True
|
4946
|
+
elif not meta_tag_meta.is_compatible(ann_tag_meta):
|
4947
|
+
suffix_found = False
|
4948
|
+
for suffix in tag_meta_suffixes:
|
4949
|
+
new_tag_meta_name = ann_tag_meta.name + suffix
|
4950
|
+
meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
|
4951
|
+
if meta_tag_meta is None:
|
4952
|
+
new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
|
4953
|
+
meta = meta.add_tag_meta(new_tag_meta)
|
4954
|
+
changed_tag_metas[ann_tag_meta.name] = new_tag_meta
|
4955
|
+
meta_changed = True
|
4956
|
+
suffix_found = True
|
4957
|
+
break
|
4958
|
+
if meta_tag_meta.is_compatible(ann_tag_meta):
|
4959
|
+
changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
|
4960
|
+
suffix_found = True
|
4961
|
+
break
|
4962
|
+
if not suffix_found:
|
4963
|
+
raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
|
4964
|
+
|
4965
|
+
if changed_tag_metas:
|
4966
|
+
objects = []
|
4967
|
+
any_object_updated = False
|
4968
|
+
for video_object in ann.objects:
|
4969
|
+
any_tag_updated = False
|
4970
|
+
object_tags = []
|
4971
|
+
for tag in video_object.tags:
|
4972
|
+
if tag.meta.name in changed_tag_metas:
|
4973
|
+
object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
4974
|
+
any_tag_updated = True
|
4975
|
+
else:
|
4976
|
+
object_tags.append(tag)
|
4977
|
+
if any_tag_updated:
|
4978
|
+
video_object = video_object.clone(tags=TagCollection(object_tags))
|
4979
|
+
any_object_updated = True
|
4980
|
+
objects.append(video_object)
|
4981
|
+
|
4982
|
+
video_tags = []
|
4983
|
+
any_tag_updated = False
|
4984
|
+
for tag in ann.tags:
|
4985
|
+
if tag.meta.name in changed_tag_metas:
|
4986
|
+
video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
4987
|
+
any_tag_updated = True
|
4988
|
+
else:
|
4989
|
+
video_tags.append(tag)
|
4990
|
+
if any_tag_updated or any_object_updated:
|
4991
|
+
if any_tag_updated:
|
4992
|
+
video_tags = VideoTagCollection(video_tags)
|
4993
|
+
else:
|
4994
|
+
video_tags = None
|
4995
|
+
if any_object_updated:
|
4996
|
+
objects = VideoObjectCollection(objects)
|
4997
|
+
else:
|
4998
|
+
objects = None
|
4999
|
+
ann = ann.clone(tags=video_tags, objects=objects)
|
5000
|
+
|
4710
5001
|
return meta, ann, meta_changed
|
4711
5002
|
|
4712
5003
|
|
@@ -59,6 +59,8 @@ class InferenceRequest:
|
|
59
59
|
self._created_at = time.monotonic()
|
60
60
|
self._updated_at = self._created_at
|
61
61
|
self._finished = False
|
62
|
+
|
63
|
+
self.tracker = None
|
62
64
|
|
63
65
|
self.global_progress = None
|
64
66
|
self.global_progress_total = 1
|
@@ -250,8 +252,7 @@ class InferenceRequest:
|
|
250
252
|
status_data.pop(key, None)
|
251
253
|
status_data.update(self.get_usage())
|
252
254
|
return status_data
|
253
|
-
|
254
|
-
|
255
|
+
|
255
256
|
class GlobalProgress:
|
256
257
|
def __init__(self):
|
257
258
|
self.progress = Progress(message="Ready", total_cnt=1)
|